From 0ada2371b6d15e6a5b1a1e15a0628d5ba75de4ba Mon Sep 17 00:00:00 2001 From: Yan <1964649083@qq.com> Date: Thu, 5 Sep 2024 00:53:00 +0800 Subject: [PATCH] fix: tool use in claude --- relay/channel/claude/relay-claude.go | 74 ++++++++++++++++++++++++---- 1 file changed, 64 insertions(+), 10 deletions(-) diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 1923e35..f73c57e 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -4,7 +4,6 @@ import ( "bufio" "encoding/json" "fmt" - "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" @@ -12,6 +11,8 @@ import ( relaycommon "one-api/relay/common" "one-api/service" "strings" + + "github.com/gin-gonic/gin" ) func stopReasonClaude2OpenAI(reason string) string { @@ -108,13 +109,10 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR } } formatMessages := make([]dto.Message, 0) - var lastMessage *dto.Message + lastMessage := dto.Message{ + Role: "tool", + } for i, message := range textRequest.Messages { - //if message.Role == "system" { - // if i != 0 { - // message.Role = "user" - // } - //} if message.Role == "" { textRequest.Messages[i].Role = "user" } @@ -122,7 +120,13 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR Role: message.Role, Content: message.Content, } - if lastMessage != nil && lastMessage.Role == message.Role { + if message.Role == "tool" { + fmtMessage.ToolCallId = message.ToolCallId + } + if message.Role == "assistant" && message.ToolCalls != nil { + fmtMessage.ToolCalls = message.ToolCalls + } + if lastMessage.Role == message.Role && lastMessage.Role != "tool" { if lastMessage.IsStringContent() && message.IsStringContent() { content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\"")) fmtMessage.Content = content @@ -135,7 +139,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR fmtMessage.Content = content } formatMessages = append(formatMessages, fmtMessage) - lastMessage = &textRequest.Messages[i] + lastMessage = fmtMessage } claudeMessages := make([]ClaudeMessage, 0) @@ -174,7 +178,35 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR claudeMessage := ClaudeMessage{ Role: message.Role, } - if message.IsStringContent() { + if message.Role == "tool" { + if len(claudeMessages) > 0 && claudeMessages[len(claudeMessages)-1].Role == "user" { + lastMessage := claudeMessages[len(claudeMessages)-1] + if content, ok := lastMessage.Content.(string); ok { + lastMessage.Content = []ClaudeMediaMessage{ + { + Type: "text", + Text: content, + }, + } + } + lastMessage.Content = append(lastMessage.Content.([]ClaudeMediaMessage), ClaudeMediaMessage{ + Type: "tool_result", + ToolUseId: message.ToolCallId, + Content: message.StringContent(), + }) + claudeMessages[len(claudeMessages)-1] = lastMessage + continue + } else { + claudeMessage.Role = "user" + claudeMessage.Content = []ClaudeMediaMessage{ + { + Type: "tool_result", + ToolUseId: message.ToolCallId, + Content: message.StringContent(), + }, + } + } + } else if message.IsStringContent() && message.ToolCalls == nil { claudeMessage.Content = message.StringContent() } else { claudeMediaMessages := make([]ClaudeMediaMessage, 0) @@ -207,6 +239,28 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR } claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage) } + if message.ToolCalls != nil { + for _, tc := range message.ToolCalls.([]interface{}) { + toolCallJSON, _ := json.Marshal(tc) + var toolCall dto.ToolCall + err := json.Unmarshal(toolCallJSON, &toolCall) + if err != nil { + common.SysError("tool call is not a dto.ToolCall: " + fmt.Sprintf("%v", tc)) + continue + } + inputObj := make(map[string]any) + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil { + common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments)) + continue + } + claudeMediaMessages = append(claudeMediaMessages, ClaudeMediaMessage{ + Type: "tool_use", + Id: toolCall.ID, + Name: toolCall.Function.Name, + Input: inputObj, + }) + } + } claudeMessage.Content = claudeMediaMessages } claudeMessages = append(claudeMessages, claudeMessage)