mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-18 01:26:37 +08:00
feat: enhance error handling and reasoning mechanisms across middleware
- Improve error handling across multiple middleware and adapter components, ensuring consistent error response formats in JSON. - Enhance the functionality of request conversion functions by including context parameters and robust error wrapping. - Introduce new features related to reasoning content in the messaging model, providing better customization and explanations in the documentation.
This commit is contained in:
parent
07d9a8e144
commit
f6cfe7cd4f
26
README.md
26
README.md
@ -30,6 +30,10 @@ Also welcome to register and use my deployed one-api gateway, which supports var
|
|||||||
- [Support claude-3-7-sonnet \& thinking](#support-claude-3-7-sonnet--thinking)
|
- [Support claude-3-7-sonnet \& thinking](#support-claude-3-7-sonnet--thinking)
|
||||||
- [Stream](#stream)
|
- [Stream](#stream)
|
||||||
- [Non-Stream](#non-stream)
|
- [Non-Stream](#non-stream)
|
||||||
|
- [Automatically Enable Thinking and Customize Reasoning Format via URL Parameters](#automatically-enable-thinking-and-customize-reasoning-format-via-url-parameters)
|
||||||
|
- [Reasoning Format - reasoning-content](#reasoning-format---reasoning-content)
|
||||||
|
- [Reasoning Format - reasoning](#reasoning-format---reasoning)
|
||||||
|
- [Reasoning Format - thinking](#reasoning-format---thinking)
|
||||||
- [Bug fix](#bug-fix)
|
- [Bug fix](#bug-fix)
|
||||||
|
|
||||||
## Turtorial
|
## Turtorial
|
||||||
@ -172,6 +176,28 @@ By default, the thinking mode is not enabled. You need to manually pass the `thi
|
|||||||
|
|
||||||

|

|
||||||
|
|
||||||
|
### Automatically Enable Thinking and Customize Reasoning Format via URL Parameters
|
||||||
|
|
||||||
|
Supports two URL parameters: `thinking` and `reasoning_format`.
|
||||||
|
|
||||||
|
- `thinking`: Whether to enable thinking mode, disabled by default.
|
||||||
|
- `reasoning_format`: Specifies the format of the returned reasoning.
|
||||||
|
- `reasoning_content`: DeepSeek official API format, returned in the `reasoning_content` field.
|
||||||
|
- `reasoning`: OpenRouter format, returned in the `reasoning` field.
|
||||||
|
- `thinking`: Claude format, returned in the `thinking` field.
|
||||||
|
|
||||||
|
#### Reasoning Format - reasoning-content
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
#### Reasoning Format - reasoning
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
#### Reasoning Format - thinking
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
## Bug fix
|
## Bug fix
|
||||||
|
|
||||||
- [BUGFIX: 更新令牌时的一些问题 #1933](https://github.com/songquanpeng/one-api/pull/1933)
|
- [BUGFIX: 更新令牌时的一些问题 #1933](https://github.com/songquanpeng/one-api/pull/1933)
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-contrib/sessions"
|
"github.com/gin-contrib/sessions"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pkg/errors"
|
||||||
"github.com/songquanpeng/one-api/common/blacklist"
|
"github.com/songquanpeng/one-api/common/blacklist"
|
||||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
@ -102,34 +102,34 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
key = parts[0]
|
key = parts[0]
|
||||||
token, err := model.ValidateUserToken(key)
|
token, err := model.ValidateUserToken(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithMessage(c, http.StatusUnauthorized, err.Error())
|
abortWithError(c, http.StatusUnauthorized, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if token.Subnet != nil && *token.Subnet != "" {
|
if token.Subnet != nil && *token.Subnet != "" {
|
||||||
if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) {
|
if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) {
|
||||||
abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("This API key can only be used in the specified subnet: %s, current IP: %s", *token.Subnet, c.ClientIP()))
|
abortWithError(c, http.StatusForbidden, errors.Errorf("This API key can only be used in the specified subnet: %s, current IP: %s", *token.Subnet, c.ClientIP()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
|
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithMessage(c, http.StatusInternalServerError, err.Error())
|
abortWithError(c, http.StatusInternalServerError, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !userEnabled || blacklist.IsUserBanned(token.UserId) {
|
if !userEnabled || blacklist.IsUserBanned(token.UserId) {
|
||||||
abortWithMessage(c, http.StatusForbidden, "User has been banned")
|
abortWithError(c, http.StatusForbidden, errors.New("User has been banned"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
requestModel, err := getRequestModel(c)
|
requestModel, err := getRequestModel(c)
|
||||||
if err != nil && shouldCheckModel(c) {
|
if err != nil && shouldCheckModel(c) {
|
||||||
abortWithMessage(c, http.StatusBadRequest, err.Error())
|
abortWithError(c, http.StatusBadRequest, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Set(ctxkey.RequestModel, requestModel)
|
c.Set(ctxkey.RequestModel, requestModel)
|
||||||
if token.Models != nil && *token.Models != "" {
|
if token.Models != nil && *token.Models != "" {
|
||||||
c.Set(ctxkey.AvailableModels, *token.Models)
|
c.Set(ctxkey.AvailableModels, *token.Models)
|
||||||
if requestModel != "" && !isModelInList(requestModel, *token.Models) {
|
if requestModel != "" && !isModelInList(requestModel, *token.Models) {
|
||||||
abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("This API key does not have permission to use the model: %s", requestModel))
|
abortWithError(c, http.StatusForbidden, errors.Errorf("This API key does not have permission to use the model: %s", requestModel))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -144,7 +144,7 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
if model.IsAdmin(token.UserId) {
|
if model.IsAdmin(token.UserId) {
|
||||||
c.Set(ctxkey.SpecificChannelId, parts[1])
|
c.Set(ctxkey.SpecificChannelId, parts[1])
|
||||||
} else {
|
} else {
|
||||||
abortWithMessage(c, http.StatusForbidden, "Ordinary users do not support specifying channels")
|
abortWithError(c, http.StatusForbidden, errors.New("Ordinary users do not support specifying channels"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
|
|
||||||
gutils "github.com/Laisky/go-utils/v5"
|
gutils "github.com/Laisky/go-utils/v5"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pkg/errors"
|
||||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/model"
|
"github.com/songquanpeng/one-api/model"
|
||||||
@ -31,16 +32,16 @@ func Distribute() func(c *gin.Context) {
|
|||||||
if ok {
|
if ok {
|
||||||
id, err := strconv.Atoi(channelId.(string))
|
id, err := strconv.Atoi(channelId.(string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithMessage(c, http.StatusBadRequest, "Invalid Channel Id")
|
abortWithError(c, http.StatusBadRequest, errors.New("Invalid Channel Id"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
channel, err = model.GetChannelById(id, true)
|
channel, err = model.GetChannelById(id, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
abortWithMessage(c, http.StatusBadRequest, "Invalid Channel Id")
|
abortWithError(c, http.StatusBadRequest, errors.New("Invalid Channel Id"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if channel.Status != model.ChannelStatusEnabled {
|
if channel.Status != model.ChannelStatusEnabled {
|
||||||
abortWithMessage(c, http.StatusForbidden, "The channel has been disabled")
|
abortWithError(c, http.StatusForbidden, errors.New("The channel has been disabled"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -53,7 +54,7 @@ func Distribute() func(c *gin.Context) {
|
|||||||
logger.SysError(fmt.Sprintf("Channel does not exist: %d", channel.Id))
|
logger.SysError(fmt.Sprintf("Channel does not exist: %d", channel.Id))
|
||||||
message = "Database consistency has been broken, please contact the administrator"
|
message = "Database consistency has been broken, please contact the administrator"
|
||||||
}
|
}
|
||||||
abortWithMessage(c, http.StatusServiceUnavailable, message)
|
abortWithError(c, http.StatusServiceUnavailable, errors.New(message))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,8 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
gmw "github.com/Laisky/gin-middlewares/v6"
|
||||||
|
"github.com/Laisky/zap"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
@ -21,6 +23,18 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
|||||||
logger.Error(c.Request.Context(), message)
|
logger.Error(c.Request.Context(), message)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func abortWithError(c *gin.Context, statusCode int, err error) {
|
||||||
|
logger := gmw.GetLogger(c)
|
||||||
|
logger.Error("server abort", zap.Error(err))
|
||||||
|
c.JSON(statusCode, gin.H{
|
||||||
|
"error": gin.H{
|
||||||
|
"message": helper.MessageWithRequestId(err.Error(), c.GetString(helper.RequestIdKey)),
|
||||||
|
"type": "one_api_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
c.Abort()
|
||||||
|
}
|
||||||
|
|
||||||
func getRequestModel(c *gin.Context) (string, error) {
|
func getRequestModel(c *gin.Context) (string, error) {
|
||||||
var modelRequest ModelRequest
|
var modelRequest ModelRequest
|
||||||
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
||||||
|
@ -51,7 +51,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
}
|
}
|
||||||
|
|
||||||
c.Set("claude_model", request.Model)
|
c.Set("claude_model", request.Model)
|
||||||
return ConvertRequest(*request), nil
|
return ConvertRequest(c, *request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
func (a *Adaptor) ConvertImageRequest(_ *gin.Context, request *model.ImageRequest) (any, error) {
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@ -38,7 +39,7 @@ func stopReasonClaude2OpenAI(reason *string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
|
func ConvertRequest(c *gin.Context, textRequest model.GeneralOpenAIRequest) (*Request, error) {
|
||||||
claudeTools := make([]Tool, 0, len(textRequest.Tools))
|
claudeTools := make([]Tool, 0, len(textRequest.Tools))
|
||||||
|
|
||||||
for _, tool := range textRequest.Tools {
|
for _, tool := range textRequest.Tools {
|
||||||
@ -66,7 +67,18 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
|
|||||||
Thinking: textRequest.Thinking,
|
Thinking: textRequest.Thinking,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.Request.URL.Query().Has("thinking") && claudeRequest.Thinking == nil {
|
||||||
|
claudeRequest.Thinking = &model.Thinking{
|
||||||
|
Type: "enabled",
|
||||||
|
BudgetTokens: int(math.Min(1024, float64(claudeRequest.MaxTokens/2))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if claudeRequest.Thinking != nil {
|
if claudeRequest.Thinking != nil {
|
||||||
|
if claudeRequest.MaxTokens <= 1024 {
|
||||||
|
return nil, fmt.Errorf("max_tokens must be greater than 1024 when using extended thinking")
|
||||||
|
}
|
||||||
|
|
||||||
// top_p must be nil when using extended thinking
|
// top_p must be nil when using extended thinking
|
||||||
claudeRequest.TopP = nil
|
claudeRequest.TopP = nil
|
||||||
}
|
}
|
||||||
@ -151,11 +163,11 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
|
|||||||
claudeMessage.Content = contents
|
claudeMessage.Content = contents
|
||||||
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
|
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
|
||||||
}
|
}
|
||||||
return &claudeRequest
|
return &claudeRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://docs.anthropic.com/claude/reference/messages-streaming
|
// https://docs.anthropic.com/claude/reference/messages-streaming
|
||||||
func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) {
|
func StreamResponseClaude2OpenAI(c *gin.Context, claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) {
|
||||||
var response *Response
|
var response *Response
|
||||||
var responseText string
|
var responseText string
|
||||||
var reasoningText string
|
var reasoningText string
|
||||||
@ -220,7 +232,7 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
|
|||||||
|
|
||||||
var choice openai.ChatCompletionsStreamResponseChoice
|
var choice openai.ChatCompletionsStreamResponseChoice
|
||||||
choice.Delta.Content = responseText
|
choice.Delta.Content = responseText
|
||||||
choice.Delta.Reasoning = &reasoningText
|
choice.Delta.SetReasoningContent(c.Query("reasoning_format"), reasoningText)
|
||||||
if len(tools) > 0 {
|
if len(tools) > 0 {
|
||||||
choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
|
choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
|
||||||
choice.Delta.ToolCalls = tools
|
choice.Delta.ToolCalls = tools
|
||||||
@ -236,7 +248,7 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
|
|||||||
return &openaiResponse, response
|
return &openaiResponse, response
|
||||||
}
|
}
|
||||||
|
|
||||||
func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
|
func ResponseClaude2OpenAI(c *gin.Context, claudeResponse *Response) *openai.TextResponse {
|
||||||
var responseText string
|
var responseText string
|
||||||
var reasoningText string
|
var reasoningText string
|
||||||
|
|
||||||
@ -273,12 +285,12 @@ func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
|
|||||||
Message: model.Message{
|
Message: model.Message{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: responseText,
|
Content: responseText,
|
||||||
Reasoning: &reasoningText,
|
|
||||||
Name: nil,
|
Name: nil,
|
||||||
ToolCalls: tools,
|
ToolCalls: tools,
|
||||||
},
|
},
|
||||||
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
|
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
|
||||||
}
|
}
|
||||||
|
choice.Message.SetReasoningContent(c.Query("reasoning_format"), reasoningText)
|
||||||
fullTextResponse := openai.TextResponse{
|
fullTextResponse := openai.TextResponse{
|
||||||
Id: fmt.Sprintf("chatcmpl-%s", claudeResponse.Id),
|
Id: fmt.Sprintf("chatcmpl-%s", claudeResponse.Id),
|
||||||
Model: claudeResponse.Model,
|
Model: claudeResponse.Model,
|
||||||
@ -328,7 +340,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
response, meta := StreamResponseClaude2OpenAI(&claudeResponse)
|
response, meta := StreamResponseClaude2OpenAI(c, &claudeResponse)
|
||||||
if meta != nil {
|
if meta != nil {
|
||||||
usage.PromptTokens += meta.Usage.InputTokens
|
usage.PromptTokens += meta.Usage.InputTokens
|
||||||
usage.CompletionTokens += meta.Usage.OutputTokens
|
usage.CompletionTokens += meta.Usage.OutputTokens
|
||||||
@ -407,7 +419,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
|
|||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
fullTextResponse := ResponseClaude2OpenAI(&claudeResponse)
|
fullTextResponse := ResponseClaude2OpenAI(c, &claudeResponse)
|
||||||
fullTextResponse.Model = modelName
|
fullTextResponse.Model = modelName
|
||||||
usage := model.Usage{
|
usage := model.Usage{
|
||||||
PromptTokens: claudeResponse.Usage.InputTokens,
|
PromptTokens: claudeResponse.Usage.InputTokens,
|
||||||
|
@ -21,7 +21,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
claudeReq := anthropic.ConvertRequest(*request)
|
claudeReq, err := anthropic.ConvertRequest(c, *request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "convert request")
|
||||||
|
}
|
||||||
|
|
||||||
c.Set(ctxkey.RequestModel, request.Model)
|
c.Set(ctxkey.RequestModel, request.Model)
|
||||||
c.Set(ctxkey.ConvertedRequest, claudeReq)
|
c.Set(ctxkey.ConvertedRequest, claudeReq)
|
||||||
return claudeReq, nil
|
return claudeReq, nil
|
||||||
|
@ -88,7 +88,7 @@ func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*
|
|||||||
return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil
|
return utils.WrapErr(errors.Wrap(err, "unmarshal response")), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
openaiResp := anthropic.ResponseClaude2OpenAI(claudeResponse)
|
openaiResp := anthropic.ResponseClaude2OpenAI(c, claudeResponse)
|
||||||
openaiResp.Model = modelName
|
openaiResp.Model = modelName
|
||||||
usage := relaymodel.Usage{
|
usage := relaymodel.Usage{
|
||||||
PromptTokens: claudeResponse.Usage.InputTokens,
|
PromptTokens: claudeResponse.Usage.InputTokens,
|
||||||
@ -159,7 +159,7 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
response, meta := anthropic.StreamResponseClaude2OpenAI(claudeResp)
|
response, meta := anthropic.StreamResponseClaude2OpenAI(c, claudeResp)
|
||||||
if meta != nil {
|
if meta != nil {
|
||||||
usage.PromptTokens += meta.Usage.InputTokens
|
usage.PromptTokens += meta.Usage.InputTokens
|
||||||
usage.CompletionTokens += meta.Usage.OutputTokens
|
usage.CompletionTokens += meta.Usage.OutputTokens
|
||||||
|
@ -2,6 +2,8 @@ package doubao
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/relay/meta"
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||||
)
|
)
|
||||||
|
@ -25,103 +25,166 @@ const (
|
|||||||
dataPrefixLength = len(dataPrefix)
|
dataPrefixLength = len(dataPrefix)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// StreamHandler processes streaming responses from OpenAI API
|
||||||
|
// It handles incremental content delivery and accumulates the final response text
|
||||||
|
// Returns error (if any), accumulated response text, and token usage information
|
||||||
func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) {
|
func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) {
|
||||||
|
// Initialize accumulators for the response
|
||||||
responseText := ""
|
responseText := ""
|
||||||
reasoningText := ""
|
reasoningText := ""
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
|
||||||
buffer := make([]byte, 256*1024)
|
|
||||||
scanner.Buffer(buffer, len(buffer))
|
|
||||||
scanner.Split(bufio.ScanLines)
|
|
||||||
var usage *model.Usage
|
var usage *model.Usage
|
||||||
|
|
||||||
|
// Set up scanner for reading the stream line by line
|
||||||
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
|
buffer := make([]byte, 256*1024) // 256KB buffer for large messages
|
||||||
|
scanner.Buffer(buffer, len(buffer))
|
||||||
|
scanner.Split(bufio.ScanLines)
|
||||||
|
|
||||||
|
// Set response headers for SSE
|
||||||
common.SetEventStreamHeaders(c)
|
common.SetEventStreamHeaders(c)
|
||||||
|
|
||||||
doneRendered := false
|
doneRendered := false
|
||||||
|
|
||||||
|
// Process each line from the stream
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
data := NormalizeDataLine(scanner.Text())
|
data := NormalizeDataLine(scanner.Text())
|
||||||
if len(data) < dataPrefixLength { // ignore blank line or wrong format
|
|
||||||
continue
|
// Skip lines that don't match expected format
|
||||||
|
if len(data) < dataPrefixLength {
|
||||||
|
continue // Ignore blank line or wrong format
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Verify line starts with expected prefix
|
||||||
if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done {
|
if data[:dataPrefixLength] != dataPrefix && data[:dataPrefixLength] != done {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for stream termination
|
||||||
if strings.HasPrefix(data[dataPrefixLength:], done) {
|
if strings.HasPrefix(data[dataPrefixLength:], done) {
|
||||||
render.StringData(c, data)
|
render.StringData(c, data)
|
||||||
doneRendered = true
|
doneRendered = true
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Process based on relay mode
|
||||||
switch relayMode {
|
switch relayMode {
|
||||||
case relaymode.ChatCompletions:
|
case relaymode.ChatCompletions:
|
||||||
var streamResponse ChatCompletionsStreamResponse
|
var streamResponse ChatCompletionsStreamResponse
|
||||||
|
|
||||||
|
// Parse the JSON response
|
||||||
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
|
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
render.StringData(c, data) // if error happened, pass the data to client
|
render.StringData(c, data) // Pass raw data to client if parsing fails
|
||||||
continue // just ignore the error
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip empty choices (Azure specific behavior)
|
||||||
if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil {
|
if len(streamResponse.Choices) == 0 && streamResponse.Usage == nil {
|
||||||
// but for empty choice and no usage, we should not pass it to client, this is for azure
|
continue
|
||||||
continue // just ignore empty choice
|
|
||||||
}
|
}
|
||||||
render.StringData(c, data)
|
|
||||||
|
// Process each choice in the response
|
||||||
for _, choice := range streamResponse.Choices {
|
for _, choice := range streamResponse.Choices {
|
||||||
if choice.Delta.Reasoning != nil {
|
// Extract reasoning content from different possible fields
|
||||||
reasoningText += *choice.Delta.Reasoning
|
currentReasoningChunk := extractReasoningContent(&choice.Delta)
|
||||||
}
|
|
||||||
if choice.Delta.ReasoningContent != nil {
|
// Update accumulated reasoning text
|
||||||
reasoningText += *choice.Delta.ReasoningContent
|
if currentReasoningChunk != "" {
|
||||||
|
reasoningText += currentReasoningChunk
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set the reasoning content in the format requested by client
|
||||||
|
choice.Delta.SetReasoningContent(c.Query("reasoning_format"), currentReasoningChunk)
|
||||||
|
|
||||||
|
// Accumulate response content
|
||||||
responseText += conv.AsString(choice.Delta.Content)
|
responseText += conv.AsString(choice.Delta.Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Send the processed data to the client
|
||||||
|
render.StringData(c, data)
|
||||||
|
|
||||||
|
// Update usage information if available
|
||||||
if streamResponse.Usage != nil {
|
if streamResponse.Usage != nil {
|
||||||
usage = streamResponse.Usage
|
usage = streamResponse.Usage
|
||||||
}
|
}
|
||||||
|
|
||||||
case relaymode.Completions:
|
case relaymode.Completions:
|
||||||
|
// Send the data immediately for Completions mode
|
||||||
render.StringData(c, data)
|
render.StringData(c, data)
|
||||||
|
|
||||||
var streamResponse CompletionsStreamResponse
|
var streamResponse CompletionsStreamResponse
|
||||||
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
|
err := json.Unmarshal([]byte(data[dataPrefixLength:]), &streamResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.SysError("error unmarshalling stream response: " + err.Error())
|
logger.SysError("error unmarshalling stream response: " + err.Error())
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Accumulate text from all choices
|
||||||
for _, choice := range streamResponse.Choices {
|
for _, choice := range streamResponse.Choices {
|
||||||
responseText += choice.Text
|
responseText += choice.Text
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for scanner errors
|
||||||
if err := scanner.Err(); err != nil {
|
if err := scanner.Err(); err != nil {
|
||||||
logger.SysError("error reading stream: " + err.Error())
|
logger.SysError("error reading stream: " + err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ensure stream termination is sent to client
|
||||||
if !doneRendered {
|
if !doneRendered {
|
||||||
render.Done(c)
|
render.Done(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := resp.Body.Close()
|
// Clean up resources
|
||||||
if err != nil {
|
if err := resp.Body.Close(); err != nil {
|
||||||
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil
|
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return the complete response text (reasoning + content) and usage
|
||||||
return nil, reasoningText + responseText, usage
|
return nil, reasoningText + responseText, usage
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handler handles the non-stream response from OpenAI API
|
// Helper function to extract reasoning content from message delta
|
||||||
|
func extractReasoningContent(delta *model.Message) string {
|
||||||
|
content := ""
|
||||||
|
|
||||||
|
// Extract reasoning from different possible fields
|
||||||
|
if delta.Reasoning != nil {
|
||||||
|
content += *delta.Reasoning
|
||||||
|
delta.Reasoning = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if delta.ReasoningContent != nil {
|
||||||
|
content += *delta.ReasoningContent
|
||||||
|
delta.ReasoningContent = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return content
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler processes non-streaming responses from OpenAI API
|
||||||
|
// Returns error (if any) and token usage information
|
||||||
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
|
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
|
||||||
var textResponse SlimTextResponse
|
// Read the entire response body
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
// Close the original response body
|
||||||
|
if err = resp.Body.Close(); err != nil {
|
||||||
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = json.Unmarshal(responseBody, &textResponse)
|
|
||||||
if err != nil {
|
// Parse the response JSON
|
||||||
|
var textResponse SlimTextResponse
|
||||||
|
if err = json.Unmarshal(responseBody, &textResponse); err != nil {
|
||||||
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for API errors
|
||||||
if textResponse.Error.Type != "" {
|
if textResponse.Error.Type != "" {
|
||||||
return &model.ErrorWithStatusCode{
|
return &model.ErrorWithStatusCode{
|
||||||
Error: textResponse.Error,
|
Error: textResponse.Error,
|
||||||
@ -129,68 +192,131 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset response body
|
// Process reasoning content in each choice
|
||||||
|
for _, msg := range textResponse.Choices {
|
||||||
|
reasoningContent := processReasoningContent(&msg)
|
||||||
|
|
||||||
|
// Set reasoning in requested format if content exists
|
||||||
|
if reasoningContent != "" {
|
||||||
|
msg.SetReasoningContent(c.Query("reasoning_format"), reasoningContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset response body for forwarding to client
|
||||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||||
logger.Debugf(c.Request.Context(), "handler response: %s", string(responseBody))
|
logger.Debugf(c.Request.Context(), "handler response: %s", string(responseBody))
|
||||||
|
|
||||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
// Forward all response headers (not just first value of each)
|
||||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
for k, values := range resp.Header {
|
||||||
// So the HTTPClient will be confused by the response.
|
for _, v := range values {
|
||||||
// For example, Postman will report error, and we cannot check the response at all.
|
c.Writer.Header().Add(k, v)
|
||||||
for k, v := range resp.Header {
|
}
|
||||||
c.Writer.Header().Set(k, v[0])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set response status and copy body to client
|
||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
_, err = io.Copy(c.Writer, resp.Body)
|
if _, err = io.Copy(c.Writer, resp.Body); err != nil {
|
||||||
if err != nil {
|
|
||||||
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
|
||||||
if err != nil {
|
// Close the reset body
|
||||||
|
if err = resp.Body.Close(); err != nil {
|
||||||
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if textResponse.Usage.TotalTokens == 0 ||
|
// Calculate token usage if not provided by API
|
||||||
(textResponse.Usage.PromptTokens == 0 && textResponse.Usage.CompletionTokens == 0) {
|
calculateTokenUsage(&textResponse, promptTokens, modelName)
|
||||||
|
|
||||||
|
return nil, &textResponse.Usage
|
||||||
|
}
|
||||||
|
|
||||||
|
// processReasoningContent is a helper function to extract and process reasoning content from the message
|
||||||
|
func processReasoningContent(msg *TextResponseChoice) string {
|
||||||
|
var reasoningContent string
|
||||||
|
|
||||||
|
// Check different locations for reasoning content
|
||||||
|
switch {
|
||||||
|
case msg.Reasoning != nil:
|
||||||
|
reasoningContent = *msg.Reasoning
|
||||||
|
msg.Reasoning = nil
|
||||||
|
case msg.ReasoningContent != nil:
|
||||||
|
reasoningContent = *msg.ReasoningContent
|
||||||
|
msg.ReasoningContent = nil
|
||||||
|
case msg.Message.Reasoning != nil:
|
||||||
|
reasoningContent = *msg.Message.Reasoning
|
||||||
|
msg.Message.Reasoning = nil
|
||||||
|
case msg.Message.ReasoningContent != nil:
|
||||||
|
reasoningContent = *msg.Message.ReasoningContent
|
||||||
|
msg.Message.ReasoningContent = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return reasoningContent
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to calculate token usage
|
||||||
|
func calculateTokenUsage(response *SlimTextResponse, promptTokens int, modelName string) {
|
||||||
|
// Calculate tokens if not provided by the API
|
||||||
|
if response.Usage.TotalTokens == 0 ||
|
||||||
|
(response.Usage.PromptTokens == 0 && response.Usage.CompletionTokens == 0) {
|
||||||
|
|
||||||
completionTokens := 0
|
completionTokens := 0
|
||||||
for _, choice := range textResponse.Choices {
|
for _, choice := range response.Choices {
|
||||||
|
// Count content tokens
|
||||||
completionTokens += CountTokenText(choice.Message.StringContent(), modelName)
|
completionTokens += CountTokenText(choice.Message.StringContent(), modelName)
|
||||||
|
|
||||||
|
// Count reasoning tokens in all possible locations
|
||||||
if choice.Message.Reasoning != nil {
|
if choice.Message.Reasoning != nil {
|
||||||
completionTokens += CountToken(*choice.Message.Reasoning)
|
completionTokens += CountToken(*choice.Message.Reasoning)
|
||||||
}
|
}
|
||||||
|
if choice.Message.ReasoningContent != nil {
|
||||||
|
completionTokens += CountToken(*choice.Message.ReasoningContent)
|
||||||
|
}
|
||||||
|
if choice.Reasoning != nil {
|
||||||
|
completionTokens += CountToken(*choice.Reasoning)
|
||||||
|
}
|
||||||
if choice.ReasoningContent != nil {
|
if choice.ReasoningContent != nil {
|
||||||
completionTokens += CountToken(*choice.ReasoningContent)
|
completionTokens += CountToken(*choice.ReasoningContent)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
textResponse.Usage = model.Usage{
|
|
||||||
|
// Set usage values
|
||||||
|
response.Usage = model.Usage{
|
||||||
PromptTokens: promptTokens,
|
PromptTokens: promptTokens,
|
||||||
CompletionTokens: completionTokens,
|
CompletionTokens: completionTokens,
|
||||||
TotalTokens: promptTokens + completionTokens,
|
TotalTokens: promptTokens + completionTokens,
|
||||||
}
|
}
|
||||||
} else if (textResponse.PromptTokensDetails != nil && textResponse.PromptTokensDetails.AudioTokens > 0) ||
|
} else if hasAudioTokens(response) {
|
||||||
(textResponse.CompletionTokensDetails != nil && textResponse.CompletionTokensDetails.AudioTokens > 0) {
|
// Handle audio tokens conversion
|
||||||
// Convert the more expensive audio tokens to uniformly priced text tokens.
|
calculateAudioTokens(response, modelName)
|
||||||
// Note that when there are no audio tokens in prompt and completion,
|
}
|
||||||
// OpenAI will return empty PromptTokensDetails and CompletionTokensDetails, which can be misleading.
|
}
|
||||||
if textResponse.PromptTokensDetails != nil {
|
|
||||||
textResponse.Usage.PromptTokens = textResponse.PromptTokensDetails.TextTokens +
|
|
||||||
int(math.Ceil(
|
|
||||||
float64(textResponse.PromptTokensDetails.AudioTokens)*
|
|
||||||
ratio.GetAudioPromptRatio(modelName),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
if textResponse.CompletionTokensDetails != nil {
|
// Helper function to check if response has audio tokens
|
||||||
textResponse.Usage.CompletionTokens = textResponse.CompletionTokensDetails.TextTokens +
|
func hasAudioTokens(response *SlimTextResponse) bool {
|
||||||
int(math.Ceil(
|
return (response.PromptTokensDetails != nil && response.PromptTokensDetails.AudioTokens > 0) ||
|
||||||
float64(textResponse.CompletionTokensDetails.AudioTokens)*
|
(response.CompletionTokensDetails != nil && response.CompletionTokensDetails.AudioTokens > 0)
|
||||||
ratio.GetAudioPromptRatio(modelName)*ratio.GetAudioCompletionRatio(modelName),
|
}
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
textResponse.Usage.TotalTokens = textResponse.Usage.PromptTokens +
|
// Helper function to calculate audio token usage
|
||||||
textResponse.Usage.CompletionTokens
|
func calculateAudioTokens(response *SlimTextResponse, modelName string) {
|
||||||
|
// Convert audio tokens for prompt
|
||||||
|
if response.PromptTokensDetails != nil {
|
||||||
|
response.Usage.PromptTokens = response.PromptTokensDetails.TextTokens +
|
||||||
|
int(math.Ceil(
|
||||||
|
float64(response.PromptTokensDetails.AudioTokens)*
|
||||||
|
ratio.GetAudioPromptRatio(modelName),
|
||||||
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, &textResponse.Usage
|
// Convert audio tokens for completion
|
||||||
|
if response.CompletionTokensDetails != nil {
|
||||||
|
response.Usage.CompletionTokens = response.CompletionTokensDetails.TextTokens +
|
||||||
|
int(math.Ceil(
|
||||||
|
float64(response.CompletionTokensDetails.AudioTokens)*
|
||||||
|
ratio.GetAudioPromptRatio(modelName)*ratio.GetAudioCompletionRatio(modelName),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate total tokens
|
||||||
|
response.Usage.TotalTokens = response.Usage.PromptTokens + response.Usage.CompletionTokens
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ package openai
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
|
@ -7,7 +7,6 @@ import (
|
|||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||||
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
|
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/relay/meta"
|
"github.com/songquanpeng/one-api/relay/meta"
|
||||||
"github.com/songquanpeng/one-api/relay/model"
|
"github.com/songquanpeng/one-api/relay/model"
|
||||||
)
|
)
|
||||||
@ -32,7 +31,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
|
|||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
claudeReq := anthropic.ConvertRequest(*request)
|
claudeReq, err := anthropic.ConvertRequest(c, *request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "convert request")
|
||||||
|
}
|
||||||
|
|
||||||
req := Request{
|
req := Request{
|
||||||
AnthropicVersion: anthropicVersion,
|
AnthropicVersion: anthropicVersion,
|
||||||
// Model: claudeReq.Model,
|
// Model: claudeReq.Model,
|
||||||
|
@ -2,10 +2,32 @@ package model
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/songquanpeng/one-api/common/logger"
|
"github.com/songquanpeng/one-api/common/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ReasoningFormat is the format of reasoning content,
|
||||||
|
// can be set by the reasoning_format parameter in the request url.
|
||||||
|
type ReasoningFormat string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ReasoningFormatUnspecified ReasoningFormat = ""
|
||||||
|
// ReasoningFormatReasoningContent is the reasoning format used by deepseek official API
|
||||||
|
ReasoningFormatReasoningContent ReasoningFormat = "reasoning_content"
|
||||||
|
// ReasoningFormatReasoning is the reasoning format used by openrouter
|
||||||
|
ReasoningFormatReasoning ReasoningFormat = "reasoning"
|
||||||
|
|
||||||
|
// ReasoningFormatThinkTag is the reasoning format used by 3rd party deepseek-r1 providers.
|
||||||
|
//
|
||||||
|
// Deprecated: I believe <think> is a very poor format, especially in stream mode, it is difficult to extract and convert.
|
||||||
|
// Considering that only a few deepseek-r1 third-party providers use this format, it has been decided to no longer support it.
|
||||||
|
// ReasoningFormatThinkTag ReasoningFormat = "think-tag"
|
||||||
|
|
||||||
|
// ReasoningFormatThinking is the reasoning format used by anthropic
|
||||||
|
ReasoningFormatThinking ReasoningFormat = "thinking"
|
||||||
|
)
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Role string `json:"role,omitempty"`
|
Role string `json:"role,omitempty"`
|
||||||
// Content is a string or a list of objects
|
// Content is a string or a list of objects
|
||||||
@ -29,6 +51,28 @@ type Message struct {
|
|||||||
// -------------------------------------
|
// -------------------------------------
|
||||||
Reasoning *string `json:"reasoning,omitempty"`
|
Reasoning *string `json:"reasoning,omitempty"`
|
||||||
Refusal *bool `json:"refusal,omitempty"`
|
Refusal *bool `json:"refusal,omitempty"`
|
||||||
|
// -------------------------------------
|
||||||
|
// Anthropic
|
||||||
|
// -------------------------------------
|
||||||
|
Thinking *string `json:"thinking,omitempty"`
|
||||||
|
Signature *string `json:"signature,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetReasoningContent sets the reasoning content based on the format
|
||||||
|
func (m *Message) SetReasoningContent(format string, reasoningContent string) {
|
||||||
|
switch ReasoningFormat(strings.ToLower(strings.TrimSpace(format))) {
|
||||||
|
case ReasoningFormatReasoningContent:
|
||||||
|
m.ReasoningContent = &reasoningContent
|
||||||
|
// case ReasoningFormatThinkTag:
|
||||||
|
// m.Content = fmt.Sprintf("<think>%s</think>%s", reasoningContent, m.Content)
|
||||||
|
case ReasoningFormatThinking:
|
||||||
|
m.Thinking = &reasoningContent
|
||||||
|
case ReasoningFormatReasoning,
|
||||||
|
ReasoningFormatUnspecified:
|
||||||
|
m.Reasoning = &reasoningContent
|
||||||
|
default:
|
||||||
|
logger.Warnf(context.TODO(), "unknown reasoning format: %q", format)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type messageAudio struct {
|
type messageAudio struct {
|
||||||
|
Loading…
Reference in New Issue
Block a user