From 480f248a3daba22e2efd9b5e9ef6902d7b4715c6 Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Wed, 19 Feb 2025 01:11:46 +0000 Subject: [PATCH 01/10] feat: support OpenRouter reasoning --- common/conv/any.go | 7 +++++-- relay/adaptor/openai/adaptor.go | 26 ++++++++++++++++++++++- relay/adaptor/openai/main.go | 11 +++++++--- relay/adaptor/openrouter/model.go | 22 +++++++++++++++++++ relay/model/general.go | 7 +++++++ relay/model/message.go | 35 +++++++++++++++++++++++++------ relay/model/misc.go | 31 ++++++++++++++++++++------- 7 files changed, 119 insertions(+), 20 deletions(-) create mode 100644 relay/adaptor/openrouter/model.go diff --git a/common/conv/any.go b/common/conv/any.go index 467e8bb7..33d34aa7 100644 --- a/common/conv/any.go +++ b/common/conv/any.go @@ -1,6 +1,9 @@ package conv func AsString(v any) string { - str, _ := v.(string) - return str + if str, ok := v.(string); ok { + return str + } + + return "" } diff --git a/relay/adaptor/openai/adaptor.go b/relay/adaptor/openai/adaptor.go index 8faf90a5..03bd3c91 100644 --- a/relay/adaptor/openai/adaptor.go +++ b/relay/adaptor/openai/adaptor.go @@ -9,6 +9,8 @@ import ( "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor/alibailian" "github.com/songquanpeng/one-api/relay/adaptor/baiduv2" @@ -16,6 +18,7 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/geminiv2" "github.com/songquanpeng/one-api/relay/adaptor/minimax" "github.com/songquanpeng/one-api/relay/adaptor/novita" + "github.com/songquanpeng/one-api/relay/adaptor/openrouter" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" @@ -85,7 +88,28 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G if request == nil { return nil, errors.New("request is nil") } - if request.Stream { + + meta := meta.GetByContext(c) + switch meta.ChannelType { + case channeltype.OpenRouter: + includeReasoning := true + request.IncludeReasoning = &includeReasoning + if request.Provider == nil || request.Provider.Sort == "" { + if request.Provider == nil { + request.Provider = &openrouter.RequestProvider{} + } + + request.Provider.Sort = "throughput" + } + default: + } + + if request.Stream && !config.EnforceIncludeUsage { + logger.Warn(c.Request.Context(), + "please set ENFORCE_INCLUDE_USAGE=true to ensure accurate billing in stream mode") + } + + if config.EnforceIncludeUsage && request.Stream { // always return usage in stream mode if request.StreamOptions == nil { request.StreamOptions = &model.StreamOptions{} diff --git a/relay/adaptor/openai/main.go b/relay/adaptor/openai/main.go index 97080738..545da981 100644 --- a/relay/adaptor/openai/main.go +++ b/relay/adaptor/openai/main.go @@ -8,12 +8,11 @@ import ( "net/http" "strings" - "github.com/songquanpeng/one-api/common/render" - "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/render" "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" ) @@ -26,6 +25,7 @@ const ( func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) { responseText := "" + reasoningText := "" scanner := bufio.NewScanner(resp.Body) scanner.Split(bufio.ScanLines) var usage *model.Usage @@ -61,6 +61,10 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E } render.StringData(c, data) for _, choice := range streamResponse.Choices { + if choice.Delta.Reasoning != nil { + reasoningText += *choice.Delta.Reasoning + } + responseText += conv.AsString(choice.Delta.Content) } if streamResponse.Usage != nil { @@ -93,7 +97,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil } - return nil, responseText, usage + return nil, reasoningText + responseText, usage } func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { @@ -147,5 +151,6 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st TotalTokens: promptTokens + completionTokens, } } + return nil, &textResponse.Usage } diff --git a/relay/adaptor/openrouter/model.go b/relay/adaptor/openrouter/model.go new file mode 100644 index 00000000..581bc2cc --- /dev/null +++ b/relay/adaptor/openrouter/model.go @@ -0,0 +1,22 @@ +package openrouter + +// RequestProvider customize how your requests are routed using the provider object +// in the request body for Chat Completions and Completions. +// +// https://openrouter.ai/docs/features/provider-routing +type RequestProvider struct { + // Order is list of provider names to try in order (e.g. ["Anthropic", "OpenAI"]). Default: empty + Order []string `json:"order,omitempty"` + // AllowFallbacks is whether to allow backup providers when the primary is unavailable. Default: true + AllowFallbacks bool `json:"allow_fallbacks,omitempty"` + // RequireParameters is only use providers that support all parameters in your request. Default: false + RequireParameters bool `json:"require_parameters,omitempty"` + // DataCollection is control whether to use providers that may store data ("allow" or "deny"). Default: "allow" + DataCollection string `json:"data_collection,omitempty" binding:"omitempty,oneof=allow deny"` + // Ignore is list of provider names to skip for this request. Default: empty + Ignore []string `json:"ignore,omitempty"` + // Quantizations is list of quantization levels to filter by (e.g. ["int4", "int8"]). Default: empty + Quantizations []string `json:"quantizations,omitempty"` + // Sort is sort providers by price or throughput (e.g. "price" or "throughput"). Default: empty + Sort string `json:"sort,omitempty" binding:"omitempty,oneof=price throughput latency"` +} diff --git a/relay/model/general.go b/relay/model/general.go index 5f5968c8..c26688cd 100644 --- a/relay/model/general.go +++ b/relay/model/general.go @@ -1,5 +1,7 @@ package model +import "github.com/songquanpeng/one-api/relay/adaptor/openrouter" + type ResponseFormat struct { Type string `json:"type,omitempty"` JsonSchema *JSONSchema `json:"json_schema,omitempty"` @@ -66,6 +68,11 @@ type GeneralOpenAIRequest struct { // Others Instruction string `json:"instruction,omitempty"` NumCtx int `json:"num_ctx,omitempty"` + // ------------------------------------- + // Openrouter + // ------------------------------------- + Provider *openrouter.RequestProvider `json:"provider,omitempty"` + IncludeReasoning *bool `json:"include_reasoning,omitempty"` } func (r GeneralOpenAIRequest) ParseInput() []string { diff --git a/relay/model/message.go b/relay/model/message.go index 5ff7b7ae..8ab54732 100644 --- a/relay/model/message.go +++ b/relay/model/message.go @@ -1,12 +1,35 @@ package model type Message struct { - Role string `json:"role,omitempty"` - Content any `json:"content,omitempty"` - ReasoningContent any `json:"reasoning_content,omitempty"` - Name *string `json:"name,omitempty"` - ToolCalls []Tool `json:"tool_calls,omitempty"` - ToolCallId string `json:"tool_call_id,omitempty"` + Role string `json:"role,omitempty"` + // Content is a string or a list of objects + Content any `json:"content,omitempty"` + Name *string `json:"name,omitempty"` + ToolCalls []Tool `json:"tool_calls,omitempty"` + ToolCallId string `json:"tool_call_id,omitempty"` + Audio *messageAudio `json:"audio,omitempty"` + // ------------------------------------- + // Deepseek 专有的一些字段 + // https://api-docs.deepseek.com/api/create-chat-completion + // ------------------------------------- + // Prefix forces the model to begin its answer with the supplied prefix in the assistant message. + // To enable this feature, set base_url to "https://api.deepseek.com/beta". + Prefix *bool `json:"prefix,omitempty"` // ReasoningContent is Used for the deepseek-reasoner model in the Chat + // Prefix Completion feature as the input for the CoT in the last assistant message. + // When using this feature, the prefix parameter must be set to true. + ReasoningContent *string `json:"reasoning_content,omitempty"` + // ------------------------------------- + // Openrouter + // ------------------------------------- + Reasoning *string `json:"reasoning,omitempty"` + Refusal *bool `json:"refusal,omitempty"` +} + +type messageAudio struct { + Id string `json:"id"` + Data string `json:"data,omitempty"` + ExpiredAt int `json:"expired_at,omitempty"` + Transcript string `json:"transcript,omitempty"` } func (m Message) IsStringContent() bool { diff --git a/relay/model/misc.go b/relay/model/misc.go index fdba01ea..9d1f7e4f 100644 --- a/relay/model/misc.go +++ b/relay/model/misc.go @@ -4,14 +4,12 @@ type Usage struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` - - CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"` -} - -type CompletionTokensDetails struct { - ReasoningTokens int `json:"reasoning_tokens"` - AcceptedPredictionTokens int `json:"accepted_prediction_tokens"` - RejectedPredictionTokens int `json:"rejected_prediction_tokens"` + // PromptTokensDetails may be empty for some models + PromptTokensDetails *usagePromptTokensDetails `gorm:"-" json:"prompt_tokens_details,omitempty"` + // CompletionTokensDetails may be empty for some models + CompletionTokensDetails *usageCompletionTokensDetails `gorm:"-" json:"completion_tokens_details,omitempty"` + ServiceTier string `gorm:"-" json:"service_tier,omitempty"` + SystemFingerprint string `gorm:"-" json:"system_fingerprint,omitempty"` } type Error struct { @@ -25,3 +23,20 @@ type ErrorWithStatusCode struct { Error StatusCode int `json:"status_code"` } + +type usagePromptTokensDetails struct { + CachedTokens int `json:"cached_tokens"` + AudioTokens int `json:"audio_tokens"` + // TextTokens could be zero for pure text chats + TextTokens int `json:"text_tokens"` + ImageTokens int `json:"image_tokens"` +} + +type usageCompletionTokensDetails struct { + ReasoningTokens int `json:"reasoning_tokens"` + AudioTokens int `json:"audio_tokens"` + AcceptedPredictionTokens int `json:"accepted_prediction_tokens"` + RejectedPredictionTokens int `json:"rejected_prediction_tokens"` + // TextTokens could be zero for pure text chats + TextTokens int `json:"text_tokens"` +} From 5ba60433d7b6c99e19b029da3813d7904148fcc2 Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Wed, 19 Feb 2025 08:10:04 +0000 Subject: [PATCH 02/10] feat: enhance reasoning token handling in OpenAI adaptor --- relay/adaptor/openai/main.go | 9 +++++++++ relay/controller/helper.go | 3 +++ 2 files changed, 12 insertions(+) diff --git a/relay/adaptor/openai/main.go b/relay/adaptor/openai/main.go index 545da981..85230633 100644 --- a/relay/adaptor/openai/main.go +++ b/relay/adaptor/openai/main.go @@ -64,6 +64,9 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E if choice.Delta.Reasoning != nil { reasoningText += *choice.Delta.Reasoning } + if choice.Delta.ReasoningContent != nil { + reasoningText += *choice.Delta.ReasoningContent + } responseText += conv.AsString(choice.Delta.Content) } @@ -97,6 +100,12 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil } + // If there is no reasoning tokens in the completion, we should count the reasoning tokens in the response. + if len(reasoningText) > 0 && + (usage.CompletionTokensDetails == nil || usage.CompletionTokensDetails.ReasoningTokens == 0) { + usage.CompletionTokens += CountToken(reasoningText) + } + return nil, reasoningText + responseText, usage } diff --git a/relay/controller/helper.go b/relay/controller/helper.go index 5b6f023f..3cbd90c4 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -102,6 +102,9 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M var quota int64 completionRatio := billingratio.GetCompletionRatio(textRequest.Model, meta.ChannelType) promptTokens := usage.PromptTokens + // It appears that DeepSeek's official service automatically merges ReasoningTokens into CompletionTokens, + // but the behavior of third-party providers may differ, so for now we do not add them manually. + // completionTokens := usage.CompletionTokens + usage.CompletionTokensDetails.ReasoningTokens completionTokens := usage.CompletionTokens quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) if ratio != 0 && quota <= 0 { From 1a6812182bf6fdb755dc9269f766e1ff32b91971 Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Wed, 19 Feb 2025 09:13:17 +0000 Subject: [PATCH 03/10] fix: improve reasoning token counting in OpenAI adaptor --- relay/adaptor/openai/main.go | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/relay/adaptor/openai/main.go b/relay/adaptor/openai/main.go index 85230633..ea153d08 100644 --- a/relay/adaptor/openai/main.go +++ b/relay/adaptor/openai/main.go @@ -100,12 +100,6 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil } - // If there is no reasoning tokens in the completion, we should count the reasoning tokens in the response. - if len(reasoningText) > 0 && - (usage.CompletionTokensDetails == nil || usage.CompletionTokensDetails.ReasoningTokens == 0) { - usage.CompletionTokens += CountToken(reasoningText) - } - return nil, reasoningText + responseText, usage } @@ -149,10 +143,17 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } - if textResponse.Usage.TotalTokens == 0 || (textResponse.Usage.PromptTokens == 0 && textResponse.Usage.CompletionTokens == 0) { + if textResponse.Usage.TotalTokens == 0 || + (textResponse.Usage.PromptTokens == 0 && textResponse.Usage.CompletionTokens == 0) { completionTokens := 0 for _, choice := range textResponse.Choices { completionTokens += CountTokenText(choice.Message.StringContent(), modelName) + if choice.Message.Reasoning != nil { + completionTokens += CountToken(*choice.Message.Reasoning) + } + if choice.ReasoningContent != nil { + completionTokens += CountToken(*choice.ReasoningContent) + } } textResponse.Usage = model.Usage{ PromptTokens: promptTokens, From 7ec33793b74389a19f352ee2aaa06df54fb0431c Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Thu, 20 Feb 2025 01:51:19 +0000 Subject: [PATCH 04/10] feat: add OpenrouterProviderSort configuration for provider sorting --- common/config/config.go | 3 +++ relay/adaptor/openai/adaptor.go | 5 +++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/common/config/config.go b/common/config/config.go index a235a8df..591943ff 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -164,3 +164,6 @@ var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30) var EnforceIncludeUsage = env.Bool("ENFORCE_INCLUDE_USAGE", false) var TestPrompt = env.String("TEST_PROMPT", "Output only your specific model name with no additional text.") + +// OpenrouterProviderSort is used to determine the order of the providers in the openrouter +var OpenrouterProviderSort = env.String("OPENROUTER_PROVIDER_SORT", "") diff --git a/relay/adaptor/openai/adaptor.go b/relay/adaptor/openai/adaptor.go index 03bd3c91..4e44e21b 100644 --- a/relay/adaptor/openai/adaptor.go +++ b/relay/adaptor/openai/adaptor.go @@ -94,12 +94,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G case channeltype.OpenRouter: includeReasoning := true request.IncludeReasoning = &includeReasoning - if request.Provider == nil || request.Provider.Sort == "" { + if request.Provider == nil || request.Provider.Sort == "" && + config.OpenrouterProviderSort != "" { if request.Provider == nil { request.Provider = &openrouter.RequestProvider{} } - request.Provider.Sort = "throughput" + request.Provider.Sort = config.OpenrouterProviderSort } default: } From 95527d76efe9b269d5d7a647f6dd58eae6098d8f Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Tue, 25 Feb 2025 01:27:07 +0000 Subject: [PATCH 05/10] feat: update model list and pricing for Claude 3.7 versions --- relay/adaptor/anthropic/constants.go | 6 ++++-- relay/adaptor/aws/claude/main.go | 2 ++ relay/adaptor/vertexai/claude/adapter.go | 1 + relay/billing/ratio/model.go | 2 ++ 4 files changed, 9 insertions(+), 2 deletions(-) diff --git a/relay/adaptor/anthropic/constants.go b/relay/adaptor/anthropic/constants.go index 9b515c1c..b3a05ee8 100644 --- a/relay/adaptor/anthropic/constants.go +++ b/relay/adaptor/anthropic/constants.go @@ -3,11 +3,13 @@ package anthropic var ModelList = []string{ "claude-instant-1.2", "claude-2.0", "claude-2.1", "claude-3-haiku-20240307", - "claude-3-5-haiku-20241022", "claude-3-5-haiku-latest", + "claude-3-5-haiku-20241022", "claude-3-sonnet-20240229", "claude-3-opus-20240229", + "claude-3-5-sonnet-latest", "claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20241022", - "claude-3-5-sonnet-latest", + "claude-3-7-sonnet-latest", + "claude-3-7-sonnet-20250219", } diff --git a/relay/adaptor/aws/claude/main.go b/relay/adaptor/aws/claude/main.go index 3fe3dfd8..378acdda 100644 --- a/relay/adaptor/aws/claude/main.go +++ b/relay/adaptor/aws/claude/main.go @@ -36,6 +36,8 @@ var AwsModelIDMap = map[string]string{ "claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0", "claude-3-5-sonnet-latest": "anthropic.claude-3-5-sonnet-20241022-v2:0", "claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0", + "claude-3-7-sonnet-latest": "anthropic.claude-3-7-sonnet-20250219-v1:0", + "claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0", } func awsModelID(requestModel string) (string, error) { diff --git a/relay/adaptor/vertexai/claude/adapter.go b/relay/adaptor/vertexai/claude/adapter.go index cb911cfe..554ada45 100644 --- a/relay/adaptor/vertexai/claude/adapter.go +++ b/relay/adaptor/vertexai/claude/adapter.go @@ -19,6 +19,7 @@ var ModelList = []string{ "claude-3-5-sonnet@20240620", "claude-3-5-sonnet-v2@20241022", "claude-3-5-haiku@20241022", + "claude-3-7-sonnet@20250219", } const anthropicVersion = "vertex-2023-10-16" diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index e8b3b615..6ef2a457 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -98,6 +98,8 @@ var ModelRatio = map[string]float64{ "claude-3-5-sonnet-20240620": 3.0 / 1000 * USD, "claude-3-5-sonnet-20241022": 3.0 / 1000 * USD, "claude-3-5-sonnet-latest": 3.0 / 1000 * USD, + "claude-3-7-sonnet-20250219": 3.0 / 1000 * USD, + "claude-3-7-sonnet-latest": 3.0 / 1000 * USD, "claude-3-opus-20240229": 15.0 / 1000 * USD, // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 "ERNIE-4.0-8K": 0.120 * RMB, From 3a8924d7aff617c79fde8cbd7156947d3e03ba38 Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Tue, 25 Feb 2025 02:57:37 +0000 Subject: [PATCH 06/10] feat: add support for extended reasoning in Claude 3.7 model --- relay/adaptor/anthropic/adaptor.go | 4 ++-- relay/adaptor/anthropic/main.go | 38 ++++++++++++++++++++++++++---- relay/adaptor/anthropic/model.go | 8 +++++++ relay/model/general.go | 10 ++++++++ 4 files changed, 54 insertions(+), 6 deletions(-) diff --git a/relay/adaptor/anthropic/adaptor.go b/relay/adaptor/anthropic/adaptor.go index bd0949be..fe4e2ef0 100644 --- a/relay/adaptor/anthropic/adaptor.go +++ b/relay/adaptor/anthropic/adaptor.go @@ -36,8 +36,8 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me // https://x.com/alexalbert__/status/1812921642143900036 // claude-3-5-sonnet can support 8k context - if strings.HasPrefix(meta.ActualModelName, "claude-3-5-sonnet") { - req.Header.Set("anthropic-beta", "max-tokens-3-5-sonnet-2024-07-15") + if strings.HasPrefix(meta.ActualModelName, "claude-3-7-sonnet") { + req.Header.Set("anthropic-beta", "output-128k-2025-02-19") } return nil diff --git a/relay/adaptor/anthropic/main.go b/relay/adaptor/anthropic/main.go index d3e306c8..9601164b 100644 --- a/relay/adaptor/anthropic/main.go +++ b/relay/adaptor/anthropic/main.go @@ -4,11 +4,12 @@ import ( "bufio" "encoding/json" "fmt" - "github.com/songquanpeng/one-api/common/render" "io" "net/http" "strings" + "github.com/songquanpeng/one-api/common/render" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" @@ -61,6 +62,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { TopK: textRequest.TopK, Stream: textRequest.Stream, Tools: claudeTools, + Thinking: textRequest.Thinking, } if len(claudeTools) > 0 { claudeToolChoice := struct { @@ -149,6 +151,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) { var response *Response var responseText string + var reasoningText string var stopReason string tools := make([]model.Tool, 0) @@ -158,6 +161,10 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo case "content_block_start": if claudeResponse.ContentBlock != nil { responseText = claudeResponse.ContentBlock.Text + if claudeResponse.ContentBlock.Thinking != nil { + reasoningText = *claudeResponse.ContentBlock.Thinking + } + if claudeResponse.ContentBlock.Type == "tool_use" { tools = append(tools, model.Tool{ Id: claudeResponse.ContentBlock.Id, @@ -172,6 +179,10 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo case "content_block_delta": if claudeResponse.Delta != nil { responseText = claudeResponse.Delta.Text + if claudeResponse.Delta.Thinking != nil { + reasoningText = *claudeResponse.Delta.Thinking + } + if claudeResponse.Delta.Type == "input_json_delta" { tools = append(tools, model.Tool{ Function: model.Function{ @@ -189,9 +200,20 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil { stopReason = *claudeResponse.Delta.StopReason } + case "thinking_delta": + if claudeResponse.Delta != nil && claudeResponse.Delta.Thinking != nil { + reasoningText = *claudeResponse.Delta.Thinking + } + case "ping", + "message_stop", + "content_block_stop": + default: + logger.SysErrorf("unknown stream response type %q", claudeResponse.Type) } + var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = responseText + choice.Delta.Reasoning = &reasoningText if len(tools) > 0 { choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ... choice.Delta.ToolCalls = tools @@ -209,11 +231,15 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { var responseText string - if len(claudeResponse.Content) > 0 { - responseText = claudeResponse.Content[0].Text - } + var reasoningText string + tools := make([]model.Tool, 0) for _, v := range claudeResponse.Content { + reasoningText += v.Text + if v.Thinking != nil { + reasoningText += *v.Thinking + } + if v.Type == "tool_use" { args, _ := json.Marshal(v.Input) tools = append(tools, model.Tool{ @@ -231,6 +257,7 @@ func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { Message: model.Message{ Role: "assistant", Content: responseText, + Reasoning: &reasoningText, Name: nil, ToolCalls: tools, }, @@ -277,6 +304,8 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC data = strings.TrimPrefix(data, "data:") data = strings.TrimSpace(data) + logger.Debugf(c.Request.Context(), "stream <- %q\n", data) + var claudeResponse StreamResponse err := json.Unmarshal([]byte(data), &claudeResponse) if err != nil { @@ -344,6 +373,7 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } + var claudeResponse Response err = json.Unmarshal(responseBody, &claudeResponse) if err != nil { diff --git a/relay/adaptor/anthropic/model.go b/relay/adaptor/anthropic/model.go index 47f193fa..6dd299c4 100644 --- a/relay/adaptor/anthropic/model.go +++ b/relay/adaptor/anthropic/model.go @@ -1,5 +1,7 @@ package anthropic +import "github.com/songquanpeng/one-api/relay/model" + // https://docs.anthropic.com/claude/reference/messages_post type Metadata struct { @@ -22,6 +24,9 @@ type Content struct { Input any `json:"input,omitempty"` Content string `json:"content,omitempty"` ToolUseId string `json:"tool_use_id,omitempty"` + // https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#implementing-extended-thinking + Thinking *string `json:"thinking,omitempty"` + Signature *string `json:"signature,omitempty"` } type Message struct { @@ -54,6 +59,7 @@ type Request struct { Tools []Tool `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` //Metadata `json:"metadata,omitempty"` + Thinking *model.Thinking `json:"thinking,omitempty"` } type Usage struct { @@ -84,6 +90,8 @@ type Delta struct { PartialJson string `json:"partial_json,omitempty"` StopReason *string `json:"stop_reason"` StopSequence *string `json:"stop_sequence"` + Thinking *string `json:"thinking,omitempty"` + Signature *string `json:"signature,omitempty"` } type StreamResponse struct { diff --git a/relay/model/general.go b/relay/model/general.go index c26688cd..a87928bd 100644 --- a/relay/model/general.go +++ b/relay/model/general.go @@ -73,6 +73,16 @@ type GeneralOpenAIRequest struct { // ------------------------------------- Provider *openrouter.RequestProvider `json:"provider,omitempty"` IncludeReasoning *bool `json:"include_reasoning,omitempty"` + // ------------------------------------- + // Anthropic + // ------------------------------------- + Thinking *Thinking `json:"thinking,omitempty"` +} + +// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#implementing-extended-thinking +type Thinking struct { + Type string `json:"type"` + BudgetTokens int `json:"budget_tokens" binding:"omitempty,min=1024"` } func (r GeneralOpenAIRequest) ParseInput() []string { From c61d6440f9be0f9bbcdd18b42b9dd6a51a690c89 Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Tue, 25 Feb 2025 03:13:21 +0000 Subject: [PATCH 07/10] fix: claude thinking for non-stream mode --- README.md | 2 +- relay/adaptor/anthropic/main.go | 18 +++++++++++++++--- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 4a7d1ae8..bf1891a5 100644 --- a/README.md +++ b/README.md @@ -385,7 +385,7 @@ graph LR + 例子:`NODE_TYPE=slave` 9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 + 例子:`CHANNEL_UPDATE_FREQUENCY=1440` -10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 +10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 +例子:`CHANNEL_TEST_FREQUENCY=1440` 11. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 + 例子:`POLLING_INTERVAL=5` diff --git a/relay/adaptor/anthropic/main.go b/relay/adaptor/anthropic/main.go index 9601164b..b906e68d 100644 --- a/relay/adaptor/anthropic/main.go +++ b/relay/adaptor/anthropic/main.go @@ -2,6 +2,7 @@ package anthropic import ( "bufio" + "context" "encoding/json" "fmt" "io" @@ -235,9 +236,17 @@ func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { tools := make([]model.Tool, 0) for _, v := range claudeResponse.Content { - reasoningText += v.Text - if v.Thinking != nil { - reasoningText += *v.Thinking + switch v.Type { + case "thinking": + if v.Thinking != nil { + reasoningText += *v.Thinking + } else { + logger.Errorf(context.Background(), "thinking is nil in response") + } + case "text": + responseText += v.Text + default: + logger.Warnf(context.Background(), "unknown response type %q", v.Type) } if v.Type == "tool_use" { @@ -252,6 +261,7 @@ func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { }) } } + choice := openai.TextResponseChoice{ Index: 0, Message: model.Message{ @@ -374,6 +384,8 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } + logger.Debugf(c.Request.Context(), "response <- %s\n", string(responseBody)) + var claudeResponse Response err = json.Unmarshal(responseBody, &claudeResponse) if err != nil { From de10e102bd0cfc52e04e4b317d1b88c809a0e9fd Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Mon, 10 Mar 2025 06:37:42 +0000 Subject: [PATCH 08/10] feat: add support for aws's cross region inferences closes #2024, closes #2145 --- relay/adaptor/anthropic/adaptor.go | 2 +- relay/adaptor/anthropic/main.go | 37 ++++++++++-- relay/adaptor/aws/claude/adapter.go | 6 +- relay/adaptor/aws/claude/main.go | 5 +- relay/adaptor/aws/llama3/main.go | 5 +- relay/adaptor/aws/utils/consts.go | 75 ++++++++++++++++++++++++ relay/adaptor/vertexai/claude/adapter.go | 6 +- 7 files changed, 125 insertions(+), 11 deletions(-) create mode 100644 relay/adaptor/aws/utils/consts.go diff --git a/relay/adaptor/anthropic/adaptor.go b/relay/adaptor/anthropic/adaptor.go index fe4e2ef0..7acee1f1 100644 --- a/relay/adaptor/anthropic/adaptor.go +++ b/relay/adaptor/anthropic/adaptor.go @@ -47,7 +47,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G if request == nil { return nil, errors.New("request is nil") } - return ConvertRequest(*request), nil + return ConvertRequest(c, *request) } func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { diff --git a/relay/adaptor/anthropic/main.go b/relay/adaptor/anthropic/main.go index b906e68d..65aaefa0 100644 --- a/relay/adaptor/anthropic/main.go +++ b/relay/adaptor/anthropic/main.go @@ -6,16 +6,17 @@ import ( "encoding/json" "fmt" "io" + "math" "net/http" "strings" - "github.com/songquanpeng/one-api/common/render" - "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/image" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/render" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/model" ) @@ -38,7 +39,16 @@ func stopReasonClaude2OpenAI(reason *string) string { } } -func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { +// isModelSupportThinking is used to check if the model supports extended thinking +func isModelSupportThinking(model string) bool { + if strings.Contains(model, "claude-3-7-sonnet") { + return true + } + + return false +} + +func ConvertRequest(c *gin.Context, textRequest model.GeneralOpenAIRequest) (*Request, error) { claudeTools := make([]Tool, 0, len(textRequest.Tools)) for _, tool := range textRequest.Tools { @@ -65,6 +75,25 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { Tools: claudeTools, Thinking: textRequest.Thinking, } + + if isModelSupportThinking(textRequest.Model) && + c.Request.URL.Query().Has("thinking") && claudeRequest.Thinking == nil { + claudeRequest.Thinking = &model.Thinking{ + Type: "enabled", + BudgetTokens: int(math.Min(1024, float64(claudeRequest.MaxTokens/2))), + } + } + + if isModelSupportThinking(textRequest.Model) && + claudeRequest.Thinking != nil { + if claudeRequest.MaxTokens <= 1024 { + return nil, errors.New("max_tokens must be greater than 1024 when using extended thinking") + } + + // top_p must be nil when using extended thinking + claudeRequest.TopP = nil + } + if len(claudeTools) > 0 { claudeToolChoice := struct { Type string `json:"type"` @@ -145,7 +174,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { claudeMessage.Content = contents claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage) } - return &claudeRequest + return &claudeRequest, nil } // https://docs.anthropic.com/claude/reference/messages-streaming diff --git a/relay/adaptor/aws/claude/adapter.go b/relay/adaptor/aws/claude/adapter.go index eb3c9fb8..2f6a4cc5 100644 --- a/relay/adaptor/aws/claude/adapter.go +++ b/relay/adaptor/aws/claude/adapter.go @@ -21,7 +21,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return nil, errors.New("request is nil") } - claudeReq := anthropic.ConvertRequest(*request) + claudeReq, err := anthropic.ConvertRequest(c, *request) + if err != nil { + return nil, errors.Wrap(err, "convert request") + } + c.Set(ctxkey.RequestModel, request.Model) c.Set(ctxkey.ConvertedRequest, claudeReq) return claudeReq, nil diff --git a/relay/adaptor/aws/claude/main.go b/relay/adaptor/aws/claude/main.go index 378acdda..da000b58 100644 --- a/relay/adaptor/aws/claude/main.go +++ b/relay/adaptor/aws/claude/main.go @@ -49,13 +49,14 @@ func awsModelID(requestModel string) (string, error) { } func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { - awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) + awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel)) if err != nil { return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil } + awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region) awsReq := &bedrockruntime.InvokeModelInput{ - ModelId: aws.String(awsModelId), + ModelId: aws.String(awsModelID), Accept: aws.String("application/json"), ContentType: aws.String("application/json"), } diff --git a/relay/adaptor/aws/llama3/main.go b/relay/adaptor/aws/llama3/main.go index e5fcd89f..aff3e0cf 100644 --- a/relay/adaptor/aws/llama3/main.go +++ b/relay/adaptor/aws/llama3/main.go @@ -70,13 +70,14 @@ func ConvertRequest(textRequest relaymodel.GeneralOpenAIRequest) *Request { } func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { - awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) + awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel)) if err != nil { return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil } + awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region) awsReq := &bedrockruntime.InvokeModelInput{ - ModelId: aws.String(awsModelId), + ModelId: aws.String(awsModelID), Accept: aws.String("application/json"), ContentType: aws.String("application/json"), } diff --git a/relay/adaptor/aws/utils/consts.go b/relay/adaptor/aws/utils/consts.go new file mode 100644 index 00000000..c91f342e --- /dev/null +++ b/relay/adaptor/aws/utils/consts.go @@ -0,0 +1,75 @@ +package utils + +import ( + "context" + "slices" + "strings" + + "github.com/songquanpeng/one-api/common/logger" +) + +// CrossRegionInferences is a list of model IDs that support cross-region inference. +// +// https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html +// +// document.querySelectorAll('pre.programlisting code').forEach((e) => {console.log(e.innerHTML)}) +var CrossRegionInferences = []string{ + "us.amazon.nova-lite-v1:0", + "us.amazon.nova-micro-v1:0", + "us.amazon.nova-pro-v1:0", + "us.anthropic.claude-3-5-haiku-20241022-v1:0", + "us.anthropic.claude-3-5-sonnet-20240620-v1:0", + "us.anthropic.claude-3-5-sonnet-20241022-v2:0", + "us.anthropic.claude-3-7-sonnet-20250219-v1:0", + "us.anthropic.claude-3-haiku-20240307-v1:0", + "us.anthropic.claude-3-opus-20240229-v1:0", + "us.anthropic.claude-3-sonnet-20240229-v1:0", + "us.meta.llama3-1-405b-instruct-v1:0", + "us.meta.llama3-1-70b-instruct-v1:0", + "us.meta.llama3-1-8b-instruct-v1:0", + "us.meta.llama3-2-11b-instruct-v1:0", + "us.meta.llama3-2-1b-instruct-v1:0", + "us.meta.llama3-2-3b-instruct-v1:0", + "us.meta.llama3-2-90b-instruct-v1:0", + "us.meta.llama3-3-70b-instruct-v1:0", + "us-gov.anthropic.claude-3-5-sonnet-20240620-v1:0", + "us-gov.anthropic.claude-3-haiku-20240307-v1:0", + "eu.amazon.nova-lite-v1:0", + "eu.amazon.nova-micro-v1:0", + "eu.amazon.nova-pro-v1:0", + "eu.anthropic.claude-3-5-sonnet-20240620-v1:0", + "eu.anthropic.claude-3-haiku-20240307-v1:0", + "eu.anthropic.claude-3-sonnet-20240229-v1:0", + "eu.meta.llama3-2-1b-instruct-v1:0", + "eu.meta.llama3-2-3b-instruct-v1:0", + "apac.amazon.nova-lite-v1:0", + "apac.amazon.nova-micro-v1:0", + "apac.amazon.nova-pro-v1:0", + "apac.anthropic.claude-3-5-sonnet-20240620-v1:0", + "apac.anthropic.claude-3-5-sonnet-20241022-v2:0", + "apac.anthropic.claude-3-haiku-20240307-v1:0", + "apac.anthropic.claude-3-sonnet-20240229-v1:0", +} + +// ConvertModelID2CrossRegionProfile converts the model ID to a cross-region profile ID. +func ConvertModelID2CrossRegionProfile(model, region string) string { + var regionPrefix string + switch prefix := strings.Split(region, "-")[0]; prefix { + case "us", "eu": + regionPrefix = prefix + case "ap": + regionPrefix = "apac" + default: + // not supported, return original model + return model + } + + newModelID := regionPrefix + "." + model + if slices.Contains(CrossRegionInferences, newModelID) { + logger.Debugf(context.TODO(), "convert model %s to cross-region profile %s", model, newModelID) + return newModelID + } + + // not found, return original model + return model +} diff --git a/relay/adaptor/vertexai/claude/adapter.go b/relay/adaptor/vertexai/claude/adapter.go index 554ada45..f591e447 100644 --- a/relay/adaptor/vertexai/claude/adapter.go +++ b/relay/adaptor/vertexai/claude/adapter.go @@ -32,7 +32,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return nil, errors.New("request is nil") } - claudeReq := anthropic.ConvertRequest(*request) + claudeReq, err := anthropic.ConvertRequest(c, *request) + if err != nil { + return nil, errors.Wrap(err, "convert request") + } + req := Request{ AnthropicVersion: anthropicVersion, // Model: claudeReq.Model, From a0d7d5a965a2c6efb4ef5500a07d7d0b34181fe8 Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Mon, 10 Mar 2025 06:59:21 +0000 Subject: [PATCH 09/10] fix: support thinking for aws claude --- Dockerfile | 2 +- relay/adaptor/aws/claude/model.go | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index 346d9c5b..66edb360 100644 --- a/Dockerfile +++ b/Dockerfile @@ -44,4 +44,4 @@ COPY --from=builder2 /build/one-api / EXPOSE 3000 WORKDIR /data -ENTRYPOINT ["/one-api"] \ No newline at end of file +ENTRYPOINT ["/one-api"] diff --git a/relay/adaptor/aws/claude/model.go b/relay/adaptor/aws/claude/model.go index 10622887..b0dd6800 100644 --- a/relay/adaptor/aws/claude/model.go +++ b/relay/adaptor/aws/claude/model.go @@ -1,6 +1,9 @@ package aws -import "github.com/songquanpeng/one-api/relay/adaptor/anthropic" +import ( + "github.com/songquanpeng/one-api/relay/adaptor/anthropic" + "github.com/songquanpeng/one-api/relay/model" +) // Request is the request to AWS Claude // @@ -17,4 +20,5 @@ type Request struct { StopSequences []string `json:"stop_sequences,omitempty"` Tools []anthropic.Tool `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` + Thinking *model.Thinking `json:"thinking,omitempty"` } From 6e634b85cf48d3b21c2d8464678b5432cddc1bb6 Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Wed, 12 Mar 2025 00:34:09 +0000 Subject: [PATCH 10/10] fix: update StreamHandler to support cross-region model IDs for AWS --- relay/adaptor/aws/claude/main.go | 5 +++-- relay/adaptor/aws/llama3/main.go | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/relay/adaptor/aws/claude/main.go b/relay/adaptor/aws/claude/main.go index da000b58..c20827b0 100644 --- a/relay/adaptor/aws/claude/main.go +++ b/relay/adaptor/aws/claude/main.go @@ -104,13 +104,14 @@ func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (* func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { createdTime := helper.GetTimestamp() - awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) + awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel)) if err != nil { return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil } + awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region) awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ - ModelId: aws.String(awsModelId), + ModelId: aws.String(awsModelID), Accept: aws.String("application/json"), ContentType: aws.String("application/json"), } diff --git a/relay/adaptor/aws/llama3/main.go b/relay/adaptor/aws/llama3/main.go index aff3e0cf..76b06f91 100644 --- a/relay/adaptor/aws/llama3/main.go +++ b/relay/adaptor/aws/llama3/main.go @@ -141,13 +141,14 @@ func ResponseLlama2OpenAI(llamaResponse *Response) *openai.TextResponse { func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { createdTime := helper.GetTimestamp() - awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) + awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel)) if err != nil { return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil } + awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region) awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ - ModelId: aws.String(awsModelId), + ModelId: aws.String(awsModelID), Accept: aws.String("application/json"), ContentType: aws.String("application/json"), }