feat: enhance error handling and reasoning mechanisms across middleware

- Improve error handling across multiple middleware and adapter components, ensuring consistent error response formats in JSON.
- Enhance the functionality of request conversion functions by including context parameters and robust error wrapping.
- Introduce new features related to reasoning content in the messaging model, providing better customization and explanations in the documentation.
This commit is contained in:
Laisky.Cai 2025-02-26 05:38:21 +00:00
parent 07d9a8e144
commit f6cfe7cd4f
13 changed files with 320 additions and 87 deletions

View File

@ -30,6 +30,10 @@ Also welcome to register and use my deployed one-api gateway, which supports var
- [Support claude-3-7-sonnet \& thinking](#support-claude-3-7-sonnet--thinking) - [Support claude-3-7-sonnet \& thinking](#support-claude-3-7-sonnet--thinking)
- [Stream](#stream) - [Stream](#stream)
- [Non-Stream](#non-stream) - [Non-Stream](#non-stream)
- [Automatically Enable Thinking and Customize Reasoning Format via URL Parameters](#automatically-enable-thinking-and-customize-reasoning-format-via-url-parameters)
- [Reasoning Format - reasoning-content](#reasoning-format---reasoning-content)
- [Reasoning Format - reasoning](#reasoning-format---reasoning)
- [Reasoning Format - thinking](#reasoning-format---thinking)
- [Bug fix](#bug-fix) - [Bug fix](#bug-fix)
## Turtorial ## Turtorial
@ -172,6 +176,28 @@ By default, the thinking mode is not enabled. You need to manually pass the `thi
![](https://s3.laisky.com/uploads/2025/02/claude-thinking-non-stream.png) ![](https://s3.laisky.com/uploads/2025/02/claude-thinking-non-stream.png)
### Automatically Enable Thinking and Customize Reasoning Format via URL Parameters
Supports two URL parameters: `thinking` and `reasoning_format`.
- `thinking`: Whether to enable thinking mode, disabled by default.
- `reasoning_format`: Specifies the format of the returned reasoning.
- `reasoning_content`: DeepSeek official API format, returned in the `reasoning_content` field.
- `reasoning`: OpenRouter format, returned in the `reasoning` field.
- `thinking`: Claude format, returned in the `thinking` field.
#### Reasoning Format - reasoning-content
![](https://s3.laisky.com/uploads/2025/02/reasoning_format-reasoning_content.png)
#### Reasoning Format - reasoning
![](https://s3.laisky.com/uploads/2025/02/reasoning_format-reasoning.png)
#### Reasoning Format - thinking
![](https://s3.laisky.com/uploads/2025/02/reasoning_format-thinking.png)
## Bug fix ## Bug fix
- [BUGFIX: 更新令牌时的一些问题 #1933](https://github.com/songquanpeng/one-api/pull/1933) - [BUGFIX: 更新令牌时的一些问题 #1933](https://github.com/songquanpeng/one-api/pull/1933)

View File

@ -1,12 +1,12 @@
package middleware package middleware
import ( import (
"fmt"
"net/http" "net/http"
"strings" "strings"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common/blacklist" "github.com/songquanpeng/one-api/common/blacklist"
"github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
@ -102,34 +102,34 @@ func TokenAuth() func(c *gin.Context) {
key = parts[0] key = parts[0]
token, err := model.ValidateUserToken(key) token, err := model.ValidateUserToken(key)
if err != nil { if err != nil {
abortWithMessage(c, http.StatusUnauthorized, err.Error()) abortWithError(c, http.StatusUnauthorized, err)
return return
} }
if token.Subnet != nil && *token.Subnet != "" { if token.Subnet != nil && *token.Subnet != "" {
if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) { if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) {
abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("This API key can only be used in the specified subnet: %s, current IP: %s", *token.Subnet, c.ClientIP())) abortWithError(c, http.StatusForbidden, errors.Errorf("This API key can only be used in the specified subnet: %s, current IP: %s", *token.Subnet, c.ClientIP()))
return return
} }
} }
userEnabled, err := model.CacheIsUserEnabled(token.UserId) userEnabled, err := model.CacheIsUserEnabled(token.UserId)
if err != nil { if err != nil {
abortWithMessage(c, http.StatusInternalServerError, err.Error()) abortWithError(c, http.StatusInternalServerError, err)
return return
} }
if !userEnabled || blacklist.IsUserBanned(token.UserId) { if !userEnabled || blacklist.IsUserBanned(token.UserId) {
abortWithMessage(c, http.StatusForbidden, "User has been banned") abortWithError(c, http.StatusForbidden, errors.New("User has been banned"))
return return
} }
requestModel, err := getRequestModel(c) requestModel, err := getRequestModel(c)
if err != nil && shouldCheckModel(c) { if err != nil && shouldCheckModel(c) {
abortWithMessage(c, http.StatusBadRequest, err.Error()) abortWithError(c, http.StatusBadRequest, err)
return return
} }
c.Set(ctxkey.RequestModel, requestModel) c.Set(ctxkey.RequestModel, requestModel)
if token.Models != nil && *token.Models != "" { if token.Models != nil && *token.Models != "" {
c.Set(ctxkey.AvailableModels, *token.Models) c.Set(ctxkey.AvailableModels, *token.Models)
if requestModel != "" && !isModelInList(requestModel, *token.Models) { if requestModel != "" && !isModelInList(requestModel, *token.Models) {
abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("This API key does not have permission to use the model: %s", requestModel)) abortWithError(c, http.StatusForbidden, errors.Errorf("This API key does not have permission to use the model: %s", requestModel))
return return
} }
} }
@ -144,7 +144,7 @@ func TokenAuth() func(c *gin.Context) {
if model.IsAdmin(token.UserId) { if model.IsAdmin(token.UserId) {
c.Set(ctxkey.SpecificChannelId, parts[1]) c.Set(ctxkey.SpecificChannelId, parts[1])
} else { } else {
abortWithMessage(c, http.StatusForbidden, "Ordinary users do not support specifying channels") abortWithError(c, http.StatusForbidden, errors.New("Ordinary users do not support specifying channels"))
return return
} }
} }

View File

@ -8,6 +8,7 @@ import (
gutils "github.com/Laisky/go-utils/v5" gutils "github.com/Laisky/go-utils/v5"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
@ -31,16 +32,16 @@ func Distribute() func(c *gin.Context) {
if ok { if ok {
id, err := strconv.Atoi(channelId.(string)) id, err := strconv.Atoi(channelId.(string))
if err != nil { if err != nil {
abortWithMessage(c, http.StatusBadRequest, "Invalid Channel Id") abortWithError(c, http.StatusBadRequest, errors.New("Invalid Channel Id"))
return return
} }
channel, err = model.GetChannelById(id, true) channel, err = model.GetChannelById(id, true)
if err != nil { if err != nil {
abortWithMessage(c, http.StatusBadRequest, "Invalid Channel Id") abortWithError(c, http.StatusBadRequest, errors.New("Invalid Channel Id"))
return return
} }
if channel.Status != model.ChannelStatusEnabled { if channel.Status != model.ChannelStatusEnabled {
abortWithMessage(c, http.StatusForbidden, "The channel has been disabled") abortWithError(c, http.StatusForbidden, errors.New("The channel has been disabled"))
return return
} }
} else { } else {
@ -53,7 +54,7 @@ func Distribute() func(c *gin.Context) {
logger.SysError(fmt.Sprintf("Channel does not exist: %d", channel.Id)) logger.SysError(fmt.Sprintf("Channel does not exist: %d", channel.Id))
message = "Database consistency has been broken, please contact the administrator" message = "Database consistency has been broken, please contact the administrator"
} }
abortWithMessage(c, http.StatusServiceUnavailable, message) abortWithError(c, http.StatusServiceUnavailable, errors.New(message))
return return
} }
} }

View File

@ -3,6 +3,8 @@ package middleware
import ( import (
"strings" "strings"
gmw "github.com/Laisky/gin-middlewares/v6"
"github.com/Laisky/zap"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
@ -21,6 +23,18 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) {
logger.Error(c.Request.Context(), message) logger.Error(c.Request.Context(), message)
} }
func abortWithError(c *gin.Context, statusCode int, err error) {
logger := gmw.GetLogger(c)
logger.Error("server abort", zap.Error(err))
c.JSON(statusCode, gin.H{
"error": gin.H{
"message": helper.MessageWithRequestId(err.Error(), c.GetString(helper.RequestIdKey)),
"type": "one_api_error",
},
})
c.Abort()
}
func getRequestModel(c *gin.Context) (string, error) { func getRequestModel(c *gin.Context) (string, error) {
var modelRequest ModelRequest var modelRequest ModelRequest
err := common.UnmarshalBodyReusable(c, &modelRequest) err := common.UnmarshalBodyReusable(c, &modelRequest)

View File

@ -51,7 +51,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
} }
c.Set("claude_model", request.Model) c.Set("claude_model", request.Model)
return ConvertRequest(*request), nil return ConvertRequest(c, *request)
} }
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {

View File

@ -6,6 +6,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"math"
"net/http" "net/http"
"strings" "strings"
@ -38,7 +39,7 @@ func stopReasonClaude2OpenAI(reason *string) string {
} }
} }
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { func ConvertRequest(c *gin.Context, textRequest model.GeneralOpenAIRequest) (*Request, error) {
claudeTools := make([]Tool, 0, len(textRequest.Tools)) claudeTools := make([]Tool, 0, len(textRequest.Tools))
for _, tool := range textRequest.Tools { for _, tool := range textRequest.Tools {
@ -66,7 +67,18 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
Thinking: textRequest.Thinking, Thinking: textRequest.Thinking,
} }
if c.Request.URL.Query().Has("thinking") && claudeRequest.Thinking == nil {
claudeRequest.Thinking = &model.Thinking{
Type: "enabled",
BudgetTokens: int(math.Min(1024, float64(claudeRequest.MaxTokens/2))),
}
}
if claudeRequest.Thinking != nil { if claudeRequest.Thinking != nil {
if claudeRequest.MaxTokens <= 1024 {
return nil, fmt.Errorf("max_tokens must be greater than 1024 when using extended thinking")
}
// top_p must be nil when using extended thinking // top_p must be nil when using extended thinking
claudeRequest.TopP = nil claudeRequest.TopP = nil
} }
@ -151,11 +163,11 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
claudeMessage.Content = contents claudeMessage.Content = contents
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage) claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
} }
return &claudeRequest return &claudeRequest, nil
} }
// https://docs.anthropic.com/claude/reference/messages-streaming // https://docs.anthropic.com/claude/reference/messages-streaming
func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) { func StreamResponseClaude2OpenAI(c *gin.Context, claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) {
var response *Response var response *Response
var responseText string var responseText string
var reasoningText string var reasoningText string
@ -220,7 +232,7 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
var choice openai.ChatCompletionsStreamResponseChoice var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = responseText choice.Delta.Content = responseText
choice.Delta.Reasoning = &reasoningText choice.Delta.SetReasoningContent(c.Query("reasoning_format"), reasoningText)
if len(tools) > 0 { if len(tools) > 0 {
choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ... choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
choice.Delta.ToolCalls = tools choice.Delta.ToolCalls = tools
@ -236,7 +248,7 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
return &openaiResponse, response return &openaiResponse, response
} }
func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { func ResponseClaude2OpenAI(c *gin.Context, claudeResponse *Response) *openai.TextResponse {
var responseText string var responseText string
var reasoningText string var reasoningText string
@ -273,12 +285,12 @@ func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
Message: model.Message{ Message: model.Message{
Role: "assistant", Role: "assistant",
Content: responseText, Content: responseText,
Reasoning: &reasoningText,
Name: nil, Name: nil,
ToolCalls: tools, ToolCalls: tools,
}, },
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
} }
choice.Message.SetReasoningContent(c.Query("reasoning_format"), reasoningText)
fullTextResponse := openai.TextResponse{ fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", claudeResponse.Id), Id: fmt.Sprintf("chatcmpl-%s", claudeResponse.Id),
Model: claudeResponse.Model, Model: claudeResponse.Model,
@ -328,7 +340,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
continue continue
} }
response, meta := StreamResponseClaude2OpenAI(&claudeResponse) response, meta := StreamResponseClaude2OpenAI(c, &claudeResponse)
if meta != nil { if meta != nil {
usage.PromptTokens += meta.Usage.InputTokens usage.PromptTokens += meta.Usage.InputTokens
usage.CompletionTokens += meta.Usage.OutputTokens usage.CompletionTokens += meta.Usage.OutputTokens
@ -407,7 +419,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
}, nil }, nil
} }
fullTextResponse := ResponseClaude2OpenAI(&claudeResponse) fullTextResponse := ResponseClaude2OpenAI(c, &claudeResponse)
fullTextResponse.Model = modelName fullTextResponse.Model = modelName
usage := model.Usage{ usage := model.Usage{
PromptTokens: claudeResponse.Usage.InputTokens, PromptTokens: claudeResponse.Usage.InputTokens,

View File

@ -21,7 +21,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
claudeReq := anthropic.ConvertRequest(*request) claudeReq, err := anthropic.ConvertRequest(c, *request)
if err != nil {
return nil, errors.Wrap(err, "convert request")
}
c.Set(ctxkey.RequestModel, request.Model) c.Set(ctxkey.RequestModel, request.Model)
c.Set(ctxkey.ConvertedRequest, claudeReq) c.Set(ctxkey.ConvertedRequest, claudeReq)
return claudeReq, nil return claudeReq, nil

View File

@ -88,7 +88,7 @@ func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*
return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil
} }
openaiResp := anthropic.ResponseClaude2OpenAI(claudeResponse) openaiResp := anthropic.ResponseClaude2OpenAI(c, claudeResponse)
openaiResp.Model = modelName openaiResp.Model = modelName
usage := relaymodel.Usage{ usage := relaymodel.Usage{
PromptTokens: claudeResponse.Usage.InputTokens, PromptTokens: claudeResponse.Usage.InputTokens,
@ -159,7 +159,7 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
return false return false
} }
response, meta := anthropic.StreamResponseClaude2OpenAI(claudeResp) response, meta := anthropic.StreamResponseClaude2OpenAI(c, claudeResp)
if meta != nil { if meta != nil {
usage.PromptTokens += meta.Usage.InputTokens usage.PromptTokens += meta.Usage.InputTokens
usage.CompletionTokens += meta.Usage.OutputTokens usage.CompletionTokens += meta.Usage.OutputTokens

View File

@ -2,6 +2,8 @@ package doubao
import ( import (
"fmt" "fmt"
"strings"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/relaymode"
) )

View File

@ -25,103 +25,166 @@ const (
dataPrefixLength = len(dataPrefix) dataPrefixLength = len(dataPrefix)
) )
// StreamHandler processes streaming responses from OpenAI API
// It handles incremental content delivery and accumulates the final response text
// Returns error (if any), accumulated response text, and token usage information
func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) { func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) {
// Initialize accumulators for the response
responseText := "" responseText := ""
reasoningText := "" reasoningText := ""
scanner := bufio.NewScanner(resp.Body)
buffer := make([]byte, 256*1024)
scanner.Buffer(buffer, len(buffer))
scanner.Split(bufio.ScanLines)
var usage *model.Usage var usage *model.Usage
// Set up scanner for reading the stream line by line
scanner := bufio.NewScanner(resp.Body)
buffer := make([]byte, 256*1024) // 256KB buffer for large messages
scanner.Buffer(buffer, len(buffer))
scanner.Split(bufio.ScanLines)
// Set response headers for SSE
common.SetEventStreamHeaders(c) common.SetEventStreamHeaders(c)
doneRendered := false doneRendered := false
// Process each line from the stream
for scanner.Scan() { for scanner.Scan() {
data := NormalizeDataLine(scanner.Text()) data := NormalizeDataLine(scanner.Text())
if len(data) < dataPrefixLength { // ignore blank line or wrong format
continue // Skip lines that don't match expected format
if len(data) < dataPrefixLength {
continue // Ignore blank line or wrong format
} }
// Verify line starts with expected prefix
if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done { if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done {
continue continue
} }
// Check for stream termination
if strings.HasPrefix(data[dataPrefixLength:], done) { if strings.HasPrefix(data[dataPrefixLength:], done) {
render.StringData(c, data) render.StringData(c, data)
doneRendered = true doneRendered = true
continue continue
} }
// Process based on relay mode
switch relayMode { switch relayMode {
case relaymode.ChatCompletions: case relaymode.ChatCompletions:
var streamResponse ChatCompletionsStreamResponse var streamResponse ChatCompletionsStreamResponse
// Parse the JSON response
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
if err != nil { if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error()) logger.SysError("error unmarshalling stream response: " + err.Error())
render.StringData(c, data) // if error happened, pass the data to client render.StringData(c, data) // Pass raw data to client if parsing fails
continue // just ignore the error continue
} }
// Skip empty choices (Azure specific behavior)
if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil { if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil {
// but for empty choice and no usage, we should not pass it to client, this is for azure continue
continue // just ignore empty choice
} }
render.StringData(c, data)
// Process each choice in the response
for _, choice := range streamResponse.Choices { for _, choice := range streamResponse.Choices {
if choice.Delta.Reasoning != nil { // Extract reasoning content from different possible fields
reasoningText += *choice.Delta.Reasoning currentReasoningChunk := extractReasoningContent(&choice.Delta)
}
if choice.Delta.ReasoningContent != nil { // Update accumulated reasoning text
reasoningText += *choice.Delta.ReasoningContent if currentReasoningChunk != "" {
reasoningText += currentReasoningChunk
} }
// Set the reasoning content in the format requested by client
choice.Delta.SetReasoningContent(c.Query("reasoning_format"), currentReasoningChunk)
// Accumulate response content
responseText += conv.AsString(choice.Delta.Content) responseText += conv.AsString(choice.Delta.Content)
} }
// Send the processed data to the client
render.StringData(c, data)
// Update usage information if available
if streamResponse.Usage != nil { if streamResponse.Usage != nil {
usage = streamResponse.Usage usage = streamResponse.Usage
} }
case relaymode.Completions: case relaymode.Completions:
// Send the data immediately for Completions mode
render.StringData(c, data) render.StringData(c, data)
var streamResponse CompletionsStreamResponse var streamResponse CompletionsStreamResponse
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse) err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
if err != nil { if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error()) logger.SysError("error unmarshalling stream response: " + err.Error())
continue continue
} }
// Accumulate text from all choices
for _, choice := range streamResponse.Choices { for _, choice := range streamResponse.Choices {
responseText += choice.Text responseText += choice.Text
} }
} }
} }
// Check for scanner errors
if err := scanner.Err(); err != nil { if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error()) logger.SysError("error reading stream: " + err.Error())
} }
// Ensure stream termination is sent to client
if !doneRendered { if !doneRendered {
render.Done(c) render.Done(c)
} }
err := resp.Body.Close() // Clean up resources
if err != nil { if err := resp.Body.Close(); err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil
} }
// Return the complete response text (reasoning + content) and usage
return nil, reasoningText + responseText, usage return nil, reasoningText + responseText, usage
} }
// Handler handles the non-stream response from OpenAI API // Helper function to extract reasoning content from message delta
func extractReasoningContent(delta *model.Message) string {
content := ""
// Extract reasoning from different possible fields
if delta.Reasoning != nil {
content += *delta.Reasoning
delta.Reasoning = nil
}
if delta.ReasoningContent != nil {
content += *delta.ReasoningContent
delta.ReasoningContent = nil
}
return content
}
// Handler processes non-streaming responses from OpenAI API
// Returns error (if any) and token usage information
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
var textResponse SlimTextResponse // Read the entire response body
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
} }
err = resp.Body.Close()
if err != nil { // Close the original response body
if err = resp.Body.Close(); err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
err = json.Unmarshal(responseBody, &textResponse)
if err != nil { // Parse the response JSON
var textResponse SlimTextResponse
if err = json.Unmarshal(responseBody, &textResponse); err != nil {
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
} }
// Check for API errors
if textResponse.Error.Type != "" { if textResponse.Error.Type != "" {
return &model.ErrorWithStatusCode{ return &model.ErrorWithStatusCode{
Error: textResponse.Error, Error: textResponse.Error,
@ -129,68 +192,131 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
}, nil }, nil
} }
// Reset response body // Process reasoning content in each choice
for _, msg := range textResponse.Choices {
reasoningContent := processReasoningContent(&msg)
// Set reasoning in requested format if content exists
if reasoningContent != "" {
msg.SetReasoningContent(c.Query("reasoning_format"), reasoningContent)
}
}
// Reset response body for forwarding to client
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
logger.Debugf(c.Request.Context(), "handler response: %s", string(responseBody)) logger.Debugf(c.Request.Context(), "handler response: %s", string(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail. // Forward all response headers (not just first value of each)
// And then we will have to send an error response, but in this case, the header has already been set. for k, values := range resp.Header {
// So the HTTPClient will be confused by the response. for _, v := range values {
// For example, Postman will report error, and we cannot check the response at all. c.Writer.Header().Add(k, v)
for k, v := range resp.Header { }
c.Writer.Header().Set(k, v[0])
} }
// Set response status and copy body to client
c.Writer.WriteHeader(resp.StatusCode) c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body) if _, err = io.Copy(c.Writer, resp.Body); err != nil {
if err != nil {
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
} }
err = resp.Body.Close()
if err != nil { // Close the reset body
if err = resp.Body.Close(); err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
if textResponse.Usage.TotalTokens == 0 || // Calculate token usage if not provided by API
(textResponse.Usage.PromptTokens == 0 && textResponse.Usage.CompletionTokens == 0) { calculateTokenUsage(&textResponse, promptTokens, modelName)
return nil, &textResponse.Usage
}
// processReasoningContent is a helper function to extract and process reasoning content from the message
func processReasoningContent(msg *TextResponseChoice) string {
var reasoningContent string
// Check different locations for reasoning content
switch {
case msg.Reasoning != nil:
reasoningContent = *msg.Reasoning
msg.Reasoning = nil
case msg.ReasoningContent != nil:
reasoningContent = *msg.ReasoningContent
msg.ReasoningContent = nil
case msg.Message.Reasoning != nil:
reasoningContent = *msg.Message.Reasoning
msg.Message.Reasoning = nil
case msg.Message.ReasoningContent != nil:
reasoningContent = *msg.Message.ReasoningContent
msg.Message.ReasoningContent = nil
}
return reasoningContent
}
// Helper function to calculate token usage
func calculateTokenUsage(response *SlimTextResponse, promptTokens int, modelName string) {
// Calculate tokens if not provided by the API
if response.Usage.TotalTokens == 0 ||
(response.Usage.PromptTokens == 0 && response.Usage.CompletionTokens == 0) {
completionTokens := 0 completionTokens := 0
for _, choice := range textResponse.Choices { for _, choice := range response.Choices {
// Count content tokens
completionTokens += CountTokenText(choice.Message.StringContent(), modelName) completionTokens += CountTokenText(choice.Message.StringContent(), modelName)
// Count reasoning tokens in all possible locations
if choice.Message.Reasoning != nil { if choice.Message.Reasoning != nil {
completionTokens += CountToken(*choice.Message.Reasoning) completionTokens += CountToken(*choice.Message.Reasoning)
} }
if choice.Message.ReasoningContent != nil {
completionTokens += CountToken(*choice.Message.ReasoningContent)
}
if choice.Reasoning != nil {
completionTokens += CountToken(*choice.Reasoning)
}
if choice.ReasoningContent != nil { if choice.ReasoningContent != nil {
completionTokens += CountToken(*choice.ReasoningContent) completionTokens += CountToken(*choice.ReasoningContent)
} }
} }
textResponse.Usage = model.Usage{
// Set usage values
response.Usage = model.Usage{
PromptTokens: promptTokens, PromptTokens: promptTokens,
CompletionTokens: completionTokens, CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens, TotalTokens: promptTokens + completionTokens,
} }
} else if (textResponse.PromptTokensDetails != nil && textResponse.PromptTokensDetails.AudioTokens > 0) || } else if hasAudioTokens(response) {
(textResponse.CompletionTokensDetails != nil && textResponse.CompletionTokensDetails.AudioTokens > 0) { // Handle audio tokens conversion
// Convert the more expensive audio tokens to uniformly priced text tokens. calculateAudioTokens(response, modelName)
// Note that when there are no audio tokens in prompt and completion, }
// OpenAI will return empty PromptTokensDetails and CompletionTokensDetails, which can be misleading. }
if textResponse.PromptTokensDetails != nil {
textResponse.Usage.PromptTokens = textResponse.PromptTokensDetails.TextTokens +
int(math.Ceil(
float64(textResponse.PromptTokensDetails.AudioTokens)*
ratio.GetAudioPromptRatio(modelName),
))
}
if textResponse.CompletionTokensDetails != nil { // Helper function to check if response has audio tokens
textResponse.Usage.CompletionTokens = textResponse.CompletionTokensDetails.TextTokens + func hasAudioTokens(response *SlimTextResponse) bool {
int(math.Ceil( return (response.PromptTokensDetails != nil && response.PromptTokensDetails.AudioTokens > 0) ||
float64(textResponse.CompletionTokensDetails.AudioTokens)* (response.CompletionTokensDetails != nil && response.CompletionTokensDetails.AudioTokens > 0)
ratio.GetAudioPromptRatio(modelName)*ratio.GetAudioCompletionRatio(modelName), }
))
}
textResponse.Usage.TotalTokens = textResponse.Usage.PromptTokens + // Helper function to calculate audio token usage
textResponse.Usage.CompletionTokens func calculateAudioTokens(response *SlimTextResponse, modelName string) {
// Convert audio tokens for prompt
if response.PromptTokensDetails != nil {
response.Usage.PromptTokens = response.PromptTokensDetails.TextTokens +
int(math.Ceil(
float64(response.PromptTokensDetails.AudioTokens)*
ratio.GetAudioPromptRatio(modelName),
))
} }
return nil, &textResponse.Usage // Convert audio tokens for completion
if response.CompletionTokensDetails != nil {
response.Usage.CompletionTokens = response.CompletionTokensDetails.TextTokens +
int(math.Ceil(
float64(response.CompletionTokensDetails.AudioTokens)*
ratio.GetAudioPromptRatio(modelName)*ratio.GetAudioCompletionRatio(modelName),
))
}
// Calculate total tokens
response.Usage.TotalTokens = response.Usage.PromptTokens + response.Usage.CompletionTokens
} }

View File

@ -3,6 +3,7 @@ package openai
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"

View File

@ -7,7 +7,6 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic" "github.com/songquanpeng/one-api/relay/adaptor/anthropic"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
) )
@ -32,7 +31,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
claudeReq := anthropic.ConvertRequest(*request) claudeReq, err := anthropic.ConvertRequest(c, *request)
if err != nil {
return nil, errors.Wrap(err, "convert request")
}
req := Request{ req := Request{
AnthropicVersion: anthropicVersion, AnthropicVersion: anthropicVersion,
// Model: claudeReq.Model, // Model: claudeReq.Model,

View File

@ -2,10 +2,32 @@ package model
import ( import (
"context" "context"
"strings"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
) )
// ReasoningFormat is the format of reasoning content,
// can be set by the reasoning_format parameter in the request url.
type ReasoningFormat string
const (
ReasoningFormatUnspecified ReasoningFormat = ""
// ReasoningFormatReasoningContent is the reasoning format used by deepseek official API
ReasoningFormatReasoningContent ReasoningFormat = "reasoning_content"
// ReasoningFormatReasoning is the reasoning format used by openrouter
ReasoningFormatReasoning ReasoningFormat = "reasoning"
// ReasoningFormatThinkTag is the reasoning format used by 3rd party deepseek-r1 providers.
//
// Deprecated: I believe <think> is a very poor format, especially in stream mode, it is difficult to extract and convert.
// Considering that only a few deepseek-r1 third-party providers use this format, it has been decided to no longer support it.
// ReasoningFormatThinkTag ReasoningFormat = "think-tag"
// ReasoningFormatThinking is the reasoning format used by anthropic
ReasoningFormatThinking ReasoningFormat = "thinking"
)
type Message struct { type Message struct {
Role string `json:"role,omitempty"` Role string `json:"role,omitempty"`
// Content is a string or a list of objects // Content is a string or a list of objects
@ -29,6 +51,28 @@ type Message struct {
// ------------------------------------- // -------------------------------------
Reasoning *string `json:"reasoning,omitempty"` Reasoning *string `json:"reasoning,omitempty"`
Refusal *bool `json:"refusal,omitempty"` Refusal *bool `json:"refusal,omitempty"`
// -------------------------------------
// Anthropic
// -------------------------------------
Thinking *string `json:"thinking,omitempty"`
Signature *string `json:"signature,omitempty"`
}
// SetReasoningContent sets the reasoning content based on the format
func (m *Message) SetReasoningContent(format string, reasoningContent string) {
switch ReasoningFormat(strings.ToLower(strings.TrimSpace(format))) {
case ReasoningFormatReasoningContent:
m.ReasoningContent = &reasoningContent
// case ReasoningFormatThinkTag:
// m.Content = fmt.Sprintf("<think>%s</think>%s", reasoningContent, m.Content)
case ReasoningFormatThinking:
m.Thinking = &reasoningContent
case ReasoningFormatReasoning,
ReasoningFormatUnspecified:
m.Reasoning = &reasoningContent
default:
logger.Warnf(context.TODO(), "unknown reasoning format: %q", format)
}
} }
type messageAudio struct { type messageAudio struct {