diff --git a/dto/text_request.go b/dto/text_request.go index 936660e..cc2d92e 100644 --- a/dto/text_request.go +++ b/dto/text_request.go @@ -32,6 +32,17 @@ type GeneralOpenAIRequest struct { TopLogProbs int `json:"top_logprobs,omitempty"` } +type OpenAITools struct { + Type string `json:"type"` + Function OpenAIFunction `json:"function"` +} + +type OpenAIFunction struct { + Description string `json:"description,omitempty"` + Name string `json:"name"` + Parameters any `json:"parameters,omitempty"` +} + func (r GeneralOpenAIRequest) ParseInput() []string { if r.Input == nil { return nil diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 4e1fd33..8997889 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -52,7 +52,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { var responseText string - err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) + err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { if info.RelayMode == relayconstant.RelayModeEmbeddings { diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index cab6a64..a450c71 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -72,8 +72,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { var responseText string - err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode) + var toolCount int + err, responseText, toolCount = OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage.CompletionTokens += toolCount * 7 } else { err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index dae9fd5..5469ed7 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -16,9 +16,10 @@ import ( "time" ) -func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) { +func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string, int) { //checkSensitive := constant.ShouldCheckCompletionSensitive() var responseTextBuilder strings.Builder + toolCount := 0 scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { @@ -69,6 +70,9 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Delta.Content) if choice.Delta.ToolCalls != nil { + if len(choice.Delta.ToolCalls) > toolCount { + toolCount = len(choice.Delta.ToolCalls) + } for _, tool := range choice.Delta.ToolCalls { responseTextBuilder.WriteString(tool.Function.Name) responseTextBuilder.WriteString(tool.Function.Arguments) @@ -82,6 +86,9 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d for _, choice := range streamResponse.Choices { responseTextBuilder.WriteString(choice.Delta.Content) if choice.Delta.ToolCalls != nil { + if len(choice.Delta.ToolCalls) > toolCount { + toolCount = len(choice.Delta.ToolCalls) + } for _, tool := range choice.Delta.ToolCalls { responseTextBuilder.WriteString(tool.Function.Name) responseTextBuilder.WriteString(tool.Function.Arguments) @@ -135,10 +142,10 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d }) err := resp.Body.Close() if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", toolCount } wg.Wait() - return nil, responseTextBuilder.String() + return nil, responseTextBuilder.String(), toolCount } func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index 24765ff..00d7710 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -46,7 +46,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { var responseText string - err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) + err, responseText, _ = openai.OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index 1b8866b..fe89ff4 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -47,8 +47,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { var responseText string - err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) + var toolCount int + err, responseText, toolCount = openai.OpenaiStreamHandler(c, resp, info.RelayMode) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + usage.CompletionTokens += toolCount * 7 } else { err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/relay-text.go b/relay/relay-text.go index 890f543..e9aa7bb 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -189,7 +189,7 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re checkSensitive := constant.ShouldCheckPromptSensitive() switch info.RelayMode { case relayconstant.RelayModeChatCompletions: - promptTokens, err, sensitiveTrigger = service.CountTokenMessages(textRequest.Messages, textRequest.Model, checkSensitive) + promptTokens, err, sensitiveTrigger = service.CountTokenChatRequest(*textRequest, textRequest.Model, checkSensitive) case relayconstant.RelayModeCompletions: promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive) case relayconstant.RelayModeModerations: diff --git a/service/token_counter.go b/service/token_counter.go index 897f49c..c1bac75 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -116,6 +116,41 @@ func getImageToken(imageUrl *dto.MessageImageUrl) (int, error) { return tiles*170 + 85, nil } +func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string, checkSensitive bool) (int, error, bool) { + tkm := 0 + msgTokens, err, b := CountTokenMessages(request.Messages, model, checkSensitive) + if err != nil { + return 0, err, b + } + tkm += msgTokens + if request.Tools != nil { + toolsData, _ := json.Marshal(request.Tools) + var openaiTools []dto.OpenAITools + err := json.Unmarshal(toolsData, &openaiTools) + if err != nil { + return 0, errors.New(fmt.Sprintf("count tools token fail: %s", err.Error())), false + } + countStr := "" + for _, tool := range openaiTools { + countStr = tool.Function.Name + if tool.Function.Description != "" { + countStr += tool.Function.Description + } + if tool.Function.Parameters != nil { + countStr += fmt.Sprintf("%v", tool.Function.Parameters) + } + } + toolTokens, err, _ := CountTokenInput(countStr, model, false) + if err != nil { + return 0, err, false + } + tkm += 8 + tkm += toolTokens + } + + return tkm, nil, false +} + func CountTokenMessages(messages []dto.Message, model string, checkSensitive bool) (int, error, bool) { //recover when panic tokenEncoder := getTokenEncoder(model)