From c88f3741e64ae9c945c4ef54f775f264b47f951f Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Thu, 11 Jul 2024 18:44:45 +0800 Subject: [PATCH 01/34] feat: support claude stop_sequences --- relay/channel/claude/relay-claude.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 9457f1e..945b20d 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -72,6 +72,19 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR if claudeRequest.MaxTokens == 0 { claudeRequest.MaxTokens = 4096 } + if textRequest.Stop != nil { + // stop maybe string/array string, convert to array string + switch textRequest.Stop.(type) { + case string: + claudeRequest.StopSequences = []string{textRequest.Stop.(string)} + case []interface{}: + stopSequences := make([]string, 0) + for _, stop := range textRequest.Stop.([]interface{}) { + stopSequences = append(stopSequences, stop.(string)) + } + claudeRequest.StopSequences = stopSequences + } + } formatMessages := make([]dto.Message, 0) var lastMessage *dto.Message for i, message := range textRequest.Messages { From 7b36a2b885f5264ae1d00c2a99d9fec52f0eb1c3 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Sat, 13 Jul 2024 19:55:22 +0800 Subject: [PATCH 02/34] feat: support cloudflare worker ai --- common/constants.go | 2 + controller/channel-test.go | 37 +++--- dto/text_request.go | 4 +- middleware/distributor.go | 4 +- relay/channel/cloudflare/adaptor.go | 76 ++++++++++++ relay/channel/cloudflare/constant.go | 38 ++++++ relay/channel/cloudflare/model.go | 13 +++ relay/channel/cloudflare/relay_cloudflare.go | 115 +++++++++++++++++++ relay/channel/cohere/dto.go | 2 +- relay/common/relay_info.go | 3 +- relay/constant/api_type.go | 3 + relay/relay_adaptor.go | 3 + service/{sse.go => relay.go} | 5 + web/src/constants/channel.constants.js | 1 + web/src/pages/Channel/EditChannel.js | 18 +++ 15 files changed, 296 insertions(+), 28 deletions(-) create mode 100644 relay/channel/cloudflare/adaptor.go create mode 100644 relay/channel/cloudflare/constant.go create mode 100644 relay/channel/cloudflare/model.go create mode 100644 relay/channel/cloudflare/relay_cloudflare.go rename service/{sse.go => relay.go} (87%) diff --git a/common/constants.go b/common/constants.go index 66cc10d..97e8583 100644 --- a/common/constants.go +++ b/common/constants.go @@ -212,6 +212,7 @@ const ( ChannelTypeSunoAPI = 36 ChannelTypeDify = 37 ChannelTypeJina = 38 + ChannelCloudflare = 39 ChannelTypeDummy // this one is only for count, do not add any channel after this @@ -257,4 +258,5 @@ var ChannelBaseURLs = []string{ "", //36 "", //37 "https://api.jina.ai", //38 + "https://api.cloudflare.com", //39 } diff --git a/controller/channel-test.go b/controller/channel-test.go index 000d7f2..268dac2 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -12,6 +12,7 @@ import ( "net/url" "one-api/common" "one-api/dto" + "one-api/middleware" "one-api/model" "one-api/relay" relaycommon "one-api/relay/common" @@ -40,29 +41,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr Body: nil, Header: make(http.Header), } - c.Request.Header.Set("Authorization", "Bearer "+channel.Key) - c.Request.Header.Set("Content-Type", "application/json") - c.Set("channel", channel.Type) - c.Set("base_url", channel.GetBaseURL()) - switch channel.Type { - case common.ChannelTypeAzure: - c.Set("api_version", channel.Other) - case common.ChannelTypeXunfei: - c.Set("api_version", channel.Other) - //case common.ChannelTypeAIProxyLibrary: - // c.Set("library_id", channel.Other) - case common.ChannelTypeGemini: - c.Set("api_version", channel.Other) - case common.ChannelTypeAli: - c.Set("plugin", channel.Other) - } - meta := relaycommon.GenRelayInfo(c) - apiType, _ := constant.ChannelType2APIType(channel.Type) - adaptor := relay.GetAdaptor(apiType) - if adaptor == nil { - return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil - } if testModel == "" { if channel.TestModel != nil && *channel.TestModel != "" { testModel = *channel.TestModel @@ -88,6 +67,20 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr } } + c.Request.Header.Set("Authorization", "Bearer "+channel.Key) + c.Request.Header.Set("Content-Type", "application/json") + c.Set("channel", channel.Type) + c.Set("base_url", channel.GetBaseURL()) + + middleware.SetupContextForSelectedChannel(c, channel, testModel) + + meta := relaycommon.GenRelayInfo(c) + apiType, _ := constant.ChannelType2APIType(channel.Type) + adaptor := relay.GetAdaptor(apiType) + if adaptor == nil { + return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil + } + request := buildTestRequest() request.Model = testModel meta.UpstreamModelName = testModel diff --git a/dto/text_request.go b/dto/text_request.go index e12c9b4..ed36988 100644 --- a/dto/text_request.go +++ b/dto/text_request.go @@ -48,8 +48,8 @@ type StreamOptions struct { IncludeUsage bool `json:"include_usage,omitempty"` } -func (r GeneralOpenAIRequest) GetMaxTokens() int64 { - return int64(r.MaxTokens) +func (r GeneralOpenAIRequest) GetMaxTokens() int { + return int(r.MaxTokens) } func (r GeneralOpenAIRequest) ParseInput() []string { diff --git a/middleware/distributor.go b/middleware/distributor.go index 61361e6..9f75207 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -198,11 +198,11 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("api_version", channel.Other) case common.ChannelTypeXunfei: c.Set("api_version", channel.Other) - //case common.ChannelTypeAIProxyLibrary: - // c.Set("library_id", channel.Other) case common.ChannelTypeGemini: c.Set("api_version", channel.Other) case common.ChannelTypeAli: c.Set("plugin", channel.Other) + case common.ChannelCloudflare: + c.Set("api_version", channel.Other) } } diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go new file mode 100644 index 0000000..571f50c --- /dev/null +++ b/relay/channel/cloudflare/adaptor.go @@ -0,0 +1,76 @@ +package cloudflare + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + "one-api/relay/channel" + relaycommon "one-api/relay/common" + "one-api/relay/constant" +) + +type Adaptor struct { +} + +func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + switch info.RelayMode { + case constant.RelayModeChatCompletions: + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.BaseUrl, info.ApiVersion), nil + case constant.RelayModeEmbeddings: + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.BaseUrl, info.ApiVersion), nil + default: + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.BaseUrl, info.ApiVersion, info.UpstreamModelName), nil + } +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { + channel.SetupApiRequestHeader(info, c, req) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + switch relayMode { + case constant.RelayModeCompletions: + return convertCf2CompletionsRequest(*request), nil + default: + return request, nil + } +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { + return channel.DoApiRequest(a, c, info, requestBody) +} + +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return request, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { + if info.IsStream { + err, usage = cfStreamHandler(c, resp, info) + } else { + err, usage = cfHandler(c, resp, info) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/cloudflare/constant.go b/relay/channel/cloudflare/constant.go new file mode 100644 index 0000000..a874685 --- /dev/null +++ b/relay/channel/cloudflare/constant.go @@ -0,0 +1,38 @@ +package cloudflare + +var ModelList = []string{ + "@cf/meta/llama-2-7b-chat-fp16", + "@cf/meta/llama-2-7b-chat-int8", + "@cf/mistral/mistral-7b-instruct-v0.1", + "@hf/thebloke/deepseek-coder-6.7b-base-awq", + "@hf/thebloke/deepseek-coder-6.7b-instruct-awq", + "@cf/deepseek-ai/deepseek-math-7b-base", + "@cf/deepseek-ai/deepseek-math-7b-instruct", + "@cf/thebloke/discolm-german-7b-v1-awq", + "@cf/tiiuae/falcon-7b-instruct", + "@cf/google/gemma-2b-it-lora", + "@hf/google/gemma-7b-it", + "@cf/google/gemma-7b-it-lora", + "@hf/nousresearch/hermes-2-pro-mistral-7b", + "@hf/thebloke/llama-2-13b-chat-awq", + "@cf/meta-llama/llama-2-7b-chat-hf-lora", + "@cf/meta/llama-3-8b-instruct", + "@hf/thebloke/llamaguard-7b-awq", + "@hf/thebloke/mistral-7b-instruct-v0.1-awq", + "@hf/mistralai/mistral-7b-instruct-v0.2", + "@cf/mistral/mistral-7b-instruct-v0.2-lora", + "@hf/thebloke/neural-chat-7b-v3-1-awq", + "@cf/openchat/openchat-3.5-0106", + "@hf/thebloke/openhermes-2.5-mistral-7b-awq", + "@cf/microsoft/phi-2", + "@cf/qwen/qwen1.5-0.5b-chat", + "@cf/qwen/qwen1.5-1.8b-chat", + "@cf/qwen/qwen1.5-14b-chat-awq", + "@cf/qwen/qwen1.5-7b-chat-awq", + "@cf/defog/sqlcoder-7b-2", + "@hf/nexusflow/starling-lm-7b-beta", + "@cf/tinyllama/tinyllama-1.1b-chat-v1.0", + "@hf/thebloke/zephyr-7b-beta-awq", +} + +var ChannelName = "cloudflare" diff --git a/relay/channel/cloudflare/model.go b/relay/channel/cloudflare/model.go new file mode 100644 index 0000000..c870813 --- /dev/null +++ b/relay/channel/cloudflare/model.go @@ -0,0 +1,13 @@ +package cloudflare + +import "one-api/dto" + +type CfRequest struct { + Messages []dto.Message `json:"messages,omitempty"` + Lora string `json:"lora,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Prompt string `json:"prompt,omitempty"` + Raw bool `json:"raw,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` +} diff --git a/relay/channel/cloudflare/relay_cloudflare.go b/relay/channel/cloudflare/relay_cloudflare.go new file mode 100644 index 0000000..94a7ea0 --- /dev/null +++ b/relay/channel/cloudflare/relay_cloudflare.go @@ -0,0 +1,115 @@ +package cloudflare + +import ( + "bufio" + "encoding/json" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/service" + "strings" +) + +func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest { + p, _ := textRequest.Prompt.(string) + return &CfRequest{ + Prompt: p, + MaxTokens: textRequest.GetMaxTokens(), + Stream: textRequest.Stream, + Temperature: textRequest.Temperature, + } +} + +func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + scanner := bufio.NewScanner(resp.Body) + scanner.Split(bufio.ScanLines) + + service.SetEventStreamHeaders(c) + id := service.GetResponseID(c) + var responseText string + + for scanner.Scan() { + data := scanner.Text() + if len(data) < len("data: ") { + continue + } + data = strings.TrimPrefix(data, "data: ") + data = strings.TrimSuffix(data, "\r") + + if data == "[DONE]" { + break + } + + var response dto.ChatCompletionsStreamResponse + err := json.Unmarshal([]byte(data), &response) + if err != nil { + common.LogError(c, "error_unmarshalling_stream_response: "+err.Error()) + continue + } + for _, choice := range response.Choices { + choice.Delta.Role = "assistant" + responseText += choice.Delta.GetContentString() + } + response.Id = id + response.Model = info.UpstreamModelName + err = service.ObjectData(c, response) + if err != nil { + common.LogError(c, "error_rendering_stream_response: "+err.Error()) + } + } + + if err := scanner.Err(); err != nil { + common.LogError(c, "error_scanning_stream_response: "+err.Error()) + } + usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + if info.ShouldIncludeUsage { + response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage) + err := service.ObjectData(c, response) + if err != nil { + common.LogError(c, "error_rendering_final_usage_response: "+err.Error()) + } + } + service.Done(c) + + err := resp.Body.Close() + if err != nil { + common.LogError(c, "close_response_body_failed: "+err.Error()) + } + + return nil, usage +} + +func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var response dto.TextResponse + err = json.Unmarshal(responseBody, &response) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + response.Model = info.UpstreamModelName + var responseText string + for _, choice := range response.Choices { + responseText += choice.Message.StringContent() + } + usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + response.Usage = *usage + response.Id = service.GetResponseID(c) + jsonResponse, err := json.Marshal(response) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, _ = c.Writer.Write(jsonResponse) + return nil, usage +} diff --git a/relay/channel/cohere/dto.go b/relay/channel/cohere/dto.go index fc6c445..b2c2739 100644 --- a/relay/channel/cohere/dto.go +++ b/relay/channel/cohere/dto.go @@ -7,7 +7,7 @@ type CohereRequest struct { ChatHistory []ChatHistory `json:"chat_history"` Message string `json:"message"` Stream bool `json:"stream"` - MaxTokens int64 `json:"max_tokens"` + MaxTokens int `json:"max_tokens"` } type ChatHistory struct { diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 42c8381..e07434a 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -68,7 +68,8 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { info.ApiVersion = GetAPIVersion(c) } if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic || - info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini { + info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini || + info.ChannelType == common.ChannelCloudflare { info.SupportStreamOptions = true } return info diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index 0ce2657..6bd93c4 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -22,6 +22,7 @@ const ( APITypeCohere APITypeDify APITypeJina + APITypeCloudflare APITypeDummy // this one is only for count, do not add any channel after this ) @@ -63,6 +64,8 @@ func ChannelType2APIType(channelType int) (int, bool) { apiType = APITypeDify case common.ChannelTypeJina: apiType = APITypeJina + case common.ChannelCloudflare: + apiType = APITypeCloudflare } if apiType == -1 { return APITypeOpenAI, false diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 8998540..4c0aef1 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -7,6 +7,7 @@ import ( "one-api/relay/channel/aws" "one-api/relay/channel/baidu" "one-api/relay/channel/claude" + "one-api/relay/channel/cloudflare" "one-api/relay/channel/cohere" "one-api/relay/channel/dify" "one-api/relay/channel/gemini" @@ -59,6 +60,8 @@ func GetAdaptor(apiType int) channel.Adaptor { return &dify.Adaptor{} case constant.APITypeJina: return &jina.Adaptor{} + case constant.APITypeCloudflare: + return &cloudflare.Adaptor{} } return nil } diff --git a/service/sse.go b/service/relay.go similarity index 87% rename from service/sse.go rename to service/relay.go index 2d531a4..22f9ce3 100644 --- a/service/sse.go +++ b/service/relay.go @@ -35,3 +35,8 @@ func ObjectData(c *gin.Context, object interface{}) error { func Done(c *gin.Context) { StringData(c, "[DONE]") } + +func GetResponseID(c *gin.Context) string { + logID := c.GetString("X-Oneapi-Request-Id") + return fmt.Sprintf("chatcmpl-%s", logID) +} diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index ff1d281..88614b0 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -99,6 +99,7 @@ export const CHANNEL_OPTIONS = [ color: 'orange', label: 'Google PaLM2', }, + { key: 39, text: 'Cloudflare', value: 39, color: 'grey', label: 'Cloudflare' }, { key: 25, text: 'Moonshot', value: 25, color: 'green', label: 'Moonshot' }, { key: 19, text: '360 智脑', value: 19, color: 'blue', label: '360 智脑' }, { key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' }, diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index aec3768..900fdf3 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -605,6 +605,24 @@ const EditChannel = (props) => { /> )} + {inputs.type === 39 && ( + <> +
+ Account ID: +
+ { + handleInputChange('other', value); + }} + value={inputs.other} + autoComplete='new-password' + /> + + )}
模型:
From e67aa370bc14c8c4b89419ce11630533f27f0b25 Mon Sep 17 00:00:00 2001 From: FENG Date: Sun, 14 Jul 2024 00:14:07 +0800 Subject: [PATCH 03/34] fix: channel timeout auto-ban and auto-enable --- controller/channel-test.go | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index 268dac2..6f82cd7 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -231,27 +231,33 @@ func testAllChannels(notify bool) error { err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) ban = true } + + // request error disables the channel if openaiErr != nil { err = errors.New(fmt.Sprintf("type %s, code %v, message %s", openaiErr.Type, openaiErr.Code, openaiErr.Message)) - ban = true - } - // parse *int to bool - if channel.AutoBan != nil && *channel.AutoBan == 0 { - ban = false - } - if openaiErr != nil { openAiErrWithStatus := dto.OpenAIErrorWithStatusCode{ StatusCode: -1, Error: *openaiErr, LocalError: false, } - if isChannelEnabled && service.ShouldDisableChannel(channel.Type, &openAiErrWithStatus) && ban { - service.DisableChannel(channel.Id, channel.Name, err.Error()) - } - if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) { - service.EnableChannel(channel.Id, channel.Name) - } + ban = service.ShouldDisableChannel(channel.Type, &openAiErrWithStatus) } + + // parse *int to bool + if channel.AutoBan != nil && *channel.AutoBan == 0 { + ban = false + } + + // disable channel + if ban && isChannelEnabled { + service.DisableChannel(channel.Id, channel.Name, err.Error()) + } + + // enable channel + if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) { + service.EnableChannel(channel.Id, channel.Name) + } + channel.UpdateResponseTime(milliseconds) time.Sleep(common.RequestInterval) } From d55cb35c1c45b3b7f89239f6ea43e0ef39088a56 Mon Sep 17 00:00:00 2001 From: FENG Date: Sun, 14 Jul 2024 01:21:05 +0800 Subject: [PATCH 04/34] fix: http code is not properly disabled --- controller/channel-test.go | 25 ++++++++++--------------- service/channel.go | 4 ++-- 2 files changed, 12 insertions(+), 17 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index 6f82cd7..2174ff1 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -25,7 +25,7 @@ import ( "github.com/gin-gonic/gin" ) -func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) { +func testChannel(channel *model.Channel, testModel string) (err error, openAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) { tik := time.Now() if channel.Type == common.ChannelTypeMidjourney { return errors.New("midjourney channel test is not supported"), nil @@ -58,8 +58,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr modelMap := make(map[string]string) err := json.Unmarshal([]byte(modelMapping), &modelMap) if err != nil { - openaiErr := service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError).Error - return err, &openaiErr + return err, service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) } if modelMap[testModel] != "" { testModel = modelMap[testModel] @@ -104,11 +103,11 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr } if resp != nil && resp.StatusCode != http.StatusOK { err := relaycommon.RelayErrorHandler(resp) - return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error + return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), err } usage, respErr := adaptor.DoResponse(c, resp, meta) if respErr != nil { - return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error + return fmt.Errorf("%s", respErr.Error.Message), respErr } if usage == nil { return errors.New("usage is nil"), nil @@ -222,7 +221,7 @@ func testAllChannels(notify bool) error { for _, channel := range channels { isChannelEnabled := channel.Status == common.ChannelStatusEnabled tik := time.Now() - err, openaiErr := testChannel(channel, "") + err, openaiWithStatusErr := testChannel(channel, "") tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() @@ -233,14 +232,10 @@ func testAllChannels(notify bool) error { } // request error disables the channel - if openaiErr != nil { - err = errors.New(fmt.Sprintf("type %s, code %v, message %s", openaiErr.Type, openaiErr.Code, openaiErr.Message)) - openAiErrWithStatus := dto.OpenAIErrorWithStatusCode{ - StatusCode: -1, - Error: *openaiErr, - LocalError: false, - } - ban = service.ShouldDisableChannel(channel.Type, &openAiErrWithStatus) + if openaiWithStatusErr != nil { + oaiErr := openaiWithStatusErr.Error + err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message)) + ban = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr) } // parse *int to bool @@ -254,7 +249,7 @@ func testAllChannels(notify bool) error { } // enable channel - if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr, channel.Status) { + if !isChannelEnabled && service.ShouldEnableChannel(err, openaiWithStatusErr, channel.Status) { service.EnableChannel(channel.Id, channel.Name) } diff --git a/service/channel.go b/service/channel.go index 76be271..5716a6d 100644 --- a/service/channel.go +++ b/service/channel.go @@ -74,14 +74,14 @@ func ShouldDisableChannel(channelType int, err *relaymodel.OpenAIErrorWithStatus return false } -func ShouldEnableChannel(err error, openAIErr *relaymodel.OpenAIError, status int) bool { +func ShouldEnableChannel(err error, openaiWithStatusErr *relaymodel.OpenAIErrorWithStatusCode, status int) bool { if !common.AutomaticEnableChannelEnabled { return false } if err != nil { return false } - if openAIErr != nil { + if openaiWithStatusErr != nil { return false } if status != common.ChannelStatusAutoDisabled { From 0f687aab9a39687d80b40aa61cd116458028582e Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Mon, 15 Jul 2024 16:05:30 +0800 Subject: [PATCH 05/34] fix: azure stream options --- controller/channel-test.go | 2 +- relay/channel/adapter.go | 2 +- relay/channel/ali/adaptor.go | 4 ++-- relay/channel/aws/adaptor.go | 2 +- relay/channel/baidu/adaptor.go | 4 ++-- relay/channel/claude/adaptor.go | 2 +- relay/channel/cloudflare/adaptor.go | 4 ++-- relay/channel/cloudflare/relay_cloudflare.go | 6 ++++++ relay/channel/cohere/adaptor.go | 2 +- relay/channel/dify/adaptor.go | 2 +- relay/channel/gemini/adaptor.go | 2 +- relay/channel/jina/adaptor.go | 2 +- relay/channel/ollama/adaptor.go | 4 ++-- relay/channel/openai/adaptor.go | 5 ++++- relay/channel/palm/adaptor.go | 2 +- relay/channel/perplexity/adaptor.go | 2 +- relay/channel/tencent/adaptor.go | 2 +- relay/channel/xunfei/adaptor.go | 2 +- relay/channel/zhipu/adaptor.go | 2 +- relay/channel/zhipu_4v/adaptor.go | 2 +- relay/relay-text.go | 2 +- 21 files changed, 33 insertions(+), 24 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index 2174ff1..4ad7457 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -87,7 +87,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr adaptor.Init(meta, *request) - convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request) + convertedRequest, err := adaptor.ConvertRequest(c, meta, request) if err != nil { return err, nil } diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index e222a70..7064b88 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -14,7 +14,7 @@ type Adaptor interface { InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) GetRequestURL(info *relaycommon.RelayInfo) (string, error) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error - ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) + ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index fbaf546..e03d29f 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -42,11 +42,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } - switch relayMode { + switch info.RelayMode { case constant.RelayModeEmbeddings: baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request) return baiduEmbeddingRequest, nil diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go index 6452392..8214777 100644 --- a/relay/channel/aws/adaptor.go +++ b/relay/channel/aws/adaptor.go @@ -41,7 +41,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index 17f5384..40a0696 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -99,11 +99,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } - switch relayMode { + switch info.RelayMode { case constant.RelayModeEmbeddings: baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(*request) return baiduEmbeddingRequest, nil diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index 4623318..8e4c75d 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -53,7 +53,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go index 571f50c..53b5a91 100644 --- a/relay/channel/cloudflare/adaptor.go +++ b/relay/channel/cloudflare/adaptor.go @@ -38,11 +38,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } - switch relayMode { + switch info.RelayMode { case constant.RelayModeCompletions: return convertCf2CompletionsRequest(*request), nil default: diff --git a/relay/channel/cloudflare/relay_cloudflare.go b/relay/channel/cloudflare/relay_cloudflare.go index 94a7ea0..d9319ef 100644 --- a/relay/channel/cloudflare/relay_cloudflare.go +++ b/relay/channel/cloudflare/relay_cloudflare.go @@ -11,6 +11,7 @@ import ( relaycommon "one-api/relay/common" "one-api/service" "strings" + "time" ) func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest { @@ -30,6 +31,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela service.SetEventStreamHeaders(c) id := service.GetResponseID(c) var responseText string + isFirst := true for scanner.Scan() { data := scanner.Text() @@ -56,6 +58,10 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela response.Id = id response.Model = info.UpstreamModelName err = service.ObjectData(c, response) + if isFirst { + isFirst = false + info.FirstResponseTime = time.Now() + } if err != nil { common.LogError(c, "error_rendering_stream_response: "+err.Error()) } diff --git a/relay/channel/cohere/adaptor.go b/relay/channel/cohere/adaptor.go index b5f3521..84243aa 100644 --- a/relay/channel/cohere/adaptor.go +++ b/relay/channel/cohere/adaptor.go @@ -34,7 +34,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { return requestOpenAI2Cohere(*request), nil } diff --git a/relay/channel/dify/adaptor.go b/relay/channel/dify/adaptor.go index a54b95b..8dbe8b8 100644 --- a/relay/channel/dify/adaptor.go +++ b/relay/channel/dify/adaptor.go @@ -32,7 +32,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 9755163..f223fbf 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -51,7 +51,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go index 48616b6..d0a379a 100644 --- a/relay/channel/jina/adaptor.go +++ b/relay/channel/jina/adaptor.go @@ -36,7 +36,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { return request, nil } diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 76de148..b0550ca 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -36,11 +36,11 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } - switch relayMode { + switch info.RelayMode { case relayconstant.RelayModeEmbeddings: return requestOpenAI2Embeddings(*request), nil default: diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 00f01fd..e327027 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -74,10 +74,13 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } + if info.ChannelType != common.ChannelTypeOpenAI { + request.StreamOptions = nil + } return request, nil } diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index 8f6dd0a..51d1399 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -33,7 +33,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index 3c65b2d..a220076 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -34,7 +34,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index d79330e..3dd9115 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -47,7 +47,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go index 9852aa1..adb054e 100644 --- a/relay/channel/xunfei/adaptor.go +++ b/relay/channel/xunfei/adaptor.go @@ -33,7 +33,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index 0893a83..09345ca 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -37,7 +37,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index 508861f..9b8bd49 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -35,7 +35,7 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *re return nil } -func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { +func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { if request == nil { return nil, errors.New("request is nil") } diff --git a/relay/relay-text.go b/relay/relay-text.go index 6e74fbb..ef169fa 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -153,7 +153,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { adaptor.Init(relayInfo, *textRequest) var requestBody io.Reader - convertedRequest, err := adaptor.ConvertRequest(c, relayInfo.RelayMode, textRequest) + convertedRequest, err := adaptor.ConvertRequest(c, relayInfo, textRequest) if err != nil { return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) } From 7029065892142e8d240477dd888f297ba018ac94 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Mon, 15 Jul 2024 18:04:05 +0800 Subject: [PATCH 06/34] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84=E6=B5=81?= =?UTF-8?q?=E6=A8=A1=E5=BC=8F=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dto/text_response.go | 15 +- relay/channel/openai/adaptor.go | 9 +- relay/channel/openai/relay-openai.go | 209 +++++++++++++-------------- service/usage_helpr.go | 4 + 4 files changed, 114 insertions(+), 123 deletions(-) diff --git a/dto/text_response.go b/dto/text_response.go index 3310d02..e1f0cc0 100644 --- a/dto/text_response.go +++ b/dto/text_response.go @@ -66,10 +66,6 @@ type ChatCompletionsStreamResponseChoiceDelta struct { ToolCalls []ToolCall `json:"tool_calls,omitempty"` } -func (c *ChatCompletionsStreamResponseChoiceDelta) IsEmpty() bool { - return c.Content == nil && len(c.ToolCalls) == 0 -} - func (c *ChatCompletionsStreamResponseChoiceDelta) SetContentString(s string) { c.Content = &s } @@ -105,6 +101,17 @@ type ChatCompletionsStreamResponse struct { Usage *Usage `json:"usage"` } +func (c *ChatCompletionsStreamResponse) GetSystemFingerprint() string { + if c.SystemFingerprint == nil { + return "" + } + return *c.SystemFingerprint +} + +func (c *ChatCompletionsStreamResponse) SetSystemFingerprint(s string) { + c.SystemFingerprint = &s +} + type ChatCompletionsStreamResponseSimple struct { Choices []ChatCompletionsStreamResponseChoice `json:"choices"` Usage *Usage `json:"usage"` diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index e327027..688dedc 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -14,7 +14,6 @@ import ( "one-api/relay/channel/minimax" "one-api/relay/channel/moonshot" relaycommon "one-api/relay/common" - "one-api/service" "strings" ) @@ -90,13 +89,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { - var responseText string - var toolCount int - err, usage, responseText, toolCount = OpenaiStreamHandler(c, resp, info) - if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 { - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) - usage.CompletionTokens += toolCount * 7 - } + err, usage, _, _ = OpenaiStreamHandler(c, resp, info) } else { err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index dace39c..3fd7f03 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -14,38 +14,33 @@ import ( relayconstant "one-api/relay/constant" "one-api/service" "strings" - "sync" "time" ) func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, string, int) { - //checkSensitive := constant.ShouldCheckCompletionSensitive() + hasStreamUsage := false + responseId := "" + var createAt int64 = 0 + var systemFingerprint string + var responseTextBuilder strings.Builder - var usage dto.Usage + var usage = &dto.Usage{} toolCount := 0 scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - dataChan := make(chan string, 5) + scanner.Split(bufio.ScanLines) + var streamItems []string // store stream items + + service.SetEventStreamHeaders(c) + + ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second) + defer ticker.Stop() + stopChan := make(chan bool, 2) defer close(stopChan) - defer close(dataChan) - var wg sync.WaitGroup + go func() { - wg.Add(1) - defer wg.Done() - var streamItems []string // store stream items for scanner.Scan() { + ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second) data := scanner.Text() if len(data) < 6 { // ignore blank line or wrong format continue @@ -53,54 +48,42 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. if data[:6] != "data: " && data[:6] != "[DONE]" { continue } - if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) { - // send data timeout, stop the stream - common.LogError(c, "send data timeout, stop the stream") - break - } data = data[6:] if !strings.HasPrefix(data, "[DONE]") { + service.StringData(c, data) streamItems = append(streamItems, data) } } - // 计算token - streamResp := "[" + strings.Join(streamItems, ",") + "]" - switch info.RelayMode { - case relayconstant.RelayModeChatCompletions: - var streamResponses []dto.ChatCompletionsStreamResponseSimple - err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) - if err != nil { - // 一次性解析失败,逐个解析 - common.SysError("error unmarshalling stream response: " + err.Error()) - for _, item := range streamItems { - var streamResponse dto.ChatCompletionsStreamResponseSimple - err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse) - if err == nil { - if streamResponse.Usage != nil { - if streamResponse.Usage.TotalTokens != 0 { - usage = *streamResponse.Usage - } - } - for _, choice := range streamResponse.Choices { - responseTextBuilder.WriteString(choice.Delta.GetContentString()) - if choice.Delta.ToolCalls != nil { - if len(choice.Delta.ToolCalls) > toolCount { - toolCount = len(choice.Delta.ToolCalls) - } - for _, tool := range choice.Delta.ToolCalls { - responseTextBuilder.WriteString(tool.Function.Name) - responseTextBuilder.WriteString(tool.Function.Arguments) - } - } - } - } - } - } else { - for _, streamResponse := range streamResponses { - if streamResponse.Usage != nil { - if streamResponse.Usage.TotalTokens != 0 { - usage = *streamResponse.Usage - } + stopChan <- true + }() + + select { + case <-ticker.C: + // 超时处理逻辑 + common.LogError(c, "streaming timeout") + case <-stopChan: + // 正常结束 + } + + // 计算token + streamResp := "[" + strings.Join(streamItems, ",") + "]" + switch info.RelayMode { + case relayconstant.RelayModeChatCompletions: + var streamResponses []dto.ChatCompletionsStreamResponse + err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) + if err != nil { + // 一次性解析失败,逐个解析 + common.SysError("error unmarshalling stream response: " + err.Error()) + for _, item := range streamItems { + var streamResponse dto.ChatCompletionsStreamResponse + err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse) + if err == nil { + responseId = streamResponse.Id + createAt = streamResponse.Created + systemFingerprint = streamResponse.GetSystemFingerprint() + if service.ValidUsage(streamResponse.Usage) { + usage = streamResponse.Usage + hasStreamUsage = true } for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Delta.GetContentString()) @@ -116,67 +99,71 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. } } } - case relayconstant.RelayModeCompletions: - var streamResponses []dto.CompletionsStreamResponse - err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) - if err != nil { - // 一次性解析失败,逐个解析 - common.SysError("error unmarshalling stream response: " + err.Error()) - for _, item := range streamItems { - var streamResponse dto.CompletionsStreamResponse - err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse) - if err == nil { - for _, choice := range streamResponse.Choices { - responseTextBuilder.WriteString(choice.Text) + } else { + for _, streamResponse := range streamResponses { + responseId = streamResponse.Id + createAt = streamResponse.Created + systemFingerprint = streamResponse.GetSystemFingerprint() + if service.ValidUsage(streamResponse.Usage) { + usage = streamResponse.Usage + hasStreamUsage = true + } + for _, choice := range streamResponse.Choices { + responseTextBuilder.WriteString(choice.Delta.GetContentString()) + if choice.Delta.ToolCalls != nil { + if len(choice.Delta.ToolCalls) > toolCount { + toolCount = len(choice.Delta.ToolCalls) + } + for _, tool := range choice.Delta.ToolCalls { + responseTextBuilder.WriteString(tool.Function.Name) + responseTextBuilder.WriteString(tool.Function.Arguments) } } } - } else { - for _, streamResponse := range streamResponses { + } + } + case relayconstant.RelayModeCompletions: + var streamResponses []dto.CompletionsStreamResponse + err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses) + if err != nil { + // 一次性解析失败,逐个解析 + common.SysError("error unmarshalling stream response: " + err.Error()) + for _, item := range streamItems { + var streamResponse dto.CompletionsStreamResponse + err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse) + if err == nil { for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Text) } } } - } - if len(dataChan) > 0 { - // wait data out - time.Sleep(2 * time.Second) - } - common.SafeSendBool(stopChan, true) - }() - service.SetEventStreamHeaders(c) - isFirst := true - ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second) - defer ticker.Stop() - c.Stream(func(w io.Writer) bool { - select { - case <-ticker.C: - common.LogError(c, "reading data from upstream timeout") - return false - case data := <-dataChan: - if isFirst { - isFirst = false - info.FirstResponseTime = time.Now() + } else { + for _, streamResponse := range streamResponses { + for _, choice := range streamResponse.Choices { + responseTextBuilder.WriteString(choice.Text) + } } - ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second) - if strings.HasPrefix(data, "data: [DONE]") { - data = data[:12] - } - // some implementations may add \r at the end of data - data = strings.TrimSuffix(data, "\r") - c.Render(-1, common.CustomEvent{Data: data}) - return true - case <-stopChan: - return false } - }) + } + + if !hasStreamUsage { + usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) + usage.CompletionTokens += toolCount * 7 + } + + if info.ShouldIncludeUsage && !hasStreamUsage { + response := service.GenerateFinalUsageResponse(responseId, createAt, info.UpstreamModelName, *usage) + response.SetSystemFingerprint(systemFingerprint) + service.ObjectData(c, response) + } + + service.Done(c) + err := resp.Body.Close() if err != nil { return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil, "", toolCount } - wg.Wait() - return nil, &usage, responseTextBuilder.String(), toolCount + return nil, usage, responseTextBuilder.String(), toolCount } func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { diff --git a/service/usage_helpr.go b/service/usage_helpr.go index 528f3d4..adec566 100644 --- a/service/usage_helpr.go +++ b/service/usage_helpr.go @@ -36,3 +36,7 @@ func GenerateFinalUsageResponse(id string, createAt int64, model string, usage d Usage: &usage, } } + +func ValidUsage(usage *dto.Usage) bool { + return usage != nil && (usage.PromptTokens != 0 || usage.CompletionTokens != 0) +} From 220ab412e26aafc246d11ca8cf42b0019759c1c9 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Mon, 15 Jul 2024 18:14:07 +0800 Subject: [PATCH 07/34] fix: openai response time --- relay/channel/openai/relay-openai.go | 1 + relay/common/relay_info.go | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 3fd7f03..16cbb0c 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -40,6 +40,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. go func() { for scanner.Scan() { + info.SetFirstResponseTime() ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second) data := scanner.Text() if len(data) < 6 { // ignore blank line or wrong format diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index e07434a..564a7ad 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -17,6 +17,7 @@ type RelayInfo struct { TokenUnlimited bool StartTime time.Time FirstResponseTime time.Time + setFirstResponse bool ApiType int IsStream bool RelayMode int @@ -83,6 +84,13 @@ func (info *RelayInfo) SetIsStream(isStream bool) { info.IsStream = isStream } +func (info *RelayInfo) SetFirstResponseTime() { + if !info.setFirstResponse { + info.FirstResponseTime = time.Now() + info.setFirstResponse = true + } +} + type TaskRelayInfo struct { ChannelType int ChannelId int From e2b906165086ded46c21bc8033aa51375d0c532e Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Mon, 15 Jul 2024 19:06:13 +0800 Subject: [PATCH 08/34] fix: openai stream response --- relay/channel/openai/relay-openai.go | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 16cbb0c..8fc4f6f 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -22,20 +22,22 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. responseId := "" var createAt int64 = 0 var systemFingerprint string + model := info.UpstreamModelName var responseTextBuilder strings.Builder var usage = &dto.Usage{} + var streamItems []string // store stream items + toolCount := 0 scanner := bufio.NewScanner(resp.Body) scanner.Split(bufio.ScanLines) - var streamItems []string // store stream items service.SetEventStreamHeaders(c) ticker := time.NewTicker(time.Duration(constant.StreamingTimeout) * time.Second) defer ticker.Stop() - stopChan := make(chan bool, 2) + stopChan := make(chan bool) defer close(stopChan) go func() { @@ -55,7 +57,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. streamItems = append(streamItems, data) } } - stopChan <- true + common.SafeSendBool(stopChan, true) }() select { @@ -82,6 +84,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. responseId = streamResponse.Id createAt = streamResponse.Created systemFingerprint = streamResponse.GetSystemFingerprint() + model = streamResponse.Model if service.ValidUsage(streamResponse.Usage) { usage = streamResponse.Usage hasStreamUsage = true @@ -105,6 +108,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. responseId = streamResponse.Id createAt = streamResponse.Created systemFingerprint = streamResponse.GetSystemFingerprint() + model = streamResponse.Model if service.ValidUsage(streamResponse.Usage) { usage = streamResponse.Usage hasStreamUsage = true @@ -153,7 +157,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. } if info.ShouldIncludeUsage && !hasStreamUsage { - response := service.GenerateFinalUsageResponse(responseId, createAt, info.UpstreamModelName, *usage) + response := service.GenerateFinalUsageResponse(responseId, createAt, model, *usage) response.SetSystemFingerprint(systemFingerprint) service.ObjectData(c, response) } @@ -162,7 +166,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. err := resp.Body.Close() if err != nil { - return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil, "", toolCount + common.LogError(c, "close_response_body_failed: "+err.Error()) } return nil, usage, responseTextBuilder.String(), toolCount } From 9bbe8e7d1ba584295f322c46ce9cde3180acc342 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Mon, 15 Jul 2024 20:23:19 +0800 Subject: [PATCH 09/34] =?UTF-8?q?fix:=20=E6=97=A5=E5=BF=97=E8=AF=A6?= =?UTF-8?q?=E6=83=85=E9=9D=9E=E6=B6=88=E8=B4=B9=E7=B1=BB=E5=9E=8B=E6=98=BE?= =?UTF-8?q?=E7=A4=BA=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- web/src/components/LogsTable.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/src/components/LogsTable.js b/web/src/components/LogsTable.js index 4bbacf0..55106f2 100644 --- a/web/src/components/LogsTable.js +++ b/web/src/components/LogsTable.js @@ -367,7 +367,7 @@ const LogsTable = () => { dataIndex: 'content', render: (text, record, index) => { let other = getLogOther(record.other); - if (other == null) { + if (other == null || record.type !== 2) { return ( Date: Mon, 15 Jul 2024 22:07:50 +0800 Subject: [PATCH 10/34] chore: openai stream --- common/model-ratio.go | 13 +++++++------ relay/channel/ollama/adaptor.go | 7 +------ relay/channel/openai/adaptor.go | 2 +- relay/channel/openai/relay-openai.go | 4 ++-- relay/channel/perplexity/adaptor.go | 7 +------ relay/channel/zhipu_4v/adaptor.go | 9 +-------- relay/channel/zhipu_4v/constants.go | 2 +- 7 files changed, 14 insertions(+), 30 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index c554036..294a0cc 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -105,12 +105,13 @@ var defaultModelRatio = map[string]float64{ "gemini-1.0-pro-latest": 1, "gemini-1.0-pro-vision-latest": 1, "gemini-ultra": 1, - "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens - "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens - "chatglm_std": 0.3572, // ¥0.005 / 1k tokens - "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens - "glm-4": 7.143, // ¥0.1 / 1k tokens - "glm-4v": 7.143, // ¥0.1 / 1k tokens + "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens + "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens + "chatglm_std": 0.3572, // ¥0.005 / 1k tokens + "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens + "glm-4": 7.143, // ¥0.1 / 1k tokens + "glm-4v": 0.05 * RMB, // ¥0.05 / 1k tokens + "glm-4-alltools": 0.1 * RMB, // ¥0.1 / 1k tokens "glm-3-turbo": 0.3572, "qwen-turbo": 0.8572, // ¥0.012 / 1k tokens "qwen-plus": 10, // ¥0.14 / 1k tokens diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index b0550ca..15ced27 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -10,7 +10,6 @@ import ( "one-api/relay/channel/openai" relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" - "one-api/service" ) type Adaptor struct { @@ -58,11 +57,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { - var responseText string - err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info) - if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 { - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) - } + err, usage = openai.OpenaiStreamHandler(c, resp, info) } else { if info.RelayMode == relayconstant.RelayModeEmbeddings { err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 688dedc..6dc56d0 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -89,7 +89,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { - err, usage, _, _ = OpenaiStreamHandler(c, resp, info) + err, usage = OpenaiStreamHandler(c, resp, info) } else { err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 8fc4f6f..b71fcce 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -17,7 +17,7 @@ import ( "time" ) -func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, string, int) { +func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { hasStreamUsage := false responseId := "" var createAt int64 = 0 @@ -168,7 +168,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. if err != nil { common.LogError(c, "close_response_body_failed: "+err.Error()) } - return nil, usage, responseTextBuilder.String(), toolCount + return nil, usage } func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index a220076..40aa0f4 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -10,7 +10,6 @@ import ( "one-api/relay/channel" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" - "one-api/service" ) type Adaptor struct { @@ -54,11 +53,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { - var responseText string - err, usage, responseText, _ = openai.OpenaiStreamHandler(c, resp, info) - if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 { - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) - } + err, usage = openai.OpenaiStreamHandler(c, resp, info) } else { err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index 9b8bd49..bdce639 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -10,7 +10,6 @@ import ( "one-api/relay/channel" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" - "one-api/service" ) type Adaptor struct { @@ -55,13 +54,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { - var responseText string - var toolCount int - err, usage, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info) - if usage == nil || usage.TotalTokens == 0 || (usage.PromptTokens+usage.CompletionTokens) == 0 { - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) - usage.CompletionTokens += toolCount * 7 - } + err, usage = openai.OpenaiStreamHandler(c, resp, info) } else { err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/zhipu_4v/constants.go b/relay/channel/zhipu_4v/constants.go index 1b0b0cc..3383eb3 100644 --- a/relay/channel/zhipu_4v/constants.go +++ b/relay/channel/zhipu_4v/constants.go @@ -1,7 +1,7 @@ package zhipu_4v var ModelList = []string{ - "glm-4", "glm-4v", "glm-3-turbo", + "glm-4", "glm-4v", "glm-3-turbo", "glm-4-alltools", } var ChannelName = "zhipu_4v" From ba27da9e2cee5850851f62e4228544220aa25e50 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Mon, 15 Jul 2024 22:09:11 +0800 Subject: [PATCH 11/34] fix: try to fix mj --- controller/midjourney.go | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/controller/midjourney.go b/controller/midjourney.go index 508c5dd..1a8cd36 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -146,7 +146,7 @@ func UpdateMidjourneyTaskBulk() { buttonStr, _ := json.Marshal(responseItem.Buttons) task.Buttons = string(buttonStr) } - + shouldReturnQuota := false if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") { common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason) task.Progress = "100%" @@ -154,20 +154,23 @@ func UpdateMidjourneyTaskBulk() { if err != nil { common.LogError(ctx, "error update user quota cache: "+err.Error()) } else { - quota := task.Quota - if quota != 0 { - err = model.IncreaseUserQuota(task.UserId, quota) - if err != nil { - common.LogError(ctx, "fail to increase user quota: "+err.Error()) - } - logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(quota)) - model.RecordLog(task.UserId, model.LogTypeSystem, logContent) + if task.Quota != 0 { + shouldReturnQuota = true } } } err = task.Update() if err != nil { common.LogError(ctx, "UpdateMidjourneyTask task error: "+err.Error()) + } else { + if shouldReturnQuota { + err = model.IncreaseUserQuota(task.UserId, task.Quota) + if err != nil { + common.LogError(ctx, "fail to increase user quota: "+err.Error()) + } + logContent := fmt.Sprintf("构图失败 %s,补偿 %s", task.MjId, common.LogQuota(task.Quota)) + model.RecordLog(task.UserId, model.LogTypeSystem, logContent) + } } } } From a3880d558ad0aaa5913809348fac829255c62fea Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Mon, 15 Jul 2024 22:14:30 +0800 Subject: [PATCH 12/34] chore: mj --- controller/midjourney.go | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/controller/midjourney.go b/controller/midjourney.go index 1a8cd36..01ddb2f 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -150,13 +150,8 @@ func UpdateMidjourneyTaskBulk() { if (task.Progress != "100%" && responseItem.FailReason != "") || (task.Progress == "100%" && task.Status == "FAILURE") { common.LogInfo(ctx, task.MjId+" 构建失败,"+task.FailReason) task.Progress = "100%" - err = model.CacheUpdateUserQuota(task.UserId) - if err != nil { - common.LogError(ctx, "error update user quota cache: "+err.Error()) - } else { - if task.Quota != 0 { - shouldReturnQuota = true - } + if task.Quota != 0 { + shouldReturnQuota = true } } err = task.Update() From 963985e76c05124aea21430f9bb1d2af4cfcbbe0 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Tue, 16 Jul 2024 14:54:03 +0800 Subject: [PATCH 13/34] chore: update model radio --- common/model-ratio.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/common/model-ratio.go b/common/model-ratio.go index 294a0cc..1200310 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -159,6 +159,8 @@ var defaultModelRatio = map[string]float64{ } var defaultModelPrice = map[string]float64{ + "suno_music": 0.1, + "suno_lyrics": 0.01, "dall-e-3": 0.04, "gpt-4-gizmo-*": 0.1, "mj_imagine": 0.1, From eb9b4b07ad94cc2378756e8eaba8a984b45ac61c Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Tue, 16 Jul 2024 15:48:56 +0800 Subject: [PATCH 14/34] feat: update register page --- web/src/components/RegisterForm.js | 217 +++++++++++++++-------------- 1 file changed, 113 insertions(+), 104 deletions(-) diff --git a/web/src/components/RegisterForm.js b/web/src/components/RegisterForm.js index fcd2638..5ff2588 100644 --- a/web/src/components/RegisterForm.js +++ b/web/src/components/RegisterForm.js @@ -1,16 +1,10 @@ import React, { useEffect, useState } from 'react'; -import { - Button, - Form, - Grid, - Header, - Image, - Message, - Segment, -} from 'semantic-ui-react'; import { Link, useNavigate } from 'react-router-dom'; import { API, getLogo, showError, showInfo, showSuccess } from '../helpers'; import Turnstile from 'react-turnstile'; +import { Button, Card, Form, Layout } from '@douyinfe/semi-ui'; +import Title from '@douyinfe/semi-ui/lib/es/typography/title'; +import Text from '@douyinfe/semi-ui/lib/es/typography/text'; const RegisterForm = () => { const [inputs, setInputs] = useState({ @@ -18,7 +12,7 @@ const RegisterForm = () => { password: '', password2: '', email: '', - verification_code: '', + verification_code: '' }); const { username, password, password2 } = inputs; const [showEmailVerification, setShowEmailVerification] = useState(false); @@ -46,9 +40,7 @@ const RegisterForm = () => { let navigate = useNavigate(); - function handleChange(e) { - const { name, value } = e.target; - console.log(name, value); + function handleChange(name, value) { setInputs((inputs) => ({ ...inputs, [name]: value })); } @@ -73,7 +65,7 @@ const RegisterForm = () => { inputs.aff_code = affCode; const res = await API.post( `/api/user/register?turnstile=${turnstileToken}`, - inputs, + inputs ); const { success, message } = res.data; if (success) { @@ -94,7 +86,7 @@ const RegisterForm = () => { } setLoading(true); const res = await API.get( - `/api/verification?email=${inputs.email}&turnstile=${turnstileToken}`, + `/api/verification?email=${inputs.email}&turnstile=${turnstileToken}` ); const { success, message } = res.data; if (success) { @@ -106,96 +98,113 @@ const RegisterForm = () => { }; return ( - - -
- 新用户注册 -
-
- - - - - {showEmailVerification ? ( - <> - - 获取验证码 - - } +
+ + + +
+
+ + + 新用户注册 + + + handleChange('username', value)} + /> + handleChange('password', value)} + /> + handleChange('password2', value)} + /> + {showEmailVerification ? ( + <> + handleChange('email', value)} + name="email" + type="email" + suffix={ + + } + /> + handleChange('verification_code', value)} + name="verification_code" + /> + + ) : ( + <> + )} + + +
+ + 已有账户? + + 点击登录 + + +
+
+ {turnstileEnabled ? ( + { + setTurnstileToken(token); + }} /> - - - ) : ( - <> - )} - {turnstileEnabled ? ( - { - setTurnstileToken(token); - }} - /> - ) : ( - <> - )} - - - - - 已有账户? - - 点击登录 - - - - + ) : ( + <> + )} +
+
+
+
+
); }; From 11856ab39eb7bcd37176c5b829eb7e92615c9aae Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Tue, 16 Jul 2024 17:02:37 +0800 Subject: [PATCH 15/34] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 4aca707..785cca8 100644 --- a/README.md +++ b/README.md @@ -72,8 +72,8 @@ ## 比原版One API多出的配置 - `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒 -- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true`, 可选值为 `true` 和 `false` -- `FORCE_STREAM_OPTION`:覆盖客户端stream_options参数,请求上游返回流模式usage,目前仅支持 `OpenAI` 渠道类型 +- `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true` +- `FORCE_STREAM_OPTION`:是否覆盖客户端stream_options参数,请求上游返回流模式usage,默认为 `true` ## 部署 ### 部署要求 - 本地数据库(默认):SQLite(Docker 部署默认使用 SQLite,必须挂载 `/data` 目录到宿主机) From bcc7f3edb28f390c19e368281643a472f3c802b2 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Tue, 16 Jul 2024 22:07:10 +0800 Subject: [PATCH 16/34] refactor: audio relay --- common/str.go | 73 ++++++++ common/utils.go | 60 ------- controller/channel-test.go | 2 +- controller/model.go | 2 +- dto/audio.go | 33 +++- middleware/distributor.go | 20 ++- relay/channel/adapter.go | 5 +- relay/channel/ali/adaptor.go | 10 +- relay/channel/api_request.go | 36 +++- relay/channel/aws/adaptor.go | 11 +- relay/channel/baidu/adaptor.go | 11 +- relay/channel/claude/adaptor.go | 11 +- relay/channel/cloudflare/adaptor.go | 11 +- relay/channel/cohere/adaptor.go | 12 +- relay/channel/dify/adaptor.go | 11 +- relay/channel/gemini/adaptor.go | 11 +- relay/channel/jina/adaptor.go | 11 +- relay/channel/ollama/adaptor.go | 11 +- relay/channel/openai/adaptor.go | 84 ++++++++-- relay/channel/openai/relay-openai.go | 137 +++++++++++++++ relay/channel/palm/adaptor.go | 11 +- relay/channel/perplexity/adaptor.go | 11 +- relay/channel/tencent/adaptor.go | 11 +- relay/channel/xunfei/adaptor.go | 11 +- relay/channel/zhipu/adaptor.go | 11 +- relay/channel/zhipu_4v/adaptor.go | 11 +- relay/constant/relay_mode.go | 10 +- relay/relay-audio.go | 239 ++++++++------------------- relay/relay-text.go | 8 +- relay/relay_rerank.go | 2 +- 30 files changed, 567 insertions(+), 320 deletions(-) create mode 100644 common/str.go diff --git a/common/str.go b/common/str.go new file mode 100644 index 0000000..d61adb1 --- /dev/null +++ b/common/str.go @@ -0,0 +1,73 @@ +package common + +import ( + "encoding/json" + "math/rand" + "strconv" + "unsafe" +) + +func GetStringIfEmpty(str string, defaultValue string) string { + if str == "" { + return defaultValue + } + return str +} + +func GetRandomString(length int) string { + //rand.Seed(time.Now().UnixNano()) + key := make([]byte, length) + for i := 0; i < length; i++ { + key[i] = keyChars[rand.Intn(len(keyChars))] + } + return string(key) +} + +func MapToJsonStr(m map[string]interface{}) string { + bytes, err := json.Marshal(m) + if err != nil { + return "" + } + return string(bytes) +} + +func MapToJsonStrFloat(m map[string]float64) string { + bytes, err := json.Marshal(m) + if err != nil { + return "" + } + return string(bytes) +} + +func StrToMap(str string) map[string]interface{} { + m := make(map[string]interface{}) + err := json.Unmarshal([]byte(str), &m) + if err != nil { + return nil + } + return m +} + +func String2Int(str string) int { + num, err := strconv.Atoi(str) + if err != nil { + return 0 + } + return num +} + +func StringsContains(strs []string, str string) bool { + for _, s := range strs { + if s == str { + return true + } + } + return false +} + +// StringToByteSlice []byte only read, panic on append +func StringToByteSlice(s string) []byte { + tmp1 := (*[2]uintptr)(unsafe.Pointer(&s)) + tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]} + return *(*[]byte)(unsafe.Pointer(&tmp2)) +} diff --git a/common/utils.go b/common/utils.go index 3e047c4..3d95508 100644 --- a/common/utils.go +++ b/common/utils.go @@ -1,7 +1,6 @@ package common import ( - "encoding/json" "fmt" "github.com/google/uuid" "html/template" @@ -13,7 +12,6 @@ import ( "strconv" "strings" "time" - "unsafe" ) func OpenBrowser(url string) { @@ -159,15 +157,6 @@ func GenerateKey() string { return string(key) } -func GetRandomString(length int) string { - //rand.Seed(time.Now().UnixNano()) - key := make([]byte, length) - for i := 0; i < length; i++ { - key[i] = keyChars[rand.Intn(len(keyChars))] - } - return string(key) -} - func GetRandomInt(max int) int { //rand.Seed(time.Now().UnixNano()) return rand.Intn(max) @@ -194,56 +183,7 @@ func MessageWithRequestId(message string, id string) string { return fmt.Sprintf("%s (request id: %s)", message, id) } -func String2Int(str string) int { - num, err := strconv.Atoi(str) - if err != nil { - return 0 - } - return num -} - -func StringsContains(strs []string, str string) bool { - for _, s := range strs { - if s == str { - return true - } - } - return false -} - -// StringToByteSlice []byte only read, panic on append -func StringToByteSlice(s string) []byte { - tmp1 := (*[2]uintptr)(unsafe.Pointer(&s)) - tmp2 := [3]uintptr{tmp1[0], tmp1[1], tmp1[1]} - return *(*[]byte)(unsafe.Pointer(&tmp2)) -} - func RandomSleep() { // Sleep for 0-3000 ms time.Sleep(time.Duration(rand.Intn(3000)) * time.Millisecond) } - -func MapToJsonStr(m map[string]interface{}) string { - bytes, err := json.Marshal(m) - if err != nil { - return "" - } - return string(bytes) -} - -func MapToJsonStrFloat(m map[string]float64) string { - bytes, err := json.Marshal(m) - if err != nil { - return "" - } - return string(bytes) -} - -func StrToMap(str string) map[string]interface{} { - m := make(map[string]interface{}) - err := json.Unmarshal([]byte(str), &m) - if err != nil { - return nil - } - return m -} diff --git a/controller/channel-test.go b/controller/channel-test.go index 4ad7457..e1af673 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -85,7 +85,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr meta.UpstreamModelName = testModel common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel)) - adaptor.Init(meta, *request) + adaptor.Init(meta) convertedRequest, err := adaptor.ConvertRequest(c, meta, request) if err != nil { diff --git a/controller/model.go b/controller/model.go index 7e3a321..6b4a878 100644 --- a/controller/model.go +++ b/controller/model.go @@ -131,7 +131,7 @@ func init() { } meta := &relaycommon.RelayInfo{ChannelType: i} adaptor := relay.GetAdaptor(apiType) - adaptor.Init(meta, dto.GeneralOpenAIRequest{}) + adaptor.Init(meta) channelId2Models[i] = adaptor.GetModelList() } } diff --git a/dto/audio.go b/dto/audio.go index c67d678..c36b3da 100644 --- a/dto/audio.go +++ b/dto/audio.go @@ -1,13 +1,34 @@ package dto -type TextToSpeechRequest struct { - Model string `json:"model" binding:"required"` - Input string `json:"input" binding:"required"` - Voice string `json:"voice" binding:"required"` - Speed float64 `json:"speed"` - ResponseFormat string `json:"response_format"` +type AudioRequest struct { + Model string `json:"model"` + Input string `json:"input"` + Voice string `json:"voice"` + Speed float64 `json:"speed,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` } type AudioResponse struct { Text string `json:"text"` } + +type WhisperVerboseJSONResponse struct { + Task string `json:"task,omitempty"` + Language string `json:"language,omitempty"` + Duration float64 `json:"duration,omitempty"` + Text string `json:"text,omitempty"` + Segments []Segment `json:"segments,omitempty"` +} + +type Segment struct { + Id int `json:"id"` + Seek int `json:"seek"` + Start float64 `json:"start"` + End float64 `json:"end"` + Text string `json:"text"` + Tokens []int `json:"tokens"` + Temperature float64 `json:"temperature"` + AvgLogprob float64 `json:"avg_logprob"` + CompressionRatio float64 `json:"compression_ratio"` + NoSpeechProb float64 `json:"no_speech_prob"` +} diff --git a/middleware/distributor.go b/middleware/distributor.go index 9f75207..2552f29 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -154,18 +154,20 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { } } if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { - if modelRequest.Model == "" { - modelRequest.Model = "dall-e" - } + modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "dall-e") } if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") { - if modelRequest.Model == "" { - if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { - modelRequest.Model = "tts-1" - } else { - modelRequest.Model = "whisper-1" - } + relayMode := relayconstant.RelayModeAudioSpeech + if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { + modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "tts-1") + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { + modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1") + relayMode = relayconstant.RelayModeAudioTranslation + } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { + modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1") + relayMode = relayconstant.RelayModeAudioTranscription } + c.Set("relay_mode", relayMode) } return &modelRequest, shouldSelectChannel, nil } diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index 7064b88..870b2b0 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -10,12 +10,13 @@ import ( type Adaptor interface { // Init IsStream bool - Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) - InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) + Init(info *relaycommon.RelayInfo) GetRequestURL(info *relaycommon.RelayInfo) (string, error) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) + ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) + ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) GetModelList() []string diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index e03d29f..88990d1 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -15,11 +15,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //TODO implement me + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index ab1131f..423a91d 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -7,14 +7,19 @@ import ( "io" "net/http" "one-api/relay/common" + "one-api/relay/constant" "one-api/service" ) func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) { - req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - req.Header.Set("Accept", c.Request.Header.Get("Accept")) - if info.IsStream && c.Request.Header.Get("Accept") == "" { - req.Header.Set("Accept", "text/event-stream") + if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation { + // multipart/form-data + } else { + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + req.Header.Set("Accept", c.Request.Header.Get("Accept")) + if info.IsStream && c.Request.Header.Get("Accept") == "" { + req.Header.Set("Accept", "text/event-stream") + } } } @@ -38,6 +43,29 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody return resp, nil } +func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) { + fullRequestURL, err := a.GetRequestURL(info) + if err != nil { + return nil, fmt.Errorf("get request url failed: %w", err) + } + req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + if err != nil { + return nil, fmt.Errorf("new request failed: %w", err) + } + // set form data + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + + err = a.SetupRequestHeader(c, req, info) + if err != nil { + return nil, fmt.Errorf("setup request header failed: %w", err) + } + resp, err := doRequest(c, req) + if err != nil { + return nil, fmt.Errorf("do request failed: %w", err) + } + return resp, nil +} + func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) { resp, err := service.GetHttpClient().Do(req) if err != nil { diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go index 8214777..44a870d 100644 --- a/relay/channel/aws/adaptor.go +++ b/relay/channel/aws/adaptor.go @@ -20,12 +20,17 @@ type Adaptor struct { RequestMode int } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me - + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { if strings.HasPrefix(info.UpstreamModelName, "claude-3") { a.RequestMode = RequestModeMessage } else { diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index 40a0696..cc0be56 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -16,12 +16,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me - + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index 8e4c75d..0544695 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -21,12 +21,17 @@ type Adaptor struct { RequestMode int } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me - + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { if strings.HasPrefix(info.UpstreamModelName, "claude-3") { a.RequestMode = RequestModeMessage } else { diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go index 53b5a91..2f3c46d 100644 --- a/relay/channel/cloudflare/adaptor.go +++ b/relay/channel/cloudflare/adaptor.go @@ -15,10 +15,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //TODO implement me + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { diff --git a/relay/channel/cohere/adaptor.go b/relay/channel/cohere/adaptor.go index 84243aa..3945774 100644 --- a/relay/channel/cohere/adaptor.go +++ b/relay/channel/cohere/adaptor.go @@ -1,6 +1,7 @@ package cohere import ( + "errors" "fmt" "github.com/gin-gonic/gin" "io" @@ -14,10 +15,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //TODO implement me + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { diff --git a/relay/channel/dify/adaptor.go b/relay/channel/dify/adaptor.go index 8dbe8b8..b582da2 100644 --- a/relay/channel/dify/adaptor.go +++ b/relay/channel/dify/adaptor.go @@ -14,12 +14,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me - + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index f223fbf..de7761a 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -14,10 +14,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //TODO implement me + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } // 定义一个映射,存储模型名称和对应的版本 diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go index d0a379a..6a04d08 100644 --- a/relay/channel/jina/adaptor.go +++ b/relay/channel/jina/adaptor.go @@ -15,10 +15,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //TODO implement me + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 15ced27..540ec85 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -15,10 +15,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //TODO implement me + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 6dc56d0..820c2bc 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -1,10 +1,13 @@ package openai import ( + "bytes" + "encoding/json" "errors" "fmt" "github.com/gin-gonic/gin" "io" + "mime/multipart" "net/http" "one-api/common" "one-api/dto" @@ -14,21 +17,16 @@ import ( "one-api/relay/channel/minimax" "one-api/relay/channel/moonshot" relaycommon "one-api/relay/common" + "one-api/relay/constant" "strings" ) type Adaptor struct { - ChannelType int + ChannelType int + ResponseFormat string } -func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { - return nil, nil -} - -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { -} - -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { a.ChannelType = info.ChannelType } @@ -83,15 +81,73 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re return request, nil } +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + a.ResponseFormat = request.ResponseFormat + if info.RelayMode == constant.RelayModeAudioSpeech { + jsonData, err := json.Marshal(request) + if err != nil { + return nil, fmt.Errorf("error marshalling object: %w", err) + } + return bytes.NewReader(jsonData), nil + } else { + var requestBody bytes.Buffer + writer := multipart.NewWriter(&requestBody) + + writer.WriteField("model", request.Model) + + // 添加文件字段 + file, header, err := c.Request.FormFile("file") + if err != nil { + return nil, errors.New("file is required") + } + defer file.Close() + + part, err := writer.CreateFormFile("file", header.Filename) + if err != nil { + return nil, errors.New("create form file failed") + } + if _, err := io.Copy(part, file); err != nil { + return nil, errors.New("copy file failed") + } + + // 关闭 multipart 编写器以设置分界线 + writer.Close() + c.Request.Header.Set("Content-Type", writer.FormDataContentType()) + return &requestBody, nil + } +} + +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { - return channel.DoApiRequest(a, c, info, requestBody) + if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation { + return channel.DoFormRequest(a, c, info, requestBody) + } else { + return channel.DoApiRequest(a, c, info, requestBody) + } } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { - if info.IsStream { - err, usage = OpenaiStreamHandler(c, resp, info) - } else { - err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + switch info.RelayMode { + case constant.RelayModeAudioSpeech: + err, usage = OpenaiTTSHandler(c, resp, info) + case constant.RelayModeAudioTranslation: + fallthrough + case constant.RelayModeAudioTranscription: + err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat) + default: + if info.IsStream { + err, usage = OpenaiStreamHandler(c, resp, info) + } else { + err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + } } return } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index b71fcce..4b27a07 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "encoding/json" + "fmt" "github.com/gin-gonic/gin" "io" "net/http" @@ -224,3 +225,139 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model } return nil, &simpleResponse.Usage } + +func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + // Reset response body + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + // We shouldn't set the header before we parse the response body, because the parse part may fail. + // And then we will have to send an error response, but in this case, the header has already been set. + // So the httpClient will be confused by the response. + // For example, Postman will report error, and we cannot check the response at all. + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + usage := &dto.Usage{} + usage.PromptTokens = info.PromptTokens + usage.TotalTokens = info.PromptTokens + + return nil, usage +} + +func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + var audioResp dto.AudioResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &audioResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + // Reset response body + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + // We shouldn't set the header before we parse the response body, because the parse part may fail. + // And then we will have to send an error response, but in this case, the header has already been set. + // So the httpClient will be confused by the response. + // For example, Postman will report error, and we cannot check the response at all. + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + + var text string + switch responseFormat { + case "json": + text, err = getTextFromJSON(responseBody) + case "text": + text, err = getTextFromText(responseBody) + case "srt": + text, err = getTextFromSRT(responseBody) + case "verbose_json": + text, err = getTextFromVerboseJSON(responseBody) + case "vtt": + text, err = getTextFromVTT(responseBody) + } + + usage := &dto.Usage{} + usage.PromptTokens = info.PromptTokens + usage.CompletionTokens, _ = service.CountTokenText(text, info.UpstreamModelName) + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + + return nil, usage +} + +func getTextFromVTT(body []byte) (string, error) { + return getTextFromSRT(body) +} + +func getTextFromVerboseJSON(body []byte) (string, error) { + var whisperResponse dto.WhisperVerboseJSONResponse + if err := json.Unmarshal(body, &whisperResponse); err != nil { + return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) + } + return whisperResponse.Text, nil +} + +func getTextFromSRT(body []byte) (string, error) { + scanner := bufio.NewScanner(strings.NewReader(string(body))) + var builder strings.Builder + var textLine bool + for scanner.Scan() { + line := scanner.Text() + if textLine { + builder.WriteString(line) + textLine = false + continue + } else if strings.Contains(line, "-->") { + textLine = true + continue + } + } + if err := scanner.Err(); err != nil { + return "", err + } + return builder.String(), nil +} + +func getTextFromText(body []byte) (string, error) { + return strings.TrimSuffix(string(body), "\n"), nil +} + +func getTextFromJSON(body []byte) (string, error) { + var whisperResponse dto.AudioResponse + if err := json.Unmarshal(body, &whisperResponse); err != nil { + return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) + } + return whisperResponse.Text, nil +} diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index 51d1399..d8c4ffb 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -15,12 +15,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me - + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index 40aa0f4..d3ed222 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -15,12 +15,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me - + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index 3dd9115..5811c87 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -23,12 +23,17 @@ type Adaptor struct { Timestamp int64 } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me - + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { a.Action = "ChatCompletions" a.Version = "2023-09-01" a.Timestamp = common.GetTimestamp() diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go index adb054e..f499bec 100644 --- a/relay/channel/xunfei/adaptor.go +++ b/relay/channel/xunfei/adaptor.go @@ -16,12 +16,17 @@ type Adaptor struct { request *dto.GeneralOpenAIRequest } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me - + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index 09345ca..f98581f 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -14,12 +14,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me - + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index bdce639..b34b756 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -15,12 +15,17 @@ import ( type Adaptor struct { } -func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { //TODO implement me - + return nil, errors.New("not implemented") } -func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index ed15b08..a072c74 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -13,6 +13,7 @@ const ( RelayModeModerations RelayModeImagesGenerations RelayModeEdits + RelayModeMidjourneyImagine RelayModeMidjourneyDescribe RelayModeMidjourneyBlend @@ -22,16 +23,19 @@ const ( RelayModeMidjourneyTaskFetch RelayModeMidjourneyTaskImageSeed RelayModeMidjourneyTaskFetchByCondition - RelayModeAudioSpeech - RelayModeAudioTranscription - RelayModeAudioTranslation RelayModeMidjourneyAction RelayModeMidjourneyModal RelayModeMidjourneyShorten RelayModeSwapFace + + RelayModeAudioSpeech // tts + RelayModeAudioTranscription // whisper + RelayModeAudioTranslation // whisper + RelayModeSunoFetch RelayModeSunoFetchByID RelayModeSunoSubmit + RelayModeRerank ) diff --git a/relay/relay-audio.go b/relay/relay-audio.go index 9137721..05b723c 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -1,13 +1,10 @@ package relay import ( - "bytes" - "context" "encoding/json" "errors" "fmt" "github.com/gin-gonic/gin" - "io" "net/http" "one-api/common" "one-api/constant" @@ -16,69 +13,71 @@ import ( relaycommon "one-api/relay/common" relayconstant "one-api/relay/constant" "one-api/service" - "strings" - "time" ) -func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { - tokenId := c.GetInt("token_id") - channelType := c.GetInt("channel") - channelId := c.GetInt("channel_id") - userId := c.GetInt("id") - group := c.GetString("group") - startTime := time.Now() - - var audioRequest dto.TextToSpeechRequest - if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { - err := common.UnmarshalBodyReusable(c, &audioRequest) - if err != nil { - return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } - } else { - audioRequest = dto.TextToSpeechRequest{ - Model: "whisper-1", - } +func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) { + audioRequest := &dto.AudioRequest{} + err := common.UnmarshalBodyReusable(c, audioRequest) + if err != nil { + return nil, err } - //err := common.UnmarshalBodyReusable(c, &audioRequest) - - // request validation - if audioRequest.Model == "" { - return service.OpenAIErrorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest) - } - - if strings.HasPrefix(audioRequest.Model, "tts-1") { - if audioRequest.Voice == "" { - return service.OpenAIErrorWrapper(errors.New("voice is required"), "required_field_missing", http.StatusBadRequest) + switch info.RelayMode { + case relayconstant.RelayModeAudioSpeech: + if audioRequest.Model == "" { + return nil, errors.New("model is required") } - } - var err error - promptTokens := 0 - preConsumedTokens := common.PreConsumedQuota - if strings.HasPrefix(audioRequest.Model, "tts-1") { if constant.ShouldCheckPromptSensitive() { - err = service.CheckSensitiveInput(audioRequest.Input) + err := service.CheckSensitiveInput(audioRequest.Input) if err != nil { - return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest) + return nil, err } } + default: + if audioRequest.Model == "" { + audioRequest.Model = c.PostForm("model") + } + if audioRequest.Model == "" { + return nil, errors.New("model is required") + } + if audioRequest.ResponseFormat == "" { + audioRequest.ResponseFormat = "json" + } + } + return audioRequest, nil +} + +func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { + relayInfo := relaycommon.GenRelayInfo(c) + audioRequest, err := getAndValidAudioRequest(c, relayInfo) + + if err != nil { + common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error())) + return service.OpenAIErrorWrapper(err, "invalid_audio_request", http.StatusBadRequest) + } + + promptTokens := 0 + preConsumedTokens := common.PreConsumedQuota + if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech { promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model) if err != nil { return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError) } preConsumedTokens = promptTokens + relayInfo.PromptTokens = promptTokens } + modelRatio := common.GetModelRatio(audioRequest.Model) - groupRatio := common.GetGroupRatio(group) + groupRatio := common.GetGroupRatio(relayInfo.Group) ratio := modelRatio * groupRatio preConsumedQuota := int(float64(preConsumedTokens) * ratio) - userQuota, err := model.CacheGetUserQuota(userId) + userQuota, err := model.CacheGetUserQuota(relayInfo.UserId) if err != nil { return service.OpenAIErrorWrapperLocal(err, "get_user_quota_failed", http.StatusInternalServerError) } if userQuota-preConsumedQuota < 0 { return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } - err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) + err = model.CacheDecreaseUserQuota(relayInfo.UserId, preConsumedQuota) if err != nil { return service.OpenAIErrorWrapperLocal(err, "decrease_user_quota_failed", http.StatusInternalServerError) } @@ -88,28 +87,12 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { preConsumedQuota = 0 } if preConsumedQuota > 0 { - userQuota, err = model.PreConsumeTokenQuota(tokenId, preConsumedQuota) + userQuota, err = model.PreConsumeTokenQuota(relayInfo.TokenId, preConsumedQuota) if err != nil { return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden) } } - succeed := false - defer func() { - if succeed { - return - } - if preConsumedQuota > 0 { - // we need to roll back the pre-consumed quota - defer func() { - go func() { - // negative means add quota back for token & user - returnPreConsumedQuota(c, tokenId, userQuota, preConsumedQuota) - }() - }() - } - }() - // map model name modelMapping := c.GetString("model_mapping") if modelMapping != "" { @@ -123,132 +106,42 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { } } - baseURL := common.ChannelBaseURLs[channelType] - requestURL := c.Request.URL.String() - if c.GetString("base_url") != "" { - baseURL = c.GetString("base_url") + adaptor := GetAdaptor(relayInfo.ApiType) + if adaptor == nil { + return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) } + adaptor.Init(relayInfo) - fullRequestURL := relaycommon.GetFullRequestURL(baseURL, requestURL, channelType) - if relayMode == relayconstant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { - // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api - apiVersion := relaycommon.GetAPIVersion(c) - fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioRequest.Model, apiVersion) - } - - requestBody := c.Request.Body - - req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest) if err != nil { - return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) } - if relayMode == relayconstant.RelayModeAudioTranscription && channelType == common.ChannelTypeAzure { - // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api - apiKey := c.Request.Header.Get("Authorization") - apiKey = strings.TrimPrefix(apiKey, "Bearer ") - req.Header.Set("api-key", apiKey) - req.ContentLength = c.Request.ContentLength - } else { - req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) - } - - req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - req.Header.Set("Accept", c.Request.Header.Get("Accept")) - - resp, err := service.GetHttpClient().Do(req) + resp, err := adaptor.DoRequest(c, relayInfo, ioReader) if err != nil { return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } - err = req.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - err = c.Request.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - - if resp.StatusCode != http.StatusOK { - return relaycommon.RelayErrorHandler(resp) - } - succeed = true - - var audioResponse dto.AudioResponse - - defer func(ctx context.Context) { - go func() { - useTimeSeconds := time.Now().Unix() - startTime.Unix() - quota := 0 - if strings.HasPrefix(audioRequest.Model, "tts-1") { - quota = promptTokens - } else { - quota, err = service.CountAudioToken(audioResponse.Text, audioRequest.Model) - } - quota = int(float64(quota) * ratio) - if ratio != 0 && quota <= 0 { - quota = 1 - } - quotaDelta := quota - preConsumedQuota - err := model.PostConsumeTokenQuota(tokenId, userQuota, quotaDelta, preConsumedQuota, true) - if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) - } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.SysError("error update user quota cache: " + err.Error()) - } - if quota != 0 { - tokenName := c.GetString("token_name") - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - other := make(map[string]interface{}) - other["model_ratio"] = modelRatio - other["group_ratio"] = groupRatio - model.RecordConsumeLog(ctx, userId, channelId, promptTokens, 0, audioRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false, other) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - channelId := c.GetInt("channel_id") - model.UpdateChannelUsedQuota(channelId, quota) - } - }() - }(c.Request.Context()) - - responseBody, err := io.ReadAll(resp.Body) - - if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } - if strings.HasPrefix(audioRequest.Model, "tts-1") { - - } else { - err = json.Unmarshal(responseBody, &audioResponse) - if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) - } - contains, words := service.SensitiveWordContains(audioResponse.Text) - if contains { - return service.OpenAIErrorWrapper(errors.New("response contains sensitive words: "+strings.Join(words, ", ")), "response_contains_sensitive_words", http.StatusBadRequest) + statusCodeMappingStr := c.GetString("status_code_mapping") + if resp != nil { + if resp.StatusCode != http.StatusOK { + returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) + openaiErr := service.RelayErrorHandler(resp) + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr } } - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) + if openaiErr != nil { + returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr + } - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.WriteHeader(resp.StatusCode) + postConsumeQuota(c, relayInfo, audioRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false) - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } return nil } diff --git a/relay/relay-text.go b/relay/relay-text.go index ef169fa..0438eba 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -91,7 +91,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { } } relayInfo.UpstreamModelName = textRequest.Model - modelPrice, success := common.GetModelPrice(textRequest.Model, false) + modelPrice, getModelPriceSuccess := common.GetModelPrice(textRequest.Model, false) groupRatio := common.GetGroupRatio(relayInfo.Group) var preConsumedQuota int @@ -112,7 +112,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError) } - if !success { + if !getModelPriceSuccess { preConsumedTokens := common.PreConsumedQuota if textRequest.MaxTokens != 0 { preConsumedTokens = promptTokens + int(textRequest.MaxTokens) @@ -150,7 +150,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { if adaptor == nil { return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) } - adaptor.Init(relayInfo, *textRequest) + adaptor.Init(relayInfo) var requestBody io.Reader convertedRequest, err := adaptor.ConvertRequest(c, relayInfo, textRequest) @@ -187,7 +187,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } - postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success) + postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess) return nil } diff --git a/relay/relay_rerank.go b/relay/relay_rerank.go index e32ca88..2fc4854 100644 --- a/relay/relay_rerank.go +++ b/relay/relay_rerank.go @@ -66,7 +66,7 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode if adaptor == nil { return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) } - adaptor.InitRerank(relayInfo, *rerankRequest) + adaptor.Init(relayInfo) convertedRequest, err := adaptor.ConvertRerankRequest(c, relayInfo.RelayMode, *rerankRequest) if err != nil { From ebb9b675b6f7d1db1e1eecb3345f20119ae75e1d Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Tue, 16 Jul 2024 23:24:47 +0800 Subject: [PATCH 17/34] feat: support cloudflare audio --- controller/channel-test.go | 2 +- middleware/distributor.go | 2 + relay/channel/cloudflare/adaptor.go | 50 +++++++++++++------ relay/channel/cloudflare/{model.go => dto.go} | 8 +++ relay/channel/cloudflare/relay_cloudflare.go | 35 +++++++++++++ relay/channel/openai/relay-openai.go | 18 ++----- relay/common/relay_utils.go | 33 ------------ relay/relay-audio.go | 1 + relay/relay-image.go | 2 +- service/error.go | 7 ++- 10 files changed, 90 insertions(+), 68 deletions(-) rename relay/channel/cloudflare/{model.go => dto.go} (78%) diff --git a/controller/channel-test.go b/controller/channel-test.go index e1af673..90d02d6 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -102,7 +102,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr return err, nil } if resp != nil && resp.StatusCode != http.StatusOK { - err := relaycommon.RelayErrorHandler(resp) + err := service.RelayErrorHandler(resp) return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), err } usage, respErr := adaptor.DoResponse(c, resp, meta) diff --git a/middleware/distributor.go b/middleware/distributor.go index 2552f29..1ce787e 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -161,9 +161,11 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "tts-1") } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { + modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model")) modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1") relayMode = relayconstant.RelayModeAudioTranslation } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { + modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, c.PostForm("model")) modelRequest.Model = common.GetStringIfEmpty(modelRequest.Model, "whisper-1") relayMode = relayconstant.RelayModeAudioTranscription } diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go index 2f3c46d..a518da8 100644 --- a/relay/channel/cloudflare/adaptor.go +++ b/relay/channel/cloudflare/adaptor.go @@ -1,6 +1,7 @@ package cloudflare import ( + "bytes" "errors" "fmt" "github.com/gin-gonic/gin" @@ -15,16 +16,6 @@ import ( type Adaptor struct { } -func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { - //TODO implement me - return nil, errors.New("not implemented") -} - -func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") -} - func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } @@ -65,11 +56,42 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return request, nil } +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + // 添加文件字段 + file, _, err := c.Request.FormFile("file") + if err != nil { + return nil, errors.New("file is required") + } + defer file.Close() + // 打开临时文件用于保存上传的文件内容 + requestBody := &bytes.Buffer{} + + // 将上传的文件内容复制到临时文件 + if _, err := io.Copy(requestBody, file); err != nil { + return nil, err + } + return requestBody, nil +} + +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + //TODO implement me + return nil, errors.New("not implemented") +} + func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { - if info.IsStream { - err, usage = cfStreamHandler(c, resp, info) - } else { - err, usage = cfHandler(c, resp, info) + switch info.RelayMode { + case constant.RelayModeEmbeddings: + fallthrough + case constant.RelayModeChatCompletions: + if info.IsStream { + err, usage = cfStreamHandler(c, resp, info) + } else { + err, usage = cfHandler(c, resp, info) + } + case constant.RelayModeAudioTranslation: + fallthrough + case constant.RelayModeAudioTranscription: + err, usage = cfSTTHandler(c, resp, info) } return } diff --git a/relay/channel/cloudflare/model.go b/relay/channel/cloudflare/dto.go similarity index 78% rename from relay/channel/cloudflare/model.go rename to relay/channel/cloudflare/dto.go index c870813..2f6531c 100644 --- a/relay/channel/cloudflare/model.go +++ b/relay/channel/cloudflare/dto.go @@ -11,3 +11,11 @@ type CfRequest struct { Stream bool `json:"stream,omitempty"` Temperature float64 `json:"temperature,omitempty"` } + +type CfAudioResponse struct { + Result CfSTTResult `json:"result"` +} + +type CfSTTResult struct { + Text string `json:"text"` +} diff --git a/relay/channel/cloudflare/relay_cloudflare.go b/relay/channel/cloudflare/relay_cloudflare.go index d9319ef..69d6b85 100644 --- a/relay/channel/cloudflare/relay_cloudflare.go +++ b/relay/channel/cloudflare/relay_cloudflare.go @@ -119,3 +119,38 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) _, _ = c.Writer.Write(jsonResponse) return nil, usage } + +func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + var cfResp CfAudioResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &cfResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + audioResp := &dto.AudioResponse{ + Text: cfResp.Result.Text, + } + + jsonResponse, err := json.Marshal(audioResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, _ = c.Writer.Write(jsonResponse) + + usage := &dto.Usage{} + usage.PromptTokens = info.PromptTokens + usage.CompletionTokens, _ = service.CountTokenText(cfResp.Result.Text, info.UpstreamModelName) + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + + return nil, usage +} diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 4b27a07..651e82e 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -165,10 +165,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. service.Done(c) - err := resp.Body.Close() - if err != nil { - common.LogError(c, "close_response_body_failed: "+err.Error()) - } + resp.Body.Close() return nil, usage } @@ -206,11 +203,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model if err != nil { return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } - + resp.Body.Close() if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { completionTokens := 0 for _, choice := range simpleResponse.Choices { @@ -257,7 +250,6 @@ func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel usage := &dto.Usage{} usage.PromptTokens = info.PromptTokens usage.TotalTokens = info.PromptTokens - return nil, usage } @@ -290,10 +282,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel if err != nil { return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + resp.Body.Close() var text string switch responseFormat { @@ -313,7 +302,6 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel usage.PromptTokens = info.PromptTokens usage.CompletionTokens, _ = service.CountTokenText(text, info.UpstreamModelName) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens - return nil, usage } diff --git a/relay/common/relay_utils.go b/relay/common/relay_utils.go index 9ef9a8b..6daf003 100644 --- a/relay/common/relay_utils.go +++ b/relay/common/relay_utils.go @@ -1,50 +1,17 @@ package common import ( - "encoding/json" "fmt" "github.com/gin-gonic/gin" _ "image/gif" _ "image/jpeg" _ "image/png" - "io" - "net/http" "one-api/common" - "one-api/dto" - "strconv" "strings" ) var StopFinishReason = "stop" -func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.OpenAIErrorWithStatusCode) { - OpenAIErrorWithStatusCode = &dto.OpenAIErrorWithStatusCode{ - StatusCode: resp.StatusCode, - Error: dto.OpenAIError{ - Message: fmt.Sprintf("bad response status code %d", resp.StatusCode), - Type: "upstream_error", - Code: "bad_response_status_code", - Param: strconv.Itoa(resp.StatusCode), - }, - } - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return - } - err = resp.Body.Close() - if err != nil { - return - } - var textResponse dto.TextResponseWithError - err = json.Unmarshal(responseBody, &textResponse) - if err != nil { - OpenAIErrorWithStatusCode.Error.Message = fmt.Sprintf("error unmarshalling response body: %s", responseBody) - return - } - OpenAIErrorWithStatusCode.Error = textResponse.Error - return -} - func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) diff --git a/relay/relay-audio.go b/relay/relay-audio.go index 05b723c..2a0278e 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -105,6 +105,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { audioRequest.Model = modelMap[audioRequest.Model] } } + relayInfo.UpstreamModelName = audioRequest.Model adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { diff --git a/relay/relay-image.go b/relay/relay-image.go index d83ec26..6d6e4d4 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -180,7 +180,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC } if resp.StatusCode != http.StatusOK { - return relaycommon.RelayErrorHandler(resp) + return service.RelayErrorHandler(resp) } var textResponse dto.ImageResponse diff --git a/service/error.go b/service/error.go index 0f6d472..3410de8 100644 --- a/service/error.go +++ b/service/error.go @@ -56,10 +56,9 @@ func RelayErrorHandler(resp *http.Response) (errWithStatusCode *dto.OpenAIErrorW errWithStatusCode = &dto.OpenAIErrorWithStatusCode{ StatusCode: resp.StatusCode, Error: dto.OpenAIError{ - Message: "", - Type: "upstream_error", - Code: "bad_response_status_code", - Param: strconv.Itoa(resp.StatusCode), + Type: "upstream_error", + Code: "bad_response_status_code", + Param: strconv.Itoa(resp.StatusCode), }, } responseBody, err := io.ReadAll(resp.Body) From 86ca533f7ab44a719210e83db2822e388c7980ee Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Tue, 16 Jul 2024 23:40:52 +0800 Subject: [PATCH 18/34] fix: fix bug --- relay/relay-text.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relay/relay-text.go b/relay/relay-text.go index 0438eba..9e1b9b7 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -300,7 +300,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN } totalTokens := promptTokens + completionTokens var logContent string - if modelPrice == -1 { + if !usePrice { logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio) } else { logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio) From 4d0d18931d27ac5c7379170eb6f098c123b78c56 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Wed, 17 Jul 2024 16:38:56 +0800 Subject: [PATCH 19/34] fix: try to fix panic #369 --- relay/channel/openai/relay-openai.go | 5 ++++- service/relay.go | 20 ++++++++++++-------- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index b71fcce..b6418cf 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -53,7 +53,10 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. } data = data[6:] if !strings.HasPrefix(data, "[DONE]") { - service.StringData(c, data) + err := service.StringData(c, data) + if err != nil { + common.LogError(c, "streaming error: "+err.Error()) + } streamItems = append(streamItems, data) } } diff --git a/service/relay.go b/service/relay.go index 22f9ce3..4f5ff8d 100644 --- a/service/relay.go +++ b/service/relay.go @@ -2,10 +2,10 @@ package service import ( "encoding/json" + "errors" "fmt" "github.com/gin-gonic/gin" "one-api/common" - "strings" ) func SetEventStreamHeaders(c *gin.Context) { @@ -16,11 +16,16 @@ func SetEventStreamHeaders(c *gin.Context) { c.Writer.Header().Set("X-Accel-Buffering", "no") } -func StringData(c *gin.Context, str string) { - str = strings.TrimPrefix(str, "data: ") - str = strings.TrimSuffix(str, "\r") +func StringData(c *gin.Context, str string) error { + //str = strings.TrimPrefix(str, "data: ") + //str = strings.TrimSuffix(str, "\r") c.Render(-1, common.CustomEvent{Data: "data: " + str}) - c.Writer.Flush() + if c.Writer != nil { + c.Writer.Flush() + } else { + return errors.New("writer is nil") + } + return nil } func ObjectData(c *gin.Context, object interface{}) error { @@ -28,12 +33,11 @@ func ObjectData(c *gin.Context, object interface{}) error { if err != nil { return fmt.Errorf("error marshalling object: %w", err) } - StringData(c, string(jsonData)) - return nil + return StringData(c, string(jsonData)) } func Done(c *gin.Context) { - StringData(c, "[DONE]") + _ = StringData(c, "[DONE]") } func GetResponseID(c *gin.Context) string { From fd872602099e20afd90c8881bf9ef60097eeb95e Mon Sep 17 00:00:00 2001 From: daggeryu <997411652@qq.com> Date: Wed, 17 Jul 2024 16:40:44 +0800 Subject: [PATCH 20/34] fix: embedding model dimensions --- dto/text_request.go | 1 + 1 file changed, 1 insertion(+) diff --git a/dto/text_request.go b/dto/text_request.go index ed36988..5c403da 100644 --- a/dto/text_request.go +++ b/dto/text_request.go @@ -31,6 +31,7 @@ type GeneralOpenAIRequest struct { User string `json:"user,omitempty"` LogProbs bool `json:"logprobs,omitempty"` TopLogProbs int `json:"top_logprobs,omitempty"` + Dimensions int `json:"dimensions,omitempty"` } type OpenAITools struct { From e3b83f886f71cd02ef97e53e2d38abed0d8254a3 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Wed, 17 Jul 2024 16:43:55 +0800 Subject: [PATCH 21/34] fix: try to fix panic #369 --- service/relay.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/service/relay.go b/service/relay.go index 4f5ff8d..03b005c 100644 --- a/service/relay.go +++ b/service/relay.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" + "net/http" "one-api/common" ) @@ -20,10 +21,10 @@ func StringData(c *gin.Context, str string) error { //str = strings.TrimPrefix(str, "data: ") //str = strings.TrimSuffix(str, "\r") c.Render(-1, common.CustomEvent{Data: "data: " + str}) - if c.Writer != nil { - c.Writer.Flush() + if flusher, ok := c.Writer.(http.Flusher); ok { + flusher.Flush() } else { - return errors.New("writer is nil") + return errors.New("streaming error: flusher not found") } return nil } From 7a0beb5793724b405b71fb242b71f090282df6da Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Wed, 17 Jul 2024 17:01:25 +0800 Subject: [PATCH 22/34] fix: distribute panic --- middleware/distributor.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/middleware/distributor.go b/middleware/distributor.go index 9f75207..9837216 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -1,6 +1,7 @@ package middleware import ( + "errors" "fmt" "net/http" "one-api/common" @@ -25,6 +26,10 @@ func Distribute() func(c *gin.Context) { var channel *model.Channel channelId, ok := c.Get("specific_channel_id") modelRequest, shouldSelectChannel, err := getModelRequest(c) + if err != nil { + abortWithOpenAiMessage(c, http.StatusBadRequest, "Invalid request, "+err.Error()) + return + } userGroup, _ := model.CacheGetUserGroup(userId) c.Set("group", userGroup) if ok { @@ -141,7 +146,7 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { } if err != nil { abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error()) - return nil, false, err + return nil, false, errors.New("无效的请求, " + err.Error()) } if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { if modelRequest.Model == "" { From b0d5491a2a609e61a4bcb324bb3522c648254de2 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Wed, 17 Jul 2024 23:50:37 +0800 Subject: [PATCH 23/34] refactor: image relay --- controller/relay.go | 4 +- relay/channel/openai/adaptor.go | 5 +- relay/relay-audio.go | 4 +- relay/relay-image.go | 219 ++++++++++++-------------------- relay/relay-text.go | 7 +- relay/relay_rerank.go | 2 +- 6 files changed, 92 insertions(+), 149 deletions(-) diff --git a/controller/relay.go b/controller/relay.go index a04c85a..bc951f7 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -22,13 +22,13 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode var err *dto.OpenAIErrorWithStatusCode switch relayMode { case relayconstant.RelayModeImagesGenerations: - err = relay.RelayImageHelper(c, relayMode) + err = relay.ImageHelper(c, relayMode) case relayconstant.RelayModeAudioSpeech: fallthrough case relayconstant.RelayModeAudioTranslation: fallthrough case relayconstant.RelayModeAudioTranscription: - err = relay.AudioHelper(c, relayMode) + err = relay.AudioHelper(c) case relayconstant.RelayModeRerank: err = relay.RerankHelper(c, relayMode) default: diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 820c2bc..2aa743f 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -122,8 +122,7 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf } func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") + return request, nil } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { @@ -142,6 +141,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom fallthrough case constant.RelayModeAudioTranscription: err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat) + case constant.RelayModeImagesGenerations: + err, usage = OpenaiTTSHandler(c, resp, info) default: if info.IsStream { err, usage = OpenaiStreamHandler(c, resp, info) diff --git a/relay/relay-audio.go b/relay/relay-audio.go index 2a0278e..b2fadcc 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -46,7 +46,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto. return audioRequest, nil } -func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { +func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { relayInfo := relaycommon.GenRelayInfo(c) audioRequest, err := getAndValidAudioRequest(c, relayInfo) @@ -142,7 +142,7 @@ func AudioHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { return openaiErr } - postConsumeQuota(c, relayInfo, audioRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false) + postConsumeQuota(c, relayInfo, audioRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "") return nil } diff --git a/relay/relay-image.go b/relay/relay-image.go index 6d6e4d4..4b1fbd2 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -2,7 +2,6 @@ package relay import ( "bytes" - "context" "encoding/json" "errors" "fmt" @@ -14,72 +13,71 @@ import ( "one-api/dto" "one-api/model" relaycommon "one-api/relay/common" - relayconstant "one-api/relay/constant" "one-api/service" "strings" - "time" ) -func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { - tokenId := c.GetInt("token_id") - channelType := c.GetInt("channel") - channelId := c.GetInt("channel_id") - userId := c.GetInt("id") - group := c.GetString("group") - startTime := time.Now() - - var imageRequest dto.ImageRequest - err := common.UnmarshalBodyReusable(c, &imageRequest) +func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.ImageRequest, error) { + imageRequest := &dto.ImageRequest{} + err := common.UnmarshalBodyReusable(c, imageRequest) if err != nil { - return service.OpenAIErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) + return nil, err } - - if imageRequest.Model == "" { - imageRequest.Model = "dall-e-3" + if imageRequest.Prompt == "" { + return nil, errors.New("prompt is required") } - if imageRequest.Size == "" { - imageRequest.Size = "1024x1024" + if strings.Contains(imageRequest.Size, "×") { + return nil, errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'") } if imageRequest.N == 0 { imageRequest.N = 1 } - // Prompt validation - if imageRequest.Prompt == "" { - return service.OpenAIErrorWrapper(errors.New("prompt is required"), "required_field_missing", http.StatusBadRequest) + if imageRequest.Size == "" { + imageRequest.Size = "1024x1024" } - - if constant.ShouldCheckPromptSensitive() { - err = service.CheckSensitiveInput(imageRequest.Prompt) - if err != nil { - return service.OpenAIErrorWrapper(err, "sensitive_words_detected", http.StatusBadRequest) - } + if imageRequest.Model == "" { + imageRequest.Model = "dall-e-2" } - - if strings.Contains(imageRequest.Size, "×") { - return service.OpenAIErrorWrapper(errors.New("size an unexpected error occurred in the parameter, please use 'x' instead of the multiplication sign '×'"), "invalid_field_value", http.StatusBadRequest) + if imageRequest.Quality == "" { + imageRequest.Quality = "standard" } // Not "256x256", "512x512", or "1024x1024" if imageRequest.Model == "dall-e-2" || imageRequest.Model == "dall-e" { if imageRequest.Size != "" && imageRequest.Size != "256x256" && imageRequest.Size != "512x512" && imageRequest.Size != "1024x1024" { - return service.OpenAIErrorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest) + return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024") } } else if imageRequest.Model == "dall-e-3" { if imageRequest.Size != "" && imageRequest.Size != "1024x1024" && imageRequest.Size != "1024x1792" && imageRequest.Size != "1792x1024" { - return service.OpenAIErrorWrapper(errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024"), "invalid_field_value", http.StatusBadRequest) + return nil, errors.New("size must be one of 256x256, 512x512, or 1024x1024, dall-e-3 1024x1792 or 1792x1024") } - if imageRequest.N != 1 { - return service.OpenAIErrorWrapper(errors.New("n must be 1"), "invalid_field_value", http.StatusBadRequest) + //if imageRequest.N != 1 { + // return nil, errors.New("n must be 1") + //} + } + // N should between 1 and 10 + //if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { + // return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) + //} + if constant.ShouldCheckPromptSensitive() { + err := service.CheckSensitiveInput(imageRequest.Prompt) + if err != nil { + return nil, err } } + return imageRequest, nil +} - // N should between 1 and 10 - if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { - return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) +func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { + relayInfo := relaycommon.GenRelayInfo(c) + + imageRequest, err := getAndValidImageRequest(c, relayInfo) + if err != nil { + common.LogError(c, fmt.Sprintf("getAndValidImageRequest failed: %s", err.Error())) + return service.OpenAIErrorWrapper(err, "invalid_image_request", http.StatusBadRequest) } // map model name modelMapping := c.GetString("model_mapping") - isModelMapped := false if modelMapping != "" { modelMap := make(map[string]string) err := json.Unmarshal([]byte(modelMapping), &modelMap) @@ -88,31 +86,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC } if modelMap[imageRequest.Model] != "" { imageRequest.Model = modelMap[imageRequest.Model] - isModelMapped = true } } - baseURL := common.ChannelBaseURLs[channelType] - requestURL := c.Request.URL.String() - if c.GetString("base_url") != "" { - baseURL = c.GetString("base_url") - } - fullRequestURL := relaycommon.GetFullRequestURL(baseURL, requestURL, channelType) - if channelType == common.ChannelTypeAzure && relayMode == relayconstant.RelayModeImagesGenerations { - // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api - apiVersion := relaycommon.GetAPIVersion(c) - // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2023-06-01-preview - fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", baseURL, imageRequest.Model, apiVersion) - } - var requestBody io.Reader - if isModelMapped || channelType == common.ChannelTypeAzure { // make Azure channel request body - jsonStr, err := json.Marshal(imageRequest) - if err != nil { - return service.OpenAIErrorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - } else { - requestBody = c.Request.Body - } + relayInfo.UpstreamModelName = imageRequest.Model modelPrice, success := common.GetModelPrice(imageRequest.Model, true) if !success { @@ -121,8 +97,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC // per 1 modelRatio = $0.04 / 16 modelPrice = 0.0025 * modelRatio } - groupRatio := common.GetGroupRatio(group) - userQuota, err := model.CacheGetUserQuota(userId) + + groupRatio := common.GetGroupRatio(relayInfo.Group) + userQuota, err := model.CacheGetUserQuota(relayInfo.UserId) sizeRatio := 1.0 // Size @@ -150,98 +127,60 @@ func RelayImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusC return service.OpenAIErrorWrapperLocal(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } - req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + adaptor := GetAdaptor(relayInfo.ApiType) + if adaptor == nil { + return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) + } + adaptor.Init(relayInfo) + + var requestBody io.Reader + + convertedRequest, err := adaptor.ConvertImageRequest(c, relayInfo, *imageRequest) if err != nil { - return service.OpenAIErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) + return service.OpenAIErrorWrapperLocal(err, "convert_request_failed", http.StatusInternalServerError) } - token := c.Request.Header.Get("Authorization") - if channelType == common.ChannelTypeAzure { // Azure authentication - token = strings.TrimPrefix(token, "Bearer ") - req.Header.Set("api-key", token) - } else { - req.Header.Set("Authorization", token) + jsonData, err := json.Marshal(convertedRequest) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "json_marshal_failed", http.StatusInternalServerError) } - req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - req.Header.Set("Accept", c.Request.Header.Get("Accept")) + requestBody = bytes.NewBuffer(jsonData) - resp, err := service.GetHttpClient().Do(req) + statusCodeMappingStr := c.GetString("status_code_mapping") + resp, err := adaptor.DoRequest(c, relayInfo, requestBody) if err != nil { return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } - err = req.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - err = c.Request.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - - if resp.StatusCode != http.StatusOK { - return service.RelayErrorHandler(resp) - } - - var textResponse dto.ImageResponse - defer func(ctx context.Context) { - useTimeSeconds := time.Now().Unix() - startTime.Unix() + if resp != nil { + relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") if resp.StatusCode != http.StatusOK { - return + openaiErr := service.RelayErrorHandler(resp) + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr } - err := model.PostConsumeTokenQuota(tokenId, userQuota, quota, 0, true) - if err != nil { - common.SysError("error consuming token remain quota: " + err.Error()) - } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.SysError("error update user quota cache: " + err.Error()) - } - if quota != 0 { - tokenName := c.GetString("token_name") - quality := "normal" - if imageRequest.Quality == "hd" { - quality = "hd" - } - logContent := fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f, 大小 %s, 品质 %s", modelPrice, groupRatio, imageRequest.Size, quality) - other := make(map[string]interface{}) - other["model_price"] = modelPrice - other["group_ratio"] = groupRatio - model.RecordConsumeLog(ctx, userId, channelId, 0, 0, imageRequest.Model, tokenName, quota, logContent, tokenId, userQuota, int(useTimeSeconds), false, other) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - channelId := c.GetInt("channel_id") - model.UpdateChannelUsedQuota(channelId, quota) - } - }(c.Request.Context()) - - responseBody, err := io.ReadAll(resp.Body) - - if err != nil { - return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } - err = json.Unmarshal(responseBody, &textResponse) - if err != nil { - return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) } - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + _, openaiErr := adaptor.DoResponse(c, resp, relayInfo) + if openaiErr != nil { + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr + } - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) + usage := &dto.Usage{ + PromptTokens: relayInfo.PromptTokens, + TotalTokens: relayInfo.PromptTokens, } - c.Writer.WriteHeader(resp.StatusCode) - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + quality := "standard" + if imageRequest.Quality == "hd" { + quality = "hd" } + + logContent := fmt.Sprintf(", 大小 %s, 品质 %s", imageRequest.Size, quality) + postConsumeQuota(c, relayInfo, imageRequest.Model, usage, 0, 0, userQuota, 0, groupRatio, modelPrice, true, logContent) + return nil } diff --git a/relay/relay-text.go b/relay/relay-text.go index 9e1b9b7..d82bd60 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -187,7 +187,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } - postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess) + postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") return nil } @@ -279,7 +279,7 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, - modelPrice float64, usePrice bool) { + modelPrice float64, usePrice bool, extraContent string) { useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() promptTokens := usage.PromptTokens @@ -338,6 +338,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN logModel = "gpt-4-gizmo-*" logContent += fmt.Sprintf(",模型 %s", modelName) } + if extraContent != "" { + logContent += extraContent + } other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice) model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other) diff --git a/relay/relay_rerank.go b/relay/relay_rerank.go index 2fc4854..9885fd3 100644 --- a/relay/relay_rerank.go +++ b/relay/relay_rerank.go @@ -99,6 +99,6 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } - postConsumeQuota(c, relayInfo, rerankRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success) + postConsumeQuota(c, relayInfo, rerankRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "") return nil } From 11fd993574f74f3acdf0df750fd0e1755a2e4959 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Thu, 18 Jul 2024 00:36:05 +0800 Subject: [PATCH 24/34] feat: support claude tool calling --- dto/text_request.go | 7 ++- dto/text_response.go | 6 +- relay/channel/claude/dto.go | 61 +++++++++++------- relay/channel/claude/relay-claude.go | 93 +++++++++++++++++++++++----- 4 files changed, 128 insertions(+), 39 deletions(-) diff --git a/dto/text_request.go b/dto/text_request.go index 5c403da..801d1c3 100644 --- a/dto/text_request.go +++ b/dto/text_request.go @@ -26,7 +26,7 @@ type GeneralOpenAIRequest struct { PresencePenalty float64 `json:"presence_penalty,omitempty"` ResponseFormat *ResponseFormat `json:"response_format,omitempty"` Seed float64 `json:"seed,omitempty"` - Tools any `json:"tools,omitempty"` + Tools []ToolCall `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` User string `json:"user,omitempty"` LogProbs bool `json:"logprobs,omitempty"` @@ -104,6 +104,11 @@ func (m Message) StringContent() string { return string(m.Content) } +func (m *Message) SetStringContent(content string) { + jsonContent, _ := json.Marshal(content) + m.Content = jsonContent +} + func (m Message) IsStringContent() bool { var stringContent string if err := json.Unmarshal(m.Content, &stringContent); err == nil { diff --git a/dto/text_response.go b/dto/text_response.go index e1f0cc0..9b12683 100644 --- a/dto/text_response.go +++ b/dto/text_response.go @@ -86,9 +86,11 @@ type ToolCall struct { } type FunctionCall struct { - Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + Name string `json:"name,omitempty"` // call function with arguments in JSON format - Arguments string `json:"arguments,omitempty"` + Parameters any `json:"parameters,omitempty"` // request + Arguments string `json:"arguments,omitempty"` } type ChatCompletionsStreamResponse struct { diff --git a/relay/channel/claude/dto.go b/relay/channel/claude/dto.go index 47f0c3b..e2a898e 100644 --- a/relay/channel/claude/dto.go +++ b/relay/channel/claude/dto.go @@ -5,11 +5,18 @@ type ClaudeMetadata struct { } type ClaudeMediaMessage struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - Source *ClaudeMessageSource `json:"source,omitempty"` - Usage *ClaudeUsage `json:"usage,omitempty"` - StopReason *string `json:"stop_reason,omitempty"` + Type string `json:"type"` + Text string `json:"text,omitempty"` + Source *ClaudeMessageSource `json:"source,omitempty"` + Usage *ClaudeUsage `json:"usage,omitempty"` + StopReason *string `json:"stop_reason,omitempty"` + PartialJson string `json:"partial_json,omitempty"` + // tool_calls + Id string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input any `json:"input,omitempty"` + Content string `json:"content,omitempty"` + ToolUseId string `json:"tool_use_id,omitempty"` } type ClaudeMessageSource struct { @@ -23,6 +30,18 @@ type ClaudeMessage struct { Content any `json:"content"` } +type Tool struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema InputSchema `json:"input_schema"` +} + +type InputSchema struct { + Type string `json:"type"` + Properties any `json:"properties,omitempty"` + Required any `json:"required,omitempty"` +} + type ClaudeRequest struct { Model string `json:"model"` Prompt string `json:"prompt,omitempty"` @@ -35,7 +54,9 @@ type ClaudeRequest struct { TopP float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` //ClaudeMetadata `json:"metadata,omitempty"` - Stream bool `json:"stream,omitempty"` + Stream bool `json:"stream,omitempty"` + Tools []Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` } type ClaudeError struct { @@ -44,24 +65,20 @@ type ClaudeError struct { } type ClaudeResponse struct { - Id string `json:"id"` - Type string `json:"type"` - Content []ClaudeMediaMessage `json:"content"` - Completion string `json:"completion"` - StopReason string `json:"stop_reason"` - Model string `json:"model"` - Error ClaudeError `json:"error"` - Usage ClaudeUsage `json:"usage"` - Index int `json:"index"` // stream only - Delta *ClaudeMediaMessage `json:"delta"` // stream only - Message *ClaudeResponse `json:"message"` // stream only: message_start + Id string `json:"id"` + Type string `json:"type"` + Content []ClaudeMediaMessage `json:"content"` + Completion string `json:"completion"` + StopReason string `json:"stop_reason"` + Model string `json:"model"` + Error ClaudeError `json:"error"` + Usage ClaudeUsage `json:"usage"` + Index int `json:"index"` // stream only + ContentBlock *ClaudeMediaMessage `json:"content_block"` + Delta *ClaudeMediaMessage `json:"delta"` // stream only + Message *ClaudeResponse `json:"message"` // stream only: message_start } -//type ClaudeResponseChoice struct { -// Index int `json:"index"` -// Type string `json:"type"` -//} - type ClaudeUsage struct { InputTokens int `json:"input_tokens"` OutputTokens int `json:"output_tokens"` diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 945b20d..0d70715 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -30,6 +30,7 @@ func stopReasonClaude2OpenAI(reason string) string { } func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest { + claudeRequest := ClaudeRequest{ Model: textRequest.Model, Prompt: "", @@ -60,6 +61,22 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeR } func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) { + claudeTools := make([]Tool, 0, len(textRequest.Tools)) + + for _, tool := range textRequest.Tools { + if params, ok := tool.Function.Parameters.(map[string]any); ok { + claudeTools = append(claudeTools, Tool{ + Name: tool.Function.Name, + Description: tool.Function.Description, + InputSchema: InputSchema{ + Type: params["type"].(string), + Properties: params["properties"], + Required: params["required"], + }, + }) + } + } + claudeRequest := ClaudeRequest{ Model: textRequest.Model, MaxTokens: textRequest.MaxTokens, @@ -68,6 +85,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR TopP: textRequest.TopP, TopK: textRequest.TopK, Stream: textRequest.Stream, + Tools: claudeTools, } if claudeRequest.MaxTokens == 0 { claudeRequest.MaxTokens = 4096 @@ -184,6 +202,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (* response.Object = "chat.completion.chunk" response.Model = claudeResponse.Model response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0) + tools := make([]dto.ToolCall, 0) var choice dto.ChatCompletionsStreamResponseChoice if reqMode == RequestModeCompletion { choice.Delta.SetContentString(claudeResponse.Completion) @@ -199,10 +218,33 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (* choice.Delta.SetContentString("") choice.Delta.Role = "assistant" } else if claudeResponse.Type == "content_block_start" { - return nil, nil + if claudeResponse.ContentBlock != nil { + //choice.Delta.SetContentString(claudeResponse.ContentBlock.Text) + if claudeResponse.ContentBlock.Type == "tool_use" { + tools = append(tools, dto.ToolCall{ + ID: claudeResponse.ContentBlock.Id, + Type: "function", + Function: dto.FunctionCall{ + Name: claudeResponse.ContentBlock.Name, + Arguments: "", + }, + }) + } + } else { + return nil, nil + } } else if claudeResponse.Type == "content_block_delta" { - choice.Index = claudeResponse.Index - choice.Delta.SetContentString(claudeResponse.Delta.Text) + if claudeResponse.Delta != nil { + choice.Index = claudeResponse.Index + choice.Delta.SetContentString(claudeResponse.Delta.Text) + if claudeResponse.Delta.Type == "input_json_delta" { + tools = append(tools, dto.ToolCall{ + Function: dto.FunctionCall{ + Arguments: claudeResponse.Delta.PartialJson, + }, + }) + } + } } else if claudeResponse.Type == "message_delta" { finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason) if finishReason != "null" { @@ -218,6 +260,10 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (* if claudeUsage == nil { claudeUsage = &ClaudeUsage{} } + if len(tools) > 0 { + choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ... + choice.Delta.ToolCalls = tools + } response.Choices = append(response.Choices, choice) return &response, claudeUsage @@ -230,6 +276,11 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope Object: "chat.completion", Created: common.GetTimestamp(), } + var responseText string + if len(claudeResponse.Content) > 0 { + responseText = claudeResponse.Content[0].Text + } + tools := make([]dto.ToolCall, 0) if reqMode == RequestModeCompletion { content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " ")) choice := dto.OpenAITextResponseChoice{ @@ -244,20 +295,32 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope choices = append(choices, choice) } else { fullTextResponse.Id = claudeResponse.Id - for i, message := range claudeResponse.Content { - content, _ := json.Marshal(message.Text) - choice := dto.OpenAITextResponseChoice{ - Index: i, - Message: dto.Message{ - Role: "assistant", - Content: content, - }, - FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), + for _, message := range claudeResponse.Content { + if message.Type == "tool_use" { + args, _ := json.Marshal(message.Input) + tools = append(tools, dto.ToolCall{ + ID: message.Id, + Type: "function", // compatible with other OpenAI derivative applications + Function: dto.FunctionCall{ + Name: message.Name, + Arguments: string(args), + }, + }) } - choices = append(choices, choice) } } - + choice := dto.OpenAITextResponseChoice{ + Index: 0, + Message: dto.Message{ + Role: "assistant", + }, + FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), + } + choice.SetStringContent(responseText) + if len(tools) > 0 { + choice.Message.ToolCalls = tools + } + choices = append(choices, choice) fullTextResponse.Choices = choices return &fullTextResponse } @@ -334,6 +397,8 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. } else if claudeResponse.Type == "message_delta" { usage.CompletionTokens = claudeUsage.OutputTokens usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens + } else if claudeResponse.Type == "content_block_start" { + } else { return true } From fae918c0550390963f66f502cd8049e47a3bd143 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Thu, 18 Jul 2024 00:41:31 +0800 Subject: [PATCH 25/34] chore: log format --- relay/relay-image.go | 2 +- relay/relay-text.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/relay/relay-image.go b/relay/relay-image.go index 4b1fbd2..21e0582 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -179,7 +179,7 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { quality = "hd" } - logContent := fmt.Sprintf(", 大小 %s, 品质 %s", imageRequest.Size, quality) + logContent := fmt.Sprintf("大小 %s, 品质 %s", imageRequest.Size, quality) postConsumeQuota(c, relayInfo, imageRequest.Model, usage, 0, 0, userQuota, 0, groupRatio, modelPrice, true, logContent) return nil diff --git a/relay/relay-text.go b/relay/relay-text.go index d82bd60..6efc2e5 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -339,7 +339,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN logContent += fmt.Sprintf(",模型 %s", modelName) } if extraContent != "" { - logContent += extraContent + logContent += ", " + extraContent } other := service.GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice) model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, From ae00a99cf5affade2ea54827bf8d9f655b8e1b5d Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Thu, 18 Jul 2024 17:04:19 +0800 Subject: [PATCH 26/34] =?UTF-8?q?feat:=20=E5=AA=92=E4=BD=93=E8=AF=B7?= =?UTF-8?q?=E6=B1=82=E8=AE=A1=E8=B4=B9=E9=80=89=E9=A1=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 2 ++ constant/env.go | 4 ++++ service/token_counter.go | 10 +++++++++- 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 785cca8..ffc0e5a 100644 --- a/README.md +++ b/README.md @@ -74,6 +74,8 @@ - `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒 - `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true` - `FORCE_STREAM_OPTION`:是否覆盖客户端stream_options参数,请求上游返回流模式usage,默认为 `true` +- `GET_MEDIA_TOKEN`:是统计图片token,默认为 `true`,关闭后将不再在本地计算图片token,可能会导致和上游计费不同,此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用, +- `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`)情况下统计图片token,默认为 `true` ## 部署 ### 部署要求 - 本地数据库(默认):SQLite(Docker 部署默认使用 SQLite,必须挂载 `/data` 目录到宿主机) diff --git a/constant/env.go b/constant/env.go index 96483fe..a18b875 100644 --- a/constant/env.go +++ b/constant/env.go @@ -9,3 +9,7 @@ var DifyDebug = common.GetEnvOrDefaultBool("DIFY_DEBUG", true) // ForceStreamOption 覆盖请求参数,强制返回usage信息 var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true) + +var GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true) + +var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true) diff --git a/service/token_counter.go b/service/token_counter.go index b99fc20..a1ab0dc 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -9,6 +9,7 @@ import ( "log" "math" "one-api/common" + "one-api/constant" "one-api/dto" "strings" "unicode/utf8" @@ -81,13 +82,20 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { } func getImageToken(imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) { - // TODO: 非流模式下不计算图片token数量 if model == "glm-4v" { return 1047, nil } if imageUrl.Detail == "low" { return 85, nil } + // TODO: 非流模式下不计算图片token数量 + if !constant.GetMediaTokenNotStream && !stream { + return 1000, nil + } + // 是否统计图片token + if !constant.GetMediaToken { + return 1000, nil + } // 同步One API的图片计费逻辑 if imageUrl.Detail == "auto" || imageUrl.Detail == "" { imageUrl.Detail = "high" From 70491ea1bb93fdd617d45f33880b8f6375e5893d Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Thu, 18 Jul 2024 17:12:28 +0800 Subject: [PATCH 27/34] fix: image relay quota --- relay/relay-image.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/relay/relay-image.go b/relay/relay-image.go index 21e0582..f6a2641 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -170,8 +170,8 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { } usage := &dto.Usage{ - PromptTokens: relayInfo.PromptTokens, - TotalTokens: relayInfo.PromptTokens, + PromptTokens: imageRequest.N, + TotalTokens: imageRequest.N, } quality := "standard" From 14bf865034aa17bf1c3a8c778acac7d83492cb4b Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Thu, 18 Jul 2024 17:26:21 +0800 Subject: [PATCH 28/34] feat: add UPDATE_TASK env --- README.md | 1 + constant/env.go | 2 ++ main.go | 3 ++- 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index ffc0e5a..d2c805c 100644 --- a/README.md +++ b/README.md @@ -76,6 +76,7 @@ - `FORCE_STREAM_OPTION`:是否覆盖客户端stream_options参数,请求上游返回流模式usage,默认为 `true` - `GET_MEDIA_TOKEN`:是统计图片token,默认为 `true`,关闭后将不再在本地计算图片token,可能会导致和上游计费不同,此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用, - `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`)情况下统计图片token,默认为 `true` +- `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认为 `true`,关闭后将不会更新任务进度 ## 部署 ### 部署要求 - 本地数据库(默认):SQLite(Docker 部署默认使用 SQLite,必须挂载 `/data` 目录到宿主机) diff --git a/constant/env.go b/constant/env.go index a18b875..76146ca 100644 --- a/constant/env.go +++ b/constant/env.go @@ -13,3 +13,5 @@ var ForceStreamOption = common.GetEnvOrDefaultBool("FORCE_STREAM_OPTION", true) var GetMediaToken = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN", true) var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STREAM", true) + +var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true) diff --git a/main.go b/main.go index e929e0c..ed2ab2e 100644 --- a/main.go +++ b/main.go @@ -9,6 +9,7 @@ import ( "log" "net/http" "one-api/common" + "one-api/constant" "one-api/controller" "one-api/middleware" "one-api/model" @@ -89,7 +90,7 @@ func main() { } go controller.AutomaticallyTestChannels(frequency) } - if common.IsMasterNode { + if common.IsMasterNode && constant.UpdateTask { common.SafeGoroutine(func() { controller.UpdateMidjourneyTaskBulk() }) From f96291a25a5e5a013e993e594558cd18c667a2b5 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Thu, 18 Jul 2024 20:28:47 +0800 Subject: [PATCH 29/34] feat: support gemini tool calling (close #368) --- dto/text_request.go | 11 +- relay/channel/gemini/adaptor.go | 2 +- relay/channel/gemini/dto.go | 10 +- relay/channel/gemini/relay-gemini.go | 177 +++++++++++++-------------- 4 files changed, 105 insertions(+), 95 deletions(-) diff --git a/dto/text_request.go b/dto/text_request.go index 801d1c3..2170e71 100644 --- a/dto/text_request.go +++ b/dto/text_request.go @@ -148,7 +148,7 @@ func (m Message) ParseContent() []MediaMessage { if ok { subObj["detail"] = detail.(string) } else { - subObj["detail"] = "auto" + subObj["detail"] = "high" } contentList = append(contentList, MediaMessage{ Type: ContentTypeImageURL, @@ -157,7 +157,16 @@ func (m Message) ParseContent() []MediaMessage { Detail: subObj["detail"].(string), }, }) + } else if url, ok := contentMap["image_url"].(string); ok { + contentList = append(contentList, MediaMessage{ + Type: ContentTypeImageURL, + ImageUrl: MessageImageUrl{ + Url: url, + Detail: "high", + }, + }) } + } } return contentList diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index de7761a..e132d2f 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -47,7 +47,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { action := "generateContent" if info.IsStream { - action = "streamGenerateContent" + action = "streamGenerateContent?alt=sse" } return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil } diff --git a/relay/channel/gemini/dto.go b/relay/channel/gemini/dto.go index 99ab654..771a616 100644 --- a/relay/channel/gemini/dto.go +++ b/relay/channel/gemini/dto.go @@ -12,9 +12,15 @@ type GeminiInlineData struct { Data string `json:"data"` } +type FunctionCall struct { + FunctionName string `json:"name"` + Arguments any `json:"args"` +} + type GeminiPart struct { - Text string `json:"text,omitempty"` - InlineData *GeminiInlineData `json:"inlineData,omitempty"` + Text string `json:"text,omitempty"` + InlineData *GeminiInlineData `json:"inlineData,omitempty"` + FunctionCall *FunctionCall `json:"functionCall,omitempty"` } type GeminiChatContent struct { diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index b7080b9..45dfbb9 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -4,18 +4,14 @@ import ( "bufio" "encoding/json" "fmt" + "github.com/gin-gonic/gin" "io" - "log" "net/http" "one-api/common" - "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" "one-api/service" "strings" - "time" - - "github.com/gin-gonic/gin" ) // Setting safety to the lowest possible values since Gemini is already powerless enough @@ -46,7 +42,17 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques MaxOutputTokens: textRequest.MaxTokens, }, } - if textRequest.Functions != nil { + if textRequest.Tools != nil { + functions := make([]dto.FunctionCall, 0, len(textRequest.Tools)) + for _, tool := range textRequest.Tools { + functions = append(functions, tool.Function) + } + geminiRequest.Tools = []GeminiChatTools{ + { + FunctionDeclarations: functions, + }, + } + } else if textRequest.Functions != nil { geminiRequest.Tools = []GeminiChatTools{ { FunctionDeclarations: textRequest.Functions, @@ -126,6 +132,30 @@ func (g *GeminiChatResponse) GetResponseText() string { return "" } +func getToolCalls(candidate *GeminiChatCandidate) []dto.ToolCall { + var toolCalls []dto.ToolCall + + item := candidate.Content.Parts[0] + if item.FunctionCall == nil { + return toolCalls + } + argsBytes, err := json.Marshal(item.FunctionCall.Arguments) + if err != nil { + //common.SysError("getToolCalls failed: " + err.Error()) + return toolCalls + } + toolCall := dto.ToolCall{ + ID: fmt.Sprintf("call_%s", common.GetUUID()), + Type: "function", + Function: dto.FunctionCall{ + Arguments: string(argsBytes), + Name: item.FunctionCall.FunctionName, + }, + } + toolCalls = append(toolCalls, toolCall) + return toolCalls +} + func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse { fullTextResponse := dto.OpenAITextResponse{ Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), @@ -144,8 +174,11 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp FinishReason: relaycommon.StopFinishReason, } if len(candidate.Content.Parts) > 0 { - content, _ = json.Marshal(candidate.Content.Parts[0].Text) - choice.Message.Content = content + if candidate.Content.Parts[0].FunctionCall != nil { + choice.Message.ToolCalls = getToolCalls(&candidate) + } else { + choice.Message.SetStringContent(candidate.Content.Parts[0].Text) + } } fullTextResponse.Choices = append(fullTextResponse.Choices, choice) } @@ -154,7 +187,17 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse { var choice dto.ChatCompletionsStreamResponseChoice - choice.Delta.SetContentString(geminiResponse.GetResponseText()) + //choice.Delta.SetContentString(geminiResponse.GetResponseText()) + if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 { + respFirst := geminiResponse.Candidates[0].Content.Parts[0] + if respFirst.FunctionCall != nil { + // function response + choice.Delta.ToolCalls = getToolCalls(&geminiResponse.Candidates[0]) + } else { + // text response + choice.Delta.SetContentString(respFirst.Text) + } + } choice.FinishReason = &relaycommon.StopFinishReason var response dto.ChatCompletionsStreamResponse response.Object = "chat.completion.chunk" @@ -165,92 +208,47 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { responseText := "" - responseJson := "" id := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) createAt := common.GetTimestamp() var usage = &dto.Usage{} - dataChan := make(chan string, 5) - stopChan := make(chan bool, 2) scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - go func() { - for scanner.Scan() { - data := scanner.Text() - responseJson += data - data = strings.TrimSpace(data) - if !strings.HasPrefix(data, "\"text\": \"") { - continue - } - data = strings.TrimPrefix(data, "\"text\": \"") - data = strings.TrimSuffix(data, "\"") - if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) { - // send data timeout, stop the stream - common.LogError(c, "send data timeout, stop the stream") - break - } - } - stopChan <- true - }() - isFirst := true + scanner.Split(bufio.ScanLines) + service.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - if isFirst { - isFirst = false - info.FirstResponseTime = time.Now() - } - // this is used to prevent annoying \ related format bug - data = fmt.Sprintf("{\"content\": \"%s\"}", data) - type dummyStruct struct { - Content string `json:"content"` - } - var dummy dummyStruct - err := json.Unmarshal([]byte(data), &dummy) - responseText += dummy.Content - var choice dto.ChatCompletionsStreamResponseChoice - choice.Delta.SetContentString(dummy.Content) - response := dto.ChatCompletionsStreamResponse{ - Id: id, - Object: "chat.completion.chunk", - Created: createAt, - Model: info.UpstreamModelName, - Choices: []dto.ChatCompletionsStreamResponseChoice{choice}, - } - jsonResponse, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - return false + for scanner.Scan() { + data := scanner.Text() + info.SetFirstResponseTime() + data = strings.TrimSpace(data) + if !strings.HasPrefix(data, "data: ") { + continue } - }) - var geminiChatResponses []GeminiChatResponse - err := json.Unmarshal([]byte(responseJson), &geminiChatResponses) - if err != nil { - log.Printf("cannot get gemini usage: %s", err.Error()) - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) - } else { - for _, response := range geminiChatResponses { - usage.PromptTokens = response.UsageMetadata.PromptTokenCount - usage.CompletionTokens = response.UsageMetadata.CandidatesTokenCount + data = strings.TrimPrefix(data, "data: ") + data = strings.TrimSuffix(data, "\"") + var geminiResponse GeminiChatResponse + err := json.Unmarshal([]byte(data), &geminiResponse) + if err != nil { + common.LogError(c, "error unmarshalling stream response: "+err.Error()) + continue + } + + response := streamResponseGeminiChat2OpenAI(&geminiResponse) + if response == nil { + continue + } + response.Id = id + response.Created = createAt + responseText += response.Choices[0].Delta.GetContentString() + if geminiResponse.UsageMetadata.TotalTokenCount != 0 { + usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount + usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + } + err = service.ObjectData(c, response) + if err != nil { + common.LogError(c, err.Error()) } - usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens } + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + if info.ShouldIncludeUsage { response := service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage) err := service.ObjectData(c, response) @@ -259,10 +257,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom } } service.Done(c) - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), usage - } + resp.Body.Close() return nil, usage } From c9100b219f9212b4e760aec70d120776b99a5c17 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Fri, 19 Jul 2024 00:45:52 +0800 Subject: [PATCH 30/34] feat: support ali image --- dto/dalle.go | 12 +- relay/channel/ali/adaptor.go | 54 +++--- relay/channel/ali/dto.go | 35 +++- relay/channel/ali/image.go | 177 ++++++++++++++++++++ relay/channel/ali/{relay-ali.go => text.go} | 54 ++---- 5 files changed, 257 insertions(+), 75 deletions(-) create mode 100644 relay/channel/ali/image.go rename relay/channel/ali/{relay-ali.go => text.go} (82%) diff --git a/dto/dalle.go b/dto/dalle.go index d366051..d0bba65 100644 --- a/dto/dalle.go +++ b/dto/dalle.go @@ -12,9 +12,11 @@ type ImageRequest struct { } type ImageResponse struct { - Created int `json:"created"` - Data []struct { - Url string `json:"url"` - B64Json string `json:"b64_json"` - } + Data []ImageData `json:"data"` + Created int64 `json:"created"` +} +type ImageData struct { + Url string `json:"url"` + B64Json string `json:"b64_json"` + RevisedPrompt string `json:"revised_prompt"` } diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index 88990d1..98728a0 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -8,6 +8,7 @@ import ( "net/http" "one-api/dto" "one-api/relay/channel" + "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/constant" ) @@ -15,23 +16,18 @@ import ( type Adaptor struct { } -func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { - //TODO implement me - return nil, errors.New("not implemented") -} - -func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { - //TODO implement me - return nil, errors.New("not implemented") -} - func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", info.BaseUrl) - if info.RelayMode == constant.RelayModeEmbeddings { + var fullRequestURL string + switch info.RelayMode { + case constant.RelayModeEmbeddings: fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", info.BaseUrl) + case constant.RelayModeImagesGenerations: + fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl) + default: + fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl) } return fullRequestURL, nil } @@ -57,13 +53,23 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re baiduEmbeddingRequest := embeddingRequestOpenAI2Ali(*request) return baiduEmbeddingRequest, nil default: - baiduRequest := requestOpenAI2Ali(*request) - return baiduRequest, nil + aliReq := requestOpenAI2Ali(*request) + return aliReq, nil } } +func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { + aliRequest := oaiImage2Ali(request) + return aliRequest, nil +} + func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { - return nil, nil + return nil, errors.New("not implemented") +} + +func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { + //TODO implement me + return nil, errors.New("not implemented") } func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { @@ -71,14 +77,16 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { - if info.IsStream { - err, usage = aliStreamHandler(c, resp) - } else { - switch info.RelayMode { - case constant.RelayModeEmbeddings: - err, usage = aliEmbeddingHandler(c, resp) - default: - err, usage = aliHandler(c, resp) + switch info.RelayMode { + case constant.RelayModeImagesGenerations: + err, usage = aliImageHandler(c, resp, info) + case constant.RelayModeEmbeddings: + err, usage = aliEmbeddingHandler(c, resp) + default: + if info.IsStream { + err, usage = openai.OpenaiStreamHandler(c, resp, info) + } else { + err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } } return diff --git a/relay/channel/ali/dto.go b/relay/channel/ali/dto.go index fd1f07a..f51286a 100644 --- a/relay/channel/ali/dto.go +++ b/relay/channel/ali/dto.go @@ -60,13 +60,40 @@ type AliUsage struct { TotalTokens int `json:"total_tokens"` } -type AliOutput struct { - Text string `json:"text"` - FinishReason string `json:"finish_reason"` +type TaskResult struct { + B64Image string `json:"b64_image,omitempty"` + Url string `json:"url,omitempty"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` } -type AliChatResponse struct { +type AliOutput struct { + TaskId string `json:"task_id,omitempty"` + TaskStatus string `json:"task_status,omitempty"` + Text string `json:"text"` + FinishReason string `json:"finish_reason"` + Message string `json:"message,omitempty"` + Code string `json:"code,omitempty"` + Results []TaskResult `json:"results,omitempty"` +} + +type AliResponse struct { Output AliOutput `json:"output"` Usage AliUsage `json:"usage"` AliError } + +type AliImageRequest struct { + Model string `json:"model"` + Input struct { + Prompt string `json:"prompt"` + NegativePrompt string `json:"negative_prompt,omitempty"` + } `json:"input"` + Parameters struct { + Size string `json:"size,omitempty"` + N int `json:"n,omitempty"` + Steps string `json:"steps,omitempty"` + Scale string `json:"scale,omitempty"` + } `json:"parameters,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` +} diff --git a/relay/channel/ali/image.go b/relay/channel/ali/image.go new file mode 100644 index 0000000..160fabf --- /dev/null +++ b/relay/channel/ali/image.go @@ -0,0 +1,177 @@ +package ali + +import ( + "encoding/json" + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/service" + "strings" + "time" +) + +func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest { + var imageRequest AliImageRequest + imageRequest.Input.Prompt = request.Prompt + imageRequest.Model = request.Model + imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1) + imageRequest.Parameters.N = request.N + imageRequest.ResponseFormat = request.ResponseFormat + + return &imageRequest +} + +func updateTask(info *relaycommon.RelayInfo, taskID string, key string) (*AliResponse, error, []byte) { + url := fmt.Sprintf("/api/v1/tasks/%s", taskID) + + var aliResponse AliResponse + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return &aliResponse, err, nil + } + + req.Header.Set("Authorization", "Bearer "+key) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + common.SysError("updateTask client.Do err: " + err.Error()) + return &aliResponse, err, nil + } + defer resp.Body.Close() + + responseBody, err := io.ReadAll(resp.Body) + + var response AliResponse + err = json.Unmarshal(responseBody, &response) + if err != nil { + common.SysError("updateTask NewDecoder err: " + err.Error()) + return &aliResponse, err, nil + } + + return &response, nil, responseBody +} + +func asyncTaskWait(info *relaycommon.RelayInfo, taskID string, key string) (*AliResponse, []byte, error) { + waitSeconds := 3 + step := 0 + maxStep := 20 + + var taskResponse AliResponse + var responseBody []byte + + for { + step++ + rsp, err, body := updateTask(info, taskID, key) + responseBody = body + if err != nil { + return &taskResponse, responseBody, err + } + + if rsp.Output.TaskStatus == "" { + return &taskResponse, responseBody, nil + } + + switch rsp.Output.TaskStatus { + case "FAILED": + fallthrough + case "CANCELED": + fallthrough + case "SUCCEEDED": + fallthrough + case "UNKNOWN": + return rsp, responseBody, nil + } + if step >= maxStep { + break + } + time.Sleep(time.Duration(waitSeconds) * time.Second) + } + + return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout") +} + +func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relaycommon.RelayInfo, responseFormat string) *dto.ImageResponse { + imageResponse := dto.ImageResponse{ + Created: info.StartTime.Unix(), + } + + for _, data := range response.Output.Results { + var b64Json string + if responseFormat == "b64_json" { + _, b64, err := service.GetImageFromUrl(data.Url) + if err != nil { + common.LogError(c, "get_image_data_failed: "+err.Error()) + continue + } + b64Json = b64 + } else { + b64Json = data.B64Image + } + + imageResponse.Data = append(imageResponse.Data, dto.ImageData{ + Url: data.Url, + B64Json: b64Json, + RevisedPrompt: "", + }) + } + return &imageResponse +} + +func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + responseFormat := c.GetString("response_format") + + var aliTaskResponse AliResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &aliTaskResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + if aliTaskResponse.Message != "" { + common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message) + return service.OpenAIErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil + } + + aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId, apiKey) + if err != nil { + return service.OpenAIErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil + } + + if aliResponse.Output.TaskStatus != "SUCCEEDED" { + return &dto.OpenAIErrorWithStatusCode{ + Error: dto.OpenAIError{ + Message: aliResponse.Output.Message, + Type: "ali_error", + Param: "", + Code: aliResponse.Output.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + + fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, nil +} diff --git a/relay/channel/ali/relay-ali.go b/relay/channel/ali/text.go similarity index 82% rename from relay/channel/ali/relay-ali.go rename to relay/channel/ali/text.go index 4280b1c..aec857f 100644 --- a/relay/channel/ali/relay-ali.go +++ b/relay/channel/ali/text.go @@ -16,34 +16,13 @@ import ( const EnableSearchModelSuffix = "-internet" -func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliChatRequest { - messages := make([]AliMessage, 0, len(request.Messages)) - //prompt := "" - for i := 0; i < len(request.Messages); i++ { - message := request.Messages[i] - messages = append(messages, AliMessage{ - Content: message.StringContent(), - Role: strings.ToLower(message.Role), - }) - } - enableSearch := false - aliModel := request.Model - if strings.HasSuffix(aliModel, EnableSearchModelSuffix) { - enableSearch = true - aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix) - } - return &AliChatRequest{ - Model: request.Model, - Input: AliInput{ - //Prompt: prompt, - Messages: messages, - }, - Parameters: AliParameters{ - IncrementalOutput: request.Stream, - Seed: uint64(request.Seed), - EnableSearch: enableSearch, - }, +func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest { + if request.TopP >= 1 { + request.TopP = 0.999 + } else if request.TopP <= 0 { + request.TopP = 0.001 } + return &request } func embeddingRequestOpenAI2Ali(request dto.GeneralOpenAIRequest) *AliEmbeddingRequest { @@ -110,7 +89,7 @@ func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *dto.OpenAIEmbe return &openAIEmbeddingResponse } -func responseAli2OpenAI(response *AliChatResponse) *dto.OpenAITextResponse { +func responseAli2OpenAI(response *AliResponse) *dto.OpenAITextResponse { content, _ := json.Marshal(response.Output.Text) choice := dto.OpenAITextResponseChoice{ Index: 0, @@ -134,7 +113,7 @@ func responseAli2OpenAI(response *AliChatResponse) *dto.OpenAITextResponse { return &fullTextResponse } -func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *dto.ChatCompletionsStreamResponse { +func streamResponseAli2OpenAI(aliResponse *AliResponse) *dto.ChatCompletionsStreamResponse { var choice dto.ChatCompletionsStreamResponseChoice choice.Delta.SetContentString(aliResponse.Output.Text) if aliResponse.Output.FinishReason != "null" { @@ -154,18 +133,7 @@ func streamResponseAli2OpenAI(aliResponse *AliChatResponse) *dto.ChatCompletions func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { var usage dto.Usage scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) + scanner.Split(bufio.ScanLines) dataChan := make(chan string) stopChan := make(chan bool) go func() { @@ -187,7 +155,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: - var aliResponse AliChatResponse + var aliResponse AliResponse err := json.Unmarshal([]byte(data), &aliResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) @@ -221,7 +189,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWith } func aliHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - var aliResponse AliChatResponse + var aliResponse AliResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil From e84300f4aee6a01270add0fd69e41a4294abf6ad Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Fri, 19 Jul 2024 01:07:37 +0800 Subject: [PATCH 31/34] chore: gopool --- controller/channel-test.go | 5 +- go.mod | 1 + go.sum | 4 + main.go | 5 +- model/log.go | 3 +- model/utils.go | 5 +- relay/channel/ali/adaptor.go | 2 +- relay/channel/claude/relay-claude.go | 129 ++++++++++----------------- relay/channel/ollama/adaptor.go | 2 +- relay/channel/openai/adaptor.go | 2 +- relay/channel/openai/relay-openai.go | 17 ++-- relay/channel/perplexity/adaptor.go | 2 +- relay/channel/zhipu/relay-zhipu.go | 13 +-- relay/channel/zhipu_4v/adaptor.go | 2 +- 14 files changed, 77 insertions(+), 115 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index 90d02d6..fe27978 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/bytedance/gopkg/util/gopool" "io" "math" "net/http" @@ -217,7 +218,7 @@ func testAllChannels(notify bool) error { if disableThreshold == 0 { disableThreshold = 10000000 // a impossible value } - go func() { + gopool.Go(func() { for _, channel := range channels { isChannelEnabled := channel.Status == common.ChannelStatusEnabled tik := time.Now() @@ -265,7 +266,7 @@ func testAllChannels(notify bool) error { common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) } } - }() + }) return nil } diff --git a/go.mod b/go.mod index a9d4a1d..f97217b 100644 --- a/go.mod +++ b/go.mod @@ -38,6 +38,7 @@ require ( github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect github.com/aws/smithy-go v1.20.2 // indirect + github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b // indirect github.com/bytedance/sonic v1.9.1 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect diff --git a/go.sum b/go.sum index a77a89c..f19b88c 100644 --- a/go.sum +++ b/go.sum @@ -18,6 +18,8 @@ github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 h1:JgHnonzbnA3pbqj76w github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4/go.mod h1:nZspkhg+9p8iApLFoyAqfyuMP0F38acy2Hm3r5r95Cg= github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q= github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= +github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b h1:LTGVFpNmNHhj0vhOlfgWueFJ32eK9blaIlHR2ciXOT0= +github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= @@ -198,6 +200,7 @@ golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -206,6 +209,7 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/main.go b/main.go index ed2ab2e..959b795 100644 --- a/main.go +++ b/main.go @@ -3,6 +3,7 @@ package main import ( "embed" "fmt" + "github.com/bytedance/gopkg/util/gopool" "github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions/cookie" "github.com/gin-gonic/gin" @@ -91,10 +92,10 @@ func main() { go controller.AutomaticallyTestChannels(frequency) } if common.IsMasterNode && constant.UpdateTask { - common.SafeGoroutine(func() { + gopool.Go(func() { controller.UpdateMidjourneyTaskBulk() }) - common.SafeGoroutine(func() { + gopool.Go(func() { controller.UpdateTaskBulk() }) } diff --git a/model/log.go b/model/log.go index cea5b98..85c53b1 100644 --- a/model/log.go +++ b/model/log.go @@ -3,6 +3,7 @@ package model import ( "context" "fmt" + "github.com/bytedance/gopkg/util/gopool" "gorm.io/gorm" "one-api/common" "strings" @@ -87,7 +88,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke common.LogError(ctx, "failed to record log: "+err.Error()) } if common.DataExportEnabled { - common.SafeGoroutine(func() { + gopool.Go(func() { LogQuotaData(userId, username, modelName, quota, common.GetTimestamp(), promptTokens+completionTokens) }) } diff --git a/model/utils.go b/model/utils.go index 44bfbb9..3905e95 100644 --- a/model/utils.go +++ b/model/utils.go @@ -2,6 +2,7 @@ package model import ( "errors" + "github.com/bytedance/gopkg/util/gopool" "gorm.io/gorm" "one-api/common" "sync" @@ -28,12 +29,12 @@ func init() { } func InitBatchUpdater() { - go func() { + gopool.Go(func() { for { time.Sleep(time.Duration(common.BatchUpdateInterval) * time.Second) batchUpdate() } - }() + }) } func addNewRecord(type_ int, id int, value int) { diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index 98728a0..ff9d533 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -84,7 +84,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom err, usage = aliEmbeddingHandler(c, resp) default: if info.IsStream { - err, usage = openai.OpenaiStreamHandler(c, resp, info) + err, usage = openai.OaiStreamHandler(c, resp, info) } else { err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 0d70715..031f825 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -8,12 +8,10 @@ import ( "io" "net/http" "one-api/common" - "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" "one-api/service" "strings" - "time" ) func stopReasonClaude2OpenAI(reason string) string { @@ -332,91 +330,59 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. responseText := "" createdTime := common.GetTimestamp() scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil + scanner.Split(bufio.ScanLines) + service.SetEventStreamHeaders(c) + + for scanner.Scan() { + data := scanner.Text() + info.SetFirstResponseTime() + if len(data) < 6 || !strings.HasPrefix(data, "data:") { + continue } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil + data = strings.TrimPrefix(data, "data:") + data = strings.TrimSpace(data) + var claudeResponse ClaudeResponse + err := json.Unmarshal([]byte(data), &claudeResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + continue } - if atEOF { - return len(data), data, nil + + response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse) + if response == nil { + continue } - return 0, nil, nil - }) - dataChan := make(chan string, 5) - stopChan := make(chan bool, 2) - go func() { - for scanner.Scan() { - data := scanner.Text() - if !strings.HasPrefix(data, "data: ") { + if requestMode == RequestModeCompletion { + responseText += claudeResponse.Completion + responseId = response.Id + } else { + if claudeResponse.Type == "message_start" { + // message_start, 获取usage + responseId = claudeResponse.Message.Id + info.UpstreamModelName = claudeResponse.Message.Model + usage.PromptTokens = claudeUsage.InputTokens + } else if claudeResponse.Type == "content_block_delta" { + responseText += claudeResponse.Delta.Text + } else if claudeResponse.Type == "message_delta" { + usage.CompletionTokens = claudeUsage.OutputTokens + usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens + } else if claudeResponse.Type == "content_block_start" { + + } else { continue } - data = strings.TrimPrefix(data, "data: ") - if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) { - // send data timeout, stop the stream - common.LogError(c, "send data timeout, stop the stream") - break - } } - stopChan <- true - }() - isFirst := true - service.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - if isFirst { - isFirst = false - info.FirstResponseTime = time.Now() - } - // some implementations may add \r at the end of data - data = strings.TrimSuffix(data, "\r") - var claudeResponse ClaudeResponse - err := json.Unmarshal([]byte(data), &claudeResponse) - if err != nil { - common.SysError("error unmarshalling stream response: " + err.Error()) - return true - } + //response.Id = responseId + response.Id = responseId + response.Created = createdTime + response.Model = info.UpstreamModelName - response, claudeUsage := StreamResponseClaude2OpenAI(requestMode, &claudeResponse) - if response == nil { - return true - } - if requestMode == RequestModeCompletion { - responseText += claudeResponse.Completion - responseId = response.Id - } else { - if claudeResponse.Type == "message_start" { - // message_start, 获取usage - responseId = claudeResponse.Message.Id - info.UpstreamModelName = claudeResponse.Message.Model - usage.PromptTokens = claudeUsage.InputTokens - } else if claudeResponse.Type == "content_block_delta" { - responseText += claudeResponse.Delta.Text - } else if claudeResponse.Type == "message_delta" { - usage.CompletionTokens = claudeUsage.OutputTokens - usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens - } else if claudeResponse.Type == "content_block_start" { - - } else { - return true - } - } - //response.Id = responseId - response.Id = responseId - response.Created = createdTime - response.Model = info.UpstreamModelName - - err = service.ObjectData(c, response) - if err != nil { - common.SysError(err.Error()) - } - return true - case <-stopChan: - return false + err = service.ObjectData(c, response) + if err != nil { + common.LogError(c, "send_stream_response_failed: "+err.Error()) } - }) + } + if requestMode == RequestModeCompletion { usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { @@ -435,10 +401,7 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. } } service.Done(c) - err := resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil - } + resp.Body.Close() return nil, usage } diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 540ec85..408db6a 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -64,7 +64,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { - err, usage = openai.OpenaiStreamHandler(c, resp, info) + err, usage = openai.OaiStreamHandler(c, resp, info) } else { if info.RelayMode == relayconstant.RelayModeEmbeddings { err, usage = ollamaEmbeddingHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode) diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 2aa743f..4388efd 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -145,7 +145,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom err, usage = OpenaiTTSHandler(c, resp, info) default: if info.IsStream { - err, usage = OpenaiStreamHandler(c, resp, info) + err, usage = OaiStreamHandler(c, resp, info) } else { err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 45e5def..807f4b1 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -5,6 +5,7 @@ import ( "bytes" "encoding/json" "fmt" + "github.com/bytedance/gopkg/util/gopool" "github.com/gin-gonic/gin" "io" "net/http" @@ -18,8 +19,8 @@ import ( "time" ) -func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { - hasStreamUsage := false +func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + containStreamUsage := false responseId := "" var createAt int64 = 0 var systemFingerprint string @@ -41,7 +42,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. stopChan := make(chan bool) defer close(stopChan) - go func() { + gopool.Go(func() { for scanner.Scan() { info.SetFirstResponseTime() ticker.Reset(time.Duration(constant.StreamingTimeout) * time.Second) @@ -62,7 +63,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. } } common.SafeSendBool(stopChan, true) - }() + }) select { case <-ticker.C: @@ -91,7 +92,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. model = streamResponse.Model if service.ValidUsage(streamResponse.Usage) { usage = streamResponse.Usage - hasStreamUsage = true + containStreamUsage = true } for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Delta.GetContentString()) @@ -115,7 +116,7 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. model = streamResponse.Model if service.ValidUsage(streamResponse.Usage) { usage = streamResponse.Usage - hasStreamUsage = true + containStreamUsage = true } for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Delta.GetContentString()) @@ -155,12 +156,12 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. } } - if !hasStreamUsage { + if !containStreamUsage { usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens) usage.CompletionTokens += toolCount * 7 } - if info.ShouldIncludeUsage && !hasStreamUsage { + if info.ShouldIncludeUsage && !containStreamUsage { response := service.GenerateFinalUsageResponse(responseId, createAt, model, *usage) response.SetSystemFingerprint(systemFingerprint) service.ObjectData(c, response) diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index d3ed222..e9d07fb 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -58,7 +58,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { - err, usage = openai.OpenaiStreamHandler(c, resp, info) + err, usage = openai.OaiStreamHandler(c, resp, info) } else { err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/zhipu/relay-zhipu.go b/relay/channel/zhipu/relay-zhipu.go index 5ef9d7a..aaf3c5d 100644 --- a/relay/channel/zhipu/relay-zhipu.go +++ b/relay/channel/zhipu/relay-zhipu.go @@ -153,18 +153,7 @@ func streamMetaResponseZhipu2OpenAI(zhipuResponse *ZhipuStreamMetaResponse) (*dt func zhipuStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { var usage *dto.Usage scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n\n"); i >= 0 && strings.Index(string(data), ":") >= 0 { - return i + 2, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) + scanner.Split(bufio.ScanLines) dataChan := make(chan string) metaChan := make(chan string) stopChan := make(chan bool) diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index b34b756..5e0906e 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -59,7 +59,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { - err, usage = openai.OpenaiStreamHandler(c, resp, info) + err, usage = openai.OaiStreamHandler(c, resp, info) } else { err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } From 67b74ada00cdbc18bc499d3ddc0da69fe887f50a Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Fri, 19 Jul 2024 01:29:08 +0800 Subject: [PATCH 32/34] feat: update model ratio --- common/model-ratio.go | 24 ++++++++++++++++-------- relay/channel/openai/constant.go | 1 + 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index 1200310..41e3eac 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -31,13 +31,15 @@ var defaultModelRatio = map[string]float64{ "gpt-4-32k": 30, //"gpt-4-32k-0314": 30, //deprecated "gpt-4-32k-0613": 30, - "gpt-4-1106-preview": 5, // $0.01 / 1K tokens - "gpt-4-0125-preview": 5, // $0.01 / 1K tokens - "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens - "gpt-4-vision-preview": 5, // $0.01 / 1K tokens - "gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens - "gpt-4o": 2.5, // $0.01 / 1K tokens - "gpt-4o-2024-05-13": 2.5, // $0.01 / 1K tokens + "gpt-4-1106-preview": 5, // $0.01 / 1K tokens + "gpt-4-0125-preview": 5, // $0.01 / 1K tokens + "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens + "gpt-4-vision-preview": 5, // $0.01 / 1K tokens + "gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens + "gpt-4o": 2.5, // $0.01 / 1K tokens + "gpt-4o-2024-05-13": 2.5, // $0.01 / 1K tokens + "gpt-4o-mini": 0.7, + "gpt-4o-mini-2024-07-18": 0.7, "gpt-4-turbo": 5, // $0.01 / 1K tokens "gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens "gpt-3.5-turbo": 0.25, // $0.0015 / 1K tokens @@ -305,7 +307,13 @@ func GetCompletionRatio(name string) float64 { return 4.0 / 3.0 } if strings.HasPrefix(name, "gpt-4") && !strings.HasSuffix(name, "-all") && !strings.HasSuffix(name, "-gizmo-*") { - if strings.HasPrefix(name, "gpt-4-turbo") || strings.HasSuffix(name, "preview") || strings.HasPrefix(name, "gpt-4o") { + if strings.HasPrefix(name, "gpt-4-turbo") || strings.HasSuffix(name, "preview") { + return 3 + } + if strings.HasPrefix(name, "gpt-4o") { + if strings.Contains(name, "mini") { + return 4 + } return 3 } return 2 diff --git a/relay/channel/openai/constant.go b/relay/channel/openai/constant.go index 82f7e97..50abc2e 100644 --- a/relay/channel/openai/constant.go +++ b/relay/channel/openai/constant.go @@ -9,6 +9,7 @@ var ModelList = []string{ "gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09", "gpt-4-vision-preview", "gpt-4o", "gpt-4o-2024-05-13", + "gpt-4o-mini", "gpt-4o-mini-2024-07-18", "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large", "text-curie-001", "text-babbage-001", "text-ada-001", "text-moderation-latest", "text-moderation-stable", From 56afe47aa8721e5b06e046445693b0526f9ef562 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Fri, 19 Jul 2024 01:34:00 +0800 Subject: [PATCH 33/34] feat: update model ratio --- common/model-ratio.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index 41e3eac..5de961d 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -38,8 +38,8 @@ var defaultModelRatio = map[string]float64{ "gpt-4-1106-vision-preview": 5, // $0.01 / 1K tokens "gpt-4o": 2.5, // $0.01 / 1K tokens "gpt-4o-2024-05-13": 2.5, // $0.01 / 1K tokens - "gpt-4o-mini": 0.7, - "gpt-4o-mini-2024-07-18": 0.7, + "gpt-4o-mini": 0.075, + "gpt-4o-mini-2024-07-18": 0.075, "gpt-4-turbo": 5, // $0.01 / 1K tokens "gpt-4-turbo-2024-04-09": 5, // $0.01 / 1K tokens "gpt-3.5-turbo": 0.25, // $0.0015 / 1K tokens From 733b3745969179b6a12e1e6498818bda5d536301 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Fri, 19 Jul 2024 03:06:20 +0800 Subject: [PATCH 34/34] Update README.md --- README.md | 63 ++++++++++++++++++++++++++++++------------------------- 1 file changed, 34 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index d2c805c..fc33c8b 100644 --- a/README.md +++ b/README.md @@ -2,15 +2,21 @@ # New API > [!NOTE] -> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发,感谢原作者的无私奉献。 -> 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 +> 本项目为开源项目,在[One API](https://github.com/songquanpeng/one-api)的基础上进行二次开发 -> 本项目为个人学习使用,不保证稳定性,且不提供任何技术支持,使用者必须在遵循 OpenAI 的使用条款以及法律法规的情况下使用,不得用于非法用途。 +> [!IMPORTANT] +> 使用者必须在遵循 OpenAI 的[使用条款](https://openai.com/policies/terms-of-use)以及**法律法规**的情况下使用,不得用于非法用途。 +> 本项目仅供个人学习使用,不保证稳定性,且不提供任何技术支持。 > 根据[《生成式人工智能服务管理暂行办法》](http://www.cac.gov.cn/2023-07/13/c_1690898327029107.htm)的要求,请勿对中国地区公众提供一切未经备案的生成式人工智能服务。 -> [!NOTE] -> 最新版Docker镜像 calciumion/new-api:latest -> 更新指令 docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR +> [!TIP] +> 最新版Docker镜像:`calciumion/new-api:latest` +> 默认账号root 密码123456 +> 更新指令: +> ``` +> docker run --rm -v /var/run/docker.sock:/var/run/docker.sock containrrr/watchtower -cR +> ``` + ## 主要变更 此分叉版本的主要变更如下: @@ -18,9 +24,9 @@ 1. 全新的UI界面(部分界面还待更新) 2. 添加[Midjourney-Proxy(Plus)](https://github.com/novicezk/midjourney-proxy)接口的支持,[对接文档](Midjourney.md) 3. 支持在线充值功能,可在系统设置中设置,当前支持的支付接口: - + [x] 易支付 + + [x] 易支付 4. 支持用key查询使用额度: - + 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用 + + 配合项目[neko-api-key-tool](https://github.com/Calcium-Ion/neko-api-key-tool)可实现用key查询使用 5. 渠道显示已使用额度,支持指定组织访问 6. 分页支持选择每页显示数量 7. 兼容原版One API的数据库,可直接使用原版数据库(one-api.db) @@ -51,25 +57,6 @@ 您可以在渠道中添加自定义模型gpt-4-gizmo-*,此模型并非OpenAI官方模型,而是第三方模型,使用官方key无法调用。 -## 渠道重试 -渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。 -如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。 -### 缓存设置方法 -1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 - + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` -2. `MEMORY_CACHE_ENABLED`:启用内存缓存(如果设置了`REDIS_CONN_STRING`,则无需手动设置),会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 - + 例子:`MEMORY_CACHE_ENABLED=true` -### 为什么有的时候没有重试 -这些错误码不会重试:400,504,524 -### 我想让400也重试 -在`渠道->编辑`中,将`状态码复写`改为 -```json -{ - "400": "500" -} -``` -可以实现400错误转为500错误,从而重试 - ## 比原版One API多出的配置 - `STREAMING_TIMEOUT`:设置流式一次回复的超时时间,默认为 30 秒 - `DIFY_DEBUG`:设置 Dify 渠道是否输出工作流和节点信息到客户端,默认为 `true` @@ -77,6 +64,7 @@ - `GET_MEDIA_TOKEN`:是统计图片token,默认为 `true`,关闭后将不再在本地计算图片token,可能会导致和上游计费不同,此项覆盖 `GET_MEDIA_TOKEN_NOT_STREAM` 选项作用, - `GET_MEDIA_TOKEN_NOT_STREAM`:是否在非流(`stream=false`)情况下统计图片token,默认为 `true` - `UPDATE_TASK`:是否更新异步任务(Midjourney、Suno),默认为 `true`,关闭后将不会更新任务进度 + ## 部署 ### 部署要求 - 本地数据库(默认):SQLite(Docker 部署默认使用 SQLite,必须挂载 `/data` 目录到宿主机) @@ -99,8 +87,25 @@ docker run --name new-api -d --restart always -p 3000:3000 -e TZ=Asia/Shanghai - docker run --name new-api -d --restart always -p 3000:3000 -e SQL_DSN="root:123456@tcp(宝塔的服务器地址:宝塔数据库端口)/宝塔数据库名称" -e TZ=Asia/Shanghai -v /www/wwwroot/new-api:/data calciumion/new-api:latest # 注意:数据库要开启远程访问,并且只允许服务器IP访问 ``` -### 默认账号密码 -默认账号root 密码123456 + +## 渠道重试 +渠道重试功能已经实现,可以在`设置->运营设置->通用设置`设置重试次数,**建议开启缓存**功能。 +如果开启了重试功能,第一次重试使用同优先级,第二次重试使用下一个优先级,以此类推。 +### 缓存设置方法 +1. `REDIS_CONN_STRING`:设置之后将使用 Redis 作为缓存使用。 + + 例子:`REDIS_CONN_STRING=redis://default:redispw@localhost:49153` +2. `MEMORY_CACHE_ENABLED`:启用内存缓存(如果设置了`REDIS_CONN_STRING`,则无需手动设置),会导致用户额度的更新存在一定的延迟,可选值为 `true` 和 `false`,未设置则默认为 `false`。 + + 例子:`MEMORY_CACHE_ENABLED=true` +### 为什么有的时候没有重试 +这些错误码不会重试:400,504,524 +### 我想让400也重试 +在`渠道->编辑`中,将`状态码复写`改为 +```json +{ + "400": "500" +} +``` +可以实现400错误转为500错误,从而重试 ## Midjourney接口设置文档 [对接文档](Midjourney.md)