mirror of
https://github.com/linux-do/new-api.git
synced 2025-11-13 01:23:41 +08:00
merge upstream
Signed-off-by: wozulong <>
This commit is contained in:
@@ -227,6 +227,8 @@ const (
|
|||||||
ChannelTypeZhipu_v4 = 26
|
ChannelTypeZhipu_v4 = 26
|
||||||
ChannelTypePerplexity = 27
|
ChannelTypePerplexity = 27
|
||||||
ChannelTypeLingYiWanWu = 31
|
ChannelTypeLingYiWanWu = 31
|
||||||
|
ChannelTypeAws = 33
|
||||||
|
ChannelTypeCohere = 34
|
||||||
)
|
)
|
||||||
|
|
||||||
var ChannelBaseURLs = []string{
|
var ChannelBaseURLs = []string{
|
||||||
@@ -262,4 +264,7 @@ var ChannelBaseURLs = []string{
|
|||||||
"", //29
|
"", //29
|
||||||
"", //30
|
"", //30
|
||||||
"https://api.lingyiwanwu.com", //31
|
"https://api.lingyiwanwu.com", //31
|
||||||
|
"", //32
|
||||||
|
"", //33
|
||||||
|
"https://api.cohere.ai", //34
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
// TODO: when a new api is enabled, check the pricing here
|
// TODO: when a new api is enabled, check the pricing here
|
||||||
// 1 === $0.002 / 1K tokens
|
// 1 === $0.002 / 1K tokens
|
||||||
// 1 === ¥0.014 / 1k tokens
|
// 1 === ¥0.014 / 1k tokens
|
||||||
|
|
||||||
var DefaultModelRatio = map[string]float64{
|
var DefaultModelRatio = map[string]float64{
|
||||||
//"midjourney": 50,
|
//"midjourney": 50,
|
||||||
"gpt-4-gizmo-*": 15,
|
"gpt-4-gizmo-*": 15,
|
||||||
@@ -73,11 +74,14 @@ var DefaultModelRatio = map[string]float64{
|
|||||||
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
|
"ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens
|
||||||
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
|
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
|
||||||
"PaLM-2": 1,
|
"PaLM-2": 1,
|
||||||
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||||
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
|
||||||
"gemini-1.0-pro-vision-001": 1,
|
"gemini-1.0-pro-vision-001": 1,
|
||||||
"gemini-1.0-pro-001": 1,
|
"gemini-1.0-pro-001": 1,
|
||||||
"gemini-1.5-pro": 1,
|
"gemini-1.5-pro-latest": 1,
|
||||||
|
"gemini-1.0-pro-latest": 1,
|
||||||
|
"gemini-1.0-pro-vision-latest": 1,
|
||||||
|
"gemini-ultra": 1,
|
||||||
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
|
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
|
||||||
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
|
||||||
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
|
||||||
@@ -102,6 +106,12 @@ var DefaultModelRatio = map[string]float64{
|
|||||||
"yi-34b-chat-0205": 0.018,
|
"yi-34b-chat-0205": 0.018,
|
||||||
"yi-34b-chat-200k": 0.0864,
|
"yi-34b-chat-200k": 0.0864,
|
||||||
"yi-vl-plus": 0.0432,
|
"yi-vl-plus": 0.0432,
|
||||||
|
"command": 0.5,
|
||||||
|
"command-nightly": 0.5,
|
||||||
|
"command-light": 0.5,
|
||||||
|
"command-light-nightly": 0.5,
|
||||||
|
"command-r": 0.25,
|
||||||
|
"command-r-plus ": 1.5,
|
||||||
}
|
}
|
||||||
|
|
||||||
var DefaultModelPrice = map[string]float64{
|
var DefaultModelPrice = map[string]float64{
|
||||||
@@ -223,6 +233,16 @@ func GetCompletionRatio(name string) float64 {
|
|||||||
if strings.HasPrefix(name, "gemini-") {
|
if strings.HasPrefix(name, "gemini-") {
|
||||||
return 3
|
return 3
|
||||||
}
|
}
|
||||||
|
if strings.HasPrefix(name, "command") {
|
||||||
|
switch name {
|
||||||
|
case "command-r":
|
||||||
|
return 3
|
||||||
|
case "command-r-plus":
|
||||||
|
return 5
|
||||||
|
default:
|
||||||
|
return 2
|
||||||
|
}
|
||||||
|
}
|
||||||
switch name {
|
switch name {
|
||||||
case "llama2-70b-4096":
|
case "llama2-70b-4096":
|
||||||
return 0.8 / 0.7
|
return 0.8 / 0.7
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
package constant
|
package constant
|
||||||
|
|
||||||
var MjNotifyEnabled = false
|
var MjNotifyEnabled = false
|
||||||
|
|
||||||
var MjModeClearEnabled = false
|
var MjModeClearEnabled = false
|
||||||
|
var MjForwardUrlEnabled = true
|
||||||
|
|
||||||
const (
|
const (
|
||||||
MjErrorUnknown = 5
|
MjErrorUnknown = 5
|
||||||
|
|||||||
8
constant/payment.go
Normal file
8
constant/payment.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
package constant
|
||||||
|
|
||||||
|
var PayAddress = ""
|
||||||
|
var CustomCallbackAddress = ""
|
||||||
|
var EpayId = ""
|
||||||
|
var EpayKey = ""
|
||||||
|
var Price = 7.3
|
||||||
|
var MinTopUp = 1
|
||||||
@@ -86,7 +86,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp != nil && resp.StatusCode != http.StatusOK {
|
||||||
err := relaycommon.RelayErrorHandler(resp)
|
err := relaycommon.RelayErrorHandler(resp)
|
||||||
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
|
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,11 +10,11 @@ import (
|
|||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -233,6 +233,12 @@ func GetAllMidjourney(c *gin.Context) {
|
|||||||
if logs == nil {
|
if logs == nil {
|
||||||
logs = make([]*model.Midjourney, 0)
|
logs = make([]*model.Midjourney, 0)
|
||||||
}
|
}
|
||||||
|
if constant.MjForwardUrlEnabled {
|
||||||
|
for i, midjourney := range logs {
|
||||||
|
midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||||
|
logs[i] = midjourney
|
||||||
|
}
|
||||||
|
}
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"success": true,
|
"success": true,
|
||||||
"message": "",
|
"message": "",
|
||||||
@@ -259,7 +265,7 @@ func GetUserMidjourney(c *gin.Context) {
|
|||||||
if logs == nil {
|
if logs == nil {
|
||||||
logs = make([]*model.Midjourney, 0)
|
logs = make([]*model.Midjourney, 0)
|
||||||
}
|
}
|
||||||
if !strings.Contains(common.ServerAddress, "localhost") {
|
if constant.MjForwardUrlEnabled {
|
||||||
for i, midjourney := range logs {
|
for i, midjourney := range logs {
|
||||||
midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
|
midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId
|
||||||
logs[i] = midjourney
|
logs[i] = midjourney
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ func GetOptions(c *gin.Context) {
|
|||||||
var options []*model.Option
|
var options []*model.Option
|
||||||
common.OptionMapRWMutex.Lock()
|
common.OptionMapRWMutex.Lock()
|
||||||
for k, v := range common.OptionMap {
|
for k, v := range common.OptionMap {
|
||||||
if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") {
|
if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") || strings.HasSuffix(k, "Key") {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
options = append(options, &model.Option{
|
options = append(options, &model.Option{
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
relayconstant "one-api/relay/constant"
|
relayconstant "one-api/relay/constant"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||||
@@ -42,7 +43,7 @@ func Relay(c *gin.Context) {
|
|||||||
group := c.GetString("group")
|
group := c.GetString("group")
|
||||||
originalModel := c.GetString("original_model")
|
originalModel := c.GetString("original_model")
|
||||||
openaiErr := relayHandler(c, relayMode)
|
openaiErr := relayHandler(c, relayMode)
|
||||||
retryLogStr := fmt.Sprintf("重试:%d", channelId)
|
useChannel := []int{channelId}
|
||||||
if openaiErr != nil {
|
if openaiErr != nil {
|
||||||
go processChannelError(c, channelId, openaiErr)
|
go processChannelError(c, channelId, openaiErr)
|
||||||
} else {
|
} else {
|
||||||
@@ -55,7 +56,7 @@ func Relay(c *gin.Context) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
channelId = channel.Id
|
channelId = channel.Id
|
||||||
retryLogStr += fmt.Sprintf("->%d", channel.Id)
|
useChannel = append(useChannel, channelId)
|
||||||
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
|
||||||
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
|
||||||
|
|
||||||
@@ -66,7 +67,10 @@ func Relay(c *gin.Context) {
|
|||||||
go processChannelError(c, channelId, openaiErr)
|
go processChannelError(c, channelId, openaiErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
common.LogInfo(c.Request.Context(), retryLogStr)
|
if len(useChannel) > 1 {
|
||||||
|
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
|
||||||
|
common.LogInfo(c.Request.Context(), retryLogStr)
|
||||||
|
}
|
||||||
|
|
||||||
if openaiErr != nil {
|
if openaiErr != nil {
|
||||||
if openaiErr.StatusCode == http.StatusTooManyRequests {
|
if openaiErr.StatusCode == http.StatusTooManyRequests {
|
||||||
@@ -105,6 +109,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt
|
|||||||
if openaiErr.StatusCode == http.StatusBadRequest {
|
if openaiErr.StatusCode == http.StatusBadRequest {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
if openaiErr.StatusCode == 408 {
|
||||||
|
// azure处理超时不重试
|
||||||
|
return false
|
||||||
|
}
|
||||||
if openaiErr.LocalError {
|
if openaiErr.LocalError {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -217,7 +217,8 @@ func GetAllUsers(c *gin.Context) {
|
|||||||
|
|
||||||
func SearchUsers(c *gin.Context) {
|
func SearchUsers(c *gin.Context) {
|
||||||
keyword := c.Query("keyword")
|
keyword := c.Query("keyword")
|
||||||
users, err := model.SearchUsers(keyword)
|
group := c.Query("group")
|
||||||
|
users, err := model.SearchUsers(keyword, group)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
@@ -453,7 +454,7 @@ func UpdateUser(c *gin.Context) {
|
|||||||
updatedUser.Password = "" // rollback to what it should be
|
updatedUser.Password = "" // rollback to what it should be
|
||||||
}
|
}
|
||||||
updatePassword := updatedUser.Password != ""
|
updatePassword := updatedUser.Password != ""
|
||||||
if err := updatedUser.Update(updatePassword); err != nil {
|
if err := updatedUser.Edit(updatePassword); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
@@ -740,7 +741,7 @@ func ManageUser(c *gin.Context) {
|
|||||||
user.Role = common.RoleCommonUser
|
user.Role = common.RoleCommonUser
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := user.UpdateAll(false); err != nil {
|
if err := user.Update(false); err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"success": false,
|
"success": false,
|
||||||
"message": err.Error(),
|
"message": err.Error(),
|
||||||
|
|||||||
@@ -32,6 +32,21 @@ type GeneralOpenAIRequest struct {
|
|||||||
TopLogProbs int `json:"top_logprobs,omitempty"`
|
TopLogProbs int `json:"top_logprobs,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OpenAITools struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Function OpenAIFunction `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OpenAIFunction struct {
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Parameters any `json:"parameters,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r GeneralOpenAIRequest) GetMaxTokens() int64 {
|
||||||
|
return int64(r.MaxTokens)
|
||||||
|
}
|
||||||
|
|
||||||
func (r GeneralOpenAIRequest) ParseInput() []string {
|
func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||||
if r.Input == nil {
|
if r.Input == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -54,13 +54,29 @@ type OpenAIEmbeddingResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ChatCompletionsStreamResponseChoice struct {
|
type ChatCompletionsStreamResponseChoice struct {
|
||||||
Delta struct {
|
Delta ChatCompletionsStreamResponseChoiceDelta `json:"delta"`
|
||||||
Content string `json:"content"`
|
FinishReason *string `json:"finish_reason,omitempty"`
|
||||||
Role string `json:"role,omitempty"`
|
Index int `json:"index,omitempty"`
|
||||||
ToolCalls any `json:"tool_calls,omitempty"`
|
}
|
||||||
} `json:"delta"`
|
|
||||||
FinishReason *string `json:"finish_reason,omitempty"`
|
type ChatCompletionsStreamResponseChoiceDelta struct {
|
||||||
Index int `json:"index,omitempty"`
|
Content string `json:"content"`
|
||||||
|
Role string `json:"role,omitempty"`
|
||||||
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolCall struct {
|
||||||
|
// Index is not nil only in chat completion chunk object
|
||||||
|
Index *int `json:"index,omitempty"`
|
||||||
|
ID string `json:"id"`
|
||||||
|
Type any `json:"type"`
|
||||||
|
Function FunctionCall `json:"function"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FunctionCall struct {
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
// call function with arguments in JSON format
|
||||||
|
Arguments string `json:"arguments,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChatCompletionsStreamResponse struct {
|
type ChatCompletionsStreamResponse struct {
|
||||||
|
|||||||
9
go.mod
9
go.mod
@@ -5,6 +5,9 @@ go 1.18
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
|
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0
|
||||||
|
github.com/aws/aws-sdk-go-v2 v1.26.1
|
||||||
|
github.com/aws/aws-sdk-go-v2/credentials v1.17.11
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4
|
||||||
github.com/gin-contrib/cors v1.4.0
|
github.com/gin-contrib/cors v1.4.0
|
||||||
github.com/gin-contrib/gzip v0.0.6
|
github.com/gin-contrib/gzip v0.0.6
|
||||||
github.com/gin-contrib/sessions v0.0.5
|
github.com/gin-contrib/sessions v0.0.5
|
||||||
@@ -15,6 +18,8 @@ require (
|
|||||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||||
github.com/google/uuid v1.3.0
|
github.com/google/uuid v1.3.0
|
||||||
github.com/gorilla/websocket v1.5.0
|
github.com/gorilla/websocket v1.5.0
|
||||||
|
github.com/jinzhu/copier v0.4.0
|
||||||
|
github.com/pkg/errors v0.9.1
|
||||||
github.com/pkoukk/tiktoken-go v0.1.6
|
github.com/pkoukk/tiktoken-go v0.1.6
|
||||||
github.com/samber/lo v1.39.0
|
github.com/samber/lo v1.39.0
|
||||||
github.com/shirou/gopsutil v3.21.11+incompatible
|
github.com/shirou/gopsutil v3.21.11+incompatible
|
||||||
@@ -29,6 +34,10 @@ require (
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect
|
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect
|
||||||
|
github.com/aws/smithy-go v1.20.2 // indirect
|
||||||
github.com/bytedance/sonic v1.9.1 // indirect
|
github.com/bytedance/sonic v1.9.1 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||||
|
|||||||
18
go.sum
18
go.sum
@@ -2,6 +2,20 @@ github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+Kc
|
|||||||
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
|
github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI=
|
||||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
|
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI=
|
||||||
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8=
|
github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8=
|
||||||
|
github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA=
|
||||||
|
github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM=
|
||||||
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to=
|
||||||
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg=
|
||||||
|
github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs=
|
||||||
|
github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo=
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg=
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I=
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0=
|
||||||
|
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76wYsSZIZZQYBxkmMEjvL6GHy8XU=
|
||||||
|
github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg=
|
||||||
|
github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q=
|
||||||
|
github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E=
|
||||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
|
||||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
|
||||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
|
||||||
@@ -83,6 +97,8 @@ github.com/jackc/pgx/v5 v5.5.1 h1:5I9etrGkLrN+2XPCsi6XLlV5DITbSL/xBZdmAxFcXPI=
|
|||||||
github.com/jackc/pgx/v5 v5.5.1/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
|
github.com/jackc/pgx/v5 v5.5.1/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
|
||||||
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
|
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
|
||||||
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
|
||||||
|
github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8=
|
||||||
|
github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
|
||||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||||
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||||
@@ -126,6 +142,8 @@ github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZO
|
|||||||
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
|
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
|
||||||
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
|
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
|
||||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||||
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw=
|
github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw=
|
||||||
github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
|||||||
@@ -177,6 +177,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
|||||||
}
|
}
|
||||||
c.Set("auto_ban", ban)
|
c.Set("auto_ban", ban)
|
||||||
c.Set("model_mapping", channel.GetModelMapping())
|
c.Set("model_mapping", channel.GetModelMapping())
|
||||||
|
c.Set("status_code_mapping", channel.GetStatusCodeMapping())
|
||||||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||||||
c.Set("base_url", channel.GetBaseURL())
|
c.Set("base_url", channel.GetBaseURL())
|
||||||
// TODO: api_version统一
|
// TODO: api_version统一
|
||||||
|
|||||||
@@ -25,9 +25,6 @@ var token2UserId = make(map[string]int)
|
|||||||
var token2UserIdLock sync.RWMutex
|
var token2UserIdLock sync.RWMutex
|
||||||
|
|
||||||
func cacheSetToken(token *Token) error {
|
func cacheSetToken(token *Token) error {
|
||||||
if !common.RedisEnabled {
|
|
||||||
return token.SelectUpdate()
|
|
||||||
}
|
|
||||||
jsonBytes, err := json.Marshal(token)
|
jsonBytes, err := json.Marshal(token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -168,7 +165,11 @@ func CacheUpdateUserQuota(id int) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
|
return cacheSetUserQuota(id, quota)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cacheSetUserQuota(id int, quota int) error {
|
||||||
|
err := common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -25,8 +25,10 @@ type Channel struct {
|
|||||||
Group string `json:"group" gorm:"type:varchar(64);default:'default'"`
|
Group string `json:"group" gorm:"type:varchar(64);default:'default'"`
|
||||||
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
|
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"`
|
||||||
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"`
|
||||||
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
//MaxInputTokens *int `json:"max_input_tokens" gorm:"default:0"`
|
||||||
AutoBan *int `json:"auto_ban" gorm:"default:1"`
|
StatusCodeMapping *string `json:"status_code_mapping" gorm:"type:varchar(1024);default:''"`
|
||||||
|
Priority *int64 `json:"priority" gorm:"bigint;default:0"`
|
||||||
|
AutoBan *int `json:"auto_ban" gorm:"default:1"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Channel, error) {
|
func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Channel, error) {
|
||||||
@@ -153,6 +155,13 @@ func (channel *Channel) GetModelMapping() string {
|
|||||||
return *channel.ModelMapping
|
return *channel.ModelMapping
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (channel *Channel) GetStatusCodeMapping() string {
|
||||||
|
if channel.StatusCodeMapping == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return *channel.StatusCodeMapping
|
||||||
|
}
|
||||||
|
|
||||||
func (channel *Channel) Insert() error {
|
func (channel *Channel) Insert() error {
|
||||||
var err error
|
var err error
|
||||||
err = DB.Create(channel).Error
|
err = DB.Create(channel).Error
|
||||||
|
|||||||
@@ -98,6 +98,7 @@ func InitOptionMap() {
|
|||||||
common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar)
|
common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar)
|
||||||
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled)
|
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled)
|
||||||
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled)
|
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled)
|
||||||
|
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(constant.MjForwardUrlEnabled)
|
||||||
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
|
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled)
|
||||||
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
|
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled)
|
||||||
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
|
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
|
||||||
@@ -207,6 +208,8 @@ func updateOptionMap(key string, value string) (err error) {
|
|||||||
constant.MjNotifyEnabled = boolValue
|
constant.MjNotifyEnabled = boolValue
|
||||||
case "MjModeClearEnabled":
|
case "MjModeClearEnabled":
|
||||||
constant.MjModeClearEnabled = boolValue
|
constant.MjModeClearEnabled = boolValue
|
||||||
|
case "MjForwardUrlEnabled":
|
||||||
|
constant.MjForwardUrlEnabled = boolValue
|
||||||
case "CheckSensitiveEnabled":
|
case "CheckSensitiveEnabled":
|
||||||
constant.CheckSensitiveEnabled = boolValue
|
constant.CheckSensitiveEnabled = boolValue
|
||||||
case "CheckSensitiveOnPromptEnabled":
|
case "CheckSensitiveOnPromptEnabled":
|
||||||
|
|||||||
@@ -102,6 +102,11 @@ func GetTokenById(id int) (*Token, error) {
|
|||||||
token := Token{Id: id}
|
token := Token{Id: id}
|
||||||
var err error = nil
|
var err error = nil
|
||||||
err = DB.First(&token, "id = ?", id).Error
|
err = DB.First(&token, "id = ?", id).Error
|
||||||
|
if err != nil {
|
||||||
|
if common.RedisEnabled {
|
||||||
|
go cacheSetToken(&token)
|
||||||
|
}
|
||||||
|
}
|
||||||
return &token, err
|
return &token, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -76,25 +76,34 @@ func GetAllUsers(startIdx int, num int) (users []*User, err error) {
|
|||||||
return users, err
|
return users, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func SearchUsers(keyword string) ([]*User, error) {
|
func SearchUsers(keyword string, group string) ([]*User, error) {
|
||||||
var users []*User
|
var users []*User
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// 尝试将关键字转换为整数ID
|
// 尝试将关键字转换为整数ID
|
||||||
keywordInt, err := strconv.Atoi(keyword)
|
keywordInt, err := strconv.Atoi(keyword)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
// 如果转换成功,按照ID搜索用户
|
// 如果转换成功,按照ID和可选的组别搜索用户
|
||||||
err = DB.Unscoped().Omit("password").Where("id = ?", keywordInt).Find(&users).Error
|
query := DB.Unscoped().Omit("password").Where("`id` = ?", keywordInt)
|
||||||
|
if group != "" {
|
||||||
|
query = query.Where("`group` = ?", group) // 使用反引号包围group
|
||||||
|
}
|
||||||
|
err = query.Find(&users).Error
|
||||||
if err != nil || len(users) > 0 {
|
if err != nil || len(users) > 0 {
|
||||||
// 如果依据ID找到用户或者发生错误,返回结果或错误
|
|
||||||
return users, err
|
return users, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 如果ID转换失败或者没有找到用户,依据其他字段进行模糊搜索
|
err = nil
|
||||||
err = DB.Unscoped().Omit("password").
|
|
||||||
Where("username LIKE ? OR email LIKE ? OR display_name LIKE ?", keyword+"%", keyword+"%", keyword+"%").
|
query := DB.Unscoped().Omit("password")
|
||||||
Find(&users).Error
|
likeCondition := "`username` LIKE ? OR `email` LIKE ? OR `display_name` LIKE ?"
|
||||||
|
if group != "" {
|
||||||
|
query = query.Where("("+likeCondition+") AND `group` = ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group)
|
||||||
|
} else {
|
||||||
|
query = query.Where(likeCondition, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
|
||||||
|
}
|
||||||
|
err = query.Find(&users).Error
|
||||||
|
|
||||||
return users, err
|
return users, err
|
||||||
}
|
}
|
||||||
@@ -252,7 +261,7 @@ func (user *User) Update(updatePassword bool) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *User) UpdateAll(updatePassword bool) error {
|
func (user *User) Edit(updatePassword bool) error {
|
||||||
var err error
|
var err error
|
||||||
if updatePassword {
|
if updatePassword {
|
||||||
user.Password, err = common.Password2Hash(user.Password)
|
user.Password, err = common.Password2Hash(user.Password)
|
||||||
@@ -262,7 +271,13 @@ func (user *User) UpdateAll(updatePassword bool) error {
|
|||||||
}
|
}
|
||||||
newUser := *user
|
newUser := *user
|
||||||
DB.First(&user, user.Id)
|
DB.First(&user, user.Id)
|
||||||
err = DB.Model(user).Select("*").Updates(newUser).Error
|
err = DB.Model(user).Updates(map[string]interface{}{
|
||||||
|
"username": newUser.Username,
|
||||||
|
"password": newUser.Password,
|
||||||
|
"display_name": newUser.DisplayName,
|
||||||
|
"group": newUser.Group,
|
||||||
|
"quota": newUser.Quota,
|
||||||
|
}).Error
|
||||||
if err == nil {
|
if err == nil {
|
||||||
if common.RedisEnabled {
|
if common.RedisEnabled {
|
||||||
_ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
|
_ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second)
|
||||||
@@ -451,6 +466,11 @@ func ValidateAccessToken(token string) (user *User) {
|
|||||||
|
|
||||||
func GetUserQuota(id int) (quota int, err error) {
|
func GetUserQuota(id int) (quota int, err error) {
|
||||||
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error
|
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error
|
||||||
|
if err != nil {
|
||||||
|
if common.RedisEnabled {
|
||||||
|
go cacheSetUserQuota(id, quota)
|
||||||
|
}
|
||||||
|
}
|
||||||
return quota, err
|
return quota, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
79
relay/channel/aws/adaptor.go
Normal file
79
relay/channel/aws/adaptor.go
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
package aws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/relay/channel/claude"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
RequestModeCompletion = 1
|
||||||
|
RequestModeMessage = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
type Adaptor struct {
|
||||||
|
RequestMode int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
|
if strings.HasPrefix(info.UpstreamModelName, "claude-3") {
|
||||||
|
a.RequestMode = RequestModeMessage
|
||||||
|
} else {
|
||||||
|
a.RequestMode = RequestModeCompletion
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
|
if request == nil {
|
||||||
|
return nil, errors.New("request is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
var claudeReq *claude.ClaudeRequest
|
||||||
|
var err error
|
||||||
|
if a.RequestMode == RequestModeCompletion {
|
||||||
|
claudeReq = claude.RequestOpenAI2ClaudeComplete(*request)
|
||||||
|
} else {
|
||||||
|
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request)
|
||||||
|
}
|
||||||
|
c.Set("request_model", request.Model)
|
||||||
|
c.Set("converted_request", claudeReq)
|
||||||
|
return claudeReq, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
if info.IsStream {
|
||||||
|
err, usage = awsStreamHandler(c, info, a.RequestMode)
|
||||||
|
} else {
|
||||||
|
err, usage = awsHandler(c, info, a.RequestMode)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetModelList() (models []string) {
|
||||||
|
for n := range awsModelIDMap {
|
||||||
|
models = append(models, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetChannelName() string {
|
||||||
|
return ChannelName
|
||||||
|
}
|
||||||
12
relay/channel/aws/constants.go
Normal file
12
relay/channel/aws/constants.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package aws
|
||||||
|
|
||||||
|
var awsModelIDMap = map[string]string{
|
||||||
|
"claude-instant-1.2": "anthropic.claude-instant-v1",
|
||||||
|
"claude-2.0": "anthropic.claude-v2",
|
||||||
|
"claude-2.1": "anthropic.claude-v2:1",
|
||||||
|
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
"claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
|
||||||
|
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
}
|
||||||
|
|
||||||
|
var ChannelName = "aws"
|
||||||
14
relay/channel/aws/dto.go
Normal file
14
relay/channel/aws/dto.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
package aws
|
||||||
|
|
||||||
|
import "one-api/relay/channel/claude"
|
||||||
|
|
||||||
|
type AwsClaudeRequest struct {
|
||||||
|
// AnthropicVersion should be "bedrock-2023-05-31"
|
||||||
|
AnthropicVersion string `json:"anthropic_version"`
|
||||||
|
Messages []claude.ClaudeMessage `json:"messages"`
|
||||||
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
|
TopK int `json:"top_k,omitempty"`
|
||||||
|
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||||
|
}
|
||||||
211
relay/channel/aws/relay-aws.go
Normal file
211
relay/channel/aws/relay-aws.go
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
package aws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/jinzhu/copier"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
relaymodel "one-api/dto"
|
||||||
|
"one-api/relay/channel/claude"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go-v2/aws"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
|
||||||
|
awsSecret := strings.Split(info.ApiKey, "|")
|
||||||
|
if len(awsSecret) != 3 {
|
||||||
|
return nil, errors.New("invalid aws secret key")
|
||||||
|
}
|
||||||
|
ak := awsSecret[0]
|
||||||
|
sk := awsSecret[1]
|
||||||
|
region := awsSecret[2]
|
||||||
|
client := bedrockruntime.New(bedrockruntime.Options{
|
||||||
|
Region: region,
|
||||||
|
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
|
||||||
|
})
|
||||||
|
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func wrapErr(err error) *relaymodel.OpenAIErrorWithStatusCode {
|
||||||
|
return &relaymodel.OpenAIErrorWithStatusCode{
|
||||||
|
StatusCode: http.StatusInternalServerError,
|
||||||
|
Error: relaymodel.OpenAIError{
|
||||||
|
Message: fmt.Sprintf("%s", err.Error()),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func awsModelID(requestModel string) (string, error) {
|
||||||
|
if awsModelID, ok := awsModelIDMap[requestModel]; ok {
|
||||||
|
return awsModelID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", errors.Errorf("model %s not found", requestModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
|
||||||
|
awsCli, err := newAwsClient(c, info)
|
||||||
|
if err != nil {
|
||||||
|
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
awsModelId, err := awsModelID(c.GetString("request_model"))
|
||||||
|
if err != nil {
|
||||||
|
return wrapErr(errors.Wrap(err, "awsModelID")), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
awsReq := &bedrockruntime.InvokeModelInput{
|
||||||
|
ModelId: aws.String(awsModelId),
|
||||||
|
Accept: aws.String("application/json"),
|
||||||
|
ContentType: aws.String("application/json"),
|
||||||
|
}
|
||||||
|
|
||||||
|
claudeReq_, ok := c.Get("converted_request")
|
||||||
|
if !ok {
|
||||||
|
return wrapErr(errors.New("request not found")), nil
|
||||||
|
}
|
||||||
|
claudeReq := claudeReq_.(*claude.ClaudeRequest)
|
||||||
|
awsClaudeReq := &AwsClaudeRequest{
|
||||||
|
AnthropicVersion: "bedrock-2023-05-31",
|
||||||
|
}
|
||||||
|
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
|
||||||
|
return wrapErr(errors.Wrap(err, "copy request")), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
awsReq.Body, err = json.Marshal(awsClaudeReq)
|
||||||
|
if err != nil {
|
||||||
|
return wrapErr(errors.Wrap(err, "marshal request")), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
|
||||||
|
if err != nil {
|
||||||
|
return wrapErr(errors.Wrap(err, "InvokeModel")), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
claudeResponse := new(claude.ClaudeResponse)
|
||||||
|
err = json.Unmarshal(awsResp.Body, claudeResponse)
|
||||||
|
if err != nil {
|
||||||
|
return wrapErr(errors.Wrap(err, "unmarshal response")), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
openaiResp := claude.ResponseClaude2OpenAI(requestMode, claudeResponse)
|
||||||
|
usage := relaymodel.Usage{
|
||||||
|
PromptTokens: claudeResponse.Usage.InputTokens,
|
||||||
|
CompletionTokens: claudeResponse.Usage.OutputTokens,
|
||||||
|
TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens,
|
||||||
|
}
|
||||||
|
openaiResp.Usage = usage
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, openaiResp)
|
||||||
|
return nil, &usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) {
|
||||||
|
awsCli, err := newAwsClient(c, info)
|
||||||
|
if err != nil {
|
||||||
|
return wrapErr(errors.Wrap(err, "newAwsClient")), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
awsModelId, err := awsModelID(c.GetString("request_model"))
|
||||||
|
if err != nil {
|
||||||
|
return wrapErr(errors.Wrap(err, "awsModelID")), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
|
||||||
|
ModelId: aws.String(awsModelId),
|
||||||
|
Accept: aws.String("application/json"),
|
||||||
|
ContentType: aws.String("application/json"),
|
||||||
|
}
|
||||||
|
|
||||||
|
claudeReq_, ok := c.Get("converted_request")
|
||||||
|
if !ok {
|
||||||
|
return wrapErr(errors.New("request not found")), nil
|
||||||
|
}
|
||||||
|
claudeReq := claudeReq_.(*claude.ClaudeRequest)
|
||||||
|
|
||||||
|
awsClaudeReq := &AwsClaudeRequest{
|
||||||
|
AnthropicVersion: "bedrock-2023-05-31",
|
||||||
|
}
|
||||||
|
if err = copier.Copy(awsClaudeReq, claudeReq); err != nil {
|
||||||
|
return wrapErr(errors.Wrap(err, "copy request")), nil
|
||||||
|
}
|
||||||
|
awsReq.Body, err = json.Marshal(awsClaudeReq)
|
||||||
|
if err != nil {
|
||||||
|
return wrapErr(errors.Wrap(err, "marshal request")), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
|
||||||
|
if err != nil {
|
||||||
|
return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil
|
||||||
|
}
|
||||||
|
stream := awsResp.GetStream()
|
||||||
|
defer stream.Close()
|
||||||
|
|
||||||
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
var usage relaymodel.Usage
|
||||||
|
var id string
|
||||||
|
var model string
|
||||||
|
c.Stream(func(w io.Writer) bool {
|
||||||
|
event, ok := <-stream.Events()
|
||||||
|
if !ok {
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := event.(type) {
|
||||||
|
case *types.ResponseStreamMemberChunk:
|
||||||
|
claudeResp := new(claude.ClaudeResponse)
|
||||||
|
err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
response, claudeUsage := claude.StreamResponseClaude2OpenAI(requestMode, claudeResp)
|
||||||
|
if claudeUsage != nil {
|
||||||
|
usage.PromptTokens += claudeUsage.InputTokens
|
||||||
|
usage.CompletionTokens += claudeUsage.OutputTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
if response == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.Id != "" {
|
||||||
|
id = response.Id
|
||||||
|
}
|
||||||
|
if response.Model != "" {
|
||||||
|
model = response.Model
|
||||||
|
}
|
||||||
|
response.Id = id
|
||||||
|
response.Model = model
|
||||||
|
|
||||||
|
jsonStr, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
||||||
|
return true
|
||||||
|
case *types.UnknownUnionMember:
|
||||||
|
fmt.Println("unknown tag:", v.Tag)
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
fmt.Println("union is nil or unknown type")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil, &usage
|
||||||
|
}
|
||||||
@@ -53,9 +53,9 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen
|
|||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
if a.RequestMode == RequestModeCompletion {
|
if a.RequestMode == RequestModeCompletion {
|
||||||
return requestOpenAI2ClaudeComplete(*request), nil
|
return RequestOpenAI2ClaudeComplete(*request), nil
|
||||||
} else {
|
} else {
|
||||||
return requestOpenAI2ClaudeMessage(*request)
|
return RequestOpenAI2ClaudeMessage(*request)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -24,16 +24,15 @@ type ClaudeMessage struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ClaudeRequest struct {
|
type ClaudeRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt,omitempty"`
|
Prompt string `json:"prompt,omitempty"`
|
||||||
System string `json:"system,omitempty"`
|
System string `json:"system,omitempty"`
|
||||||
Messages []ClaudeMessage `json:"messages,omitempty"`
|
Messages []ClaudeMessage `json:"messages,omitempty"`
|
||||||
MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
|
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
TopP float64 `json:"top_p,omitempty"`
|
TopK int `json:"top_k,omitempty"`
|
||||||
TopK int `json:"top_k,omitempty"`
|
|
||||||
//ClaudeMetadata `json:"metadata,omitempty"`
|
//ClaudeMetadata `json:"metadata,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,25 +20,25 @@ func stopReasonClaude2OpenAI(reason string) string {
|
|||||||
case "end_turn":
|
case "end_turn":
|
||||||
return "stop"
|
return "stop"
|
||||||
case "max_tokens":
|
case "max_tokens":
|
||||||
return "length"
|
return "max_tokens"
|
||||||
default:
|
default:
|
||||||
return reason
|
return reason
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func requestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
|
func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
|
||||||
claudeRequest := ClaudeRequest{
|
claudeRequest := ClaudeRequest{
|
||||||
Model: textRequest.Model,
|
Model: textRequest.Model,
|
||||||
Prompt: "",
|
Prompt: "",
|
||||||
MaxTokensToSample: textRequest.MaxTokens,
|
MaxTokens: textRequest.MaxTokens,
|
||||||
StopSequences: nil,
|
StopSequences: nil,
|
||||||
Temperature: textRequest.Temperature,
|
Temperature: textRequest.Temperature,
|
||||||
TopP: textRequest.TopP,
|
TopP: textRequest.TopP,
|
||||||
TopK: textRequest.TopK,
|
TopK: textRequest.TopK,
|
||||||
Stream: textRequest.Stream,
|
Stream: textRequest.Stream,
|
||||||
}
|
}
|
||||||
if claudeRequest.MaxTokensToSample == 0 {
|
if claudeRequest.MaxTokens == 0 {
|
||||||
claudeRequest.MaxTokensToSample = 1000000
|
claudeRequest.MaxTokens = 4096
|
||||||
}
|
}
|
||||||
prompt := ""
|
prompt := ""
|
||||||
for _, message := range textRequest.Messages {
|
for _, message := range textRequest.Messages {
|
||||||
@@ -57,7 +57,7 @@ func requestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeR
|
|||||||
return &claudeRequest
|
return &claudeRequest
|
||||||
}
|
}
|
||||||
|
|
||||||
func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) {
|
func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) {
|
||||||
claudeRequest := ClaudeRequest{
|
claudeRequest := ClaudeRequest{
|
||||||
Model: textRequest.Model,
|
Model: textRequest.Model,
|
||||||
MaxTokens: textRequest.MaxTokens,
|
MaxTokens: textRequest.MaxTokens,
|
||||||
@@ -70,8 +70,39 @@ func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
|||||||
if claudeRequest.MaxTokens == 0 {
|
if claudeRequest.MaxTokens == 0 {
|
||||||
claudeRequest.MaxTokens = 4096
|
claudeRequest.MaxTokens = 4096
|
||||||
}
|
}
|
||||||
|
formatMessages := make([]dto.Message, 0)
|
||||||
|
var lastMessage *dto.Message
|
||||||
|
for i, message := range textRequest.Messages {
|
||||||
|
if message.Role == "system" {
|
||||||
|
if i != 0 {
|
||||||
|
message.Role = "user"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if message.Role == "" {
|
||||||
|
message.Role = "user"
|
||||||
|
}
|
||||||
|
fmtMessage := dto.Message{
|
||||||
|
Role: message.Role,
|
||||||
|
Content: message.Content,
|
||||||
|
}
|
||||||
|
if lastMessage != nil && lastMessage.Role == message.Role {
|
||||||
|
if lastMessage.IsStringContent() && message.IsStringContent() {
|
||||||
|
content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
|
||||||
|
fmtMessage.Content = content
|
||||||
|
// delete last message
|
||||||
|
formatMessages = formatMessages[:len(formatMessages)-1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if fmtMessage.Content == nil {
|
||||||
|
content, _ := json.Marshal("...")
|
||||||
|
fmtMessage.Content = content
|
||||||
|
}
|
||||||
|
formatMessages = append(formatMessages, fmtMessage)
|
||||||
|
lastMessage = &message
|
||||||
|
}
|
||||||
|
|
||||||
claudeMessages := make([]ClaudeMessage, 0)
|
claudeMessages := make([]ClaudeMessage, 0)
|
||||||
for _, message := range textRequest.Messages {
|
for _, message := range formatMessages {
|
||||||
if message.Role == "system" {
|
if message.Role == "system" {
|
||||||
claudeRequest.System = message.StringContent()
|
claudeRequest.System = message.StringContent()
|
||||||
} else {
|
} else {
|
||||||
@@ -122,7 +153,7 @@ func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
|
|||||||
return &claudeRequest, nil
|
return &claudeRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func streamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) {
|
func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) {
|
||||||
var response dto.ChatCompletionsStreamResponse
|
var response dto.ChatCompletionsStreamResponse
|
||||||
var claudeUsage *ClaudeUsage
|
var claudeUsage *ClaudeUsage
|
||||||
response.Object = "chat.completion.chunk"
|
response.Object = "chat.completion.chunk"
|
||||||
@@ -149,6 +180,8 @@ func streamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
|
|||||||
choice.FinishReason = &finishReason
|
choice.FinishReason = &finishReason
|
||||||
}
|
}
|
||||||
claudeUsage = &claudeResponse.Usage
|
claudeUsage = &claudeResponse.Usage
|
||||||
|
} else if claudeResponse.Type == "message_stop" {
|
||||||
|
return nil, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if claudeUsage == nil {
|
if claudeUsage == nil {
|
||||||
@@ -158,7 +191,7 @@ func streamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
|
|||||||
return &response, claudeUsage
|
return &response, claudeUsage
|
||||||
}
|
}
|
||||||
|
|
||||||
func responseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
|
func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse {
|
||||||
choices := make([]dto.OpenAITextResponseChoice, 0)
|
choices := make([]dto.OpenAITextResponseChoice, 0)
|
||||||
fullTextResponse := dto.OpenAITextResponse{
|
fullTextResponse := dto.OpenAITextResponse{
|
||||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
||||||
@@ -242,7 +275,10 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
response, claudeUsage := streamResponseClaude2OpenAI(requestMode, &claudeResponse)
|
response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||||
|
if response == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
if requestMode == RequestModeCompletion {
|
if requestMode == RequestModeCompletion {
|
||||||
responseText += claudeResponse.Completion
|
responseText += claudeResponse.Completion
|
||||||
responseId = response.Id
|
responseId = response.Id
|
||||||
@@ -317,7 +353,7 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT
|
|||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
fullTextResponse := responseClaude2OpenAI(requestMode, &claudeResponse)
|
fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
|
||||||
completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, false)
|
completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
|
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil
|
||||||
|
|||||||
52
relay/channel/cohere/adaptor.go
Normal file
52
relay/channel/cohere/adaptor.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
package cohere
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/relay/channel"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Adaptor struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
|
return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||||
|
channel.SetupApiRequestHeader(info, c, req)
|
||||||
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||||
|
return requestOpenAI2Cohere(*request), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||||
|
return channel.DoApiRequest(a, c, info, requestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
if info.IsStream {
|
||||||
|
err, usage = cohereStreamHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
|
||||||
|
} else {
|
||||||
|
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetModelList() []string {
|
||||||
|
return ModelList
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Adaptor) GetChannelName() string {
|
||||||
|
return ChannelName
|
||||||
|
}
|
||||||
7
relay/channel/cohere/constant.go
Normal file
7
relay/channel/cohere/constant.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
package cohere
|
||||||
|
|
||||||
|
var ModelList = []string{
|
||||||
|
"command-r", "command-r-plus", "command-light", "command-light-nightly", "command", "command-nightly",
|
||||||
|
}
|
||||||
|
|
||||||
|
var ChannelName = "cohere"
|
||||||
44
relay/channel/cohere/dto.go
Normal file
44
relay/channel/cohere/dto.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
package cohere
|
||||||
|
|
||||||
|
type CohereRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
ChatHistory []ChatHistory `json:"chat_history"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Stream bool `json:"stream"`
|
||||||
|
MaxTokens int64 `json:"max_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatHistory struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CohereResponse struct {
|
||||||
|
IsFinished bool `json:"is_finished"`
|
||||||
|
EventType string `json:"event_type"`
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
FinishReason string `json:"finish_reason,omitempty"`
|
||||||
|
Response *CohereResponseResult `json:"response"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CohereResponseResult struct {
|
||||||
|
ResponseId string `json:"response_id"`
|
||||||
|
FinishReason string `json:"finish_reason,omitempty"`
|
||||||
|
Text string `json:"text"`
|
||||||
|
Meta CohereMeta `json:"meta"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CohereMeta struct {
|
||||||
|
//Tokens CohereTokens `json:"tokens"`
|
||||||
|
BilledUnits CohereBilledUnits `json:"billed_units"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CohereBilledUnits struct {
|
||||||
|
InputTokens int `json:"input_tokens"`
|
||||||
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type CohereTokens struct {
|
||||||
|
InputTokens int `json:"input_tokens"`
|
||||||
|
OutputTokens int `json:"output_tokens"`
|
||||||
|
}
|
||||||
189
relay/channel/cohere/relay-cohere.go
Normal file
189
relay/channel/cohere/relay-cohere.go
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
package cohere
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/service"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
|
||||||
|
cohereReq := CohereRequest{
|
||||||
|
Model: textRequest.Model,
|
||||||
|
ChatHistory: []ChatHistory{},
|
||||||
|
Message: "",
|
||||||
|
Stream: textRequest.Stream,
|
||||||
|
MaxTokens: textRequest.GetMaxTokens(),
|
||||||
|
}
|
||||||
|
if cohereReq.MaxTokens == 0 {
|
||||||
|
cohereReq.MaxTokens = 4000
|
||||||
|
}
|
||||||
|
for _, msg := range textRequest.Messages {
|
||||||
|
if msg.Role == "user" {
|
||||||
|
cohereReq.Message = msg.StringContent()
|
||||||
|
} else {
|
||||||
|
var role string
|
||||||
|
if msg.Role == "assistant" {
|
||||||
|
role = "CHATBOT"
|
||||||
|
} else if msg.Role == "system" {
|
||||||
|
role = "SYSTEM"
|
||||||
|
} else {
|
||||||
|
role = "USER"
|
||||||
|
}
|
||||||
|
cohereReq.ChatHistory = append(cohereReq.ChatHistory, ChatHistory{
|
||||||
|
Role: role,
|
||||||
|
Message: msg.StringContent(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &cohereReq
|
||||||
|
}
|
||||||
|
|
||||||
|
func stopReasonCohere2OpenAI(reason string) string {
|
||||||
|
switch reason {
|
||||||
|
case "COMPLETE":
|
||||||
|
return "stop"
|
||||||
|
case "MAX_TOKENS":
|
||||||
|
return "max_tokens"
|
||||||
|
default:
|
||||||
|
return reason
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||||
|
createdTime := common.GetTimestamp()
|
||||||
|
usage := &dto.Usage{}
|
||||||
|
responseText := ""
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
|
if atEOF && len(data) == 0 {
|
||||||
|
return 0, nil, nil
|
||||||
|
}
|
||||||
|
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||||
|
return i + 1, data[0:i], nil
|
||||||
|
}
|
||||||
|
if atEOF {
|
||||||
|
return len(data), data, nil
|
||||||
|
}
|
||||||
|
return 0, nil, nil
|
||||||
|
})
|
||||||
|
dataChan := make(chan string)
|
||||||
|
stopChan := make(chan bool)
|
||||||
|
go func() {
|
||||||
|
for scanner.Scan() {
|
||||||
|
data := scanner.Text()
|
||||||
|
dataChan <- data
|
||||||
|
}
|
||||||
|
stopChan <- true
|
||||||
|
}()
|
||||||
|
service.SetEventStreamHeaders(c)
|
||||||
|
c.Stream(func(w io.Writer) bool {
|
||||||
|
select {
|
||||||
|
case data := <-dataChan:
|
||||||
|
data = strings.TrimSuffix(data, "\r")
|
||||||
|
var cohereResp CohereResponse
|
||||||
|
err := json.Unmarshal([]byte(data), &cohereResp)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
var openaiResp dto.ChatCompletionsStreamResponse
|
||||||
|
openaiResp.Id = responseId
|
||||||
|
openaiResp.Created = createdTime
|
||||||
|
openaiResp.Object = "chat.completion.chunk"
|
||||||
|
openaiResp.Model = modelName
|
||||||
|
if cohereResp.IsFinished {
|
||||||
|
finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
|
||||||
|
openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
|
||||||
|
{
|
||||||
|
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{},
|
||||||
|
Index: 0,
|
||||||
|
FinishReason: &finishReason,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if cohereResp.Response != nil {
|
||||||
|
usage.PromptTokens = cohereResp.Response.Meta.BilledUnits.InputTokens
|
||||||
|
usage.CompletionTokens = cohereResp.Response.Meta.BilledUnits.OutputTokens
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
|
||||||
|
{
|
||||||
|
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: cohereResp.Text,
|
||||||
|
},
|
||||||
|
Index: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
responseText += cohereResp.Text
|
||||||
|
}
|
||||||
|
jsonStr, err := json.Marshal(openaiResp)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling stream response: " + err.Error())
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
||||||
|
return true
|
||||||
|
case <-stopChan:
|
||||||
|
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
|
if usage.PromptTokens == 0 {
|
||||||
|
usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens)
|
||||||
|
}
|
||||||
|
return nil, usage
|
||||||
|
}
|
||||||
|
|
||||||
|
func cohereHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
createdTime := common.GetTimestamp()
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
var cohereResp CohereResponseResult
|
||||||
|
err = json.Unmarshal(responseBody, &cohereResp)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
usage := dto.Usage{}
|
||||||
|
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
|
||||||
|
usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
|
||||||
|
usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
|
||||||
|
|
||||||
|
var openaiResp dto.TextResponse
|
||||||
|
openaiResp.Id = cohereResp.ResponseId
|
||||||
|
openaiResp.Created = createdTime
|
||||||
|
openaiResp.Object = "chat.completion"
|
||||||
|
openaiResp.Model = modelName
|
||||||
|
openaiResp.Usage = usage
|
||||||
|
|
||||||
|
content, _ := json.Marshal(cohereResp.Text)
|
||||||
|
openaiResp.Choices = []dto.OpenAITextResponseChoice{
|
||||||
|
{
|
||||||
|
Index: 0,
|
||||||
|
Message: dto.Message{Content: content, Role: "assistant"},
|
||||||
|
FinishReason: stopReasonCohere2OpenAI(cohereResp.FinishReason),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonResponse, err := json.Marshal(openaiResp)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
|
}
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
|
return nil, &usage
|
||||||
|
}
|
||||||
@@ -52,7 +52,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
} else {
|
} else {
|
||||||
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
if info.RelayMode == relayconstant.RelayModeEmbeddings {
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
|
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
|
||||||
@@ -41,7 +42,7 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest {
|
|||||||
func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest {
|
func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest {
|
||||||
return &OllamaEmbeddingRequest{
|
return &OllamaEmbeddingRequest{
|
||||||
Model: request.Model,
|
Model: request.Model,
|
||||||
Prompt: request.Input,
|
Prompt: strings.Join(request.ParseInput(), " "),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -72,8 +72,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode)
|
var toolCount int
|
||||||
|
err, responseText, toolCount = OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
|
usage.CompletionTokens += toolCount * 7
|
||||||
} else {
|
} else {
|
||||||
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,9 +16,10 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) {
|
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string, int) {
|
||||||
//checkSensitive := constant.ShouldCheckCompletionSensitive()
|
//checkSensitive := constant.ShouldCheckCompletionSensitive()
|
||||||
var responseTextBuilder strings.Builder
|
var responseTextBuilder strings.Builder
|
||||||
|
toolCount := 0
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||||
if atEOF && len(data) == 0 {
|
if atEOF && len(data) == 0 {
|
||||||
@@ -68,6 +69,15 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
for _, choice := range streamResponse.Choices {
|
for _, choice := range streamResponse.Choices {
|
||||||
responseTextBuilder.WriteString(choice.Delta.Content)
|
responseTextBuilder.WriteString(choice.Delta.Content)
|
||||||
|
if choice.Delta.ToolCalls != nil {
|
||||||
|
if len(choice.Delta.ToolCalls) > toolCount {
|
||||||
|
toolCount = len(choice.Delta.ToolCalls)
|
||||||
|
}
|
||||||
|
for _, tool := range choice.Delta.ToolCalls {
|
||||||
|
responseTextBuilder.WriteString(tool.Function.Name)
|
||||||
|
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -75,6 +85,15 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
|
|||||||
for _, streamResponse := range streamResponses {
|
for _, streamResponse := range streamResponses {
|
||||||
for _, choice := range streamResponse.Choices {
|
for _, choice := range streamResponse.Choices {
|
||||||
responseTextBuilder.WriteString(choice.Delta.Content)
|
responseTextBuilder.WriteString(choice.Delta.Content)
|
||||||
|
if choice.Delta.ToolCalls != nil {
|
||||||
|
if len(choice.Delta.ToolCalls) > toolCount {
|
||||||
|
toolCount = len(choice.Delta.ToolCalls)
|
||||||
|
}
|
||||||
|
for _, tool := range choice.Delta.ToolCalls {
|
||||||
|
responseTextBuilder.WriteString(tool.Function.Name)
|
||||||
|
responseTextBuilder.WriteString(tool.Function.Arguments)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -123,10 +142,10 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
|
|||||||
})
|
})
|
||||||
err := resp.Body.Close()
|
err := resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", toolCount
|
||||||
}
|
}
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
return nil, responseTextBuilder.String()
|
return nil, responseTextBuilder.String(), toolCount
|
||||||
}
|
}
|
||||||
|
|
||||||
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
|
|||||||
@@ -47,8 +47,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
var toolCount int
|
||||||
|
err, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
|
usage.CompletionTokens += toolCount * 7
|
||||||
} else {
|
} else {
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ const (
|
|||||||
APITypeZhipu_v4
|
APITypeZhipu_v4
|
||||||
APITypeOllama
|
APITypeOllama
|
||||||
APITypePerplexity
|
APITypePerplexity
|
||||||
|
APITypeAws
|
||||||
|
APITypeCohere
|
||||||
|
|
||||||
APITypeDummy // this one is only for count, do not add any channel after this
|
APITypeDummy // this one is only for count, do not add any channel after this
|
||||||
)
|
)
|
||||||
@@ -49,6 +51,10 @@ func ChannelType2APIType(channelType int) int {
|
|||||||
apiType = APITypeOllama
|
apiType = APITypeOllama
|
||||||
case common.ChannelTypePerplexity:
|
case common.ChannelTypePerplexity:
|
||||||
apiType = APITypePerplexity
|
apiType = APITypePerplexity
|
||||||
|
case common.ChannelTypeAws:
|
||||||
|
apiType = APITypeAws
|
||||||
|
case common.ChannelTypeCohere:
|
||||||
|
apiType = APITypeCohere
|
||||||
}
|
}
|
||||||
return apiType
|
return apiType
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,15 +20,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var availableVoices = []string{
|
|
||||||
"alloy",
|
|
||||||
"echo",
|
|
||||||
"fable",
|
|
||||||
"onyx",
|
|
||||||
"nova",
|
|
||||||
"shimmer",
|
|
||||||
}
|
|
||||||
|
|
||||||
func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
||||||
tokenId := c.GetInt("token_id")
|
tokenId := c.GetInt("token_id")
|
||||||
channelType := c.GetInt("channel")
|
channelType := c.GetInt("channel")
|
||||||
@@ -59,9 +50,6 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
|||||||
if audioRequest.Voice == "" {
|
if audioRequest.Voice == "" {
|
||||||
return service.OpenAIErrorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest)
|
return service.OpenAIErrorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
if !common.StringsContains(availableVoices, audioRequest.Voice) {
|
|
||||||
return service.OpenAIErrorWrapper(errors.New("voice must be one of "+strings.Join(availableVoices, ", ")), "invalid_field_value", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
var err error
|
var err error
|
||||||
promptTokens := 0
|
promptTokens := 0
|
||||||
@@ -100,6 +88,22 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
succeed := false
|
||||||
|
defer func() {
|
||||||
|
if succeed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if preConsumedQuota > 0 {
|
||||||
|
// we need to roll back the pre-consumed quota
|
||||||
|
defer func() {
|
||||||
|
go func() {
|
||||||
|
// negative means add quota back for token & user
|
||||||
|
returnPreConsumedQuota(c, tokenId, userQuota, preConsumedQuota)
|
||||||
|
}()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// map model name
|
// map model name
|
||||||
modelMapping := c.GetString("model_mapping")
|
modelMapping := c.GetString("model_mapping")
|
||||||
if modelMapping != "" {
|
if modelMapping != "" {
|
||||||
@@ -163,6 +167,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode {
|
|||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
return relaycommon.RelayErrorHandler(resp)
|
return relaycommon.RelayErrorHandler(resp)
|
||||||
}
|
}
|
||||||
|
succeed = true
|
||||||
|
|
||||||
var audioResponse dto.AudioResponse
|
var audioResponse dto.AudioResponse
|
||||||
|
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
|
|||||||
}
|
}
|
||||||
|
|
||||||
if imageRequest.Model == "" {
|
if imageRequest.Model == "" {
|
||||||
imageRequest.Model = "dall-e-2"
|
imageRequest.Model = "dall-e-3"
|
||||||
}
|
}
|
||||||
if imageRequest.Size == "" {
|
if imageRequest.Size == "" {
|
||||||
imageRequest.Size = "1024x1024"
|
imageRequest.Size = "1024x1024"
|
||||||
@@ -186,7 +186,11 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC
|
|||||||
}
|
}
|
||||||
if quota != 0 {
|
if quota != 0 {
|
||||||
tokenName := c.GetString("token_name")
|
tokenName := c.GetString("token_name")
|
||||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
quality := "normal"
|
||||||
|
if imageRequest.Quality == "hd" {
|
||||||
|
quality = "hd"
|
||||||
|
}
|
||||||
|
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f, 大小 %s, 品质 %s", modelRatio, groupRatio, imageRequest.Size, quality)
|
||||||
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false)
|
model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false)
|
||||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||||
channelId := c.GetInt("channel_id")
|
channelId := c.GetInt("channel_id")
|
||||||
|
|||||||
@@ -110,11 +110,13 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
|
|||||||
midjourneyTask.StartTime = originTask.StartTime
|
midjourneyTask.StartTime = originTask.StartTime
|
||||||
midjourneyTask.FinishTime = originTask.FinishTime
|
midjourneyTask.FinishTime = originTask.FinishTime
|
||||||
midjourneyTask.ImageUrl = ""
|
midjourneyTask.ImageUrl = ""
|
||||||
if originTask.ImageUrl != "" {
|
if originTask.ImageUrl != "" && constant.MjForwardUrlEnabled {
|
||||||
midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId
|
midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId
|
||||||
if originTask.Status != "SUCCESS" {
|
if originTask.Status != "SUCCESS" {
|
||||||
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
|
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
midjourneyTask.ImageUrl = originTask.ImageUrl
|
||||||
}
|
}
|
||||||
midjourneyTask.Status = originTask.Status
|
midjourneyTask.Status = originTask.Status
|
||||||
midjourneyTask.FailReason = originTask.FailReason
|
midjourneyTask.FailReason = originTask.FailReason
|
||||||
|
|||||||
@@ -154,20 +154,28 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
|
|||||||
requestBody = bytes.NewBuffer(jsonData)
|
requestBody = bytes.NewBuffer(jsonData)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
statusCodeMappingStr := c.GetString("status_code_mapping")
|
||||||
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp != nil {
|
||||||
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||||
return service.RelayErrorHandler(resp)
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||||
|
openaiErr := service.RelayErrorHandler(resp)
|
||||||
|
// reset status code 重置状态码
|
||||||
|
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||||
|
return openaiErr
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
|
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
|
||||||
if openaiErr != nil {
|
if openaiErr != nil {
|
||||||
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
|
||||||
|
// reset status code 重置状态码
|
||||||
|
service.ResetStatusCode(openaiErr, statusCodeMappingStr)
|
||||||
return openaiErr
|
return openaiErr
|
||||||
}
|
}
|
||||||
postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice)
|
postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice)
|
||||||
@@ -181,7 +189,7 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
|
|||||||
checkSensitive := constant.ShouldCheckPromptSensitive()
|
checkSensitive := constant.ShouldCheckPromptSensitive()
|
||||||
switch info.RelayMode {
|
switch info.RelayMode {
|
||||||
case relayconstant.RelayModeChatCompletions:
|
case relayconstant.RelayModeChatCompletions:
|
||||||
promptTokens, err, sensitiveTrigger = service.CountTokenMessages(textRequest.Messages, textRequest.Model, checkSensitive)
|
promptTokens, err, sensitiveTrigger = service.CountTokenChatRequest(*textRequest, textRequest.Model, checkSensitive)
|
||||||
case relayconstant.RelayModeCompletions:
|
case relayconstant.RelayModeCompletions:
|
||||||
promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive)
|
promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive)
|
||||||
case relayconstant.RelayModeModerations:
|
case relayconstant.RelayModeModerations:
|
||||||
|
|||||||
@@ -3,8 +3,10 @@ package relay
|
|||||||
import (
|
import (
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
"one-api/relay/channel/ali"
|
"one-api/relay/channel/ali"
|
||||||
|
"one-api/relay/channel/aws"
|
||||||
"one-api/relay/channel/baidu"
|
"one-api/relay/channel/baidu"
|
||||||
"one-api/relay/channel/claude"
|
"one-api/relay/channel/claude"
|
||||||
|
"one-api/relay/channel/cohere"
|
||||||
"one-api/relay/channel/gemini"
|
"one-api/relay/channel/gemini"
|
||||||
"one-api/relay/channel/ollama"
|
"one-api/relay/channel/ollama"
|
||||||
"one-api/relay/channel/openai"
|
"one-api/relay/channel/openai"
|
||||||
@@ -45,6 +47,10 @@ func GetAdaptor(apiType int) channel.Adaptor {
|
|||||||
return &ollama.Adaptor{}
|
return &ollama.Adaptor{}
|
||||||
case constant.APITypePerplexity:
|
case constant.APITypePerplexity:
|
||||||
return &perplexity.Adaptor{}
|
return &perplexity.Adaptor{}
|
||||||
|
case constant.APITypeAws:
|
||||||
|
return &aws.Adaptor{}
|
||||||
|
case constant.APITypeCohere:
|
||||||
|
return &cohere.Adaptor{}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -86,3 +86,22 @@ func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorW
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ResetStatusCode(openaiErr *dto.OpenAIErrorWithStatusCode, statusCodeMappingStr string) {
|
||||||
|
if statusCodeMappingStr == "" || statusCodeMappingStr == "{}" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
statusCodeMapping := make(map[string]string)
|
||||||
|
err := json.Unmarshal([]byte(statusCodeMappingStr), &statusCodeMapping)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if openaiErr.StatusCode == http.StatusOK {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
codeStr := strconv.Itoa(openaiErr.StatusCode)
|
||||||
|
if _, ok := statusCodeMapping[codeStr]; ok {
|
||||||
|
intCode, _ := strconv.Atoi(statusCodeMapping[codeStr])
|
||||||
|
openaiErr.StatusCode = intCode
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -116,6 +116,41 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) {
|
|||||||
return tiles*170 + 85, nil
|
return tiles*170 + 85, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, checkSensitive bool) (int, error, bool) {
|
||||||
|
tkm := 0
|
||||||
|
msgTokens, err, b := CountTokenMessages(request.Messages, model, checkSensitive)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err, b
|
||||||
|
}
|
||||||
|
tkm += msgTokens
|
||||||
|
if request.Tools != nil {
|
||||||
|
toolsData, _ := json.Marshal(request.Tools)
|
||||||
|
var openaiTools []dto.OpenAITools
|
||||||
|
err := json.Unmarshal(toolsData, &openaiTools)
|
||||||
|
if err != nil {
|
||||||
|
return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error())), false
|
||||||
|
}
|
||||||
|
countStr := ""
|
||||||
|
for _, tool := range openaiTools {
|
||||||
|
countStr = tool.Function.Name
|
||||||
|
if tool.Function.Description != "" {
|
||||||
|
countStr += tool.Function.Description
|
||||||
|
}
|
||||||
|
if tool.Function.Parameters != nil {
|
||||||
|
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
toolTokens, err, _ := CountTokenInput(countStr, model, false)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err, false
|
||||||
|
}
|
||||||
|
tkm += 8
|
||||||
|
tkm += toolTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
return tkm, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error, bool) {
|
func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error, bool) {
|
||||||
//recover when panic
|
//recover when panic
|
||||||
tokenEncoder := getTokenEncoder(model)
|
tokenEncoder := getTokenEncoder(model)
|
||||||
@@ -138,48 +173,31 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo
|
|||||||
tokenNum += tokensPerMessage
|
tokenNum += tokensPerMessage
|
||||||
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||||||
if len(message.Content) > 0 {
|
if len(message.Content) > 0 {
|
||||||
var arrayContent []dto.MediaMessage
|
if message.IsStringContent() {
|
||||||
if err := json.Unmarshal(message.Content, &arrayContent); err != nil {
|
stringContent := message.StringContent()
|
||||||
var stringContent string
|
if checkSensitive {
|
||||||
if err := json.Unmarshal(message.Content, &stringContent); err != nil {
|
contains, words := SensitiveWordContains(stringContent)
|
||||||
return 0, err, false
|
if contains {
|
||||||
} else {
|
err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", "))
|
||||||
if checkSensitive {
|
return 0, err, true
|
||||||
contains, words := SensitiveWordContains(stringContent)
|
|
||||||
if contains {
|
|
||||||
err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", "))
|
|
||||||
return 0, err, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tokenNum += getTokenNum(tokenEncoder, stringContent)
|
|
||||||
if message.Name != nil {
|
|
||||||
tokenNum += tokensPerName
|
|
||||||
tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
tokenNum += getTokenNum(tokenEncoder, stringContent)
|
||||||
|
if message.Name != nil {
|
||||||
|
tokenNum += tokensPerName
|
||||||
|
tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
|
var err error
|
||||||
|
arrayContent := message.ParseContent()
|
||||||
for _, m := range arrayContent {
|
for _, m := range arrayContent {
|
||||||
if m.Type == "image_url" {
|
if m.Type == "image_url" {
|
||||||
var imageTokenNum int
|
var imageTokenNum int
|
||||||
if model == "glm-4v" {
|
if model == "glm-4v" {
|
||||||
imageTokenNum = 1047
|
imageTokenNum = 1047
|
||||||
} else {
|
} else {
|
||||||
if str, ok := m.ImageUrl.(string); ok {
|
imageUrl := m.ImageUrl.(dto.MessageImageUrl)
|
||||||
imageTokenNum, err = getImageToken(&dto.MessageImageUrl{Url: str, Detail: "auto"})
|
imageTokenNum, err = getImageToken(&imageUrl)
|
||||||
} else {
|
|
||||||
imageUrlMap := m.ImageUrl.(map[string]interface{})
|
|
||||||
detail, ok := imageUrlMap["detail"]
|
|
||||||
if ok {
|
|
||||||
imageUrlMap["detail"] = detail.(string)
|
|
||||||
} else {
|
|
||||||
imageUrlMap["detail"] = "auto"
|
|
||||||
}
|
|
||||||
imageUrl := dto.MessageImageUrl{
|
|
||||||
Url: imageUrlMap["url"].(string),
|
|
||||||
Detail: imageUrlMap["detail"].(string),
|
|
||||||
}
|
|
||||||
imageTokenNum, err = getImageToken(&imageUrl)
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err, false
|
return 0, err, false
|
||||||
}
|
}
|
||||||
@@ -211,6 +229,23 @@ func CountTokenInput(input any, model string, check bool) (int, error, bool) {
|
|||||||
return CountTokenInput(fmt.Sprintf("%v", input), model, check)
|
return CountTokenInput(fmt.Sprintf("%v", input), model, check)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
|
||||||
|
tokens := 0
|
||||||
|
for _, message := range messages {
|
||||||
|
tkm, _, _ := CountTokenInput(message.Delta.Content, model, false)
|
||||||
|
tokens += tkm
|
||||||
|
if message.Delta.ToolCalls != nil {
|
||||||
|
for _, tool := range message.Delta.ToolCalls {
|
||||||
|
tkm, _, _ := CountTokenInput(tool.Function.Name, model, false)
|
||||||
|
tokens += tkm
|
||||||
|
tkm, _, _ = CountTokenInput(tool.Function.Arguments, model, false)
|
||||||
|
tokens += tkm
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
func CountAudioToken(text string, model string, check bool) (int, error, bool) {
|
func CountAudioToken(text string, model string, check bool) (int, error, bool) {
|
||||||
if strings.HasPrefix(model, "tts") {
|
if strings.HasPrefix(model, "tts") {
|
||||||
contains, words := SensitiveWordContains(text)
|
contains, words := SensitiveWordContains(text)
|
||||||
|
|||||||
@@ -208,7 +208,6 @@ const LoginForm = () => {
|
|||||||
</Text>
|
</Text>
|
||||||
</div>
|
</div>
|
||||||
{status.github_oauth ||
|
{status.github_oauth ||
|
||||||
status.linuxdo_oauth ||
|
|
||||||
status.wechat_login ||
|
status.wechat_login ||
|
||||||
status.telegram_oauth ? (
|
status.telegram_oauth ? (
|
||||||
<>
|
<>
|
||||||
@@ -226,7 +225,6 @@ const LoginForm = () => {
|
|||||||
<Button
|
<Button
|
||||||
type='primary'
|
type='primary'
|
||||||
icon={<IconGithubLogo />}
|
icon={<IconGithubLogo />}
|
||||||
style={{ margin: '0 5px' }}
|
|
||||||
onClick={() =>
|
onClick={() =>
|
||||||
onGitHubOAuthClicked(status.github_client_id)
|
onGitHubOAuthClicked(status.github_client_id)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ const OperationSetting = () => {
|
|||||||
SensitiveWords: '',
|
SensitiveWords: '',
|
||||||
MjNotifyEnabled: '',
|
MjNotifyEnabled: '',
|
||||||
MjModeClearEnabled: '',
|
MjModeClearEnabled: '',
|
||||||
|
MjForwardUrlEnabled: '',
|
||||||
DrawingEnabled: '',
|
DrawingEnabled: '',
|
||||||
DataExportEnabled: '',
|
DataExportEnabled: '',
|
||||||
DataExportDefaultTime: 'hour',
|
DataExportDefaultTime: 'hour',
|
||||||
@@ -322,6 +323,12 @@ const OperationSetting = () => {
|
|||||||
name='MjNotifyEnabled'
|
name='MjNotifyEnabled'
|
||||||
onChange={handleInputChange}
|
onChange={handleInputChange}
|
||||||
/>
|
/>
|
||||||
|
<Form.Checkbox
|
||||||
|
checked={inputs.MjForwardUrlEnabled === 'true'}
|
||||||
|
label='开启之后将上游地址替换为服务器地址'
|
||||||
|
name='MjForwardUrlEnabled'
|
||||||
|
onChange={handleInputChange}
|
||||||
|
/>
|
||||||
<Form.Checkbox
|
<Form.Checkbox
|
||||||
checked={inputs.MjModeClearEnabled === 'true'}
|
checked={inputs.MjModeClearEnabled === 'true'}
|
||||||
label='开启之后会清除用户提示词中的--fast、--relax以及--turbo参数'
|
label='开启之后会清除用户提示词中的--fast、--relax以及--turbo参数'
|
||||||
|
|||||||
@@ -253,6 +253,8 @@ const UsersTable = () => {
|
|||||||
const [activePage, setActivePage] = useState(1);
|
const [activePage, setActivePage] = useState(1);
|
||||||
const [searchKeyword, setSearchKeyword] = useState('');
|
const [searchKeyword, setSearchKeyword] = useState('');
|
||||||
const [searching, setSearching] = useState(false);
|
const [searching, setSearching] = useState(false);
|
||||||
|
const [searchGroup, setSearchGroup] = useState('');
|
||||||
|
const [groupOptions, setGroupOptions] = useState([]);
|
||||||
const [userCount, setUserCount] = useState(ITEMS_PER_PAGE);
|
const [userCount, setUserCount] = useState(ITEMS_PER_PAGE);
|
||||||
const [showAddUser, setShowAddUser] = useState(false);
|
const [showAddUser, setShowAddUser] = useState(false);
|
||||||
const [showEditUser, setShowEditUser] = useState(false);
|
const [showEditUser, setShowEditUser] = useState(false);
|
||||||
@@ -316,6 +318,7 @@ const UsersTable = () => {
|
|||||||
.catch((reason) => {
|
.catch((reason) => {
|
||||||
showError(reason);
|
showError(reason);
|
||||||
});
|
});
|
||||||
|
fetchGroups().then();
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
const manageUser = async (username, action, record) => {
|
const manageUser = async (username, action, record) => {
|
||||||
@@ -370,15 +373,17 @@ const UsersTable = () => {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const searchUsers = async () => {
|
const searchUsers = async (searchKeyword, searchGroup) => {
|
||||||
if (searchKeyword === '') {
|
if (searchKeyword === '' && searchGroup === '') {
|
||||||
// if keyword is blank, load files instead.
|
// if keyword is blank, load files instead.
|
||||||
await loadUsers(0);
|
await loadUsers(0);
|
||||||
setActivePage(1);
|
setActivePage(1);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
setSearching(true);
|
setSearching(true);
|
||||||
const res = await API.get(`/api/user/search?keyword=${searchKeyword}`);
|
const res = await API.get(
|
||||||
|
`/api/user/search?keyword=${searchKeyword}&group=${searchGroup}`,
|
||||||
|
);
|
||||||
const { success, message, data } = res.data;
|
const { success, message, data } = res.data;
|
||||||
if (success) {
|
if (success) {
|
||||||
setUsers(data);
|
setUsers(data);
|
||||||
@@ -439,6 +444,25 @@ const UsersTable = () => {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const fetchGroups = async () => {
|
||||||
|
try {
|
||||||
|
let res = await API.get(`/api/group/`);
|
||||||
|
// add 'all' option
|
||||||
|
// res.data.data.unshift('all');
|
||||||
|
if (res === undefined) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
setGroupOptions(
|
||||||
|
res.data.data.map((group) => ({
|
||||||
|
label: group,
|
||||||
|
value: group,
|
||||||
|
})),
|
||||||
|
);
|
||||||
|
} catch (error) {
|
||||||
|
showError(error.message);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<AddUser
|
<AddUser
|
||||||
@@ -452,17 +476,44 @@ const UsersTable = () => {
|
|||||||
handleClose={closeEditUser}
|
handleClose={closeEditUser}
|
||||||
editingUser={editingUser}
|
editingUser={editingUser}
|
||||||
></EditUser>
|
></EditUser>
|
||||||
<Form onSubmit={searchUsers}>
|
<Form
|
||||||
<Form.Input
|
onSubmit={() => {
|
||||||
label='搜索关键字'
|
searchUsers(searchKeyword, searchGroup);
|
||||||
icon='search'
|
}}
|
||||||
field='keyword'
|
labelPosition='left'
|
||||||
iconPosition='left'
|
>
|
||||||
placeholder='搜索用户的 ID,用户名,显示名称,以及邮箱地址 ...'
|
<div style={{ display: 'flex' }}>
|
||||||
value={searchKeyword}
|
<Space>
|
||||||
loading={searching}
|
<Form.Input
|
||||||
onChange={(value) => handleKeywordChange(value)}
|
label='搜索关键字'
|
||||||
/>
|
icon='search'
|
||||||
|
field='keyword'
|
||||||
|
iconPosition='left'
|
||||||
|
placeholder='搜索用户的 ID,用户名,显示名称,以及邮箱地址 ...'
|
||||||
|
value={searchKeyword}
|
||||||
|
loading={searching}
|
||||||
|
onChange={(value) => handleKeywordChange(value)}
|
||||||
|
/>
|
||||||
|
<Form.Select
|
||||||
|
field='group'
|
||||||
|
label='分组'
|
||||||
|
optionList={groupOptions}
|
||||||
|
onChange={(value) => {
|
||||||
|
setSearchGroup(value);
|
||||||
|
searchUsers(searchKeyword, value);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
<Button
|
||||||
|
label='查询'
|
||||||
|
type='primary'
|
||||||
|
htmlType='submit'
|
||||||
|
className='btn-margin-right'
|
||||||
|
style={{ marginRight: 8 }}
|
||||||
|
>
|
||||||
|
查询
|
||||||
|
</Button>
|
||||||
|
</Space>
|
||||||
|
</div>
|
||||||
</Form>
|
</Form>
|
||||||
|
|
||||||
<Table
|
<Table
|
||||||
|
|||||||
@@ -22,6 +22,13 @@ export const CHANNEL_OPTIONS = [
|
|||||||
color: 'indigo',
|
color: 'indigo',
|
||||||
label: 'Anthropic Claude',
|
label: 'Anthropic Claude',
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
key: 33,
|
||||||
|
text: 'AWS Claude',
|
||||||
|
value: 33,
|
||||||
|
color: 'indigo',
|
||||||
|
label: 'AWS Claude',
|
||||||
|
},
|
||||||
{
|
{
|
||||||
key: 3,
|
key: 3,
|
||||||
text: 'Azure OpenAI',
|
text: 'Azure OpenAI',
|
||||||
@@ -43,6 +50,13 @@ export const CHANNEL_OPTIONS = [
|
|||||||
color: 'orange',
|
color: 'orange',
|
||||||
label: 'Google Gemini',
|
label: 'Google Gemini',
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
key: 34,
|
||||||
|
text: 'Cohere',
|
||||||
|
value: 34,
|
||||||
|
color: 'purple',
|
||||||
|
label: 'Cohere',
|
||||||
|
},
|
||||||
{
|
{
|
||||||
key: 15,
|
key: 15,
|
||||||
text: '百度文心千帆',
|
text: '百度文心千帆',
|
||||||
|
|||||||
@@ -166,13 +166,12 @@ export const modelColorMap = {
|
|||||||
'dall-e': 'rgb(147,112,219)', // 深紫色
|
'dall-e': 'rgb(147,112,219)', // 深紫色
|
||||||
'dall-e-2': 'rgb(147,112,219)', // 介于紫色和蓝色之间的色调
|
'dall-e-2': 'rgb(147,112,219)', // 介于紫色和蓝色之间的色调
|
||||||
'dall-e-3': 'rgb(153,50,204)', // 介于紫罗兰和洋红之间的色调
|
'dall-e-3': 'rgb(153,50,204)', // 介于紫罗兰和洋红之间的色调
|
||||||
midjourney: 'rgb(136,43,180)', // 介于紫罗兰和洋红之间的色调
|
|
||||||
'gpt-3.5-turbo': 'rgb(184,227,167)', // 浅绿色
|
'gpt-3.5-turbo': 'rgb(184,227,167)', // 浅绿色
|
||||||
'gpt-3.5-turbo-0301': 'rgb(131,220,131)', // 亮绿色
|
'gpt-3.5-turbo-0301': 'rgb(131,220,131)', // 亮绿色
|
||||||
'gpt-3.5-turbo-0613': 'rgb(60,179,113)', // 海洋绿
|
'gpt-3.5-turbo-0613': 'rgb(60,179,113)', // 海洋绿
|
||||||
'gpt-3.5-turbo-1106': 'rgb(32,178,170)', // 浅海洋绿
|
'gpt-3.5-turbo-1106': 'rgb(32,178,170)', // 浅海洋绿
|
||||||
'gpt-3.5-turbo-16k': 'rgb(252,200,149)', // 淡橙色
|
'gpt-3.5-turbo-16k': 'rgb(149,252,206)', // 淡橙色
|
||||||
'gpt-3.5-turbo-16k-0613': 'rgb(255,181,119)', // 淡桃色
|
'gpt-3.5-turbo-16k-0613': 'rgb(119,255,214)', // 淡桃色
|
||||||
'gpt-3.5-turbo-instruct': 'rgb(175,238,238)', // 粉蓝色
|
'gpt-3.5-turbo-instruct': 'rgb(175,238,238)', // 粉蓝色
|
||||||
'gpt-4': 'rgb(135,206,235)', // 天蓝色
|
'gpt-4': 'rgb(135,206,235)', // 天蓝色
|
||||||
'gpt-4-0314': 'rgb(70,130,180)', // 钢蓝色
|
'gpt-4-0314': 'rgb(70,130,180)', // 钢蓝色
|
||||||
@@ -203,6 +202,10 @@ export const modelColorMap = {
|
|||||||
'tts-1-hd': 'rgb(255,215,0)', // 金色
|
'tts-1-hd': 'rgb(255,215,0)', // 金色
|
||||||
'tts-1-hd-1106': 'rgb(255,223,0)', // 金黄色(略有区别)
|
'tts-1-hd-1106': 'rgb(255,223,0)', // 金黄色(略有区别)
|
||||||
'whisper-1': 'rgb(245,245,220)', // 米色
|
'whisper-1': 'rgb(245,245,220)', // 米色
|
||||||
|
'claude-3-opus-20240229': 'rgb(255,132,31)', // 橙红色
|
||||||
|
'claude-3-sonnet-20240229': 'rgb(253,135,93)', // 橙色
|
||||||
|
'claude-3-haiku-20240307': 'rgb(255,175,146)', // 浅橙色
|
||||||
|
'claude-2.1': 'rgb(255,209,190)', // 浅橙色(略有区别)
|
||||||
};
|
};
|
||||||
|
|
||||||
export function stringToColor(str) {
|
export function stringToColor(str) {
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import {
|
|||||||
Checkbox,
|
Checkbox,
|
||||||
Banner,
|
Banner,
|
||||||
} from '@douyinfe/semi-ui';
|
} from '@douyinfe/semi-ui';
|
||||||
|
import { Divider } from 'semantic-ui-react';
|
||||||
|
|
||||||
const MODEL_MAPPING_EXAMPLE = {
|
const MODEL_MAPPING_EXAMPLE = {
|
||||||
'gpt-3.5-turbo-0301': 'gpt-3.5-turbo',
|
'gpt-3.5-turbo-0301': 'gpt-3.5-turbo',
|
||||||
@@ -29,6 +30,10 @@ const MODEL_MAPPING_EXAMPLE = {
|
|||||||
'gpt-4-32k-0314': 'gpt-4-32k',
|
'gpt-4-32k-0314': 'gpt-4-32k',
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const STATUS_CODE_MAPPING_EXAMPLE = {
|
||||||
|
400: '500',
|
||||||
|
};
|
||||||
|
|
||||||
function type2secretPrompt(type) {
|
function type2secretPrompt(type) {
|
||||||
// inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')
|
// inputs.type === 15 ? '按照如下格式输入:APIKey|SecretKey' : (inputs.type === 18 ? '按照如下格式输入:APPID|APISecret|APIKey' : '请输入渠道对应的鉴权密钥')
|
||||||
switch (type) {
|
switch (type) {
|
||||||
@@ -40,6 +45,8 @@ function type2secretPrompt(type) {
|
|||||||
return '按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041';
|
return '按照如下格式输入:APIKey-AppId,例如:fastgpt-0sp2gtvfdgyi4k30jwlgwf1i-64f335d84283f05518e9e041';
|
||||||
case 23:
|
case 23:
|
||||||
return '按照如下格式输入:AppId|SecretId|SecretKey';
|
return '按照如下格式输入:AppId|SecretId|SecretKey';
|
||||||
|
case 33:
|
||||||
|
return '按照如下格式输入:Ak|Sk|Region';
|
||||||
default:
|
default:
|
||||||
return '请输入渠道对应的鉴权密钥';
|
return '请输入渠道对应的鉴权密钥';
|
||||||
}
|
}
|
||||||
@@ -58,9 +65,11 @@ const EditChannel = (props) => {
|
|||||||
type: 1,
|
type: 1,
|
||||||
key: '',
|
key: '',
|
||||||
openai_organization: '',
|
openai_organization: '',
|
||||||
|
max_input_tokens: 0,
|
||||||
base_url: '',
|
base_url: '',
|
||||||
other: '',
|
other: '',
|
||||||
model_mapping: '',
|
model_mapping: '',
|
||||||
|
status_code_mapping: '',
|
||||||
models: [],
|
models: [],
|
||||||
auto_ban: 1,
|
auto_ban: 1,
|
||||||
test_model: '',
|
test_model: '',
|
||||||
@@ -81,6 +90,7 @@ const EditChannel = (props) => {
|
|||||||
if (name === 'type' && inputs.models.length === 0) {
|
if (name === 'type' && inputs.models.length === 0) {
|
||||||
let localModels = [];
|
let localModels = [];
|
||||||
switch (value) {
|
switch (value) {
|
||||||
|
case 33:
|
||||||
case 14:
|
case 14:
|
||||||
localModels = [
|
localModels = [
|
||||||
'claude-instant-1.2',
|
'claude-instant-1.2',
|
||||||
@@ -136,7 +146,24 @@ const EditChannel = (props) => {
|
|||||||
localModels = ['hunyuan'];
|
localModels = ['hunyuan'];
|
||||||
break;
|
break;
|
||||||
case 24:
|
case 24:
|
||||||
localModels = ['gemini-pro', 'gemini-pro-vision'];
|
localModels = [
|
||||||
|
'gemini-1.0-pro-001',
|
||||||
|
'gemini-1.0-pro-vision-001',
|
||||||
|
'gemini-1.5-pro',
|
||||||
|
'gemini-1.5-pro-latest',
|
||||||
|
'gemini-pro',
|
||||||
|
'gemini-pro-vision',
|
||||||
|
];
|
||||||
|
break;
|
||||||
|
case 34:
|
||||||
|
localModels = [
|
||||||
|
'command-r',
|
||||||
|
'command-r-plus',
|
||||||
|
'command-light',
|
||||||
|
'command-light-nightly',
|
||||||
|
'command',
|
||||||
|
'command-nightly',
|
||||||
|
];
|
||||||
break;
|
break;
|
||||||
case 25:
|
case 25:
|
||||||
localModels = [
|
localModels = [
|
||||||
@@ -658,18 +685,22 @@ const EditChannel = (props) => {
|
|||||||
autoComplete='new-password'
|
autoComplete='new-password'
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
<div style={{ marginTop: 10 }}>
|
{inputs.type === 1 && (
|
||||||
<Typography.Text strong>组织:</Typography.Text>
|
<>
|
||||||
</div>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Input
|
<Typography.Text strong>组织:</Typography.Text>
|
||||||
label='组织,可选,不填则为默认组织'
|
</div>
|
||||||
name='openai_organization'
|
<Input
|
||||||
placeholder='请输入组织org-xxx'
|
label='组织,可选,不填则为默认组织'
|
||||||
onChange={(value) => {
|
name='openai_organization'
|
||||||
handleInputChange('openai_organization', value);
|
placeholder='请输入组织org-xxx'
|
||||||
}}
|
onChange={(value) => {
|
||||||
value={inputs.openai_organization}
|
handleInputChange('openai_organization', value);
|
||||||
/>
|
}}
|
||||||
|
value={inputs.openai_organization}
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
<div style={{ marginTop: 10 }}>
|
<div style={{ marginTop: 10 }}>
|
||||||
<Typography.Text strong>默认测试模型:</Typography.Text>
|
<Typography.Text strong>默认测试模型:</Typography.Text>
|
||||||
</div>
|
</div>
|
||||||
@@ -745,6 +776,50 @@ const EditChannel = (props) => {
|
|||||||
/>
|
/>
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
|
<div style={{ marginTop: 10 }}>
|
||||||
|
<Typography.Text strong>
|
||||||
|
状态码复写(仅影响本地判断,不修改返回到上游的状态码):
|
||||||
|
</Typography.Text>
|
||||||
|
</div>
|
||||||
|
<TextArea
|
||||||
|
placeholder={`此项可选,用于复写返回的状态码,比如将claude渠道的400错误复写为500(用于重试),请勿滥用该功能,例如:\n${JSON.stringify(STATUS_CODE_MAPPING_EXAMPLE, null, 2)}`}
|
||||||
|
name='status_code_mapping'
|
||||||
|
onChange={(value) => {
|
||||||
|
handleInputChange('status_code_mapping', value);
|
||||||
|
}}
|
||||||
|
autosize
|
||||||
|
value={inputs.status_code_mapping}
|
||||||
|
autoComplete='new-password'
|
||||||
|
/>
|
||||||
|
<Typography.Text
|
||||||
|
style={{
|
||||||
|
color: 'rgba(var(--semi-blue-5), 1)',
|
||||||
|
userSelect: 'none',
|
||||||
|
cursor: 'pointer',
|
||||||
|
}}
|
||||||
|
onClick={() => {
|
||||||
|
handleInputChange(
|
||||||
|
'status_code_mapping',
|
||||||
|
JSON.stringify(STATUS_CODE_MAPPING_EXAMPLE, null, 2),
|
||||||
|
);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
填入模板
|
||||||
|
</Typography.Text>
|
||||||
|
{/*<div style={{ marginTop: 10 }}>*/}
|
||||||
|
{/* <Typography.Text strong>*/}
|
||||||
|
{/* 最大请求token(0表示不限制):*/}
|
||||||
|
{/* </Typography.Text>*/}
|
||||||
|
{/*</div>*/}
|
||||||
|
{/*<Input*/}
|
||||||
|
{/* label='最大请求token'*/}
|
||||||
|
{/* name='max_input_tokens'*/}
|
||||||
|
{/* placeholder='默认为0,表示不限制'*/}
|
||||||
|
{/* onChange={(value) => {*/}
|
||||||
|
{/* handleInputChange('max_input_tokens', value);*/}
|
||||||
|
{/* }}*/}
|
||||||
|
{/* value={inputs.max_input_tokens}*/}
|
||||||
|
{/*/>*/}
|
||||||
</Spin>
|
</Spin>
|
||||||
</SideSheet>
|
</SideSheet>
|
||||||
</>
|
</>
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ const Detail = (props) => {
|
|||||||
stack: true,
|
stack: true,
|
||||||
legends: {
|
legends: {
|
||||||
visible: true,
|
visible: true,
|
||||||
|
selectMode: 'single',
|
||||||
},
|
},
|
||||||
title: {
|
title: {
|
||||||
visible: true,
|
visible: true,
|
||||||
@@ -216,6 +217,8 @@ const Detail = (props) => {
|
|||||||
} else if (dataExportDefaultTime === 'week') {
|
} else if (dataExportDefaultTime === 'week') {
|
||||||
timeGranularity = 604800;
|
timeGranularity = 604800;
|
||||||
}
|
}
|
||||||
|
// sort created_at
|
||||||
|
data.sort((a, b) => a.created_at - b.created_at);
|
||||||
data.forEach((item) => {
|
data.forEach((item) => {
|
||||||
item['created_at'] =
|
item['created_at'] =
|
||||||
Math.floor(item['created_at'] / timeGranularity) * timeGranularity;
|
Math.floor(item['created_at'] / timeGranularity) * timeGranularity;
|
||||||
|
|||||||
Reference in New Issue
Block a user