mirror of
				https://github.com/linux-do/new-api.git
				synced 2025-11-04 13:23:42 +08:00 
			
		
		
		
	Merge pull request #464 from Yan-Zero/main
fix: tool use in claude and add gemini mapping
This commit is contained in:
		@@ -106,8 +106,10 @@ var defaultModelRatio = map[string]float64{
 | 
			
		||||
	"gemini-pro-vision":              1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
 | 
			
		||||
	"gemini-1.0-pro-vision-001":      1,
 | 
			
		||||
	"gemini-1.0-pro-001":             1,
 | 
			
		||||
	"gemini-1.5-pro-latest":          1,
 | 
			
		||||
	"gemini-1.5-pro-latest":          1.75, // $3.5 / 1M tokens
 | 
			
		||||
	"gemini-1.5-pro-exp-0827":        1.75, // $3.5 / 1M tokens
 | 
			
		||||
	"gemini-1.5-flash-latest":        1,
 | 
			
		||||
	"gemini-1.5-flash-exp-0827":      1,
 | 
			
		||||
	"gemini-1.0-pro-latest":          1,
 | 
			
		||||
	"gemini-1.0-pro-vision-latest":   1,
 | 
			
		||||
	"gemini-ultra":                   1,
 | 
			
		||||
 
 | 
			
		||||
@@ -20,14 +20,16 @@ var GetMediaTokenNotStream = common.GetEnvOrDefaultBool("GET_MEDIA_TOKEN_NOT_STR
 | 
			
		||||
var UpdateTask = common.GetEnvOrDefaultBool("UPDATE_TASK", true)
 | 
			
		||||
 | 
			
		||||
var GeminiModelMap = map[string]string{
 | 
			
		||||
	"gemini-1.5-pro-latest":   "v1beta",
 | 
			
		||||
	"gemini-1.5-pro-001":      "v1beta",
 | 
			
		||||
	"gemini-1.5-pro":          "v1beta",
 | 
			
		||||
	"gemini-1.5-pro-exp-0801": "v1beta",
 | 
			
		||||
	"gemini-1.5-flash-latest": "v1beta",
 | 
			
		||||
	"gemini-1.5-flash-001":    "v1beta",
 | 
			
		||||
	"gemini-1.5-flash":        "v1beta",
 | 
			
		||||
	"gemini-ultra":            "v1beta",
 | 
			
		||||
	"gemini-1.5-pro-latest":     "v1beta",
 | 
			
		||||
	"gemini-1.5-pro-001":        "v1beta",
 | 
			
		||||
	"gemini-1.5-pro":            "v1beta",
 | 
			
		||||
	"gemini-1.5-pro-exp-0801":   "v1beta",
 | 
			
		||||
	"gemini-1.5-pro-exp-0827":   "v1beta",
 | 
			
		||||
	"gemini-1.5-flash-latest":   "v1beta",
 | 
			
		||||
	"gemini-1.5-flash-exp-0827": "v1beta",
 | 
			
		||||
	"gemini-1.5-flash-001":      "v1beta",
 | 
			
		||||
	"gemini-1.5-flash":          "v1beta",
 | 
			
		||||
	"gemini-ultra":              "v1beta",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func InitEnv() {
 | 
			
		||||
 
 | 
			
		||||
@@ -4,7 +4,6 @@ import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
@@ -12,6 +11,8 @@ import (
 | 
			
		||||
	relaycommon "one-api/relay/common"
 | 
			
		||||
	"one-api/service"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func stopReasonClaude2OpenAI(reason string) string {
 | 
			
		||||
@@ -108,13 +109,10 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	formatMessages := make([]dto.Message, 0)
 | 
			
		||||
	var lastMessage *dto.Message
 | 
			
		||||
	lastMessage := dto.Message{
 | 
			
		||||
		Role: "tool",
 | 
			
		||||
	}
 | 
			
		||||
	for i, message := range textRequest.Messages {
 | 
			
		||||
		//if message.Role == "system" {
 | 
			
		||||
		//	if i != 0 {
 | 
			
		||||
		//		message.Role = "user"
 | 
			
		||||
		//	}
 | 
			
		||||
		//}
 | 
			
		||||
		if message.Role == "" {
 | 
			
		||||
			textRequest.Messages[i].Role = "user"
 | 
			
		||||
		}
 | 
			
		||||
@@ -122,7 +120,13 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
 | 
			
		||||
			Role:    message.Role,
 | 
			
		||||
			Content: message.Content,
 | 
			
		||||
		}
 | 
			
		||||
		if lastMessage != nil && lastMessage.Role == message.Role {
 | 
			
		||||
		if message.Role == "tool" {
 | 
			
		||||
			fmtMessage.ToolCallId = message.ToolCallId
 | 
			
		||||
		}
 | 
			
		||||
		if message.Role == "assistant" && message.ToolCalls != nil {
 | 
			
		||||
			fmtMessage.ToolCalls = message.ToolCalls
 | 
			
		||||
		}
 | 
			
		||||
		if lastMessage.Role == message.Role && lastMessage.Role != "tool" {
 | 
			
		||||
			if lastMessage.IsStringContent() && message.IsStringContent() {
 | 
			
		||||
				content, _ := json.Marshal(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
 | 
			
		||||
				fmtMessage.Content = content
 | 
			
		||||
@@ -135,7 +139,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
 | 
			
		||||
			fmtMessage.Content = content
 | 
			
		||||
		}
 | 
			
		||||
		formatMessages = append(formatMessages, fmtMessage)
 | 
			
		||||
		lastMessage = &textRequest.Messages[i]
 | 
			
		||||
		lastMessage = fmtMessage
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	claudeMessages := make([]ClaudeMessage, 0)
 | 
			
		||||
@@ -174,7 +178,35 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
 | 
			
		||||
			claudeMessage := ClaudeMessage{
 | 
			
		||||
				Role: message.Role,
 | 
			
		||||
			}
 | 
			
		||||
			if message.IsStringContent() {
 | 
			
		||||
			if message.Role == "tool" {
 | 
			
		||||
				if len(claudeMessages) > 0 && claudeMessages[len(claudeMessages)-1].Role == "user" {
 | 
			
		||||
					lastMessage := claudeMessages[len(claudeMessages)-1]
 | 
			
		||||
					if content, ok := lastMessage.Content.(string); ok {
 | 
			
		||||
						lastMessage.Content = []ClaudeMediaMessage{
 | 
			
		||||
							{
 | 
			
		||||
								Type: "text",
 | 
			
		||||
								Text: content,
 | 
			
		||||
							},
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
					lastMessage.Content = append(lastMessage.Content.([]ClaudeMediaMessage), ClaudeMediaMessage{
 | 
			
		||||
						Type:      "tool_result",
 | 
			
		||||
						ToolUseId: message.ToolCallId,
 | 
			
		||||
						Content:   message.StringContent(),
 | 
			
		||||
					})
 | 
			
		||||
					claudeMessages[len(claudeMessages)-1] = lastMessage
 | 
			
		||||
					continue
 | 
			
		||||
				} else {
 | 
			
		||||
					claudeMessage.Role = "user"
 | 
			
		||||
					claudeMessage.Content = []ClaudeMediaMessage{
 | 
			
		||||
						{
 | 
			
		||||
							Type:      "tool_result",
 | 
			
		||||
							ToolUseId: message.ToolCallId,
 | 
			
		||||
							Content:   message.StringContent(),
 | 
			
		||||
						},
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			} else if message.IsStringContent() && message.ToolCalls == nil {
 | 
			
		||||
				claudeMessage.Content = message.StringContent()
 | 
			
		||||
			} else {
 | 
			
		||||
				claudeMediaMessages := make([]ClaudeMediaMessage, 0)
 | 
			
		||||
@@ -207,6 +239,28 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*ClaudeR
 | 
			
		||||
					}
 | 
			
		||||
					claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
 | 
			
		||||
				}
 | 
			
		||||
				if message.ToolCalls != nil {
 | 
			
		||||
					for _, tc := range message.ToolCalls.([]interface{}) {
 | 
			
		||||
						toolCallJSON, _ := json.Marshal(tc)
 | 
			
		||||
						var toolCall dto.ToolCall
 | 
			
		||||
						err := json.Unmarshal(toolCallJSON, &toolCall)
 | 
			
		||||
						if err != nil {
 | 
			
		||||
							common.SysError("tool call is not a dto.ToolCall: " + fmt.Sprintf("%v", tc))
 | 
			
		||||
							continue
 | 
			
		||||
						}
 | 
			
		||||
						inputObj := make(map[string]any)
 | 
			
		||||
						if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
 | 
			
		||||
							common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
 | 
			
		||||
							continue
 | 
			
		||||
						}
 | 
			
		||||
						claudeMediaMessages = append(claudeMediaMessages, ClaudeMediaMessage{
 | 
			
		||||
							Type:  "tool_use",
 | 
			
		||||
							Id:    toolCall.ID,
 | 
			
		||||
							Name:  toolCall.Function.Name,
 | 
			
		||||
							Input: inputObj,
 | 
			
		||||
						})
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
				claudeMessage.Content = claudeMediaMessages
 | 
			
		||||
			}
 | 
			
		||||
			claudeMessages = append(claudeMessages, claudeMessage)
 | 
			
		||||
 
 | 
			
		||||
@@ -6,7 +6,7 @@ const (
 | 
			
		||||
 | 
			
		||||
var ModelList = []string{
 | 
			
		||||
	"gemini-1.0-pro-latest", "gemini-1.0-pro-001", "gemini-1.5-pro-latest", "gemini-1.5-flash-latest", "gemini-ultra",
 | 
			
		||||
	"gemini-1.0-pro-vision-latest", "gemini-1.0-pro-vision-001",
 | 
			
		||||
	"gemini-1.0-pro-vision-latest", "gemini-1.0-pro-vision-001", "gemini-1.5-pro-exp-0827", "gemini-1.5-flash-exp-0827",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var ChannelName = "google gemini"
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user