From f6cfe7cd4fb30ed2a997553bd36e865a897549f2 Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Wed, 26 Feb 2025 05:38:21 +0000 Subject: [PATCH] feat: enhance error handling and reasoning mechanisms across middleware - Improve error handling across multiple middleware and adapter components, ensuring consistent error response formats in JSON. - Enhance the functionality of request conversion functions by including context parameters and robust error wrapping. - Introduce new features related to reasoning content in the messaging model, providing better customization and explanations in the documentation. --- README.md | 26 +++ middleware/auth.go | 16 +- middleware/distributor.go | 9 +- middleware/utils.go | 14 ++ relay/adaptor/anthropic/adaptor.go | 2 +- relay/adaptor/anthropic/main.go | 28 ++- relay/adaptor/aws/claude/adapter.go | 6 +- relay/adaptor/aws/claude/main.go | 4 +- relay/adaptor/doubao/main.go | 2 + relay/adaptor/openai/main.go | 248 +++++++++++++++++------ relay/adaptor/openai/util.go | 1 + relay/adaptor/vertexai/claude/adapter.go | 7 +- relay/model/message.go | 44 ++++ 13 files changed, 320 insertions(+), 87 deletions(-) diff --git a/README.md b/README.md index da86cf1d..dc9f5f20 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,10 @@ Also welcome to register and use my deployed one-api gateway, which supports var - [Support claude-3-7-sonnet \& thinking](#support-claude-3-7-sonnet--thinking) - [Stream](#stream) - [Non-Stream](#non-stream) + - [Automatically Enable Thinking and Customize Reasoning Format via URL Parameters](#automatically-enable-thinking-and-customize-reasoning-format-via-url-parameters) + - [Reasoning Format - reasoning-content](#reasoning-format---reasoning-content) + - [Reasoning Format - reasoning](#reasoning-format---reasoning) + - [Reasoning Format - thinking](#reasoning-format---thinking) - [Bug fix](#bug-fix) ## Turtorial @@ -172,6 +176,28 @@ By default, the thinking mode is not enabled. You need to manually pass the `thi ![](https://s3.laisky.com/uploads/2025/02/claude-thinking-non-stream.png) +### Automatically Enable Thinking and Customize Reasoning Format via URL Parameters + +Supports two URL parameters: `thinking` and `reasoning_format`. + +- `thinking`: Whether to enable thinking mode, disabled by default. +- `reasoning_format`: Specifies the format of the returned reasoning. + - `reasoning_content`: DeepSeek official API format, returned in the `reasoning_content` field. + - `reasoning`: OpenRouter format, returned in the `reasoning` field. + - `thinking`: Claude format, returned in the `thinking` field. + +#### Reasoning Format - reasoning-content + +![](https://s3.laisky.com/uploads/2025/02/reasoning_format-reasoning_content.png) + +#### Reasoning Format - reasoning + +![](https://s3.laisky.com/uploads/2025/02/reasoning_format-reasoning.png) + +#### Reasoning Format - thinking + +![](https://s3.laisky.com/uploads/2025/02/reasoning_format-thinking.png) + ## Bug fix - [BUGFIX: 更新令牌时的一些问题 #1933](https://github.com/songquanpeng/one-api/pull/1933) diff --git a/middleware/auth.go b/middleware/auth.go index f3a1f1c7..f3af72aa 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -1,12 +1,12 @@ package middleware import ( - "fmt" "net/http" "strings" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common/blacklist" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/logger" @@ -102,34 +102,34 @@ func TokenAuth() func(c *gin.Context) { key = parts[0] token, err := model.ValidateUserToken(key) if err != nil { - abortWithMessage(c, http.StatusUnauthorized, err.Error()) + abortWithError(c, http.StatusUnauthorized, err) return } if token.Subnet != nil && *token.Subnet != "" { if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) { - abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("This API key can only be used in the specified subnet: %s, current IP: %s", *token.Subnet, c.ClientIP())) + abortWithError(c, http.StatusForbidden, errors.Errorf("This API key can only be used in the specified subnet: %s, current IP: %s", *token.Subnet, c.ClientIP())) return } } userEnabled, err := model.CacheIsUserEnabled(token.UserId) if err != nil { - abortWithMessage(c, http.StatusInternalServerError, err.Error()) + abortWithError(c, http.StatusInternalServerError, err) return } if !userEnabled || blacklist.IsUserBanned(token.UserId) { - abortWithMessage(c, http.StatusForbidden, "User has been banned") + abortWithError(c, http.StatusForbidden, errors.New("User has been banned")) return } requestModel, err := getRequestModel(c) if err != nil && shouldCheckModel(c) { - abortWithMessage(c, http.StatusBadRequest, err.Error()) + abortWithError(c, http.StatusBadRequest, err) return } c.Set(ctxkey.RequestModel, requestModel) if token.Models != nil && *token.Models != "" { c.Set(ctxkey.AvailableModels, *token.Models) if requestModel != "" && !isModelInList(requestModel, *token.Models) { - abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("This API key does not have permission to use the model: %s", requestModel)) + abortWithError(c, http.StatusForbidden, errors.Errorf("This API key does not have permission to use the model: %s", requestModel)) return } } @@ -144,7 +144,7 @@ func TokenAuth() func(c *gin.Context) { if model.IsAdmin(token.UserId) { c.Set(ctxkey.SpecificChannelId, parts[1]) } else { - abortWithMessage(c, http.StatusForbidden, "Ordinary users do not support specifying channels") + abortWithError(c, http.StatusForbidden, errors.New("Ordinary users do not support specifying channels")) return } } diff --git a/middleware/distributor.go b/middleware/distributor.go index b4d55a01..3d663557 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -8,6 +8,7 @@ import ( gutils "github.com/Laisky/go-utils/v5" "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" @@ -31,16 +32,16 @@ func Distribute() func(c *gin.Context) { if ok { id, err := strconv.Atoi(channelId.(string)) if err != nil { - abortWithMessage(c, http.StatusBadRequest, "Invalid Channel Id") + abortWithError(c, http.StatusBadRequest, errors.New("Invalid Channel Id")) return } channel, err = model.GetChannelById(id, true) if err != nil { - abortWithMessage(c, http.StatusBadRequest, "Invalid Channel Id") + abortWithError(c, http.StatusBadRequest, errors.New("Invalid Channel Id")) return } if channel.Status != model.ChannelStatusEnabled { - abortWithMessage(c, http.StatusForbidden, "The channel has been disabled") + abortWithError(c, http.StatusForbidden, errors.New("The channel has been disabled")) return } } else { @@ -53,7 +54,7 @@ func Distribute() func(c *gin.Context) { logger.SysError(fmt.Sprintf("Channel does not exist: %d", channel.Id)) message = "Database consistency has been broken, please contact the administrator" } - abortWithMessage(c, http.StatusServiceUnavailable, message) + abortWithError(c, http.StatusServiceUnavailable, errors.New(message)) return } } diff --git a/middleware/utils.go b/middleware/utils.go index 46120f2a..7e445b34 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -3,6 +3,8 @@ package middleware import ( "strings" + gmw "github.com/Laisky/gin-middlewares/v6" + "github.com/Laisky/zap" "github.com/gin-gonic/gin" "github.com/pkg/errors" "github.com/songquanpeng/one-api/common" @@ -21,6 +23,18 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) { logger.Error(c.Request.Context(), message) } +func abortWithError(c *gin.Context, statusCode int, err error) { + logger := gmw.GetLogger(c) + logger.Error("server abort", zap.Error(err)) + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "message": helper.MessageWithRequestId(err.Error(), c.GetString(helper.RequestIdKey)), + "type": "one_api_error", + }, + }) + c.Abort() +} + func getRequestModel(c *gin.Context) (string, error) { var modelRequest ModelRequest err := common.UnmarshalBodyReusable(c, &modelRequest) diff --git a/relay/adaptor/anthropic/adaptor.go b/relay/adaptor/anthropic/adaptor.go index d2edc1d8..8a3f02a5 100644 --- a/relay/adaptor/anthropic/adaptor.go +++ b/relay/adaptor/anthropic/adaptor.go @@ -51,7 +51,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G } c.Set("claude_model", request.Model) - return ConvertRequest(*request), nil + return ConvertRequest(c, *request) } func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { diff --git a/relay/adaptor/anthropic/main.go b/relay/adaptor/anthropic/main.go index 045bef9f..04260723 100644 --- a/relay/adaptor/anthropic/main.go +++ b/relay/adaptor/anthropic/main.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "io" + "math" "net/http" "strings" @@ -38,7 +39,7 @@ func stopReasonClaude2OpenAI(reason *string) string { } } -func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { +func ConvertRequest(c *gin.Context, textRequest model.GeneralOpenAIRequest) (*Request, error) { claudeTools := make([]Tool, 0, len(textRequest.Tools)) for _, tool := range textRequest.Tools { @@ -66,7 +67,18 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { Thinking: textRequest.Thinking, } + if c.Request.URL.Query().Has("thinking") && claudeRequest.Thinking == nil { + claudeRequest.Thinking = &model.Thinking{ + Type: "enabled", + BudgetTokens: int(math.Min(1024, float64(claudeRequest.MaxTokens/2))), + } + } + if claudeRequest.Thinking != nil { + if claudeRequest.MaxTokens <= 1024 { + return nil, fmt.Errorf("max_tokens must be greater than 1024 when using extended thinking") + } + // top_p must be nil when using extended thinking claudeRequest.TopP = nil } @@ -151,11 +163,11 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { claudeMessage.Content = contents claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage) } - return &claudeRequest + return &claudeRequest, nil } // https://docs.anthropic.com/claude/reference/messages-streaming -func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) { +func StreamResponseClaude2OpenAI(c *gin.Context, claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) { var response *Response var responseText string var reasoningText string @@ -220,7 +232,7 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = responseText - choice.Delta.Reasoning = &reasoningText + choice.Delta.SetReasoningContent(c.Query("reasoning_format"), reasoningText) if len(tools) > 0 { choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ... choice.Delta.ToolCalls = tools @@ -236,7 +248,7 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo return &openaiResponse, response } -func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { +func ResponseClaude2OpenAI(c *gin.Context, claudeResponse *Response) *openai.TextResponse { var responseText string var reasoningText string @@ -273,12 +285,12 @@ func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { Message: model.Message{ Role: "assistant", Content: responseText, - Reasoning: &reasoningText, Name: nil, ToolCalls: tools, }, FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), } + choice.Message.SetReasoningContent(c.Query("reasoning_format"), reasoningText) fullTextResponse := openai.TextResponse{ Id: fmt.Sprintf("chatcmpl-%s", claudeResponse.Id), Model: claudeResponse.Model, @@ -328,7 +340,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC continue } - response, meta := StreamResponseClaude2OpenAI(&claudeResponse) + response, meta := StreamResponseClaude2OpenAI(c, &claudeResponse) if meta != nil { usage.PromptTokens += meta.Usage.InputTokens usage.CompletionTokens += meta.Usage.OutputTokens @@ -407,7 +419,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st StatusCode: resp.StatusCode, }, nil } - fullTextResponse := ResponseClaude2OpenAI(&claudeResponse) + fullTextResponse := ResponseClaude2OpenAI(c, &claudeResponse) fullTextResponse.Model = modelName usage := model.Usage{ PromptTokens: claudeResponse.Usage.InputTokens, diff --git a/relay/adaptor/aws/claude/adapter.go b/relay/adaptor/aws/claude/adapter.go index eb3c9fb8..2f6a4cc5 100644 --- a/relay/adaptor/aws/claude/adapter.go +++ b/relay/adaptor/aws/claude/adapter.go @@ -21,7 +21,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return nil, errors.New("request is nil") } - claudeReq := anthropic.ConvertRequest(*request) + claudeReq, err := anthropic.ConvertRequest(c, *request) + if err != nil { + return nil, errors.Wrap(err, "convert request") + } + c.Set(ctxkey.RequestModel, request.Model) c.Set(ctxkey.ConvertedRequest, claudeReq) return claudeReq, nil diff --git a/relay/adaptor/aws/claude/main.go b/relay/adaptor/aws/claude/main.go index 378acdda..69251c3d 100644 --- a/relay/adaptor/aws/claude/main.go +++ b/relay/adaptor/aws/claude/main.go @@ -88,7 +88,7 @@ func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (* return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil } - openaiResp := anthropic.ResponseClaude2OpenAI(claudeResponse) + openaiResp := anthropic.ResponseClaude2OpenAI(c, claudeResponse) openaiResp.Model = modelName usage := relaymodel.Usage{ PromptTokens: claudeResponse.Usage.InputTokens, @@ -159,7 +159,7 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E return false } - response, meta := anthropic.StreamResponseClaude2OpenAI(claudeResp) + response, meta := anthropic.StreamResponseClaude2OpenAI(c, claudeResp) if meta != nil { usage.PromptTokens += meta.Usage.InputTokens usage.CompletionTokens += meta.Usage.OutputTokens diff --git a/relay/adaptor/doubao/main.go b/relay/adaptor/doubao/main.go index 669a6751..5efa8ad2 100644 --- a/relay/adaptor/doubao/main.go +++ b/relay/adaptor/doubao/main.go @@ -2,6 +2,8 @@ package doubao import ( "fmt" + "strings" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/relaymode" ) diff --git a/relay/adaptor/openai/main.go b/relay/adaptor/openai/main.go index 8b62619a..78403917 100644 --- a/relay/adaptor/openai/main.go +++ b/relay/adaptor/openai/main.go @@ -25,103 +25,166 @@ const ( dataPrefixLength = len(dataPrefix) ) +// StreamHandler processes streaming responses from OpenAI API +// It handles incremental content delivery and accumulates the final response text +// Returns error (if any), accumulated response text, and token usage information func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) { + // Initialize accumulators for the response responseText := "" reasoningText := "" - scanner := bufio.NewScanner(resp.Body) - buffer := make([]byte, 256*1024) - scanner.Buffer(buffer, len(buffer)) - scanner.Split(bufio.ScanLines) var usage *model.Usage + // Set up scanner for reading the stream line by line + scanner := bufio.NewScanner(resp.Body) + buffer := make([]byte, 256*1024) // 256KB buffer for large messages + scanner.Buffer(buffer, len(buffer)) + scanner.Split(bufio.ScanLines) + + // Set response headers for SSE common.SetEventStreamHeaders(c) doneRendered := false + + // Process each line from the stream for scanner.Scan() { data := NormalizeDataLine(scanner.Text()) - if len(data) < dataPrefixLength { // ignore blank line or wrong format - continue + + // Skip lines that don't match expected format + if len(data) < dataPrefixLength { + continue // Ignore blank line or wrong format } + + // Verify line starts with expected prefix if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done { continue } + + // Check for stream termination if strings.HasPrefix(data[dataPrefixLength:], done) { render.StringData(c, data) doneRendered = true continue } + + // Process based on relay mode switch relayMode { case relaymode.ChatCompletions: var streamResponse ChatCompletionsStreamResponse + + // Parse the JSON response err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) if err != nil { logger.SysError("error unmarshalling stream response: " + err.Error()) - render.StringData(c, data) // if error happened, pass the data to client - continue // just ignore the error + render.StringData(c, data) // Pass raw data to client if parsing fails + continue } + + // Skip empty choices (Azure specific behavior) if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil { - // but for empty choice and no usage, we should not pass it to client, this is for azure - continue // just ignore empty choice + continue } - render.StringData(c, data) + + // Process each choice in the response for _, choice := range streamResponse.Choices { - if choice.Delta.Reasoning != nil { - reasoningText += *choice.Delta.Reasoning - } - if choice.Delta.ReasoningContent != nil { - reasoningText += *choice.Delta.ReasoningContent + // Extract reasoning content from different possible fields + currentReasoningChunk := extractReasoningContent(&choice.Delta) + + // Update accumulated reasoning text + if currentReasoningChunk != "" { + reasoningText += currentReasoningChunk } + // Set the reasoning content in the format requested by client + choice.Delta.SetReasoningContent(c.Query("reasoning_format"), currentReasoningChunk) + + // Accumulate response content responseText += conv.AsString(choice.Delta.Content) } + + // Send the processed data to the client + render.StringData(c, data) + + // Update usage information if available if streamResponse.Usage != nil { usage = streamResponse.Usage } + case relaymode.Completions: + // Send the data immediately for Completions mode render.StringData(c, data) + var streamResponse CompletionsStreamResponse err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) if err != nil { logger.SysError("error unmarshalling stream response: " + err.Error()) continue } + + // Accumulate text from all choices for _, choice := range streamResponse.Choices { responseText += choice.Text } } } + // Check for scanner errors if err := scanner.Err(); err != nil { logger.SysError("error reading stream: " + err.Error()) } + // Ensure stream termination is sent to client if !doneRendered { render.Done(c) } - err := resp.Body.Close() - if err != nil { + // Clean up resources + if err := resp.Body.Close(); err != nil { return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil } + // Return the complete response text (reasoning + content) and usage return nil, reasoningText + responseText, usage } -// Handler handles the non-stream response from OpenAI API +// Helper function to extract reasoning content from message delta +func extractReasoningContent(delta *model.Message) string { + content := "" + + // Extract reasoning from different possible fields + if delta.Reasoning != nil { + content += *delta.Reasoning + delta.Reasoning = nil + } + + if delta.ReasoningContent != nil { + content += *delta.ReasoningContent + delta.ReasoningContent = nil + } + + return content +} + +// Handler processes non-streaming responses from OpenAI API +// Returns error (if any) and token usage information func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { - var textResponse SlimTextResponse + // Read the entire response body responseBody, err := io.ReadAll(resp.Body) if err != nil { return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { + + // Close the original response body + if err = resp.Body.Close(); err != nil { return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } - err = json.Unmarshal(responseBody, &textResponse) - if err != nil { + + // Parse the response JSON + var textResponse SlimTextResponse + if err = json.Unmarshal(responseBody, &textResponse); err != nil { return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } + + // Check for API errors if textResponse.Error.Type != "" { return &model.ErrorWithStatusCode{ Error: textResponse.Error, @@ -129,68 +192,131 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st }, nil } - // Reset response body + // Process reasoning content in each choice + for _, msg := range textResponse.Choices { + reasoningContent := processReasoningContent(&msg) + + // Set reasoning in requested format if content exists + if reasoningContent != "" { + msg.SetReasoningContent(c.Query("reasoning_format"), reasoningContent) + } + } + + // Reset response body for forwarding to client resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) logger.Debugf(c.Request.Context(), "handler response: %s", string(responseBody)) - // We shouldn't set the header before we parse the response body, because the parse part may fail. - // And then we will have to send an error response, but in this case, the header has already been set. - // So the HTTPClient will be confused by the response. - // For example, Postman will report error, and we cannot check the response at all. - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) + // Forward all response headers (not just first value of each) + for k, values := range resp.Header { + for _, v := range values { + c.Writer.Header().Add(k, v) + } } + + // Set response status and copy body to client c.Writer.WriteHeader(resp.StatusCode) - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { + if _, err = io.Copy(c.Writer, resp.Body); err != nil { return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil } - err = resp.Body.Close() - if err != nil { + + // Close the reset body + if err = resp.Body.Close(); err != nil { return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } - if textResponse.Usage.TotalTokens == 0 || - (textResponse.Usage.PromptTokens == 0 && textResponse.Usage.CompletionTokens == 0) { + // Calculate token usage if not provided by API + calculateTokenUsage(&textResponse, promptTokens, modelName) + + return nil, &textResponse.Usage +} + +// processReasoningContent is a helper function to extract and process reasoning content from the message +func processReasoningContent(msg *TextResponseChoice) string { + var reasoningContent string + + // Check different locations for reasoning content + switch { + case msg.Reasoning != nil: + reasoningContent = *msg.Reasoning + msg.Reasoning = nil + case msg.ReasoningContent != nil: + reasoningContent = *msg.ReasoningContent + msg.ReasoningContent = nil + case msg.Message.Reasoning != nil: + reasoningContent = *msg.Message.Reasoning + msg.Message.Reasoning = nil + case msg.Message.ReasoningContent != nil: + reasoningContent = *msg.Message.ReasoningContent + msg.Message.ReasoningContent = nil + } + + return reasoningContent +} + +// Helper function to calculate token usage +func calculateTokenUsage(response *SlimTextResponse, promptTokens int, modelName string) { + // Calculate tokens if not provided by the API + if response.Usage.TotalTokens == 0 || + (response.Usage.PromptTokens == 0 && response.Usage.CompletionTokens == 0) { + completionTokens := 0 - for _, choice := range textResponse.Choices { + for _, choice := range response.Choices { + // Count content tokens completionTokens += CountTokenText(choice.Message.StringContent(), modelName) + + // Count reasoning tokens in all possible locations if choice.Message.Reasoning != nil { completionTokens += CountToken(*choice.Message.Reasoning) } + if choice.Message.ReasoningContent != nil { + completionTokens += CountToken(*choice.Message.ReasoningContent) + } + if choice.Reasoning != nil { + completionTokens += CountToken(*choice.Reasoning) + } if choice.ReasoningContent != nil { completionTokens += CountToken(*choice.ReasoningContent) } } - textResponse.Usage = model.Usage{ + + // Set usage values + response.Usage = model.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, TotalTokens: promptTokens + completionTokens, } - } else if (textResponse.PromptTokensDetails != nil && textResponse.PromptTokensDetails.AudioTokens > 0) || - (textResponse.CompletionTokensDetails != nil && textResponse.CompletionTokensDetails.AudioTokens > 0) { - // Convert the more expensive audio tokens to uniformly priced text tokens. - // Note that when there are no audio tokens in prompt and completion, - // OpenAI will return empty PromptTokensDetails and CompletionTokensDetails, which can be misleading. - if textResponse.PromptTokensDetails != nil { - textResponse.Usage.PromptTokens = textResponse.PromptTokensDetails.TextTokens + - int(math.Ceil( - float64(textResponse.PromptTokensDetails.AudioTokens)* - ratio.GetAudioPromptRatio(modelName), - )) - } + } else if hasAudioTokens(response) { + // Handle audio tokens conversion + calculateAudioTokens(response, modelName) + } +} - if textResponse.CompletionTokensDetails != nil { - textResponse.Usage.CompletionTokens = textResponse.CompletionTokensDetails.TextTokens + - int(math.Ceil( - float64(textResponse.CompletionTokensDetails.AudioTokens)* - ratio.GetAudioPromptRatio(modelName)*ratio.GetAudioCompletionRatio(modelName), - )) - } +// Helper function to check if response has audio tokens +func hasAudioTokens(response *SlimTextResponse) bool { + return (response.PromptTokensDetails != nil && response.PromptTokensDetails.AudioTokens > 0) || + (response.CompletionTokensDetails != nil && response.CompletionTokensDetails.AudioTokens > 0) +} - textResponse.Usage.TotalTokens = textResponse.Usage.PromptTokens + - textResponse.Usage.CompletionTokens +// Helper function to calculate audio token usage +func calculateAudioTokens(response *SlimTextResponse, modelName string) { + // Convert audio tokens for prompt + if response.PromptTokensDetails != nil { + response.Usage.PromptTokens = response.PromptTokensDetails.TextTokens + + int(math.Ceil( + float64(response.PromptTokensDetails.AudioTokens)* + ratio.GetAudioPromptRatio(modelName), + )) } - return nil, &textResponse.Usage + // Convert audio tokens for completion + if response.CompletionTokensDetails != nil { + response.Usage.CompletionTokens = response.CompletionTokensDetails.TextTokens + + int(math.Ceil( + float64(response.CompletionTokensDetails.AudioTokens)* + ratio.GetAudioPromptRatio(modelName)*ratio.GetAudioCompletionRatio(modelName), + )) + } + + // Calculate total tokens + response.Usage.TotalTokens = response.Usage.PromptTokens + response.Usage.CompletionTokens } diff --git a/relay/adaptor/openai/util.go b/relay/adaptor/openai/util.go index 0ee3896b..ca5605c4 100644 --- a/relay/adaptor/openai/util.go +++ b/relay/adaptor/openai/util.go @@ -3,6 +3,7 @@ package openai import ( "context" "fmt" + "strings" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/model" diff --git a/relay/adaptor/vertexai/claude/adapter.go b/relay/adaptor/vertexai/claude/adapter.go index c615668d..5b4633ab 100644 --- a/relay/adaptor/vertexai/claude/adapter.go +++ b/relay/adaptor/vertexai/claude/adapter.go @@ -7,7 +7,6 @@ import ( "github.com/pkg/errors" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/relay/adaptor/anthropic" - "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" ) @@ -32,7 +31,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return nil, errors.New("request is nil") } - claudeReq := anthropic.ConvertRequest(*request) + claudeReq, err := anthropic.ConvertRequest(c, *request) + if err != nil { + return nil, errors.Wrap(err, "convert request") + } + req := Request{ AnthropicVersion: anthropicVersion, // Model: claudeReq.Model, diff --git a/relay/model/message.go b/relay/model/message.go index 9f088658..388225f7 100644 --- a/relay/model/message.go +++ b/relay/model/message.go @@ -2,10 +2,32 @@ package model import ( "context" + "strings" "github.com/songquanpeng/one-api/common/logger" ) +// ReasoningFormat is the format of reasoning content, +// can be set by the reasoning_format parameter in the request url. +type ReasoningFormat string + +const ( + ReasoningFormatUnspecified ReasoningFormat = "" + // ReasoningFormatReasoningContent is the reasoning format used by deepseek official API + ReasoningFormatReasoningContent ReasoningFormat = "reasoning_content" + // ReasoningFormatReasoning is the reasoning format used by openrouter + ReasoningFormatReasoning ReasoningFormat = "reasoning" + + // ReasoningFormatThinkTag is the reasoning format used by 3rd party deepseek-r1 providers. + // + // Deprecated: I believe is a very poor format, especially in stream mode, it is difficult to extract and convert. + // Considering that only a few deepseek-r1 third-party providers use this format, it has been decided to no longer support it. + // ReasoningFormatThinkTag ReasoningFormat = "think-tag" + + // ReasoningFormatThinking is the reasoning format used by anthropic + ReasoningFormatThinking ReasoningFormat = "thinking" +) + type Message struct { Role string `json:"role,omitempty"` // Content is a string or a list of objects @@ -29,6 +51,28 @@ type Message struct { // ------------------------------------- Reasoning *string `json:"reasoning,omitempty"` Refusal *bool `json:"refusal,omitempty"` + // ------------------------------------- + // Anthropic + // ------------------------------------- + Thinking *string `json:"thinking,omitempty"` + Signature *string `json:"signature,omitempty"` +} + +// SetReasoningContent sets the reasoning content based on the format +func (m *Message) SetReasoningContent(format string, reasoningContent string) { + switch ReasoningFormat(strings.ToLower(strings.TrimSpace(format))) { + case ReasoningFormatReasoningContent: + m.ReasoningContent = &reasoningContent + // case ReasoningFormatThinkTag: + // m.Content = fmt.Sprintf("%s%s", reasoningContent, m.Content) + case ReasoningFormatThinking: + m.Thinking = &reasoningContent + case ReasoningFormatReasoning, + ReasoningFormatUnspecified: + m.Reasoning = &reasoningContent + default: + logger.Warnf(context.TODO(), "unknown reasoning format: %q", format) + } } type messageAudio struct {