feat: add support for Claude 3 tool use (function calling) (#1587)

* feat: add tool support for AWS & Claude

* fix: add {} for openai compatibility in streaming tool_use
This commit is contained in:
Mikey
2024-07-02 00:12:01 +08:00
committed by GitHub
parent 1ce1e529ee
commit 0fc07ea558
6 changed files with 168 additions and 14 deletions

View File

@@ -6,6 +6,7 @@ import (
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"io"
"net/http"
@@ -143,6 +144,8 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
c.Writer.Header().Set("Content-Type", "text/event-stream")
var usage relaymodel.Usage
var id string
var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice
c.Stream(func(w io.Writer) bool {
event, ok := <-stream.Events()
if !ok {
@@ -163,8 +166,19 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
if meta != nil {
usage.PromptTokens += meta.Usage.InputTokens
usage.CompletionTokens += meta.Usage.OutputTokens
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
return true
if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event.
id = fmt.Sprintf("chatcmpl-%s", meta.Id)
return true
} else { // finish_reason case
if len(lastToolCallChoice.Delta.ToolCalls) > 0 {
lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function
if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments.
lastArgs.Arguments = "{}"
response.Choices[len(response.Choices)-1].Delta.Content = nil
response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls
}
}
}
}
if response == nil {
return true
@@ -172,6 +186,12 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E
response.Id = id
response.Model = c.GetString(ctxkey.OriginalModel)
response.Created = createdTime
for _, choice := range response.Choices {
if len(choice.Delta.ToolCalls) > 0 {
lastToolCallChoice = choice
}
}
jsonStr, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())

View File

@@ -9,9 +9,12 @@ type Request struct {
// AnthropicVersion should be "bedrock-2023-05-31"
AnthropicVersion string `json:"anthropic_version"`
Messages []anthropic.Message `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Tools []anthropic.Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
}