fix: fix tool calls

This commit is contained in:
CaIon 2024-01-31 01:41:38 +08:00
parent 364d4f96c7
commit 6d0479632a
2 changed files with 46 additions and 43 deletions

View File

@ -145,47 +145,48 @@ func countTokenMessages(messages []Message, model string) (int, error) {
for _, message := range messages { for _, message := range messages {
tokenNum += tokensPerMessage tokenNum += tokensPerMessage
tokenNum += getTokenNum(tokenEncoder, message.Role) tokenNum += getTokenNum(tokenEncoder, message.Role)
var arrayContent []MediaMessage if len(message.Content) > 0 {
if err := json.Unmarshal(message.Content, &arrayContent); err != nil { var arrayContent []MediaMessage
if err := json.Unmarshal(message.Content, &arrayContent); err != nil {
var stringContent string var stringContent string
if err := json.Unmarshal(message.Content, &stringContent); err != nil { if err := json.Unmarshal(message.Content, &stringContent); err != nil {
return 0, err return 0, err
} else {
tokenNum += getTokenNum(tokenEncoder, stringContent)
if message.Name != nil {
tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, *message.Name)
}
}
} else {
for _, m := range arrayContent {
if m.Type == "image_url" {
var imageTokenNum int
if str, ok := m.ImageUrl.(string); ok {
imageTokenNum, err = getImageToken(&MessageImageUrl{Url: str, Detail: "auto"})
} else {
imageUrlMap := m.ImageUrl.(map[string]interface{})
detail, ok := imageUrlMap["detail"]
if ok {
imageUrlMap["detail"] = detail.(string)
} else {
imageUrlMap["detail"] = "auto"
}
imageUrl := MessageImageUrl{
Url: imageUrlMap["url"].(string),
Detail: imageUrlMap["detail"].(string),
}
imageTokenNum, err = getImageToken(&imageUrl)
}
if err != nil {
return 0, err
}
tokenNum += imageTokenNum
log.Printf("image token num: %d", imageTokenNum)
} else { } else {
tokenNum += getTokenNum(tokenEncoder, m.Text) tokenNum += getTokenNum(tokenEncoder, stringContent)
if message.Name != nil {
tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, *message.Name)
}
}
} else {
for _, m := range arrayContent {
if m.Type == "image_url" {
var imageTokenNum int
if str, ok := m.ImageUrl.(string); ok {
imageTokenNum, err = getImageToken(&MessageImageUrl{Url: str, Detail: "auto"})
} else {
imageUrlMap := m.ImageUrl.(map[string]interface{})
detail, ok := imageUrlMap["detail"]
if ok {
imageUrlMap["detail"] = detail.(string)
} else {
imageUrlMap["detail"] = "auto"
}
imageUrl := MessageImageUrl{
Url: imageUrlMap["url"].(string),
Detail: imageUrlMap["detail"].(string),
}
imageTokenNum, err = getImageToken(&imageUrl)
}
if err != nil {
return 0, err
}
tokenNum += imageTokenNum
log.Printf("image token num: %d", imageTokenNum)
} else {
tokenNum += getTokenNum(tokenEncoder, m.Text)
}
} }
} }
} }

View File

@ -13,9 +13,11 @@ import (
) )
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content json.RawMessage `json:"content"` Content json.RawMessage `json:"content"`
Name *string `json:"name,omitempty"` Name *string `json:"name,omitempty"`
ToolCalls any `json:"tool_calls,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"`
} }
type MediaMessage struct { type MediaMessage struct {