diff --git a/common/constants.go b/common/constants.go index 82a1fa5..226bd42 100644 --- a/common/constants.go +++ b/common/constants.go @@ -227,6 +227,8 @@ const ( ChannelTypeZhipu_v4 = 26 ChannelTypePerplexity = 27 ChannelTypeLingYiWanWu = 31 + ChannelTypeAws = 33 + ChannelTypeCohere = 34 ) var ChannelBaseURLs = []string{ @@ -262,4 +264,7 @@ var ChannelBaseURLs = []string{ "", //29 "", //30 "https://api.lingyiwanwu.com", //31 + "", //32 + "", //33 + "https://api.cohere.ai", //34 } diff --git a/common/model-ratio.go b/common/model-ratio.go index 6df0311..bc06de1 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -12,6 +12,7 @@ import ( // TODO: when a new api is enabled, check the pricing here // 1 === $0.002 / 1K tokens // 1 === ¥0.014 / 1k tokens + var DefaultModelRatio = map[string]float64{ //"midjourney": 50, "gpt-4-gizmo-*": 15, @@ -73,11 +74,14 @@ var DefaultModelRatio = map[string]float64{ "ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens "PaLM-2": 1, - "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens - "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens - "gemini-1.0-pro-vision-001": 1, - "gemini-1.0-pro-001": 1, - "gemini-1.5-pro": 1, + "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens + "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens + "gemini-1.0-pro-vision-001": 1, + "gemini-1.0-pro-001": 1, + "gemini-1.5-pro-latest": 1, + "gemini-1.0-pro-latest": 1, + "gemini-1.0-pro-vision-latest": 1, + "gemini-ultra": 1, "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens @@ -102,6 +106,12 @@ var DefaultModelRatio = map[string]float64{ "yi-34b-chat-0205": 0.018, "yi-34b-chat-200k": 0.0864, "yi-vl-plus": 0.0432, + "command": 0.5, + "command-nightly": 0.5, + "command-light": 0.5, + "command-light-nightly": 0.5, + "command-r": 0.25, + "command-r-plus ": 1.5, } var DefaultModelPrice = map[string]float64{ @@ -223,6 +233,16 @@ func GetCompletionRatio(name string) float64 { if strings.HasPrefix(name, "gemini-") { return 3 } + if strings.HasPrefix(name, "command") { + switch name { + case "command-r": + return 3 + case "command-r-plus": + return 5 + default: + return 2 + } + } switch name { case "llama2-70b-4096": return 0.8 / 0.7 diff --git a/constant/midjourney.go b/constant/midjourney.go index 6d0b5ac..cd38d5f 100644 --- a/constant/midjourney.go +++ b/constant/midjourney.go @@ -1,8 +1,8 @@ package constant var MjNotifyEnabled = false - var MjModeClearEnabled = false +var MjForwardUrlEnabled = true const ( MjErrorUnknown = 5 diff --git a/constant/payment.go b/constant/payment.go new file mode 100644 index 0000000..da1e0dd --- /dev/null +++ b/constant/payment.go @@ -0,0 +1,8 @@ +package constant + +var PayAddress = "" +var CustomCallbackAddress = "" +var EpayId = "" +var EpayKey = "" +var Price = 7.3 +var MinTopUp = 1 diff --git a/controller/channel-test.go b/controller/channel-test.go index e407193..f66e0d6 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -86,7 +86,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr if err != nil { return err, nil } - if resp.StatusCode != http.StatusOK { + if resp != nil && resp.StatusCode != http.StatusOK { err := relaycommon.RelayErrorHandler(resp) return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error } diff --git a/controller/midjourney.go b/controller/midjourney.go index b5b832b..3d779c8 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -10,11 +10,11 @@ import ( "log" "net/http" "one-api/common" + "one-api/constant" "one-api/dto" "one-api/model" "one-api/service" "strconv" - "strings" "time" ) @@ -233,6 +233,12 @@ func GetAllMidjourney(c *gin.Context) { if logs == nil { logs = make([]*model.Midjourney, 0) } + if constant.MjForwardUrlEnabled { + for i, midjourney := range logs { + midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId + logs[i] = midjourney + } + } c.JSON(200, gin.H{ "success": true, "message": "", @@ -259,7 +265,7 @@ func GetUserMidjourney(c *gin.Context) { if logs == nil { logs = make([]*model.Midjourney, 0) } - if !strings.Contains(common.ServerAddress, "localhost") { + if constant.MjForwardUrlEnabled { for i, midjourney := range logs { midjourney.ImageUrl = common.ServerAddress + "/mj/image/" + midjourney.MjId logs[i] = midjourney diff --git a/controller/option.go b/controller/option.go index d2272b7..d43a08a 100644 --- a/controller/option.go +++ b/controller/option.go @@ -14,7 +14,7 @@ func GetOptions(c *gin.Context) { var options []*model.Option common.OptionMapRWMutex.Lock() for k, v := range common.OptionMap { - if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") { + if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") || strings.HasSuffix(k, "Key") { continue } options = append(options, &model.Option{ diff --git a/controller/relay.go b/controller/relay.go index 0fd9d7a..3a4ae72 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -15,6 +15,7 @@ import ( "one-api/relay/constant" relayconstant "one-api/relay/constant" "one-api/service" + "strings" ) func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { @@ -42,7 +43,7 @@ func Relay(c *gin.Context) { group := c.GetString("group") originalModel := c.GetString("original_model") openaiErr := relayHandler(c, relayMode) - retryLogStr := fmt.Sprintf("重试:%d", channelId) + useChannel := []int{channelId} if openaiErr != nil { go processChannelError(c, channelId, openaiErr) } else { @@ -55,7 +56,7 @@ func Relay(c *gin.Context) { break } channelId = channel.Id - retryLogStr += fmt.Sprintf("->%d", channel.Id) + useChannel = append(useChannel, channelId) common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) middleware.SetupContextForSelectedChannel(c, channel, originalModel) @@ -66,7 +67,10 @@ func Relay(c *gin.Context) { go processChannelError(c, channelId, openaiErr) } } - common.LogInfo(c.Request.Context(), retryLogStr) + if len(useChannel) > 1 { + retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) + common.LogInfo(c.Request.Context(), retryLogStr) + } if openaiErr != nil { if openaiErr.StatusCode == http.StatusTooManyRequests { @@ -105,6 +109,10 @@ func shouldRetry(c *gin.Context, channelId int, openaiErr *dto.OpenAIErrorWithSt if openaiErr.StatusCode == http.StatusBadRequest { return false } + if openaiErr.StatusCode == 408 { + // azure处理超时不重试 + return false + } if openaiErr.LocalError { return false } diff --git a/controller/user.go b/controller/user.go index c305cd4..4048713 100644 --- a/controller/user.go +++ b/controller/user.go @@ -217,7 +217,8 @@ func GetAllUsers(c *gin.Context) { func SearchUsers(c *gin.Context) { keyword := c.Query("keyword") - users, err := model.SearchUsers(keyword) + group := c.Query("group") + users, err := model.SearchUsers(keyword, group) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -453,7 +454,7 @@ func UpdateUser(c *gin.Context) { updatedUser.Password = "" // rollback to what it should be } updatePassword := updatedUser.Password != "" - if err := updatedUser.Update(updatePassword); err != nil { + if err := updatedUser.Edit(updatePassword); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), @@ -740,7 +741,7 @@ func ManageUser(c *gin.Context) { user.Role = common.RoleCommonUser } - if err := user.UpdateAll(false); err != nil { + if err := user.Update(false); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), diff --git a/dto/text_request.go b/dto/text_request.go index 936660e..0f696fc 100644 --- a/dto/text_request.go +++ b/dto/text_request.go @@ -32,6 +32,21 @@ type GeneralOpenAIRequest struct { TopLogProbs int `json:"top_logprobs,omitempty"` } +type OpenAITools struct { + Type string `json:"type"` + Function OpenAIFunction `json:"function"` +} + +type OpenAIFunction struct { + Description string `json:"description,omitempty"` + Name string `json:"name"` + Parameters any `json:"parameters,omitempty"` +} + +func (r GeneralOpenAIRequest) GetMaxTokens() int64 { + return int64(r.MaxTokens) +} + func (r GeneralOpenAIRequest) ParseInput() []string { if r.Input == nil { return nil diff --git a/dto/text_response.go b/dto/text_response.go index 98275fe..a589d75 100644 --- a/dto/text_response.go +++ b/dto/text_response.go @@ -54,13 +54,29 @@ type OpenAIEmbeddingResponse struct { } type ChatCompletionsStreamResponseChoice struct { - Delta struct { - Content string `json:"content"` - Role string `json:"role,omitempty"` - ToolCalls any `json:"tool_calls,omitempty"` - } `json:"delta"` - FinishReason *string `json:"finish_reason,omitempty"` - Index int `json:"index,omitempty"` + Delta ChatCompletionsStreamResponseChoiceDelta `json:"delta"` + FinishReason *string `json:"finish_reason,omitempty"` + Index int `json:"index,omitempty"` +} + +type ChatCompletionsStreamResponseChoiceDelta struct { + Content string `json:"content"` + Role string `json:"role,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` +} + +type ToolCall struct { + // Index is not nil only in chat completion chunk object + Index *int `json:"index,omitempty"` + ID string `json:"id"` + Type any `json:"type"` + Function FunctionCall `json:"function"` +} + +type FunctionCall struct { + Name string `json:"name,omitempty"` + // call function with arguments in JSON format + Arguments string `json:"arguments,omitempty"` } type ChatCompletionsStreamResponse struct { diff --git a/go.mod b/go.mod index 958700a..a6e41cc 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,9 @@ go 1.18 require ( github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 + github.com/aws/aws-sdk-go-v2 v1.26.1 + github.com/aws/aws-sdk-go-v2/credentials v1.17.11 + github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 github.com/gin-contrib/cors v1.4.0 github.com/gin-contrib/gzip v0.0.6 github.com/gin-contrib/sessions v0.0.5 @@ -15,6 +18,8 @@ require ( github.com/golang-jwt/jwt v3.2.2+incompatible github.com/google/uuid v1.3.0 github.com/gorilla/websocket v1.5.0 + github.com/jinzhu/copier v0.4.0 + github.com/pkg/errors v0.9.1 github.com/pkoukk/tiktoken-go v0.1.6 github.com/samber/lo v1.39.0 github.com/shirou/gopsutil v3.21.11+incompatible @@ -29,6 +34,10 @@ require ( require ( github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect + github.com/aws/smithy-go v1.20.2 // indirect github.com/bytedance/sonic v1.9.1 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect diff --git a/go.sum b/go.sum index 3f8223c..bfd81ef 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,20 @@ github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0 h1:onfun1RA+Kc github.com/anknown/ahocorasick v0.0.0-20190904063843-d75dbd5169c0/go.mod h1:4yg+jNTYlDEzBjhGS96v+zjyA3lfXlFd5CiTLIkPBLI= github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6 h1:HblK3eJHq54yET63qPCTJnks3loDse5xRmmqHgHzwoI= github.com/anknown/darts v0.0.0-20151216065714-83ff685239e6/go.mod h1:pbiaLIeYLUbgMY1kwEAdwO6UKD5ZNwdPGQlwokS9fe8= +github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA= +github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg= +github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs= +github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76wYsSZIZZQYBxkmMEjvL6GHy8XU= +github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg= +github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q= +github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= @@ -83,6 +97,8 @@ github.com/jackc/pgx/v5 v5.5.1 h1:5I9etrGkLrN+2XPCsi6XLlV5DITbSL/xBZdmAxFcXPI= github.com/jackc/pgx/v5 v5.5.1/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8= +github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= @@ -126,6 +142,8 @@ github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZO github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw= github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/middleware/distributor.go b/middleware/distributor.go index 108c783..ae5707f 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -177,6 +177,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode } c.Set("auto_ban", ban) c.Set("model_mapping", channel.GetModelMapping()) + c.Set("status_code_mapping", channel.GetStatusCodeMapping()) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) c.Set("base_url", channel.GetBaseURL()) // TODO: api_version统一 diff --git a/model/cache.go b/model/cache.go index f8ac584..c5fdc6d 100644 --- a/model/cache.go +++ b/model/cache.go @@ -25,9 +25,6 @@ var token2UserId = make(map[string]int) var token2UserIdLock sync.RWMutex func cacheSetToken(token *Token) error { - if !common.RedisEnabled { - return token.SelectUpdate() - } jsonBytes, err := json.Marshal(token) if err != nil { return err @@ -168,7 +165,11 @@ func CacheUpdateUserQuota(id int) error { if err != nil { return err } - err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) + return cacheSetUserQuota(id, quota) +} + +func cacheSetUserQuota(id int, quota int) error { + err := common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) return err } diff --git a/model/channel.go b/model/channel.go index 3e30ad4..5b35851 100644 --- a/model/channel.go +++ b/model/channel.go @@ -25,8 +25,10 @@ type Channel struct { Group string `json:"group" gorm:"type:varchar(64);default:'default'"` UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` ModelMapping *string `json:"model_mapping" gorm:"type:varchar(1024);default:''"` - Priority *int64 `json:"priority" gorm:"bigint;default:0"` - AutoBan *int `json:"auto_ban" gorm:"default:1"` + //MaxInputTokens *int `json:"max_input_tokens" gorm:"default:0"` + StatusCodeMapping *string `json:"status_code_mapping" gorm:"type:varchar(1024);default:''"` + Priority *int64 `json:"priority" gorm:"bigint;default:0"` + AutoBan *int `json:"auto_ban" gorm:"default:1"` } func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Channel, error) { @@ -153,6 +155,13 @@ func (channel *Channel) GetModelMapping() string { return *channel.ModelMapping } +func (channel *Channel) GetStatusCodeMapping() string { + if channel.StatusCodeMapping == nil { + return "" + } + return *channel.StatusCodeMapping +} + func (channel *Channel) Insert() error { var err error err = DB.Create(channel).Error diff --git a/model/option.go b/model/option.go index da6cfcb..18082b2 100644 --- a/model/option.go +++ b/model/option.go @@ -98,6 +98,7 @@ func InitOptionMap() { common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar) common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled) common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled) + common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(constant.MjForwardUrlEnabled) common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled) common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled) //common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled) @@ -207,6 +208,8 @@ func updateOptionMap(key string, value string) (err error) { constant.MjNotifyEnabled = boolValue case "MjModeClearEnabled": constant.MjModeClearEnabled = boolValue + case "MjForwardUrlEnabled": + constant.MjForwardUrlEnabled = boolValue case "CheckSensitiveEnabled": constant.CheckSensitiveEnabled = boolValue case "CheckSensitiveOnPromptEnabled": diff --git a/model/token.go b/model/token.go index 08909cb..1bbf6c4 100644 --- a/model/token.go +++ b/model/token.go @@ -102,6 +102,11 @@ func GetTokenById(id int) (*Token, error) { token := Token{Id: id} var err error = nil err = DB.First(&token, "id = ?", id).Error + if err != nil { + if common.RedisEnabled { + go cacheSetToken(&token) + } + } return &token, err } diff --git a/model/user.go b/model/user.go index aa9060d..3f75182 100644 --- a/model/user.go +++ b/model/user.go @@ -76,25 +76,34 @@ func GetAllUsers(startIdx int, num int) (users []*User, err error) { return users, err } -func SearchUsers(keyword string) ([]*User, error) { +func SearchUsers(keyword string, group string) ([]*User, error) { var users []*User var err error // 尝试将关键字转换为整数ID keywordInt, err := strconv.Atoi(keyword) if err == nil { - // 如果转换成功,按照ID搜索用户 - err = DB.Unscoped().Omit("password").Where("id = ?", keywordInt).Find(&users).Error + // 如果转换成功,按照ID和可选的组别搜索用户 + query := DB.Unscoped().Omit("password").Where("`id` = ?", keywordInt) + if group != "" { + query = query.Where("`group` = ?", group) // 使用反引号包围group + } + err = query.Find(&users).Error if err != nil || len(users) > 0 { - // 如果依据ID找到用户或者发生错误,返回结果或错误 return users, err } } - // 如果ID转换失败或者没有找到用户,依据其他字段进行模糊搜索 - err = DB.Unscoped().Omit("password"). - Where("username LIKE ? OR email LIKE ? OR display_name LIKE ?", keyword+"%", keyword+"%", keyword+"%"). - Find(&users).Error + err = nil + + query := DB.Unscoped().Omit("password") + likeCondition := "`username` LIKE ? OR `email` LIKE ? OR `display_name` LIKE ?" + if group != "" { + query = query.Where("("+likeCondition+") AND `group` = ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group) + } else { + query = query.Where(likeCondition, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%") + } + err = query.Find(&users).Error return users, err } @@ -252,7 +261,7 @@ func (user *User) Update(updatePassword bool) error { return err } -func (user *User) UpdateAll(updatePassword bool) error { +func (user *User) Edit(updatePassword bool) error { var err error if updatePassword { user.Password, err = common.Password2Hash(user.Password) @@ -262,7 +271,13 @@ func (user *User) UpdateAll(updatePassword bool) error { } newUser := *user DB.First(&user, user.Id) - err = DB.Model(user).Select("*").Updates(newUser).Error + err = DB.Model(user).Updates(map[string]interface{}{ + "username": newUser.Username, + "password": newUser.Password, + "display_name": newUser.DisplayName, + "group": newUser.Group, + "quota": newUser.Quota, + }).Error if err == nil { if common.RedisEnabled { _ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second) @@ -451,6 +466,11 @@ func ValidateAccessToken(token string) (user *User) { func GetUserQuota(id int) (quota int, err error) { err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error + if err != nil { + if common.RedisEnabled { + go cacheSetUserQuota(id, quota) + } + } return quota, err } diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go new file mode 100644 index 0000000..23c69db --- /dev/null +++ b/relay/channel/aws/adaptor.go @@ -0,0 +1,79 @@ +package aws + +import ( + "errors" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + "one-api/relay/channel/claude" + relaycommon "one-api/relay/common" + "strings" +) + +const ( + RequestModeCompletion = 1 + RequestModeMessage = 2 +) + +type Adaptor struct { + RequestMode int +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { + if strings.HasPrefix(info.UpstreamModelName, "claude-3") { + a.RequestMode = RequestModeMessage + } else { + a.RequestMode = RequestModeCompletion + } +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + return "", nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + + var claudeReq *claude.ClaudeRequest + var err error + if a.RequestMode == RequestModeCompletion { + claudeReq = claude.RequestOpenAI2ClaudeComplete(*request) + } else { + claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request) + } + c.Set("request_model", request.Model) + c.Set("converted_request", claudeReq) + return claudeReq, err +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { + return nil, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { + if info.IsStream { + err, usage = awsStreamHandler(c, info, a.RequestMode) + } else { + err, usage = awsHandler(c, info, a.RequestMode) + } + return +} + +func (a *Adaptor) GetModelList() (models []string) { + for n := range awsModelIDMap { + models = append(models, n) + } + + return +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/aws/constants.go b/relay/channel/aws/constants.go new file mode 100644 index 0000000..0b03785 --- /dev/null +++ b/relay/channel/aws/constants.go @@ -0,0 +1,12 @@ +package aws + +var awsModelIDMap = map[string]string{ + "claude-instant-1.2": "anthropic.claude-instant-v1", + "claude-2.0": "anthropic.claude-v2", + "claude-2.1": "anthropic.claude-v2:1", + "claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0", + "claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0", + "claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0", +} + +var ChannelName = "aws" diff --git a/relay/channel/aws/dto.go b/relay/channel/aws/dto.go new file mode 100644 index 0000000..7450908 --- /dev/null +++ b/relay/channel/aws/dto.go @@ -0,0 +1,14 @@ +package aws + +import "one-api/relay/channel/claude" + +type AwsClaudeRequest struct { + // AnthropicVersion should be "bedrock-2023-05-31" + AnthropicVersion string `json:"anthropic_version"` + Messages []claude.ClaudeMessage `json:"messages"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` +} diff --git a/relay/channel/aws/relay-aws.go b/relay/channel/aws/relay-aws.go new file mode 100644 index 0000000..bf64f03 --- /dev/null +++ b/relay/channel/aws/relay-aws.go @@ -0,0 +1,211 @@ +package aws + +import ( + "bytes" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "github.com/jinzhu/copier" + "github.com/pkg/errors" + "io" + "net/http" + "one-api/common" + relaymodel "one-api/dto" + "one-api/relay/channel/claude" + relaycommon "one-api/relay/common" + "strings" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" + "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" +) + +func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) { + awsSecret := strings.Split(info.ApiKey, "|") + if len(awsSecret) != 3 { + return nil, errors.New("invalid aws secret key") + } + ak := awsSecret[0] + sk := awsSecret[1] + region := awsSecret[2] + client := bedrockruntime.New(bedrockruntime.Options{ + Region: region, + Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")), + }) + + return client, nil +} + +func wrapErr(err error) *relaymodel.OpenAIErrorWithStatusCode { + return &relaymodel.OpenAIErrorWithStatusCode{ + StatusCode: http.StatusInternalServerError, + Error: relaymodel.OpenAIError{ + Message: fmt.Sprintf("%s", err.Error()), + }, + } +} + +func awsModelID(requestModel string) (string, error) { + if awsModelID, ok := awsModelIDMap[requestModel]; ok { + return awsModelID, nil + } + + return "", errors.Errorf("model %s not found", requestModel) +} + +func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) { + awsCli, err := newAwsClient(c, info) + if err != nil { + return wrapErr(errors.Wrap(err, "newAwsClient")), nil + } + + awsModelId, err := awsModelID(c.GetString("request_model")) + if err != nil { + return wrapErr(errors.Wrap(err, "awsModelID")), nil + } + + awsReq := &bedrockruntime.InvokeModelInput{ + ModelId: aws.String(awsModelId), + Accept: aws.String("application/json"), + ContentType: aws.String("application/json"), + } + + claudeReq_, ok := c.Get("converted_request") + if !ok { + return wrapErr(errors.New("request not found")), nil + } + claudeReq := claudeReq_.(*claude.ClaudeRequest) + awsClaudeReq := &AwsClaudeRequest{ + AnthropicVersion: "bedrock-2023-05-31", + } + if err = copier.Copy(awsClaudeReq, claudeReq); err != nil { + return wrapErr(errors.Wrap(err, "copy request")), nil + } + + awsReq.Body, err = json.Marshal(awsClaudeReq) + if err != nil { + return wrapErr(errors.Wrap(err, "marshal request")), nil + } + + awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq) + if err != nil { + return wrapErr(errors.Wrap(err, "InvokeModel")), nil + } + + claudeResponse := new(claude.ClaudeResponse) + err = json.Unmarshal(awsResp.Body, claudeResponse) + if err != nil { + return wrapErr(errors.Wrap(err, "unmarshal response")), nil + } + + openaiResp := claude.ResponseClaude2OpenAI(requestMode, claudeResponse) + usage := relaymodel.Usage{ + PromptTokens: claudeResponse.Usage.InputTokens, + CompletionTokens: claudeResponse.Usage.OutputTokens, + TotalTokens: claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens, + } + openaiResp.Usage = usage + + c.JSON(http.StatusOK, openaiResp) + return nil, &usage +} + +func awsStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*relaymodel.OpenAIErrorWithStatusCode, *relaymodel.Usage) { + awsCli, err := newAwsClient(c, info) + if err != nil { + return wrapErr(errors.Wrap(err, "newAwsClient")), nil + } + + awsModelId, err := awsModelID(c.GetString("request_model")) + if err != nil { + return wrapErr(errors.Wrap(err, "awsModelID")), nil + } + + awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ + ModelId: aws.String(awsModelId), + Accept: aws.String("application/json"), + ContentType: aws.String("application/json"), + } + + claudeReq_, ok := c.Get("converted_request") + if !ok { + return wrapErr(errors.New("request not found")), nil + } + claudeReq := claudeReq_.(*claude.ClaudeRequest) + + awsClaudeReq := &AwsClaudeRequest{ + AnthropicVersion: "bedrock-2023-05-31", + } + if err = copier.Copy(awsClaudeReq, claudeReq); err != nil { + return wrapErr(errors.Wrap(err, "copy request")), nil + } + awsReq.Body, err = json.Marshal(awsClaudeReq) + if err != nil { + return wrapErr(errors.Wrap(err, "marshal request")), nil + } + + awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq) + if err != nil { + return wrapErr(errors.Wrap(err, "InvokeModelWithResponseStream")), nil + } + stream := awsResp.GetStream() + defer stream.Close() + + c.Writer.Header().Set("Content-Type", "text/event-stream") + var usage relaymodel.Usage + var id string + var model string + c.Stream(func(w io.Writer) bool { + event, ok := <-stream.Events() + if !ok { + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + + switch v := event.(type) { + case *types.ResponseStreamMemberChunk: + claudeResp := new(claude.ClaudeResponse) + err := json.NewDecoder(bytes.NewReader(v.Value.Bytes)).Decode(claudeResp) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return false + } + + response, claudeUsage := claude.StreamResponseClaude2OpenAI(requestMode, claudeResp) + if claudeUsage != nil { + usage.PromptTokens += claudeUsage.InputTokens + usage.CompletionTokens += claudeUsage.OutputTokens + } + + if response == nil { + return true + } + + if response.Id != "" { + id = response.Id + } + if response.Model != "" { + model = response.Model + } + response.Id = id + response.Model = model + + jsonStr, err := json.Marshal(response) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) + return true + case *types.UnknownUnionMember: + fmt.Println("unknown tag:", v.Tag) + return false + default: + fmt.Println("union is nil or unknown type") + return false + } + }) + + return nil, &usage +} diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index 45efd01..9add208 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -53,9 +53,9 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.Gen return nil, errors.New("request is nil") } if a.RequestMode == RequestModeCompletion { - return requestOpenAI2ClaudeComplete(*request), nil + return RequestOpenAI2ClaudeComplete(*request), nil } else { - return requestOpenAI2ClaudeMessage(*request) + return RequestOpenAI2ClaudeMessage(*request) } } diff --git a/relay/channel/claude/dto.go b/relay/channel/claude/dto.go index 50513d8..2d13e46 100644 --- a/relay/channel/claude/dto.go +++ b/relay/channel/claude/dto.go @@ -24,16 +24,15 @@ type ClaudeMessage struct { } type ClaudeRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt,omitempty"` - System string `json:"system,omitempty"` - Messages []ClaudeMessage `json:"messages,omitempty"` - MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"` - MaxTokens uint `json:"max_tokens,omitempty"` - StopSequences []string `json:"stop_sequences,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` + Model string `json:"model"` + Prompt string `json:"prompt,omitempty"` + System string `json:"system,omitempty"` + Messages []ClaudeMessage `json:"messages,omitempty"` + MaxTokens uint `json:"max_tokens,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` //ClaudeMetadata `json:"metadata,omitempty"` Stream bool `json:"stream,omitempty"` } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 2b5d3d2..33e742a 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -20,25 +20,25 @@ func stopReasonClaude2OpenAI(reason string) string { case "end_turn": return "stop" case "max_tokens": - return "length" + return "max_tokens" default: return reason } } -func requestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest { +func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest { claudeRequest := ClaudeRequest{ - Model: textRequest.Model, - Prompt: "", - MaxTokensToSample: textRequest.MaxTokens, - StopSequences: nil, - Temperature: textRequest.Temperature, - TopP: textRequest.TopP, - TopK: textRequest.TopK, - Stream: textRequest.Stream, + Model: textRequest.Model, + Prompt: "", + MaxTokens: textRequest.MaxTokens, + StopSequences: nil, + Temperature: textRequest.Temperature, + TopP: textRequest.TopP, + TopK: textRequest.TopK, + Stream: textRequest.Stream, } - if claudeRequest.MaxTokensToSample == 0 { - claudeRequest.MaxTokensToSample = 1000000 + if claudeRequest.MaxTokens == 0 { + claudeRequest.MaxTokens = 4096 } prompt := "" for _, message := range textRequest.Messages { @@ -57,7 +57,7 @@ func requestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeR return &claudeRequest } -func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) { +func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) { claudeRequest := ClaudeRequest{ Model: textRequest.Model, MaxTokens: textRequest.MaxTokens, @@ -70,8 +70,39 @@ func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR if claudeRequest.MaxTokens == 0 { claudeRequest.MaxTokens = 4096 } + formatMessages := make([]dto.Message, 0) + var lastMessage *dto.Message + for i, message := range textRequest.Messages { + if message.Role == "system" { + if i != 0 { + message.Role = "user" + } + } + if message.Role == "" { + message.Role = "user" + } + fmtMessage := dto.Message{ + Role: message.Role, + Content: message.Content, + } + if lastMessage != nil && lastMessage.Role == message.Role { + if lastMessage.IsStringContent() && message.IsStringContent() { + content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\"")) + fmtMessage.Content = content + // delete last message + formatMessages = formatMessages[:len(formatMessages)-1] + } + } + if fmtMessage.Content == nil { + content, _ := json.Marshal("...") + fmtMessage.Content = content + } + formatMessages = append(formatMessages, fmtMessage) + lastMessage = &message + } + claudeMessages := make([]ClaudeMessage, 0) - for _, message := range textRequest.Messages { + for _, message := range formatMessages { if message.Role == "system" { claudeRequest.System = message.StringContent() } else { @@ -122,7 +153,7 @@ func requestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR return &claudeRequest, nil } -func streamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) { +func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*dto.ChatCompletionsStreamResponse, *ClaudeUsage) { var response dto.ChatCompletionsStreamResponse var claudeUsage *ClaudeUsage response.Object = "chat.completion.chunk" @@ -149,6 +180,8 @@ func streamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (* choice.FinishReason = &finishReason } claudeUsage = &claudeResponse.Usage + } else if claudeResponse.Type == "message_stop" { + return nil, nil } } if claudeUsage == nil { @@ -158,7 +191,7 @@ func streamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (* return &response, claudeUsage } -func responseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse { +func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.OpenAITextResponse { choices := make([]dto.OpenAITextResponseChoice, 0) fullTextResponse := dto.OpenAITextResponse{ Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), @@ -242,7 +275,10 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c return true } - response, claudeUsage := streamResponseClaude2OpenAI(requestMode, &claudeResponse) + response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse) + if response == nil { + return true + } if requestMode == RequestModeCompletion { responseText += claudeResponse.Completion responseId = response.Id @@ -317,7 +353,7 @@ func claudeHandler(requestMode int, c *gin.Context, resp *http.Response, promptT StatusCode: resp.StatusCode, }, nil } - fullTextResponse := responseClaude2OpenAI(requestMode, &claudeResponse) + fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse) completionTokens, err, _ := service.CountTokenText(claudeResponse.Completion, model, false) if err != nil { return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil diff --git a/relay/channel/cohere/adaptor.go b/relay/channel/cohere/adaptor.go new file mode 100644 index 0000000..44b7f38 --- /dev/null +++ b/relay/channel/cohere/adaptor.go @@ -0,0 +1,52 @@ +package cohere + +import ( + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + "one-api/relay/channel" + relaycommon "one-api/relay/common" +) + +type Adaptor struct { +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { + channel.SetupApiRequestHeader(info, c, req) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { + return requestOpenAI2Cohere(*request), nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { + return channel.DoApiRequest(a, c, info, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { + if info.IsStream { + err, usage = cohereStreamHandler(c, resp, info.UpstreamModelName, info.PromptTokens) + } else { + err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/cohere/constant.go b/relay/channel/cohere/constant.go new file mode 100644 index 0000000..189d234 --- /dev/null +++ b/relay/channel/cohere/constant.go @@ -0,0 +1,7 @@ +package cohere + +var ModelList = []string{ + "command-r", "command-r-plus", "command-light", "command-light-nightly", "command", "command-nightly", +} + +var ChannelName = "cohere" diff --git a/relay/channel/cohere/dto.go b/relay/channel/cohere/dto.go new file mode 100644 index 0000000..958343c --- /dev/null +++ b/relay/channel/cohere/dto.go @@ -0,0 +1,44 @@ +package cohere + +type CohereRequest struct { + Model string `json:"model"` + ChatHistory []ChatHistory `json:"chat_history"` + Message string `json:"message"` + Stream bool `json:"stream"` + MaxTokens int64 `json:"max_tokens"` +} + +type ChatHistory struct { + Role string `json:"role"` + Message string `json:"message"` +} + +type CohereResponse struct { + IsFinished bool `json:"is_finished"` + EventType string `json:"event_type"` + Text string `json:"text,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + Response *CohereResponseResult `json:"response"` +} + +type CohereResponseResult struct { + ResponseId string `json:"response_id"` + FinishReason string `json:"finish_reason,omitempty"` + Text string `json:"text"` + Meta CohereMeta `json:"meta"` +} + +type CohereMeta struct { + //Tokens CohereTokens `json:"tokens"` + BilledUnits CohereBilledUnits `json:"billed_units"` +} + +type CohereBilledUnits struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +type CohereTokens struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} diff --git a/relay/channel/cohere/relay-cohere.go b/relay/channel/cohere/relay-cohere.go new file mode 100644 index 0000000..a21d4a9 --- /dev/null +++ b/relay/channel/cohere/relay-cohere.go @@ -0,0 +1,189 @@ +package cohere + +import ( + "bufio" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "one-api/dto" + "one-api/service" + "strings" +) + +func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest { + cohereReq := CohereRequest{ + Model: textRequest.Model, + ChatHistory: []ChatHistory{}, + Message: "", + Stream: textRequest.Stream, + MaxTokens: textRequest.GetMaxTokens(), + } + if cohereReq.MaxTokens == 0 { + cohereReq.MaxTokens = 4000 + } + for _, msg := range textRequest.Messages { + if msg.Role == "user" { + cohereReq.Message = msg.StringContent() + } else { + var role string + if msg.Role == "assistant" { + role = "CHATBOT" + } else if msg.Role == "system" { + role = "SYSTEM" + } else { + role = "USER" + } + cohereReq.ChatHistory = append(cohereReq.ChatHistory, ChatHistory{ + Role: role, + Message: msg.StringContent(), + }) + } + } + return &cohereReq +} + +func stopReasonCohere2OpenAI(reason string) string { + switch reason { + case "COMPLETE": + return "stop" + case "MAX_TOKENS": + return "max_tokens" + default: + return reason + } +} + +func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) + createdTime := common.GetTimestamp() + usage := &dto.Usage{} + responseText := "" + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + dataChan <- data + } + stopChan <- true + }() + service.SetEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + data = strings.TrimSuffix(data, "\r") + var cohereResp CohereResponse + err := json.Unmarshal([]byte(data), &cohereResp) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + var openaiResp dto.ChatCompletionsStreamResponse + openaiResp.Id = responseId + openaiResp.Created = createdTime + openaiResp.Object = "chat.completion.chunk" + openaiResp.Model = modelName + if cohereResp.IsFinished { + finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason) + openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{ + { + Delta: dto.ChatCompletionsStreamResponseChoiceDelta{}, + Index: 0, + FinishReason: &finishReason, + }, + } + if cohereResp.Response != nil { + usage.PromptTokens = cohereResp.Response.Meta.BilledUnits.InputTokens + usage.CompletionTokens = cohereResp.Response.Meta.BilledUnits.OutputTokens + } + } else { + openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{ + { + Delta: dto.ChatCompletionsStreamResponseChoiceDelta{ + Role: "assistant", + Content: cohereResp.Text, + }, + Index: 0, + }, + } + responseText += cohereResp.Text + } + jsonStr, err := json.Marshal(openaiResp) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + if usage.PromptTokens == 0 { + usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens) + } + return nil, usage +} + +func cohereHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + createdTime := common.GetTimestamp() + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var cohereResp CohereResponseResult + err = json.Unmarshal(responseBody, &cohereResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + usage := dto.Usage{} + usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens + usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens + usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens + + var openaiResp dto.TextResponse + openaiResp.Id = cohereResp.ResponseId + openaiResp.Created = createdTime + openaiResp.Object = "chat.completion" + openaiResp.Model = modelName + openaiResp.Usage = usage + + content, _ := json.Marshal(cohereResp.Text) + openaiResp.Choices = []dto.OpenAITextResponseChoice{ + { + Index: 0, + Message: dto.Message{Content: content, Role: "assistant"}, + FinishReason: stopReasonCohere2OpenAI(cohereResp.FinishReason), + }, + } + + jsonResponse, err := json.Marshal(openaiResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &usage +} diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 4e1fd33..8997889 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -52,7 +52,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { var responseText string - err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) + err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { if info.RelayMode == relayconstant.RelayModeEmbeddings { diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index 828ddea..f63fe57 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -9,6 +9,7 @@ import ( "net/http" "one-api/dto" "one-api/service" + "strings" ) func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest { @@ -41,7 +42,7 @@ func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) *OllamaRequest { func requestOpenAI2Embeddings(request dto.GeneralOpenAIRequest) *OllamaEmbeddingRequest { return &OllamaEmbeddingRequest{ Model: request.Model, - Prompt: request.Input, + Prompt: strings.Join(request.ParseInput(), " "), } } diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index cab6a64..a450c71 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -72,8 +72,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { var responseText string - err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode) + var toolCount int + err, responseText, toolCount = OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage.CompletionTokens += toolCount * 7 } else { err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index fe5cd48..5469ed7 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -16,9 +16,10 @@ import ( "time" ) -func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) { +func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string, int) { //checkSensitive := constant.ShouldCheckCompletionSensitive() var responseTextBuilder strings.Builder + toolCount := 0 scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { @@ -68,6 +69,15 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d if err == nil { for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Delta.Content) + if choice.Delta.ToolCalls != nil { + if len(choice.Delta.ToolCalls) > toolCount { + toolCount = len(choice.Delta.ToolCalls) + } + for _, tool := range choice.Delta.ToolCalls { + responseTextBuilder.WriteString(tool.Function.Name) + responseTextBuilder.WriteString(tool.Function.Arguments) + } + } } } } @@ -75,6 +85,15 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d for _, streamResponse := range streamResponses { for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Delta.Content) + if choice.Delta.ToolCalls != nil { + if len(choice.Delta.ToolCalls) > toolCount { + toolCount = len(choice.Delta.ToolCalls) + } + for _, tool := range choice.Delta.ToolCalls { + responseTextBuilder.WriteString(tool.Function.Name) + responseTextBuilder.WriteString(tool.Function.Arguments) + } + } } } } @@ -123,10 +142,10 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d }) err := resp.Body.Close() if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", toolCount } wg.Wait() - return nil, responseTextBuilder.String() + return nil, responseTextBuilder.String(), toolCount } func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index 24765ff..00d7710 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -46,7 +46,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { var responseText string - err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) + err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index 1b8866b..fe89ff4 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -47,8 +47,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { var responseText string - err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) + var toolCount int + err, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage.CompletionTokens += toolCount * 7 } else { err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index 8e6f67e..7f11ae2 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -18,6 +18,8 @@ const ( APITypeZhipu_v4 APITypeOllama APITypePerplexity + APITypeAws + APITypeCohere APITypeDummy // this one is only for count, do not add any channel after this ) @@ -49,6 +51,10 @@ func ChannelType2APIType(channelType int) int { apiType = APITypeOllama case common.ChannelTypePerplexity: apiType = APITypePerplexity + case common.ChannelTypeAws: + apiType = APITypeAws + case common.ChannelTypeCohere: + apiType = APITypeCohere } return apiType } diff --git a/relay/relay-audio.go b/relay/relay-audio.go index d4458ce..09ac2a0 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -20,15 +20,6 @@ import ( "time" ) -var availableVoices = []string{ - "alloy", - "echo", - "fable", - "onyx", - "nova", - "shimmer", -} - func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { tokenId := c.GetInt("token_id") channelType := c.GetInt("channel") @@ -59,9 +50,6 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { if audioRequest.Voice == "" { return service.OpenAIErrorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest) } - if !common.StringsContains(availableVoices, audioRequest.Voice) { - return service.OpenAIErrorWrapper(errors.New("voice must be one of "+strings.Join(availableVoices, ", ")), "invalid_field_value", http.StatusBadRequest) - } } var err error promptTokens := 0 @@ -100,6 +88,22 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { } } + succeed := false + defer func() { + if succeed { + return + } + if preConsumedQuota > 0 { + // we need to roll back the pre-consumed quota + defer func() { + go func() { + // negative means add quota back for token & user + returnPreConsumedQuota(c, tokenId, userQuota, preConsumedQuota) + }() + }() + } + }() + // map model name modelMapping := c.GetString("model_mapping") if modelMapping != "" { @@ -163,6 +167,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { if resp.StatusCode != http.StatusOK { return relaycommon.RelayErrorHandler(resp) } + succeed = true var audioResponse dto.AudioResponse diff --git a/relay/relay-image.go b/relay/relay-image.go index aabe4ba..ce072d1 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -34,7 +34,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC } if imageRequest.Model == "" { - imageRequest.Model = "dall-e-2" + imageRequest.Model = "dall-e-3" } if imageRequest.Size == "" { imageRequest.Size = "1024x1024" @@ -186,7 +186,11 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC } if quota != 0 { tokenName := c.GetString("token_name") - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + quality := "normal" + if imageRequest.Quality == "hd" { + quality = "hd" + } + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f, 大小 %s, 品质 %s", modelRatio, groupRatio, imageRequest.Size, quality) model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false) model.UpdateUserUsedQuotaAndRequestCount(userId, quota) channelId := c.GetInt("channel_id") diff --git a/relay/relay-mj.go b/relay/relay-mj.go index 7b3a4e2..27b4c6d 100644 --- a/relay/relay-mj.go +++ b/relay/relay-mj.go @@ -110,11 +110,13 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo midjourneyTask.StartTime = originTask.StartTime midjourneyTask.FinishTime = originTask.FinishTime midjourneyTask.ImageUrl = "" - if originTask.ImageUrl != "" { + if originTask.ImageUrl != "" && constant.MjForwardUrlEnabled { midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId if originTask.Status != "SUCCESS" { midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10) } + } else { + midjourneyTask.ImageUrl = originTask.ImageUrl } midjourneyTask.Status = originTask.Status midjourneyTask.FailReason = originTask.FailReason diff --git a/relay/relay-text.go b/relay/relay-text.go index 71a47c2..e9aa7bb 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -154,20 +154,28 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { requestBody = bytes.NewBuffer(jsonData) } + statusCodeMappingStr := c.GetString("status_code_mapping") resp, err := adaptor.DoRequest(c, relayInfo, requestBody) if err != nil { return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } - relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") - if resp.StatusCode != http.StatusOK { - returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) - return service.RelayErrorHandler(resp) + if resp != nil { + relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") + if resp.StatusCode != http.StatusOK { + returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) + openaiErr := service.RelayErrorHandler(resp) + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr + } } usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) if openaiErr != nil { returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice) @@ -181,7 +189,7 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re checkSensitive := constant.ShouldCheckPromptSensitive() switch info.RelayMode { case relayconstant.RelayModeChatCompletions: - promptTokens, err, sensitiveTrigger = service.CountTokenMessages(textRequest.Messages, textRequest.Model, checkSensitive) + promptTokens, err, sensitiveTrigger = service.CountTokenChatRequest(*textRequest, textRequest.Model, checkSensitive) case relayconstant.RelayModeCompletions: promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive) case relayconstant.RelayModeModerations: diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index e6afab5..01e9cec 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -3,8 +3,10 @@ package relay import ( "one-api/relay/channel" "one-api/relay/channel/ali" + "one-api/relay/channel/aws" "one-api/relay/channel/baidu" "one-api/relay/channel/claude" + "one-api/relay/channel/cohere" "one-api/relay/channel/gemini" "one-api/relay/channel/ollama" "one-api/relay/channel/openai" @@ -45,6 +47,10 @@ func GetAdaptor(apiType int) channel.Adaptor { return &ollama.Adaptor{} case constant.APITypePerplexity: return &perplexity.Adaptor{} + case constant.APITypeAws: + return &aws.Adaptor{} + case constant.APITypeCohere: + return &cohere.Adaptor{} } return nil } diff --git a/service/error.go b/service/error.go index 39eb0f9..4b00f37 100644 --- a/service/error.go +++ b/service/error.go @@ -86,3 +86,22 @@ func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorW } return } + +func ResetStatusCode(openaiErr *dto.OpenAIErrorWithStatusCode, statusCodeMappingStr string) { + if statusCodeMappingStr == "" || statusCodeMappingStr == "{}" { + return + } + statusCodeMapping := make(map[string]string) + err := json.Unmarshal([]byte(statusCodeMappingStr), &statusCodeMapping) + if err != nil { + return + } + if openaiErr.StatusCode == http.StatusOK { + return + } + codeStr := strconv.Itoa(openaiErr.StatusCode) + if _, ok := statusCodeMapping[codeStr]; ok { + intCode, _ := strconv.Atoi(statusCodeMapping[codeStr]) + openaiErr.StatusCode = intCode + } +} diff --git a/service/token_counter.go b/service/token_counter.go index 5255c80..18fc5a3 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -116,6 +116,41 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) { return tiles*170 + 85, nil } +func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, checkSensitive bool) (int, error, bool) { + tkm := 0 + msgTokens, err, b := CountTokenMessages(request.Messages, model, checkSensitive) + if err != nil { + return 0, err, b + } + tkm += msgTokens + if request.Tools != nil { + toolsData, _ := json.Marshal(request.Tools) + var openaiTools []dto.OpenAITools + err := json.Unmarshal(toolsData, &openaiTools) + if err != nil { + return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error())), false + } + countStr := "" + for _, tool := range openaiTools { + countStr = tool.Function.Name + if tool.Function.Description != "" { + countStr += tool.Function.Description + } + if tool.Function.Parameters != nil { + countStr += fmt.Sprintf("%v", tool.Function.Parameters) + } + } + toolTokens, err, _ := CountTokenInput(countStr, model, false) + if err != nil { + return 0, err, false + } + tkm += 8 + tkm += toolTokens + } + + return tkm, nil, false +} + func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error, bool) { //recover when panic tokenEncoder := getTokenEncoder(model) @@ -138,48 +173,31 @@ func CountTokenMessages(messages []dto.Message, model string, checkSensitive boo tokenNum += tokensPerMessage tokenNum += getTokenNum(tokenEncoder, message.Role) if len(message.Content) > 0 { - var arrayContent []dto.MediaMessage - if err := json.Unmarshal(message.Content, &arrayContent); err != nil { - var stringContent string - if err := json.Unmarshal(message.Content, &stringContent); err != nil { - return 0, err, false - } else { - if checkSensitive { - contains, words := SensitiveWordContains(stringContent) - if contains { - err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", ")) - return 0, err, true - } - } - tokenNum += getTokenNum(tokenEncoder, stringContent) - if message.Name != nil { - tokenNum += tokensPerName - tokenNum += getTokenNum(tokenEncoder, *message.Name) + if message.IsStringContent() { + stringContent := message.StringContent() + if checkSensitive { + contains, words := SensitiveWordContains(stringContent) + if contains { + err := fmt.Errorf("message contains sensitive words: [%s]", strings.Join(words, ", ")) + return 0, err, true } } + tokenNum += getTokenNum(tokenEncoder, stringContent) + if message.Name != nil { + tokenNum += tokensPerName + tokenNum += getTokenNum(tokenEncoder, *message.Name) + } } else { + var err error + arrayContent := message.ParseContent() for _, m := range arrayContent { if m.Type == "image_url" { var imageTokenNum int if model == "glm-4v" { imageTokenNum = 1047 } else { - if str, ok := m.ImageUrl.(string); ok { - imageTokenNum, err = getImageToken(&dto.MessageImageUrl{Url: str, Detail: "auto"}) - } else { - imageUrlMap := m.ImageUrl.(map[string]interface{}) - detail, ok := imageUrlMap["detail"] - if ok { - imageUrlMap["detail"] = detail.(string) - } else { - imageUrlMap["detail"] = "auto" - } - imageUrl := dto.MessageImageUrl{ - Url: imageUrlMap["url"].(string), - Detail: imageUrlMap["detail"].(string), - } - imageTokenNum, err = getImageToken(&imageUrl) - } + imageUrl := m.ImageUrl.(dto.MessageImageUrl) + imageTokenNum, err = getImageToken(&imageUrl) if err != nil { return 0, err, false } @@ -211,6 +229,23 @@ func CountTokenInput(input any, model string, check bool) (int, error, bool) { return CountTokenInput(fmt.Sprintf("%v", input), model, check) } +func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int { + tokens := 0 + for _, message := range messages { + tkm, _, _ := CountTokenInput(message.Delta.Content, model, false) + tokens += tkm + if message.Delta.ToolCalls != nil { + for _, tool := range message.Delta.ToolCalls { + tkm, _, _ := CountTokenInput(tool.Function.Name, model, false) + tokens += tkm + tkm, _, _ = CountTokenInput(tool.Function.Arguments, model, false) + tokens += tkm + } + } + } + return tokens +} + func CountAudioToken(text string, model string, check bool) (int, error, bool) { if strings.HasPrefix(model, "tts") { contains, words := SensitiveWordContains(text) diff --git a/web/src/components/LoginForm.js b/web/src/components/LoginForm.js index b1179e2..84a9a71 100644 --- a/web/src/components/LoginForm.js +++ b/web/src/components/LoginForm.js @@ -208,7 +208,6 @@ const LoginForm = () => { {status.github_oauth || - status.linuxdo_oauth || status.wechat_login || status.telegram_oauth ? ( <> @@ -226,7 +225,6 @@ const LoginForm = () => { } - style={{ margin: '0 5px' }} onClick={() => onGitHubOAuthClicked(status.github_client_id) } diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js index 7a112e5..44b6214 100644 --- a/web/src/components/OperationSetting.js +++ b/web/src/components/OperationSetting.js @@ -39,6 +39,7 @@ const OperationSetting = () => { SensitiveWords: '', MjNotifyEnabled: '', MjModeClearEnabled: '', + MjForwardUrlEnabled: '', DrawingEnabled: '', DataExportEnabled: '', DataExportDefaultTime: 'hour', @@ -322,6 +323,12 @@ const OperationSetting = () => { name='MjNotifyEnabled' onChange={handleInputChange} /> +