From 0ada2371b6d15e6a5b1a1e15a0628d5ba75de4ba Mon Sep 17 00:00:00 2001 From: Yan <1964649083@qq.com> Date: Thu, 5 Sep 2024 00:53:00 +0800 Subject: [PATCH 1/2] fix: tool use in claude --- relay/channel/claude/relay-claude.go | 74 ++++++++++++++++++++++++---- 1 file changed, 64 insertions(+), 10 deletions(-) diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 1923e35..f73c57e 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -4,7 +4,6 @@ import ( "bufio" "encoding/json" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" @@ -12,6 +11,8 @@ import ( relaycommon "one-api/relay/common" "one-api/service" "strings" + + "github.com/gin-gonic/gin" ) func stopReasonClaude2OpenAI(reason string) string { @@ -108,13 +109,10 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR } } formatMessages := make([]dto.Message, 0) - var lastMessage *dto.Message + lastMessage := dto.Message{ + Role: "tool", + } for i, message := range textRequest.Messages { - //if message.Role == "system" { - // if i != 0 { - // message.Role = "user" - // } - //} if message.Role == "" { textRequest.Messages[i].Role = "user" } @@ -122,7 +120,13 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR Role: message.Role, Content: message.Content, } - if lastMessage != nil && lastMessage.Role == message.Role { + if message.Role == "tool" { + fmtMessage.ToolCallId = message.ToolCallId + } + if message.Role == "assistant" && message.ToolCalls != nil { + fmtMessage.ToolCalls = message.ToolCalls + } + if lastMessage.Role == message.Role && lastMessage.Role != "tool" { if lastMessage.IsStringContent() && message.IsStringContent() { content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\"")) fmtMessage.Content = content @@ -135,7 +139,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR fmtMessage.Content = content } formatMessages = append(formatMessages, fmtMessage) - lastMessage = &textRequest.Messages[i] + lastMessage = fmtMessage } claudeMessages := make([]ClaudeMessage, 0) @@ -174,7 +178,35 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR claudeMessage := ClaudeMessage{ Role: message.Role, } - if message.IsStringContent() { + if message.Role == "tool" { + if len(claudeMessages) > 0 && claudeMessages[len(claudeMessages)-1].Role == "user" { + lastMessage := claudeMessages[len(claudeMessages)-1] + if content, ok := lastMessage.Content.(string); ok { + lastMessage.Content = []ClaudeMediaMessage{ + { + Type: "text", + Text: content, + }, + } + } + lastMessage.Content = append(lastMessage.Content.([]ClaudeMediaMessage), ClaudeMediaMessage{ + Type: "tool_result", + ToolUseId: message.ToolCallId, + Content: message.StringContent(), + }) + claudeMessages[len(claudeMessages)-1] = lastMessage + continue + } else { + claudeMessage.Role = "user" + claudeMessage.Content = []ClaudeMediaMessage{ + { + Type: "tool_result", + ToolUseId: message.ToolCallId, + Content: message.StringContent(), + }, + } + } + } else if message.IsStringContent() && message.ToolCalls == nil { claudeMessage.Content = message.StringContent() } else { claudeMediaMessages := make([]ClaudeMediaMessage, 0) @@ -207,6 +239,28 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR } claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage) } + if message.ToolCalls != nil { + for _, tc := range message.ToolCalls.([]interface{}) { + toolCallJSON, _ := json.Marshal(tc) + var toolCall dto.ToolCall + err := json.Unmarshal(toolCallJSON, &toolCall) + if err != nil { + common.SysError("tool call is not a dto.ToolCall: " + fmt.Sprintf("%v", tc)) + continue + } + inputObj := make(map[string]any) + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil { + common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments)) + continue + } + claudeMediaMessages = append(claudeMediaMessages, ClaudeMediaMessage{ + Type: "tool_use", + Id: toolCall.ID, + Name: toolCall.Function.Name, + Input: inputObj, + }) + } + } claudeMessage.Content = claudeMediaMessages } claudeMessages = append(claudeMessages, claudeMessage) From 0b5f2a7089b3584febcd9b28412c30c13eb861b5 Mon Sep 17 00:00:00 2001 From: Yan <1964649083@qq.com> Date: Wed, 11 Sep 2024 19:37:03 +0800 Subject: [PATCH 2/2] add gemini exp --- common/model-ratio.go | 4 +++- constant/env.go | 18 ++++++++++-------- relay/channel/gemini/constant.go | 2 +- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/common/model-ratio.go b/common/model-ratio.go index 664cd1d..4916449 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -106,8 +106,10 @@ var defaultModelRatio = map[string]float64{ "gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "gemini-1.0-pro-vision-001": 1, "gemini-1.0-pro-001": 1, - "gemini-1.5-pro-latest": 1, + "gemini-1.5-pro-latest": 1.75, // $3.5 / 1M tokens + "gemini-1.5-pro-exp-0827": 1.75, // $3.5 / 1M tokens "gemini-1.5-flash-latest": 1, + "gemini-1.5-flash-exp-0827": 1, "gemini-1.0-pro-latest": 1, "gemini-1.0-pro-vision-latest": 1, "gemini-ultra": 1, diff --git a/constant/env.go b/constant/env.go index dd3ae65..c5d498d 100644 --- a/constant/env.go +++ b/constant/env.go @@ -20,14 +20,16 @@ var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STR var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true) var GeminiModelMap = map[string]string{ - "gemini-1.5-pro-latest": "v1beta", - "gemini-1.5-pro-001": "v1beta", - "gemini-1.5-pro": "v1beta", - "gemini-1.5-pro-exp-0801": "v1beta", - "gemini-1.5-flash-latest": "v1beta", - "gemini-1.5-flash-001": "v1beta", - "gemini-1.5-flash": "v1beta", - "gemini-ultra": "v1beta", + "gemini-1.5-pro-latest": "v1beta", + "gemini-1.5-pro-001": "v1beta", + "gemini-1.5-pro": "v1beta", + "gemini-1.5-pro-exp-0801": "v1beta", + "gemini-1.5-pro-exp-0827": "v1beta", + "gemini-1.5-flash-latest": "v1beta", + "gemini-1.5-flash-exp-0827": "v1beta", + "gemini-1.5-flash-001": "v1beta", + "gemini-1.5-flash": "v1beta", + "gemini-ultra": "v1beta", } func InitEnv() { diff --git a/relay/channel/gemini/constant.go b/relay/channel/gemini/constant.go index 621336b..4a2e4dd 100644 --- a/relay/channel/gemini/constant.go +++ b/relay/channel/gemini/constant.go @@ -6,7 +6,7 @@ const ( var ModelList = []string{ "gemini-1.0-pro-latest", "gemini-1.0-pro-001", "gemini-1.5-pro-latest", "gemini-1.5-flash-latest", "gemini-ultra", - "gemini-1.0-pro-vision-latest", "gemini-1.0-pro-vision-001", + "gemini-1.0-pro-vision-latest", "gemini-1.0-pro-vision-001", "gemini-1.5-pro-exp-0827", "gemini-1.5-flash-exp-0827", } var ChannelName = "google gemini"