Merge pull request #6 from Sagit-chu/upstream-pr-2182

[Upstream] fix: support aws cross region inferences
This commit is contained in:
sagit 2025-03-18 17:58:27 +08:00 committed by GitHub
commit d2fdcc23d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 343 additions and 44 deletions

View File

@ -164,3 +164,6 @@ var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30)
var EnforceIncludeUsage = env.Bool("ENFORCE_INCLUDE_USAGE", false) var EnforceIncludeUsage = env.Bool("ENFORCE_INCLUDE_USAGE", false)
var TestPrompt = env.String("TEST_PROMPT", "Output only your specific model name with no additional text.") var TestPrompt = env.String("TEST_PROMPT", "Output only your specific model name with no additional text.")
// OpenrouterProviderSort is used to determine the order of the providers in the openrouter
var OpenrouterProviderSort = env.String("OPENROUTER_PROVIDER_SORT", "")

View File

@ -1,6 +1,9 @@
package conv package conv
func AsString(v any) string { func AsString(v any) string {
str, _ := v.(string) if str, ok := v.(string); ok {
return str return str
}
return ""
} }

View File

@ -36,8 +36,8 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *me
// https://x.com/alexalbert__/status/1812921642143900036 // https://x.com/alexalbert__/status/1812921642143900036
// claude-3-5-sonnet can support 8k context // claude-3-5-sonnet can support 8k context
if strings.HasPrefix(meta.ActualModelName, "claude-3-5-sonnet") { if strings.HasPrefix(meta.ActualModelName, "claude-3-7-sonnet") {
req.Header.Set("anthropic-beta", "max-tokens-3-5-sonnet-2024-07-15") req.Header.Set("anthropic-beta", "output-128k-2025-02-19")
} }
return nil return nil
@ -47,7 +47,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
return ConvertRequest(*request), nil return ConvertRequest(c, *request)
} }
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {

View File

@ -3,10 +3,11 @@ package anthropic
var ModelList = []string{ var ModelList = []string{
"claude-instant-1.2", "claude-2.0", "claude-2.1", "claude-instant-1.2", "claude-2.0", "claude-2.1",
"claude-3-haiku-20240307", "claude-3-haiku-20240307",
"claude-3-5-haiku-20241022",
"claude-3-5-haiku-latest", "claude-3-5-haiku-latest",
"claude-3-5-haiku-20241022",
"claude-3-sonnet-20240229", "claude-3-sonnet-20240229",
"claude-3-opus-20240229", "claude-3-opus-20240229",
"claude-3-5-sonnet-latest",
"claude-3-5-sonnet-20240620", "claude-3-5-sonnet-20240620",
"claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20241022",
"claude-3-5-sonnet-latest", "claude-3-5-sonnet-latest",

View File

@ -2,18 +2,21 @@ package anthropic
import ( import (
"bufio" "bufio"
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/render"
"io" "io"
"math"
"net/http" "net/http"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/image" "github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/render"
"github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
) )
@ -36,7 +39,16 @@ func stopReasonClaude2OpenAI(reason *string) string {
} }
} }
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { // isModelSupportThinking is used to check if the model supports extended thinking
func isModelSupportThinking(model string) bool {
if strings.Contains(model, "claude-3-7-sonnet") {
return true
}
return false
}
func ConvertRequest(c *gin.Context, textRequest model.GeneralOpenAIRequest) (*Request, error) {
claudeTools := make([]Tool, 0, len(textRequest.Tools)) claudeTools := make([]Tool, 0, len(textRequest.Tools))
for _, tool := range textRequest.Tools { for _, tool := range textRequest.Tools {
@ -61,7 +73,27 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
TopK: textRequest.TopK, TopK: textRequest.TopK,
Stream: textRequest.Stream, Stream: textRequest.Stream,
Tools: claudeTools, Tools: claudeTools,
Thinking: textRequest.Thinking,
} }
if isModelSupportThinking(textRequest.Model) &&
c.Request.URL.Query().Has("thinking") && claudeRequest.Thinking == nil {
claudeRequest.Thinking = &model.Thinking{
Type: "enabled",
BudgetTokens: int(math.Min(1024, float64(claudeRequest.MaxTokens/2))),
}
}
if isModelSupportThinking(textRequest.Model) &&
claudeRequest.Thinking != nil {
if claudeRequest.MaxTokens <= 1024 {
return nil, errors.New("max_tokens must be greater than 1024 when using extended thinking")
}
// top_p must be nil when using extended thinking
claudeRequest.TopP = nil
}
if len(claudeTools) > 0 { if len(claudeTools) > 0 {
claudeToolChoice := struct { claudeToolChoice := struct {
Type string `json:"type"` Type string `json:"type"`
@ -142,13 +174,14 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
claudeMessage.Content = contents claudeMessage.Content = contents
claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage) claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
} }
return &claudeRequest return &claudeRequest, nil
} }
// https://docs.anthropic.com/claude/reference/messages-streaming // https://docs.anthropic.com/claude/reference/messages-streaming
func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) { func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCompletionsStreamResponse, *Response) {
var response *Response var response *Response
var responseText string var responseText string
var reasoningText string
var stopReason string var stopReason string
tools := make([]model.Tool, 0) tools := make([]model.Tool, 0)
@ -158,6 +191,10 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
case "content_block_start": case "content_block_start":
if claudeResponse.ContentBlock != nil { if claudeResponse.ContentBlock != nil {
responseText = claudeResponse.ContentBlock.Text responseText = claudeResponse.ContentBlock.Text
if claudeResponse.ContentBlock.Thinking != nil {
reasoningText = *claudeResponse.ContentBlock.Thinking
}
if claudeResponse.ContentBlock.Type == "tool_use" { if claudeResponse.ContentBlock.Type == "tool_use" {
tools = append(tools, model.Tool{ tools = append(tools, model.Tool{
Id: claudeResponse.ContentBlock.Id, Id: claudeResponse.ContentBlock.Id,
@ -172,6 +209,10 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
case "content_block_delta": case "content_block_delta":
if claudeResponse.Delta != nil { if claudeResponse.Delta != nil {
responseText = claudeResponse.Delta.Text responseText = claudeResponse.Delta.Text
if claudeResponse.Delta.Thinking != nil {
reasoningText = *claudeResponse.Delta.Thinking
}
if claudeResponse.Delta.Type == "input_json_delta" { if claudeResponse.Delta.Type == "input_json_delta" {
tools = append(tools, model.Tool{ tools = append(tools, model.Tool{
Function: model.Function{ Function: model.Function{
@ -189,9 +230,20 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil { if claudeResponse.Delta != nil && claudeResponse.Delta.StopReason != nil {
stopReason = *claudeResponse.Delta.StopReason stopReason = *claudeResponse.Delta.StopReason
} }
case "thinking_delta":
if claudeResponse.Delta != nil && claudeResponse.Delta.Thinking != nil {
reasoningText = *claudeResponse.Delta.Thinking
} }
case "ping",
"message_stop",
"content_block_stop":
default:
logger.SysErrorf("unknown stream response type %q", claudeResponse.Type)
}
var choice openai.ChatCompletionsStreamResponseChoice var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = responseText choice.Delta.Content = responseText
choice.Delta.Reasoning = &reasoningText
if len(tools) > 0 { if len(tools) > 0 {
choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ... choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
choice.Delta.ToolCalls = tools choice.Delta.ToolCalls = tools
@ -209,11 +261,23 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
var responseText string var responseText string
if len(claudeResponse.Content) > 0 { var reasoningText string
responseText = claudeResponse.Content[0].Text
}
tools := make([]model.Tool, 0) tools := make([]model.Tool, 0)
for _, v := range claudeResponse.Content { for _, v := range claudeResponse.Content {
switch v.Type {
case "thinking":
if v.Thinking != nil {
reasoningText += *v.Thinking
} else {
logger.Errorf(context.Background(), "thinking is nil in response")
}
case "text":
responseText += v.Text
default:
logger.Warnf(context.Background(), "unknown response type %q", v.Type)
}
if v.Type == "tool_use" { if v.Type == "tool_use" {
args, _ := json.Marshal(v.Input) args, _ := json.Marshal(v.Input)
tools = append(tools, model.Tool{ tools = append(tools, model.Tool{
@ -226,11 +290,13 @@ func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
}) })
} }
} }
choice := openai.TextResponseChoice{ choice := openai.TextResponseChoice{
Index: 0, Index: 0,
Message: model.Message{ Message: model.Message{
Role: "assistant", Role: "assistant",
Content: responseText, Content: responseText,
Reasoning: &reasoningText,
Name: nil, Name: nil,
ToolCalls: tools, ToolCalls: tools,
}, },
@ -277,6 +343,8 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
data = strings.TrimPrefix(data, "data:") data = strings.TrimPrefix(data, "data:")
data = strings.TrimSpace(data) data = strings.TrimSpace(data)
logger.Debugf(c.Request.Context(), "stream <- %q\n", data)
var claudeResponse StreamResponse var claudeResponse StreamResponse
err := json.Unmarshal([]byte(data), &claudeResponse) err := json.Unmarshal([]byte(data), &claudeResponse)
if err != nil { if err != nil {
@ -344,6 +412,9 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
logger.Debugf(c.Request.Context(), "response <- %s\n", string(responseBody))
var claudeResponse Response var claudeResponse Response
err = json.Unmarshal(responseBody, &claudeResponse) err = json.Unmarshal(responseBody, &claudeResponse)
if err != nil { if err != nil {

View File

@ -1,5 +1,7 @@
package anthropic package anthropic
import "github.com/songquanpeng/one-api/relay/model"
// https://docs.anthropic.com/claude/reference/messages_post // https://docs.anthropic.com/claude/reference/messages_post
type Metadata struct { type Metadata struct {
@ -22,6 +24,9 @@ type Content struct {
Input any `json:"input,omitempty"` Input any `json:"input,omitempty"`
Content string `json:"content,omitempty"` Content string `json:"content,omitempty"`
ToolUseId string `json:"tool_use_id,omitempty"` ToolUseId string `json:"tool_use_id,omitempty"`
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#implementing-extended-thinking
Thinking *string `json:"thinking,omitempty"`
Signature *string `json:"signature,omitempty"`
} }
type Message struct { type Message struct {
@ -54,6 +59,7 @@ type Request struct {
Tools []Tool `json:"tools,omitempty"` Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"` ToolChoice any `json:"tool_choice,omitempty"`
//Metadata `json:"metadata,omitempty"` //Metadata `json:"metadata,omitempty"`
Thinking *model.Thinking `json:"thinking,omitempty"`
} }
type Usage struct { type Usage struct {
@ -84,6 +90,8 @@ type Delta struct {
PartialJson string `json:"partial_json,omitempty"` PartialJson string `json:"partial_json,omitempty"`
StopReason *string `json:"stop_reason"` StopReason *string `json:"stop_reason"`
StopSequence *string `json:"stop_sequence"` StopSequence *string `json:"stop_sequence"`
Thinking *string `json:"thinking,omitempty"`
Signature *string `json:"signature,omitempty"`
} }
type StreamResponse struct { type StreamResponse struct {

View File

@ -21,7 +21,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
claudeReq := anthropic.ConvertRequest(*request) claudeReq, err := anthropic.ConvertRequest(c, *request)
if err != nil {
return nil, errors.Wrap(err, "convert request")
}
c.Set(ctxkey.RequestModel, request.Model) c.Set(ctxkey.RequestModel, request.Model)
c.Set(ctxkey.ConvertedRequest, claudeReq) c.Set(ctxkey.ConvertedRequest, claudeReq)
return claudeReq, nil return claudeReq, nil

View File

@ -48,13 +48,14 @@ func awsModelID(requestModel string) (string, error) {
} }
func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil { if err != nil {
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
} }
awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region)
awsReq := &bedrockruntime.InvokeModelInput{ awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId), ModelId: aws.String(awsModelID),
Accept: aws.String("application/json"), Accept: aws.String("application/json"),
ContentType: aws.String("application/json"), ContentType: aws.String("application/json"),
} }
@ -102,13 +103,14 @@ func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*
func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
createdTime := helper.GetTimestamp() createdTime := helper.GetTimestamp()
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil { if err != nil {
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
} }
awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region)
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
ModelId: aws.String(awsModelId), ModelId: aws.String(awsModelID),
Accept: aws.String("application/json"), Accept: aws.String("application/json"),
ContentType: aws.String("application/json"), ContentType: aws.String("application/json"),
} }

View File

@ -1,6 +1,9 @@
package aws package aws
import "github.com/songquanpeng/one-api/relay/adaptor/anthropic" import (
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
"github.com/songquanpeng/one-api/relay/model"
)
// Request is the request to AWS Claude // Request is the request to AWS Claude
// //
@ -17,4 +20,5 @@ type Request struct {
StopSequences []string `json:"stop_sequences,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"`
Tools []anthropic.Tool `json:"tools,omitempty"` Tools []anthropic.Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"` ToolChoice any `json:"tool_choice,omitempty"`
Thinking *model.Thinking `json:"thinking,omitempty"`
} }

View File

@ -70,13 +70,14 @@ func ConvertRequest(textRequest relaymodel.GeneralOpenAIRequest) *Request {
} }
func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { func Handler(c *gin.Context, awsCli *bedrockruntime.Client, modelName string) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil { if err != nil {
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
} }
awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region)
awsReq := &bedrockruntime.InvokeModelInput{ awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId), ModelId: aws.String(awsModelID),
Accept: aws.String("application/json"), Accept: aws.String("application/json"),
ContentType: aws.String("application/json"), ContentType: aws.String("application/json"),
} }
@ -140,13 +141,14 @@ func ResponseLlama2OpenAI(llamaResponse *Response) *openai.TextResponse {
func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) { func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.ErrorWithStatusCode, *relaymodel.Usage) {
createdTime := helper.GetTimestamp() createdTime := helper.GetTimestamp()
awsModelId, err := awsModelID(c.GetString(ctxkey.RequestModel)) awsModelID, err := awsModelID(c.GetString(ctxkey.RequestModel))
if err != nil { if err != nil {
return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil return utils.WrapErr(errors.Wrap(err, "awsModelID")), nil
} }
awsModelID = utils.ConvertModelID2CrossRegionProfile(awsModelID, awsCli.Options().Region)
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{ awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
ModelId: aws.String(awsModelId), ModelId: aws.String(awsModelID),
Accept: aws.String("application/json"), Accept: aws.String("application/json"),
ContentType: aws.String("application/json"), ContentType: aws.String("application/json"),
} }

View File

@ -0,0 +1,75 @@
package utils
import (
"context"
"slices"
"strings"
"github.com/songquanpeng/one-api/common/logger"
)
// CrossRegionInferences is a list of model IDs that support cross-region inference.
//
// https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html
//
// document.querySelectorAll('pre.programlisting code').forEach((e) => {console.log(e.innerHTML)})
var CrossRegionInferences = []string{
"us.amazon.nova-lite-v1:0",
"us.amazon.nova-micro-v1:0",
"us.amazon.nova-pro-v1:0",
"us.anthropic.claude-3-5-haiku-20241022-v1:0",
"us.anthropic.claude-3-5-sonnet-20240620-v1:0",
"us.anthropic.claude-3-5-sonnet-20241022-v2:0",
"us.anthropic.claude-3-7-sonnet-20250219-v1:0",
"us.anthropic.claude-3-haiku-20240307-v1:0",
"us.anthropic.claude-3-opus-20240229-v1:0",
"us.anthropic.claude-3-sonnet-20240229-v1:0",
"us.meta.llama3-1-405b-instruct-v1:0",
"us.meta.llama3-1-70b-instruct-v1:0",
"us.meta.llama3-1-8b-instruct-v1:0",
"us.meta.llama3-2-11b-instruct-v1:0",
"us.meta.llama3-2-1b-instruct-v1:0",
"us.meta.llama3-2-3b-instruct-v1:0",
"us.meta.llama3-2-90b-instruct-v1:0",
"us.meta.llama3-3-70b-instruct-v1:0",
"us-gov.anthropic.claude-3-5-sonnet-20240620-v1:0",
"us-gov.anthropic.claude-3-haiku-20240307-v1:0",
"eu.amazon.nova-lite-v1:0",
"eu.amazon.nova-micro-v1:0",
"eu.amazon.nova-pro-v1:0",
"eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
"eu.anthropic.claude-3-haiku-20240307-v1:0",
"eu.anthropic.claude-3-sonnet-20240229-v1:0",
"eu.meta.llama3-2-1b-instruct-v1:0",
"eu.meta.llama3-2-3b-instruct-v1:0",
"apac.amazon.nova-lite-v1:0",
"apac.amazon.nova-micro-v1:0",
"apac.amazon.nova-pro-v1:0",
"apac.anthropic.claude-3-5-sonnet-20240620-v1:0",
"apac.anthropic.claude-3-5-sonnet-20241022-v2:0",
"apac.anthropic.claude-3-haiku-20240307-v1:0",
"apac.anthropic.claude-3-sonnet-20240229-v1:0",
}
// ConvertModelID2CrossRegionProfile converts the model ID to a cross-region profile ID.
func ConvertModelID2CrossRegionProfile(model, region string) string {
var regionPrefix string
switch prefix := strings.Split(region, "-")[0]; prefix {
case "us", "eu":
regionPrefix = prefix
case "ap":
regionPrefix = "apac"
default:
// not supported, return original model
return model
}
newModelID := regionPrefix + "." + model
if slices.Contains(CrossRegionInferences, newModelID) {
logger.Debugf(context.TODO(), "convert model %s to cross-region profile %s", model, newModelID)
return newModelID
}
// not found, return original model
return model
}

View File

@ -9,6 +9,8 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/alibailian" "github.com/songquanpeng/one-api/relay/adaptor/alibailian"
"github.com/songquanpeng/one-api/relay/adaptor/baiduv2" "github.com/songquanpeng/one-api/relay/adaptor/baiduv2"
@ -16,6 +18,7 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/geminiv2" "github.com/songquanpeng/one-api/relay/adaptor/geminiv2"
"github.com/songquanpeng/one-api/relay/adaptor/minimax" "github.com/songquanpeng/one-api/relay/adaptor/minimax"
"github.com/songquanpeng/one-api/relay/adaptor/novita" "github.com/songquanpeng/one-api/relay/adaptor/novita"
"github.com/songquanpeng/one-api/relay/adaptor/openrouter"
"github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
@ -85,7 +88,29 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
if request.Stream {
meta := meta.GetByContext(c)
switch meta.ChannelType {
case channeltype.OpenRouter:
includeReasoning := true
request.IncludeReasoning = &includeReasoning
if request.Provider == nil || request.Provider.Sort == "" &&
config.OpenrouterProviderSort != "" {
if request.Provider == nil {
request.Provider = &openrouter.RequestProvider{}
}
request.Provider.Sort = config.OpenrouterProviderSort
}
default:
}
if request.Stream && !config.EnforceIncludeUsage {
logger.Warn(c.Request.Context(),
"please set ENFORCE_INCLUDE_USAGE=true to ensure accurate billing in stream mode")
}
if config.EnforceIncludeUsage && request.Stream {
// always return usage in stream mode // always return usage in stream mode
if request.StreamOptions == nil { if request.StreamOptions == nil {
request.StreamOptions = &model.StreamOptions{} request.StreamOptions = &model.StreamOptions{}

View File

@ -8,12 +8,11 @@ import (
"net/http" "net/http"
"strings" "strings"
"github.com/songquanpeng/one-api/common/render"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/conv"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/render"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode" "github.com/songquanpeng/one-api/relay/relaymode"
) )
@ -26,6 +25,7 @@ const (
func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) { func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.ErrorWithStatusCode, string, *model.Usage) {
responseText := "" responseText := ""
reasoningText := ""
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines) scanner.Split(bufio.ScanLines)
var usage *model.Usage var usage *model.Usage
@ -61,6 +61,13 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
} }
render.StringData(c, data) render.StringData(c, data)
for _, choice := range streamResponse.Choices { for _, choice := range streamResponse.Choices {
if choice.Delta.Reasoning != nil {
reasoningText += *choice.Delta.Reasoning
}
if choice.Delta.ReasoningContent != nil {
reasoningText += *choice.Delta.ReasoningContent
}
responseText += conv.AsString(choice.Delta.Content) responseText += conv.AsString(choice.Delta.Content)
} }
if streamResponse.Usage != nil { if streamResponse.Usage != nil {
@ -93,7 +100,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "", nil
} }
return nil, responseText, usage return nil, reasoningText + responseText, usage
} }
func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) { func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName string) (*model.ErrorWithStatusCode, *model.Usage) {
@ -136,10 +143,17 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
if textResponse.Usage.TotalTokens == 0 || (textResponse.Usage.PromptTokens == 0 && textResponse.Usage.CompletionTokens == 0) { if textResponse.Usage.TotalTokens == 0 ||
(textResponse.Usage.PromptTokens == 0 && textResponse.Usage.CompletionTokens == 0) {
completionTokens := 0 completionTokens := 0
for _, choice := range textResponse.Choices { for _, choice := range textResponse.Choices {
completionTokens += CountTokenText(choice.Message.StringContent(), modelName) completionTokens += CountTokenText(choice.Message.StringContent(), modelName)
if choice.Message.Reasoning != nil {
completionTokens += CountToken(*choice.Message.Reasoning)
}
if choice.ReasoningContent != nil {
completionTokens += CountToken(*choice.ReasoningContent)
}
} }
textResponse.Usage = model.Usage{ textResponse.Usage = model.Usage{
PromptTokens: promptTokens, PromptTokens: promptTokens,
@ -147,5 +161,6 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
TotalTokens: promptTokens + completionTokens, TotalTokens: promptTokens + completionTokens,
} }
} }
return nil, &textResponse.Usage return nil, &textResponse.Usage
} }

View File

@ -0,0 +1,22 @@
package openrouter
// RequestProvider customize how your requests are routed using the provider object
// in the request body for Chat Completions and Completions.
//
// https://openrouter.ai/docs/features/provider-routing
type RequestProvider struct {
// Order is list of provider names to try in order (e.g. ["Anthropic", "OpenAI"]). Default: empty
Order []string `json:"order,omitempty"`
// AllowFallbacks is whether to allow backup providers when the primary is unavailable. Default: true
AllowFallbacks bool `json:"allow_fallbacks,omitempty"`
// RequireParameters is only use providers that support all parameters in your request. Default: false
RequireParameters bool `json:"require_parameters,omitempty"`
// DataCollection is control whether to use providers that may store data ("allow" or "deny"). Default: "allow"
DataCollection string `json:"data_collection,omitempty" binding:"omitempty,oneof=allow deny"`
// Ignore is list of provider names to skip for this request. Default: empty
Ignore []string `json:"ignore,omitempty"`
// Quantizations is list of quantization levels to filter by (e.g. ["int4", "int8"]). Default: empty
Quantizations []string `json:"quantizations,omitempty"`
// Sort is sort providers by price or throughput (e.g. "price" or "throughput"). Default: empty
Sort string `json:"sort,omitempty" binding:"omitempty,oneof=price throughput latency"`
}

View File

@ -32,7 +32,11 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
claudeReq := anthropic.ConvertRequest(*request) claudeReq, err := anthropic.ConvertRequest(c, *request)
if err != nil {
return nil, errors.Wrap(err, "convert request")
}
req := Request{ req := Request{
AnthropicVersion: anthropicVersion, AnthropicVersion: anthropicVersion,
// Model: claudeReq.Model, // Model: claudeReq.Model,

View File

@ -98,6 +98,8 @@ var ModelRatio = map[string]float64{
"claude-3-5-sonnet-20240620": 3.0 / 1000 * USD, "claude-3-5-sonnet-20240620": 3.0 / 1000 * USD,
"claude-3-5-sonnet-20241022": 3.0 / 1000 * USD, "claude-3-5-sonnet-20241022": 3.0 / 1000 * USD,
"claude-3-5-sonnet-latest": 3.0 / 1000 * USD, "claude-3-5-sonnet-latest": 3.0 / 1000 * USD,
"claude-3-7-sonnet-20250219": 3.0 / 1000 * USD,
"claude-3-7-sonnet-latest": 3.0 / 1000 * USD,
"claude-3-opus-20240229": 15.0 / 1000 * USD, "claude-3-opus-20240229": 15.0 / 1000 * USD,
"claude-3-7-sonnet-20250219": 3.0 / 1000 * USD, "claude-3-7-sonnet-20250219": 3.0 / 1000 * USD,
"claude-3-7-sonnet-latest": 3.0 / 1000 * USD, "claude-3-7-sonnet-latest": 3.0 / 1000 * USD,

View File

@ -102,6 +102,9 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M
var quota int64 var quota int64
completionRatio := billingratio.GetCompletionRatio(textRequest.Model, meta.ChannelType) completionRatio := billingratio.GetCompletionRatio(textRequest.Model, meta.ChannelType)
promptTokens := usage.PromptTokens promptTokens := usage.PromptTokens
// It appears that DeepSeek's official service automatically merges ReasoningTokens into CompletionTokens,
// but the behavior of third-party providers may differ, so for now we do not add them manually.
// completionTokens := usage.CompletionTokens + usage.CompletionTokensDetails.ReasoningTokens
completionTokens := usage.CompletionTokens completionTokens := usage.CompletionTokens
quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
if ratio != 0 && quota <= 0 { if ratio != 0 && quota <= 0 {

View File

@ -1,5 +1,7 @@
package model package model
import "github.com/songquanpeng/one-api/relay/adaptor/openrouter"
type ResponseFormat struct { type ResponseFormat struct {
Type string `json:"type,omitempty"` Type string `json:"type,omitempty"`
JsonSchema *JSONSchema `json:"json_schema,omitempty"` JsonSchema *JSONSchema `json:"json_schema,omitempty"`
@ -66,6 +68,21 @@ type GeneralOpenAIRequest struct {
// Others // Others
Instruction string `json:"instruction,omitempty"` Instruction string `json:"instruction,omitempty"`
NumCtx int `json:"num_ctx,omitempty"` NumCtx int `json:"num_ctx,omitempty"`
// -------------------------------------
// Openrouter
// -------------------------------------
Provider *openrouter.RequestProvider `json:"provider,omitempty"`
IncludeReasoning *bool `json:"include_reasoning,omitempty"`
// -------------------------------------
// Anthropic
// -------------------------------------
Thinking *Thinking `json:"thinking,omitempty"`
}
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#implementing-extended-thinking
type Thinking struct {
Type string `json:"type"`
BudgetTokens int `json:"budget_tokens" binding:"omitempty,min=1024"`
} }
func (r GeneralOpenAIRequest) ParseInput() []string { func (r GeneralOpenAIRequest) ParseInput() []string {

View File

@ -2,11 +2,34 @@ package model
type Message struct { type Message struct {
Role string `json:"role,omitempty"` Role string `json:"role,omitempty"`
// Content is a string or a list of objects
Content any `json:"content,omitempty"` Content any `json:"content,omitempty"`
ReasoningContent any `json:"reasoning_content,omitempty"`
Name *string `json:"name,omitempty"` Name *string `json:"name,omitempty"`
ToolCalls []Tool `json:"tool_calls,omitempty"` ToolCalls []Tool `json:"tool_calls,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"` ToolCallId string `json:"tool_call_id,omitempty"`
Audio *messageAudio `json:"audio,omitempty"`
// -------------------------------------
// Deepseek 专有的一些字段
// https://api-docs.deepseek.com/api/create-chat-completion
// -------------------------------------
// Prefix forces the model to begin its answer with the supplied prefix in the assistant message.
// To enable this feature, set base_url to "https://api.deepseek.com/beta".
Prefix *bool `json:"prefix,omitempty"` // ReasoningContent is Used for the deepseek-reasoner model in the Chat
// Prefix Completion feature as the input for the CoT in the last assistant message.
// When using this feature, the prefix parameter must be set to true.
ReasoningContent *string `json:"reasoning_content,omitempty"`
// -------------------------------------
// Openrouter
// -------------------------------------
Reasoning *string `json:"reasoning,omitempty"`
Refusal *bool `json:"refusal,omitempty"`
}
type messageAudio struct {
Id string `json:"id"`
Data string `json:"data,omitempty"`
ExpiredAt int `json:"expired_at,omitempty"`
Transcript string `json:"transcript,omitempty"`
} }
func (m Message) IsStringContent() bool { func (m Message) IsStringContent() bool {

View File

@ -4,14 +4,12 @@ type Usage struct {
PromptTokens int `json:"prompt_tokens"` PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"` CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"` TotalTokens int `json:"total_tokens"`
// PromptTokensDetails may be empty for some models
CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"` PromptTokensDetails *usagePromptTokensDetails `gorm:"-" json:"prompt_tokens_details,omitempty"`
} // CompletionTokensDetails may be empty for some models
CompletionTokensDetails *usageCompletionTokensDetails `gorm:"-" json:"completion_tokens_details,omitempty"`
type CompletionTokensDetails struct { ServiceTier string `gorm:"-" json:"service_tier,omitempty"`
ReasoningTokens int `json:"reasoning_tokens"` SystemFingerprint string `gorm:"-" json:"system_fingerprint,omitempty"`
AcceptedPredictionTokens int `json:"accepted_prediction_tokens"`
RejectedPredictionTokens int `json:"rejected_prediction_tokens"`
} }
type Error struct { type Error struct {
@ -25,3 +23,20 @@ type ErrorWithStatusCode struct {
Error Error
StatusCode int `json:"status_code"` StatusCode int `json:"status_code"`
} }
type usagePromptTokensDetails struct {
CachedTokens int `json:"cached_tokens"`
AudioTokens int `json:"audio_tokens"`
// TextTokens could be zero for pure text chats
TextTokens int `json:"text_tokens"`
ImageTokens int `json:"image_tokens"`
}
type usageCompletionTokensDetails struct {
ReasoningTokens int `json:"reasoning_tokens"`
AudioTokens int `json:"audio_tokens"`
AcceptedPredictionTokens int `json:"accepted_prediction_tokens"`
RejectedPredictionTokens int `json:"rejected_prediction_tokens"`
// TextTokens could be zero for pure text chats
TextTokens int `json:"text_tokens"`
}