feat: support claude tool calling

This commit is contained in:
CalciumIon 2024-07-18 00:36:05 +08:00
parent 0f94ff47b5
commit 11fd993574
4 changed files with 128 additions and 39 deletions

View File

@ -26,7 +26,7 @@ type GeneralOpenAIRequest struct {
PresencePenalty float64 `json:"presence_penalty,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"` ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Seed float64 `json:"seed,omitempty"` Seed float64 `json:"seed,omitempty"`
Tools any `json:"tools,omitempty"` Tools []ToolCall `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"` ToolChoice any `json:"tool_choice,omitempty"`
User string `json:"user,omitempty"` User string `json:"user,omitempty"`
LogProbs bool `json:"logprobs,omitempty"` LogProbs bool `json:"logprobs,omitempty"`
@ -104,6 +104,11 @@ func (m Message) StringContent() string {
return string(m.Content) return string(m.Content)
} }
func (m *Message) SetStringContent(content string) {
jsonContent, _ := json.Marshal(content)
m.Content = jsonContent
}
func (m Message) IsStringContent() bool { func (m Message) IsStringContent() bool {
var stringContent string var stringContent string
if err := json.Unmarshal(m.Content, &stringContent); err == nil { if err := json.Unmarshal(m.Content, &stringContent); err == nil {

View File

@ -86,9 +86,11 @@ type ToolCall struct {
} }
type FunctionCall struct { type FunctionCall struct {
Name string `json:"name,omitempty"` Description string `json:"description,omitempty"`
Name string `json:"name,omitempty"`
// call function with arguments in JSON format // call function with arguments in JSON format
Arguments string `json:"arguments,omitempty"` Parameters any `json:"parameters,omitempty"` // request
Arguments string `json:"arguments,omitempty"`
} }
type ChatCompletionsStreamResponse struct { type ChatCompletionsStreamResponse struct {

View File

@ -5,11 +5,18 @@ type ClaudeMetadata struct {
} }
type ClaudeMediaMessage struct { type ClaudeMediaMessage struct {
Type string `json:"type"` Type string `json:"type"`
Text string `json:"text,omitempty"` Text string `json:"text,omitempty"`
Source *ClaudeMessageSource `json:"source,omitempty"` Source *ClaudeMessageSource `json:"source,omitempty"`
Usage *ClaudeUsage `json:"usage,omitempty"` Usage *ClaudeUsage `json:"usage,omitempty"`
StopReason *string `json:"stop_reason,omitempty"` StopReason *string `json:"stop_reason,omitempty"`
PartialJson string `json:"partial_json,omitempty"`
// tool_calls
Id string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Input any `json:"input,omitempty"`
Content string `json:"content,omitempty"`
ToolUseId string `json:"tool_use_id,omitempty"`
} }
type ClaudeMessageSource struct { type ClaudeMessageSource struct {
@ -23,6 +30,18 @@ type ClaudeMessage struct {
Content any `json:"content"` Content any `json:"content"`
} }
type Tool struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema InputSchema `json:"input_schema"`
}
type InputSchema struct {
Type string `json:"type"`
Properties any `json:"properties,omitempty"`
Required any `json:"required,omitempty"`
}
type ClaudeRequest struct { type ClaudeRequest struct {
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt,omitempty"` Prompt string `json:"prompt,omitempty"`
@ -35,7 +54,9 @@ type ClaudeRequest struct {
TopP float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
//ClaudeMetadata `json:"metadata,omitempty"` //ClaudeMetadata `json:"metadata,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
} }
type ClaudeError struct { type ClaudeError struct {
@ -44,24 +65,20 @@ type ClaudeError struct {
} }
type ClaudeResponse struct { type ClaudeResponse struct {
Id string `json:"id"` Id string `json:"id"`
Type string `json:"type"` Type string `json:"type"`
Content []ClaudeMediaMessage `json:"content"` Content []ClaudeMediaMessage `json:"content"`
Completion string `json:"completion"` Completion string `json:"completion"`
StopReason string `json:"stop_reason"` StopReason string `json:"stop_reason"`
Model string `json:"model"` Model string `json:"model"`
Error ClaudeError `json:"error"` Error ClaudeError `json:"error"`
Usage ClaudeUsage `json:"usage"` Usage ClaudeUsage `json:"usage"`
Index int `json:"index"` // stream only Index int `json:"index"` // stream only
Delta *ClaudeMediaMessage `json:"delta"` // stream only ContentBlock *ClaudeMediaMessage `json:"content_block"`
Message *ClaudeResponse `json:"message"` // stream only: message_start Delta *ClaudeMediaMessage `json:"delta"` // stream only
Message *ClaudeResponse `json:"message"` // stream only: message_start
} }
//type ClaudeResponseChoice struct {
// Index int `json:"index"`
// Type string `json:"type"`
//}
type ClaudeUsage struct { type ClaudeUsage struct {
InputTokens int `json:"input_tokens"` InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"` OutputTokens int `json:"output_tokens"`

View File

@ -30,6 +30,7 @@ func stopReasonClaude2OpenAI(reason string) string {
} }
func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest { func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeRequest {
claudeRequest := ClaudeRequest{ claudeRequest := ClaudeRequest{
Model: textRequest.Model, Model: textRequest.Model,
Prompt: "", Prompt: "",
@ -60,6 +61,22 @@ func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *ClaudeR
} }
func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) { func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeRequest, error) {
claudeTools := make([]Tool, 0, len(textRequest.Tools))
for _, tool := range textRequest.Tools {
if params, ok := tool.Function.Parameters.(map[string]any); ok {
claudeTools = append(claudeTools, Tool{
Name: tool.Function.Name,
Description: tool.Function.Description,
InputSchema: InputSchema{
Type: params["type"].(string),
Properties: params["properties"],
Required: params["required"],
},
})
}
}
claudeRequest := ClaudeRequest{ claudeRequest := ClaudeRequest{
Model: textRequest.Model, Model: textRequest.Model,
MaxTokens: textRequest.MaxTokens, MaxTokens: textRequest.MaxTokens,
@ -68,6 +85,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
TopP: textRequest.TopP, TopP: textRequest.TopP,
TopK: textRequest.TopK, TopK: textRequest.TopK,
Stream: textRequest.Stream, Stream: textRequest.Stream,
Tools: claudeTools,
} }
if claudeRequest.MaxTokens == 0 { if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = 4096 claudeRequest.MaxTokens = 4096
@ -184,6 +202,7 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
response.Object = "chat.completion.chunk" response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model response.Model = claudeResponse.Model
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0) response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
tools := make([]dto.ToolCall, 0)
var choice dto.ChatCompletionsStreamResponseChoice var choice dto.ChatCompletionsStreamResponseChoice
if reqMode == RequestModeCompletion { if reqMode == RequestModeCompletion {
choice.Delta.SetContentString(claudeResponse.Completion) choice.Delta.SetContentString(claudeResponse.Completion)
@ -199,10 +218,33 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
choice.Delta.SetContentString("") choice.Delta.SetContentString("")
choice.Delta.Role = "assistant" choice.Delta.Role = "assistant"
} else if claudeResponse.Type == "content_block_start" { } else if claudeResponse.Type == "content_block_start" {
return nil, nil if claudeResponse.ContentBlock != nil {
//choice.Delta.SetContentString(claudeResponse.ContentBlock.Text)
if claudeResponse.ContentBlock.Type == "tool_use" {
tools = append(tools, dto.ToolCall{
ID: claudeResponse.ContentBlock.Id,
Type: "function",
Function: dto.FunctionCall{
Name: claudeResponse.ContentBlock.Name,
Arguments: "",
},
})
}
} else {
return nil, nil
}
} else if claudeResponse.Type == "content_block_delta" { } else if claudeResponse.Type == "content_block_delta" {
choice.Index = claudeResponse.Index if claudeResponse.Delta != nil {
choice.Delta.SetContentString(claudeResponse.Delta.Text) choice.Index = claudeResponse.Index
choice.Delta.SetContentString(claudeResponse.Delta.Text)
if claudeResponse.Delta.Type == "input_json_delta" {
tools = append(tools, dto.ToolCall{
Function: dto.FunctionCall{
Arguments: claudeResponse.Delta.PartialJson,
},
})
}
}
} else if claudeResponse.Type == "message_delta" { } else if claudeResponse.Type == "message_delta" {
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason) finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
if finishReason != "null" { if finishReason != "null" {
@ -218,6 +260,10 @@ func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) (*
if claudeUsage == nil { if claudeUsage == nil {
claudeUsage = &ClaudeUsage{} claudeUsage = &ClaudeUsage{}
} }
if len(tools) > 0 {
choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
choice.Delta.ToolCalls = tools
}
response.Choices = append(response.Choices, choice) response.Choices = append(response.Choices, choice)
return &response, claudeUsage return &response, claudeUsage
@ -230,6 +276,11 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
Object: "chat.completion", Object: "chat.completion",
Created: common.GetTimestamp(), Created: common.GetTimestamp(),
} }
var responseText string
if len(claudeResponse.Content) > 0 {
responseText = claudeResponse.Content[0].Text
}
tools := make([]dto.ToolCall, 0)
if reqMode == RequestModeCompletion { if reqMode == RequestModeCompletion {
content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " ")) content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " "))
choice := dto.OpenAITextResponseChoice{ choice := dto.OpenAITextResponseChoice{
@ -244,20 +295,32 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope
choices = append(choices, choice) choices = append(choices, choice)
} else { } else {
fullTextResponse.Id = claudeResponse.Id fullTextResponse.Id = claudeResponse.Id
for i, message := range claudeResponse.Content { for _, message := range claudeResponse.Content {
content, _ := json.Marshal(message.Text) if message.Type == "tool_use" {
choice := dto.OpenAITextResponseChoice{ args, _ := json.Marshal(message.Input)
Index: i, tools = append(tools, dto.ToolCall{
Message: dto.Message{ ID: message.Id,
Role: "assistant", Type: "function", // compatible with other OpenAI derivative applications
Content: content, Function: dto.FunctionCall{
}, Name: message.Name,
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), Arguments: string(args),
},
})
} }
choices = append(choices, choice)
} }
} }
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
choice.SetStringContent(responseText)
if len(tools) > 0 {
choice.Message.ToolCalls = tools
}
choices = append(choices, choice)
fullTextResponse.Choices = choices fullTextResponse.Choices = choices
return &fullTextResponse return &fullTextResponse
} }
@ -334,6 +397,8 @@ func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
} else if claudeResponse.Type == "message_delta" { } else if claudeResponse.Type == "message_delta" {
usage.CompletionTokens = claudeUsage.OutputTokens usage.CompletionTokens = claudeUsage.OutputTokens
usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens usage.TotalTokens = claudeUsage.InputTokens + claudeUsage.OutputTokens
} else if claudeResponse.Type == "content_block_start" {
} else { } else {
return true return true
} }