mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	feat: add support for Claude 3 tool use (function calling) (#1587)
* feat: add tool support for AWS & Claude
* fix: add {} for openai compatibility in streaming tool_use
			
			
This commit is contained in:
		@@ -29,12 +29,30 @@ func stopReasonClaude2OpenAI(reason *string) string {
 | 
			
		||||
		return "stop"
 | 
			
		||||
	case "max_tokens":
 | 
			
		||||
		return "length"
 | 
			
		||||
	case "tool_use":
 | 
			
		||||
		return "tool_calls"
 | 
			
		||||
	default:
 | 
			
		||||
		return *reason
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
 | 
			
		||||
	claudeTools := make([]Tool, 0, len(textRequest.Tools))
 | 
			
		||||
 | 
			
		||||
	for _, tool := range textRequest.Tools {
 | 
			
		||||
		if params, ok := tool.Function.Parameters.(map[string]any); ok {
 | 
			
		||||
			claudeTools = append(claudeTools, Tool{
 | 
			
		||||
				Name:        tool.Function.Name,
 | 
			
		||||
				Description: tool.Function.Description,
 | 
			
		||||
				InputSchema: InputSchema{
 | 
			
		||||
					Type:       params["type"].(string),
 | 
			
		||||
					Properties: params["properties"],
 | 
			
		||||
					Required:   params["required"],
 | 
			
		||||
				},
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	claudeRequest := Request{
 | 
			
		||||
		Model:       textRequest.Model,
 | 
			
		||||
		MaxTokens:   textRequest.MaxTokens,
 | 
			
		||||
@@ -42,6 +60,24 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
 | 
			
		||||
		TopP:        textRequest.TopP,
 | 
			
		||||
		TopK:        textRequest.TopK,
 | 
			
		||||
		Stream:      textRequest.Stream,
 | 
			
		||||
		Tools:       claudeTools,
 | 
			
		||||
	}
 | 
			
		||||
	if len(claudeTools) > 0 {
 | 
			
		||||
		claudeToolChoice := struct {
 | 
			
		||||
			Type string `json:"type"`
 | 
			
		||||
			Name string `json:"name,omitempty"`
 | 
			
		||||
		}{Type: "auto"} // default value https://docs.anthropic.com/en/docs/build-with-claude/tool-use#controlling-claudes-output
 | 
			
		||||
		if choice, ok := textRequest.ToolChoice.(map[string]any); ok {
 | 
			
		||||
			if function, ok := choice["function"].(map[string]any); ok {
 | 
			
		||||
				claudeToolChoice.Type = "tool"
 | 
			
		||||
				claudeToolChoice.Name = function["name"].(string)
 | 
			
		||||
			}
 | 
			
		||||
		} else if toolChoiceType, ok := textRequest.ToolChoice.(string); ok {
 | 
			
		||||
			if toolChoiceType == "any" {
 | 
			
		||||
				claudeToolChoice.Type = toolChoiceType
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		claudeRequest.ToolChoice = claudeToolChoice
 | 
			
		||||
	}
 | 
			
		||||
	if claudeRequest.MaxTokens == 0 {
 | 
			
		||||
		claudeRequest.MaxTokens = 4096
 | 
			
		||||
@@ -64,7 +100,24 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
 | 
			
		||||
		if message.IsStringContent() {
 | 
			
		||||
			content.Type = "text"
 | 
			
		||||
			content.Text = message.StringContent()
 | 
			
		||||
			if message.Role == "tool" {
 | 
			
		||||
				claudeMessage.Role = "user"
 | 
			
		||||
				content.Type = "tool_result"
 | 
			
		||||
				content.Content = content.Text
 | 
			
		||||
				content.Text = ""
 | 
			
		||||
				content.ToolUseId = message.ToolCallId
 | 
			
		||||
			}
 | 
			
		||||
			claudeMessage.Content = append(claudeMessage.Content, content)
 | 
			
		||||
			for i := range message.ToolCalls {
 | 
			
		||||
				inputParam := make(map[string]any)
 | 
			
		||||
				_ = json.Unmarshal([]byte(message.ToolCalls[i].Function.Arguments.(string)), &inputParam)
 | 
			
		||||
				claudeMessage.Content = append(claudeMessage.Content, Content{
 | 
			
		||||
					Type:  "tool_use",
 | 
			
		||||
					Id:    message.ToolCalls[i].Id,
 | 
			
		||||
					Name:  message.ToolCalls[i].Function.Name,
 | 
			
		||||
					Input: inputParam,
 | 
			
		||||
				})
 | 
			
		||||
			}
 | 
			
		||||
			claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
@@ -97,16 +150,35 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
 | 
			
		||||
	var response *Response
 | 
			
		||||
	var responseText string
 | 
			
		||||
	var stopReason string
 | 
			
		||||
	tools := make([]model.Tool, 0)
 | 
			
		||||
 | 
			
		||||
	switch claudeResponse.Type {
 | 
			
		||||
	case "message_start":
 | 
			
		||||
		return nil, claudeResponse.Message
 | 
			
		||||
	case "content_block_start":
 | 
			
		||||
		if claudeResponse.ContentBlock != nil {
 | 
			
		||||
			responseText = claudeResponse.ContentBlock.Text
 | 
			
		||||
			if claudeResponse.ContentBlock.Type == "tool_use" {
 | 
			
		||||
				tools = append(tools, model.Tool{
 | 
			
		||||
					Id:   claudeResponse.ContentBlock.Id,
 | 
			
		||||
					Type: "function",
 | 
			
		||||
					Function: model.Function{
 | 
			
		||||
						Name:      claudeResponse.ContentBlock.Name,
 | 
			
		||||
						Arguments: "",
 | 
			
		||||
					},
 | 
			
		||||
				})
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	case "content_block_delta":
 | 
			
		||||
		if claudeResponse.Delta != nil {
 | 
			
		||||
			responseText = claudeResponse.Delta.Text
 | 
			
		||||
			if claudeResponse.Delta.Type == "input_json_delta" {
 | 
			
		||||
				tools = append(tools, model.Tool{
 | 
			
		||||
					Function: model.Function{
 | 
			
		||||
						Arguments: claudeResponse.Delta.PartialJson,
 | 
			
		||||
					},
 | 
			
		||||
				})
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	case "message_delta":
 | 
			
		||||
		if claudeResponse.Usage != nil {
 | 
			
		||||
@@ -120,6 +192,10 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo
 | 
			
		||||
	}
 | 
			
		||||
	var choice openai.ChatCompletionsStreamResponseChoice
 | 
			
		||||
	choice.Delta.Content = responseText
 | 
			
		||||
	if len(tools) > 0 {
 | 
			
		||||
		choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
 | 
			
		||||
		choice.Delta.ToolCalls = tools
 | 
			
		||||
	}
 | 
			
		||||
	choice.Delta.Role = "assistant"
 | 
			
		||||
	finishReason := stopReasonClaude2OpenAI(&stopReason)
 | 
			
		||||
	if finishReason != "null" {
 | 
			
		||||
@@ -136,12 +212,27 @@ func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse {
 | 
			
		||||
	if len(claudeResponse.Content) > 0 {
 | 
			
		||||
		responseText = claudeResponse.Content[0].Text
 | 
			
		||||
	}
 | 
			
		||||
	tools := make([]model.Tool, 0)
 | 
			
		||||
	for _, v := range claudeResponse.Content {
 | 
			
		||||
		if v.Type == "tool_use" {
 | 
			
		||||
			args, _ := json.Marshal(v.Input)
 | 
			
		||||
			tools = append(tools, model.Tool{
 | 
			
		||||
				Id:   v.Id,
 | 
			
		||||
				Type: "function", // compatible with other OpenAI derivative applications
 | 
			
		||||
				Function: model.Function{
 | 
			
		||||
					Name:      v.Name,
 | 
			
		||||
					Arguments: string(args),
 | 
			
		||||
				},
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	choice := openai.TextResponseChoice{
 | 
			
		||||
		Index: 0,
 | 
			
		||||
		Message: model.Message{
 | 
			
		||||
			Role:    "assistant",
 | 
			
		||||
			Content: responseText,
 | 
			
		||||
			Name:    nil,
 | 
			
		||||
			Role:      "assistant",
 | 
			
		||||
			Content:   responseText,
 | 
			
		||||
			Name:      nil,
 | 
			
		||||
			ToolCalls: tools,
 | 
			
		||||
		},
 | 
			
		||||
		FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
 | 
			
		||||
	}
 | 
			
		||||
@@ -176,6 +267,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
 | 
			
		||||
	var usage model.Usage
 | 
			
		||||
	var modelName string
 | 
			
		||||
	var id string
 | 
			
		||||
	var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice
 | 
			
		||||
 | 
			
		||||
	for scanner.Scan() {
 | 
			
		||||
		data := scanner.Text()
 | 
			
		||||
@@ -196,9 +288,20 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
 | 
			
		||||
		if meta != nil {
 | 
			
		||||
			usage.PromptTokens += meta.Usage.InputTokens
 | 
			
		||||
			usage.CompletionTokens += meta.Usage.OutputTokens
 | 
			
		||||
			modelName = meta.Model
 | 
			
		||||
			id = fmt.Sprintf("chatcmpl-%s", meta.Id)
 | 
			
		||||
			continue
 | 
			
		||||
			if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event.
 | 
			
		||||
				modelName = meta.Model
 | 
			
		||||
				id = fmt.Sprintf("chatcmpl-%s", meta.Id)
 | 
			
		||||
				continue
 | 
			
		||||
			} else { // finish_reason case
 | 
			
		||||
				if len(lastToolCallChoice.Delta.ToolCalls) > 0 {
 | 
			
		||||
					lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function
 | 
			
		||||
					if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments.
 | 
			
		||||
						lastArgs.Arguments = "{}"
 | 
			
		||||
						response.Choices[len(response.Choices)-1].Delta.Content = nil
 | 
			
		||||
						response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		if response == nil {
 | 
			
		||||
			continue
 | 
			
		||||
@@ -207,6 +310,12 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
 | 
			
		||||
		response.Id = id
 | 
			
		||||
		response.Model = modelName
 | 
			
		||||
		response.Created = createdTime
 | 
			
		||||
 | 
			
		||||
		for _, choice := range response.Choices {
 | 
			
		||||
			if len(choice.Delta.ToolCalls) > 0 {
 | 
			
		||||
				lastToolCallChoice = choice
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		err = render.ObjectData(c, response)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.SysError(err.Error())
 | 
			
		||||
 
 | 
			
		||||
@@ -16,6 +16,12 @@ type Content struct {
 | 
			
		||||
	Type   string       `json:"type"`
 | 
			
		||||
	Text   string       `json:"text,omitempty"`
 | 
			
		||||
	Source *ImageSource `json:"source,omitempty"`
 | 
			
		||||
	// tool_calls
 | 
			
		||||
	Id        string `json:"id,omitempty"`
 | 
			
		||||
	Name      string `json:"name,omitempty"`
 | 
			
		||||
	Input     any    `json:"input,omitempty"`
 | 
			
		||||
	Content   string `json:"content,omitempty"`
 | 
			
		||||
	ToolUseId string `json:"tool_use_id,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Message struct {
 | 
			
		||||
@@ -23,6 +29,18 @@ type Message struct {
 | 
			
		||||
	Content []Content `json:"content"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Tool struct {
 | 
			
		||||
	Name        string      `json:"name"`
 | 
			
		||||
	Description string      `json:"description,omitempty"`
 | 
			
		||||
	InputSchema InputSchema `json:"input_schema"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type InputSchema struct {
 | 
			
		||||
	Type       string `json:"type"`
 | 
			
		||||
	Properties any    `json:"properties,omitempty"`
 | 
			
		||||
	Required   any    `json:"required,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Request struct {
 | 
			
		||||
	Model         string    `json:"model"`
 | 
			
		||||
	Messages      []Message `json:"messages"`
 | 
			
		||||
@@ -33,6 +51,8 @@ type Request struct {
 | 
			
		||||
	Temperature   float64   `json:"temperature,omitempty"`
 | 
			
		||||
	TopP          float64   `json:"top_p,omitempty"`
 | 
			
		||||
	TopK          int       `json:"top_k,omitempty"`
 | 
			
		||||
	Tools         []Tool    `json:"tools,omitempty"`
 | 
			
		||||
	ToolChoice    any       `json:"tool_choice,omitempty"`
 | 
			
		||||
	//Metadata    `json:"metadata,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -61,6 +81,7 @@ type Response struct {
 | 
			
		||||
type Delta struct {
 | 
			
		||||
	Type         string  `json:"type"`
 | 
			
		||||
	Text         string  `json:"text"`
 | 
			
		||||
	PartialJson  string  `json:"partial_json,omitempty"`
 | 
			
		||||
	StopReason   *string `json:"stop_reason"`
 | 
			
		||||
	StopSequence *string `json:"stop_sequence"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -6,6 +6,7 @@ import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/ctxkey"
 | 
			
		||||
	"github.com/songquanpeng/one-api/relay/adaptor/openai"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
 | 
			
		||||
@@ -143,6 +144,8 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
 | 
			
		||||
	c.Writer.Header().Set("Content-Type", "text/event-stream")
 | 
			
		||||
	var usage relaymodel.Usage
 | 
			
		||||
	var id string
 | 
			
		||||
	var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice
 | 
			
		||||
 | 
			
		||||
	c.Stream(func(w io.Writer) bool {
 | 
			
		||||
		event, ok := <-stream.Events()
 | 
			
		||||
		if !ok {
 | 
			
		||||
@@ -163,8 +166,19 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
 | 
			
		||||
			if meta != nil {
 | 
			
		||||
				usage.PromptTokens += meta.Usage.InputTokens
 | 
			
		||||
				usage.CompletionTokens += meta.Usage.OutputTokens
 | 
			
		||||
				id = fmt.Sprintf("chatcmpl-%s", meta.Id)
 | 
			
		||||
				return true
 | 
			
		||||
				if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event.
 | 
			
		||||
					id = fmt.Sprintf("chatcmpl-%s", meta.Id)
 | 
			
		||||
					return true
 | 
			
		||||
				} else { // finish_reason case
 | 
			
		||||
					if len(lastToolCallChoice.Delta.ToolCalls) > 0 {
 | 
			
		||||
						lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function
 | 
			
		||||
						if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments.
 | 
			
		||||
							lastArgs.Arguments = "{}"
 | 
			
		||||
							response.Choices[len(response.Choices)-1].Delta.Content = nil
 | 
			
		||||
							response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			if response == nil {
 | 
			
		||||
				return true
 | 
			
		||||
@@ -172,6 +186,12 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
 | 
			
		||||
			response.Id = id
 | 
			
		||||
			response.Model = c.GetString(ctxkey.OriginalModel)
 | 
			
		||||
			response.Created = createdTime
 | 
			
		||||
 | 
			
		||||
			for _, choice := range response.Choices {
 | 
			
		||||
				if len(choice.Delta.ToolCalls) > 0 {
 | 
			
		||||
					lastToolCallChoice = choice
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			jsonStr, err := json.Marshal(response)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.SysError("error marshalling stream response: " + err.Error())
 | 
			
		||||
 
 | 
			
		||||
@@ -9,9 +9,12 @@ type Request struct {
 | 
			
		||||
	// AnthropicVersion should be "bedrock-2023-05-31"
 | 
			
		||||
	AnthropicVersion string              `json:"anthropic_version"`
 | 
			
		||||
	Messages         []anthropic.Message `json:"messages"`
 | 
			
		||||
	System           string              `json:"system,omitempty"`
 | 
			
		||||
	MaxTokens        int                 `json:"max_tokens,omitempty"`
 | 
			
		||||
	Temperature      float64             `json:"temperature,omitempty"`
 | 
			
		||||
	TopP             float64             `json:"top_p,omitempty"`
 | 
			
		||||
	TopK             int                 `json:"top_k,omitempty"`
 | 
			
		||||
	StopSequences    []string            `json:"stop_sequences,omitempty"`
 | 
			
		||||
	Tools            []anthropic.Tool    `json:"tools,omitempty"`
 | 
			
		||||
	ToolChoice       any                 `json:"tool_choice,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,10 +1,11 @@
 | 
			
		||||
package model
 | 
			
		||||
 | 
			
		||||
type Message struct {
 | 
			
		||||
	Role      string  `json:"role,omitempty"`
 | 
			
		||||
	Content   any     `json:"content,omitempty"`
 | 
			
		||||
	Name      *string `json:"name,omitempty"`
 | 
			
		||||
	ToolCalls []Tool  `json:"tool_calls,omitempty"`
 | 
			
		||||
	Role       string  `json:"role,omitempty"`
 | 
			
		||||
	Content    any     `json:"content,omitempty"`
 | 
			
		||||
	Name       *string `json:"name,omitempty"`
 | 
			
		||||
	ToolCalls  []Tool  `json:"tool_calls,omitempty"`
 | 
			
		||||
	ToolCallId string  `json:"tool_call_id,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m Message) IsStringContent() bool {
 | 
			
		||||
 
 | 
			
		||||
@@ -2,13 +2,13 @@ package model
 | 
			
		||||
 | 
			
		||||
type Tool struct {
 | 
			
		||||
	Id       string   `json:"id,omitempty"`
 | 
			
		||||
	Type     string   `json:"type"`
 | 
			
		||||
	Type     string   `json:"type,omitempty"` // when splicing claude tools stream messages, it is empty
 | 
			
		||||
	Function Function `json:"function"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Function struct {
 | 
			
		||||
	Description string `json:"description,omitempty"`
 | 
			
		||||
	Name        string `json:"name"`
 | 
			
		||||
	Name        string `json:"name,omitempty"`       // when splicing claude tools stream messages, it is empty
 | 
			
		||||
	Parameters  any    `json:"parameters,omitempty"` // request
 | 
			
		||||
	Arguments   any    `json:"arguments,omitempty"`  // response
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user