From a5f5e85c44b3c0acd3ee14ed8004122da94d139d Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Wed, 19 Feb 2025 01:11:46 +0000 Subject: [PATCH] feat: support OpenRouter reasoning --- common/conv/any.go | 7 +++++-- relay/adaptor/openai/adaptor.go | 16 ++++++++++++++ relay/adaptor/openai/main.go | 35 ++++++++++++++++++++----------- relay/adaptor/openrouter/model.go | 22 +++++++++++++++++++ relay/model/general.go | 5 +++++ relay/model/message.go | 5 +++++ relay/model/misc.go | 8 +++---- 7 files changed, 80 insertions(+), 18 deletions(-) create mode 100644 relay/adaptor/openrouter/model.go diff --git a/common/conv/any.go b/common/conv/any.go index 467e8bb7..33d34aa7 100644 --- a/common/conv/any.go +++ b/common/conv/any.go @@ -1,6 +1,9 @@ package conv func AsString(v any) string { - str, _ := v.(string) - return str + if str, ok := v.(string); ok { + return str + } + + return "" } diff --git a/relay/adaptor/openai/adaptor.go b/relay/adaptor/openai/adaptor.go index 7553dce3..2361462a 100644 --- a/relay/adaptor/openai/adaptor.go +++ b/relay/adaptor/openai/adaptor.go @@ -17,6 +17,7 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/geminiv2" "github.com/songquanpeng/one-api/relay/adaptor/minimax" "github.com/songquanpeng/one-api/relay/adaptor/novita" + "github.com/songquanpeng/one-api/relay/adaptor/openrouter" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" @@ -95,6 +96,21 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return nil, errors.New("request is nil") } + meta := meta.GetByContext(c) + switch meta.ChannelType { + case channeltype.OpenRouter: + includeReasoning := true + request.IncludeReasoning = &includeReasoning + if request.Provider == nil || request.Provider.Sort == "" { + if request.Provider == nil { + request.Provider = &openrouter.RequestProvider{} + } + + request.Provider.Sort = "throughput" + } + default: + } + if request.Stream && !config.EnforceIncludeUsage { logger.Warn(c.Request.Context(), "please set ENFORCE_INCLUDE_USAGE=true to ensure accurate billing in stream mode") diff --git a/relay/adaptor/openai/main.go b/relay/adaptor/openai/main.go index f986ed09..5065a5b1 100644 --- a/relay/adaptor/openai/main.go +++ b/relay/adaptor/openai/main.go @@ -27,6 +27,7 @@ const ( func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) { responseText := "" + reasoningText := "" scanner := bufio.NewScanner(resp.Body) scanner.Split(bufio.ScanLines) var usage *model.Usage @@ -62,6 +63,10 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E } render.StringData(c, data) for _, choice := range streamResponse.Choices { + if choice.Delta.Reasoning != nil { + reasoningText += *choice.Delta.Reasoning + } + responseText += conv.AsString(choice.Delta.Content) } if streamResponse.Usage != nil { @@ -94,7 +99,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil } - return nil, responseText, usage + return nil, reasoningText + responseText, usage } // Handler handles the non-stream response from OpenAI API @@ -150,20 +155,26 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st CompletionTokens: completionTokens, TotalTokens: promptTokens + completionTokens, } - } else if textResponse.PromptTokensDetails.AudioTokens+textResponse.CompletionTokensDetails.AudioTokens > 0 { + } else if (textResponse.PromptTokensDetails != nil && textResponse.PromptTokensDetails.AudioTokens > 0) || + (textResponse.CompletionTokensDetails != nil && textResponse.CompletionTokensDetails.AudioTokens > 0) { // Convert the more expensive audio tokens to uniformly priced text tokens. // Note that when there are no audio tokens in prompt and completion, // OpenAI will return empty PromptTokensDetails and CompletionTokensDetails, which can be misleading. - textResponse.Usage.PromptTokens = textResponse.PromptTokensDetails.TextTokens + - int(math.Ceil( - float64(textResponse.PromptTokensDetails.AudioTokens)* - ratio.GetAudioPromptRatio(modelName), - )) - textResponse.Usage.CompletionTokens = textResponse.CompletionTokensDetails.TextTokens + - int(math.Ceil( - float64(textResponse.CompletionTokensDetails.AudioTokens)* - ratio.GetAudioPromptRatio(modelName)*ratio.GetAudioCompletionRatio(modelName), - )) + if textResponse.PromptTokensDetails != nil { + textResponse.Usage.PromptTokens = textResponse.PromptTokensDetails.TextTokens + + int(math.Ceil( + float64(textResponse.PromptTokensDetails.AudioTokens)* + ratio.GetAudioPromptRatio(modelName), + )) + } + + if textResponse.CompletionTokensDetails != nil { + textResponse.Usage.CompletionTokens = textResponse.CompletionTokensDetails.TextTokens + + int(math.Ceil( + float64(textResponse.CompletionTokensDetails.AudioTokens)* + ratio.GetAudioPromptRatio(modelName)*ratio.GetAudioCompletionRatio(modelName), + )) + } textResponse.Usage.TotalTokens = textResponse.Usage.PromptTokens + textResponse.Usage.CompletionTokens diff --git a/relay/adaptor/openrouter/model.go b/relay/adaptor/openrouter/model.go new file mode 100644 index 00000000..581bc2cc --- /dev/null +++ b/relay/adaptor/openrouter/model.go @@ -0,0 +1,22 @@ +package openrouter + +// RequestProvider customize how your requests are routed using the provider object +// in the request body for Chat Completions and Completions. +// +// https://openrouter.ai/docs/features/provider-routing +type RequestProvider struct { + // Order is list of provider names to try in order (e.g. ["Anthropic", "OpenAI"]). Default: empty + Order []string `json:"order,omitempty"` + // AllowFallbacks is whether to allow backup providers when the primary is unavailable. Default: true + AllowFallbacks bool `json:"allow_fallbacks,omitempty"` + // RequireParameters is only use providers that support all parameters in your request. Default: false + RequireParameters bool `json:"require_parameters,omitempty"` + // DataCollection is control whether to use providers that may store data ("allow" or "deny"). Default: "allow" + DataCollection string `json:"data_collection,omitempty" binding:"omitempty,oneof=allow deny"` + // Ignore is list of provider names to skip for this request. Default: empty + Ignore []string `json:"ignore,omitempty"` + // Quantizations is list of quantization levels to filter by (e.g. ["int4", "int8"]). Default: empty + Quantizations []string `json:"quantizations,omitempty"` + // Sort is sort providers by price or throughput (e.g. "price" or "throughput"). Default: empty + Sort string `json:"sort,omitempty" binding:"omitempty,oneof=price throughput latency"` +} diff --git a/relay/model/general.go b/relay/model/general.go index cf7005a5..b7285456 100644 --- a/relay/model/general.go +++ b/relay/model/general.go @@ -1,5 +1,7 @@ package model +import "github.com/songquanpeng/one-api/relay/adaptor/openrouter" + type ResponseFormat struct { Type string `json:"type,omitempty"` JsonSchema *JSONSchema `json:"json_schema,omitempty"` @@ -68,6 +70,9 @@ type GeneralOpenAIRequest struct { // Others Instruction string `json:"instruction,omitempty"` NumCtx int `json:"num_ctx,omitempty"` + // Openrouter + Provider *openrouter.RequestProvider `json:"provider,omitempty"` + IncludeReasoning *bool `json:"include_reasoning,omitempty"` } func (r GeneralOpenAIRequest) ParseInput() []string { diff --git a/relay/model/message.go b/relay/model/message.go index bd30676b..9f088658 100644 --- a/relay/model/message.go +++ b/relay/model/message.go @@ -24,6 +24,11 @@ type Message struct { // Prefix Completion feature as the input for the CoT in the last assistant message. // When using this feature, the prefix parameter must be set to true. ReasoningContent *string `json:"reasoning_content,omitempty"` + // ------------------------------------- + // Openrouter + // ------------------------------------- + Reasoning *string `json:"reasoning,omitempty"` + Refusal *bool `json:"refusal,omitempty"` } type messageAudio struct { diff --git a/relay/model/misc.go b/relay/model/misc.go index 62c3fe6f..9d1f7e4f 100644 --- a/relay/model/misc.go +++ b/relay/model/misc.go @@ -5,11 +5,11 @@ type Usage struct { CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` // PromptTokensDetails may be empty for some models - PromptTokensDetails usagePromptTokensDetails `gorm:"-" json:"prompt_tokens_details,omitempty"` + PromptTokensDetails *usagePromptTokensDetails `gorm:"-" json:"prompt_tokens_details,omitempty"` // CompletionTokensDetails may be empty for some models - CompletionTokensDetails usageCompletionTokensDetails `gorm:"-" json:"completion_tokens_details,omitempty"` - ServiceTier string `gorm:"-" json:"service_tier,omitempty"` - SystemFingerprint string `gorm:"-" json:"system_fingerprint,omitempty"` + CompletionTokensDetails *usageCompletionTokensDetails `gorm:"-" json:"completion_tokens_details,omitempty"` + ServiceTier string `gorm:"-" json:"service_tier,omitempty"` + SystemFingerprint string `gorm:"-" json:"system_fingerprint,omitempty"` } type Error struct {