fix: vertex claude request max tokens

This commit is contained in:
linzhaoming 2024-06-24 06:53:05 +08:00
parent 5e1125e4cb
commit 811d89fee9
2 changed files with 13 additions and 91 deletions

View File

@ -5,8 +5,12 @@ import "one-api/relay/channel/claude"
type VertexClaudeRequest struct {
// vertex-2023-10-16
AnthropicVersion string `json:"anthropic_version"`
System string `json:"system,omitempty"`
System string `json:"system"`
Messages []claude.ClaudeMessage `json:"messages"`
MaxTokens int `json:"max_tokens,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
}

View File

@ -4,8 +4,8 @@ import (
"bufio"
"context"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/jinzhu/copier"
"github.com/pkg/errors"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
@ -51,97 +51,15 @@ func getRedirectModel(requestModel string) (string, error) {
}
func requestOpenAI2VertexClaude(request dto.GeneralOpenAIRequest) (*VertexClaudeRequest, error) {
vertexClaudeRequest := VertexClaudeRequest{
vertexClaudeRequest := &VertexClaudeRequest{
AnthropicVersion: "vertex-2023-10-16",
Stream: request.Stream,
}
if vertexClaudeRequest.MaxTokens == 0 {
vertexClaudeRequest.MaxTokens = 4096
}
formatMessages := make([]dto.Message, 0)
var lastMessage *dto.Message
for i, message := range request.Messages {
if message.Role == "" {
request.Messages[i].Role = "user"
}
fmtMessage := dto.Message{
Role: message.Role,
Content: message.Content,
}
if lastMessage != nil && lastMessage.Role == message.Role {
if lastMessage.IsStringContent() && message.IsStringContent() {
content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
fmtMessage.Content = content
// delete last message
formatMessages = formatMessages[:len(formatMessages)-1]
}
}
if fmtMessage.Content == nil {
content, _ := json.Marshal("...")
fmtMessage.Content = content
}
formatMessages = append(formatMessages, fmtMessage)
lastMessage = &request.Messages[i]
}
claudeMessages := make([]claude.ClaudeMessage, 0)
for _, message := range formatMessages {
if message.Role == "system" {
if message.IsStringContent() {
vertexClaudeRequest.System = message.StringContent()
} else {
contents := message.ParseContent()
content := ""
for _, ctx := range contents {
if ctx.Type == "text" {
content += ctx.Text
}
}
vertexClaudeRequest.System = content
}
} else {
claudeMessage := claude.ClaudeMessage{
Role: message.Role,
}
if message.IsStringContent() {
claudeMessage.Content = message.StringContent()
} else {
claudeMediaMessages := make([]claude.ClaudeMediaMessage, 0)
for _, mediaMessage := range message.ParseContent() {
claudeMediaMessage := claude.ClaudeMediaMessage{
Type: mediaMessage.Type,
}
if mediaMessage.Type == "text" {
claudeMediaMessage.Text = mediaMessage.Text
} else {
imageUrl := mediaMessage.ImageUrl.(dto.MessageImageUrl)
claudeMediaMessage.Type = "image"
claudeMediaMessage.Source = &claude.ClaudeMessageSource{
Type: "base64",
}
// 判断是否是url
if strings.HasPrefix(imageUrl.Url, "http") {
// 是url获取图片的类型和base64编码的数据
mimeType, data, _ := common.GetImageFromUrl(imageUrl.Url)
claudeMediaMessage.Source.MediaType = mimeType
claudeMediaMessage.Source.Data = data
} else {
_, format, base64String, err := common.DecodeBase64ImageData(imageUrl.Url)
claudeRequest, _ := claude.RequestOpenAI2ClaudeMessage(request)
err := copier.Copy(vertexClaudeRequest, claudeRequest)
if err != nil {
return nil, err
}
claudeMediaMessage.Source.MediaType = "image/" + format
claudeMediaMessage.Source.Data = base64String
}
}
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
}
claudeMessage.Content = claudeMediaMessages
}
claudeMessages = append(claudeMessages, claudeMessage)
}
}
vertexClaudeRequest.Messages = claudeMessages
return &vertexClaudeRequest, nil
return vertexClaudeRequest, nil
}
func vertexClaudeHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {