From 11fd993574f74f3acdf0df750fd0e1755a2e4959 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Thu, 18 Jul 2024 00:36:05 +0800 Subject: [PATCH] feat: support claude tool calling --- dto/text_request.go | 7 ++- dto/text_response.go | 6 +- relay/channel/claude/dto.go | 61 +++++++++++------- relay/channel/claude/relay-claude.go | 93 +++++++++++++++++++++++----- 4 files changed, 128 insertions(+), 39 deletions(-) diff --git a/dto/text_request.go b/dto/text_request.go index 5c403da..801d1c3 100644 --- a/dto/text_request.go +++ b/dto/text_request.go @@ -26,7 +26,7 @@ type GeneralOpenAIRequest struct { PresencePenalty float64 `json:"presence_penalty,omitempty"` ResponseFormat *ResponseFormat `json:"response_format,omitempty"` Seed float64 `json:"seed,omitempty"` - Tools any `json:"tools,omitempty"` + Tools []ToolCall `json:"tools,omitempty"` ToolChoice any `json:"tool_choice,omitempty"` User string `json:"user,omitempty"` LogProbs bool `json:"logprobs,omitempty"` @@ -104,6 +104,11 @@ func (m Message) StringContent() string { return string(m.Content) } +func (m *Message) SetStringContent(content string) { + jsonContent, _ := json.Marshal(content) + m.Content = jsonContent +} + func (m Message) IsStringContent() bool { var stringContent string if err := json.Unmarshal(m.Content, &stringContent); err == nil { diff --git a/dto/text_response.go b/dto/text_response.go index e1f0cc0..9b12683 100644 --- a/dto/text_response.go +++ b/dto/text_response.go @@ -86,9 +86,11 @@ type ToolCall struct { } type FunctionCall struct { - Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + Name string `json:"name,omitempty"` // call function with arguments in JSON format - Arguments string `json:"arguments,omitempty"` + Parameters any `json:"parameters,omitempty"` // request + Arguments string `json:"arguments,omitempty"` } type ChatCompletionsStreamResponse struct { diff --git a/relay/channel/claude/dto.go b/relay/channel/claude/dto.go index 47f0c3b..e2a898e 100644 --- a/relay/channel/claude/dto.go +++ b/relay/channel/claude/dto.go @@ -5,11 +5,18 @@ type ClaudeMetadata struct { } type ClaudeMediaMessage struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - Source *ClaudeMessageSource `json:"source,omitempty"` - Usage *ClaudeUsage `json:"usage,omitempty"` - StopReason *string `json:"stop_reason,omitempty"` + Type string `json:"type"` + Text string `json:"text,omitempty"` + Source *ClaudeMessageSource `json:"source,omitempty"` + Usage *ClaudeUsage `json:"usage,omitempty"` + StopReason *string `json:"stop_reason,omitempty"` + PartialJson string `json:"partial_json,omitempty"` + // tool_calls + Id string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input any `json:"input,omitempty"` + Content string `json:"content,omitempty"` + ToolUseId string `json:"tool_use_id,omitempty"` } type ClaudeMessageSource struct { @@ -23,6 +30,18 @@ type ClaudeMessage struct { Content any `json:"content"` } +type Tool struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema InputSchema `json:"input_schema"` +} + +type InputSchema struct { + Type string `json:"type"` + Properties any `json:"properties,omitempty"` + Required any `json:"required,omitempty"` +} + type ClaudeRequest struct { Model string `json:"model"` Prompt string `json:"prompt,omitempty"` @@ -35,7 +54,9 @@ type ClaudeRequest struct { TopP float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` //ClaudeMetadata `json:"metadata,omitempty"` - Stream bool `json:"stream,omitempty"` + Stream bool `json:"stream,omitempty"` + Tools []Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` } type ClaudeError struct { @@ -44,24 +65,20 @@ type ClaudeError struct { } type ClaudeResponse struct { - Id string `json:"id"` - Type string `json:"type"` - Content []ClaudeMediaMessage `json:"content"` - Completion string `json:"completion"` - StopReason string `json:"stop_reason"` - Model string `json:"model"` - Error ClaudeError `json:"error"` - Usage ClaudeUsage `json:"usage"` - Index int `json:"index"` // stream only - Delta *ClaudeMediaMessage `json:"delta"` // stream only - Message *ClaudeResponse `json:"message"` // stream only: message_start + Id string `json:"id"` + Type string `json:"type"` + Content []ClaudeMediaMessage `json:"content"` + Completion string `json:"completion"` + StopReason string `json:"stop_reason"` + Model string `json:"model"` + Error ClaudeError `json:"error"` + Usage ClaudeUsage `json:"usage"` + Index int `json:"index"` // stream only + ContentBlock *ClaudeMediaMessage `json:"content_block"` + Delta *ClaudeMediaMessage `json:"delta"` // stream only + Message *ClaudeResponse `json:"message"` // stream only: message_start } -//type ClaudeResponseChoice struct { -// Index int `json:"index"` -// Type string `json:"type"` -//} - type ClaudeUsage struct { InputTokens int `json:"input_tokens"` OutputTokens int `json:"output_tokens"` diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 945b20d..0d70715 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -30,6 +30,7 @@ func stopReasonClaude2OpenAI(reason string) string { } func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest { + claudeRequest := ClaudeRequest{ Model: textRequest.Model, Prompt: "", @@ -60,6 +61,22 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeR } func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) { + claudeTools := make([]Tool, 0, len(textRequest.Tools)) + + for _, tool := range textRequest.Tools { + if params, ok := tool.Function.Parameters.(map[string]any); ok { + claudeTools = append(claudeTools, Tool{ + Name: tool.Function.Name, + Description: tool.Function.Description, + InputSchema: InputSchema{ + Type: params["type"].(string), + Properties: params["properties"], + Required: params["required"], + }, + }) + } + } + claudeRequest := ClaudeRequest{ Model: textRequest.Model, MaxTokens: textRequest.MaxTokens, @@ -68,6 +85,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR TopP: textRequest.TopP, TopK: textRequest.TopK, Stream: textRequest.Stream, + Tools: claudeTools, } if claudeRequest.MaxTokens == 0 { claudeRequest.MaxTokens = 4096 @@ -184,6 +202,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (* response.Object = "chat.completion.chunk" response.Model = claudeResponse.Model response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0) + tools := make([]dto.ToolCall, 0) var choice dto.ChatCompletionsStreamResponseChoice if reqMode == RequestModeCompletion { choice.Delta.SetContentString(claudeResponse.Completion) @@ -199,10 +218,33 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (* choice.Delta.SetContentString("") choice.Delta.Role = "assistant" } else if claudeResponse.Type == "content_block_start" { - return nil, nil + if claudeResponse.ContentBlock != nil { + //choice.Delta.SetContentString(claudeResponse.ContentBlock.Text) + if claudeResponse.ContentBlock.Type == "tool_use" { + tools = append(tools, dto.ToolCall{ + ID: claudeResponse.ContentBlock.Id, + Type: "function", + Function: dto.FunctionCall{ + Name: claudeResponse.ContentBlock.Name, + Arguments: "", + }, + }) + } + } else { + return nil, nil + } } else if claudeResponse.Type == "content_block_delta" { - choice.Index = claudeResponse.Index - choice.Delta.SetContentString(claudeResponse.Delta.Text) + if claudeResponse.Delta != nil { + choice.Index = claudeResponse.Index + choice.Delta.SetContentString(claudeResponse.Delta.Text) + if claudeResponse.Delta.Type == "input_json_delta" { + tools = append(tools, dto.ToolCall{ + Function: dto.FunctionCall{ + Arguments: claudeResponse.Delta.PartialJson, + }, + }) + } + } } else if claudeResponse.Type == "message_delta" { finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason) if finishReason != "null" { @@ -218,6 +260,10 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (* if claudeUsage == nil { claudeUsage = &ClaudeUsage{} } + if len(tools) > 0 { + choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ... + choice.Delta.ToolCalls = tools + } response.Choices = append(response.Choices, choice) return &response, claudeUsage @@ -230,6 +276,11 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope Object: "chat.completion", Created: common.GetTimestamp(), } + var responseText string + if len(claudeResponse.Content) > 0 { + responseText = claudeResponse.Content[0].Text + } + tools := make([]dto.ToolCall, 0) if reqMode == RequestModeCompletion { content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " ")) choice := dto.OpenAITextResponseChoice{ @@ -244,20 +295,32 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope choices = append(choices, choice) } else { fullTextResponse.Id = claudeResponse.Id - for i, message := range claudeResponse.Content { - content, _ := json.Marshal(message.Text) - choice := dto.OpenAITextResponseChoice{ - Index: i, - Message: dto.Message{ - Role: "assistant", - Content: content, - }, - FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), + for _, message := range claudeResponse.Content { + if message.Type == "tool_use" { + args, _ := json.Marshal(message.Input) + tools = append(tools, dto.ToolCall{ + ID: message.Id, + Type: "function", // compatible with other OpenAI derivative applications + Function: dto.FunctionCall{ + Name: message.Name, + Arguments: string(args), + }, + }) } - choices = append(choices, choice) } } - + choice := dto.OpenAITextResponseChoice{ + Index: 0, + Message: dto.Message{ + Role: "assistant", + }, + FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), + } + choice.SetStringContent(responseText) + if len(tools) > 0 { + choice.Message.ToolCalls = tools + } + choices = append(choices, choice) fullTextResponse.Choices = choices return &fullTextResponse } @@ -334,6 +397,8 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon. } else if claudeResponse.Type == "message_delta" { usage.CompletionTokens = claudeUsage.OutputTokens usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens + } else if claudeResponse.Type == "content_block_start" { + } else { return true }