mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-08 09:43:42 +08:00
feat: add support for Claude 3 tool use (function calling) (#1587)
* feat: add tool support for AWS & Claude
* fix: add {} for openai compatibility in streaming tool_use
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
@@ -143,6 +144,8 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
var usage relaymodel.Usage
|
||||
var id string
|
||||
var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice
|
||||
|
||||
c.Stream(func(w io.Writer) bool {
|
||||
event, ok := <-stream.Events()
|
||||
if !ok {
|
||||
@@ -163,8 +166,19 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
|
||||
if meta != nil {
|
||||
usage.PromptTokens += meta.Usage.InputTokens
|
||||
usage.CompletionTokens += meta.Usage.OutputTokens
|
||||
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
|
||||
return true
|
||||
if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event.
|
||||
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
|
||||
return true
|
||||
} else { // finish_reason case
|
||||
if len(lastToolCallChoice.Delta.ToolCalls) > 0 {
|
||||
lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function
|
||||
if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments.
|
||||
lastArgs.Arguments = "{}"
|
||||
response.Choices[len(response.Choices)-1].Delta.Content = nil
|
||||
response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if response == nil {
|
||||
return true
|
||||
@@ -172,6 +186,12 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
|
||||
response.Id = id
|
||||
response.Model = c.GetString(ctxkey.OriginalModel)
|
||||
response.Created = createdTime
|
||||
|
||||
for _, choice := range response.Choices {
|
||||
if len(choice.Delta.ToolCalls) > 0 {
|
||||
lastToolCallChoice = choice
|
||||
}
|
||||
}
|
||||
jsonStr, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
logger.SysError("error marshalling stream response: " + err.Error())
|
||||
|
||||
@@ -9,9 +9,12 @@ type Request struct {
|
||||
// AnthropicVersion should be "bedrock-2023-05-31"
|
||||
AnthropicVersion string `json:"anthropic_version"`
|
||||
Messages []anthropic.Message `json:"messages"`
|
||||
System string `json:"system,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Tools []anthropic.Tool `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user