diff --git a/dto/text_request.go b/dto/text_request.go index 801d1c3..2170e71 100644 --- a/dto/text_request.go +++ b/dto/text_request.go @@ -148,7 +148,7 @@ func (m Message) ParseContent() []MediaMessage { if ok { subObj["detail"] = detail.(string) } else { - subObj["detail"] = "auto" + subObj["detail"] = "high" } contentList = append(contentList, MediaMessage{ Type: ContentTypeImageURL, @@ -157,7 +157,16 @@ func (m Message) ParseContent() []MediaMessage { Detail: subObj["detail"].(string), }, }) + } else if url, ok := contentMap["image_url"].(string); ok { + contentList = append(contentList, MediaMessage{ + Type: ContentTypeImageURL, + ImageUrl: MessageImageUrl{ + Url: url, + Detail: "high", + }, + }) } + } } return contentList diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index de7761a..e132d2f 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -47,7 +47,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { action := "generateContent" if info.IsStream { - action = "streamGenerateContent" + action = "streamGenerateContent?alt=sse" } return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil } diff --git a/relay/channel/gemini/dto.go b/relay/channel/gemini/dto.go index 99ab654..771a616 100644 --- a/relay/channel/gemini/dto.go +++ b/relay/channel/gemini/dto.go @@ -12,9 +12,15 @@ type GeminiInlineData struct { Data string `json:"data"` } +type FunctionCall struct { + FunctionName string `json:"name"` + Arguments any `json:"args"` +} + type GeminiPart struct { - Text string `json:"text,omitempty"` - InlineData *GeminiInlineData `json:"inlineData,omitempty"` + Text string `json:"text,omitempty"` + InlineData *GeminiInlineData `json:"inlineData,omitempty"` + FunctionCall *FunctionCall `json:"functionCall,omitempty"` } type GeminiChatContent struct { diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index b7080b9..45dfbb9 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -4,18 +4,14 @@ import ( "bufio" "encoding/json" "fmt" + "github.com/gin-gonic/gin" "io" - "log" "net/http" "one-api/common" - "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" "one-api/service" "strings" - "time" - - "github.com/gin-gonic/gin" ) // Setting safety to the lowest possible values since Gemini is already powerless enough @@ -46,7 +42,17 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest) *GeminiChatReques MaxOutputTokens: textRequest.MaxTokens, }, } - if textRequest.Functions != nil { + if textRequest.Tools != nil { + functions := make([]dto.FunctionCall, 0, len(textRequest.Tools)) + for _, tool := range textRequest.Tools { + functions = append(functions, tool.Function) + } + geminiRequest.Tools = []GeminiChatTools{ + { + FunctionDeclarations: functions, + }, + } + } else if textRequest.Functions != nil { geminiRequest.Tools = []GeminiChatTools{ { FunctionDeclarations: textRequest.Functions, @@ -126,6 +132,30 @@ func (g *GeminiChatResponse) GetResponseText() string { return "" } +func getToolCalls(candidate *GeminiChatCandidate) []dto.ToolCall { + var toolCalls []dto.ToolCall + + item := candidate.Content.Parts[0] + if item.FunctionCall == nil { + return toolCalls + } + argsBytes, err := json.Marshal(item.FunctionCall.Arguments) + if err != nil { + //common.SysError("getToolCalls failed: " + err.Error()) + return toolCalls + } + toolCall := dto.ToolCall{ + ID: fmt.Sprintf("call_%s", common.GetUUID()), + Type: "function", + Function: dto.FunctionCall{ + Arguments: string(argsBytes), + Name: item.FunctionCall.FunctionName, + }, + } + toolCalls = append(toolCalls, toolCall) + return toolCalls +} + func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResponse { fullTextResponse := dto.OpenAITextResponse{ Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), @@ -144,8 +174,11 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp FinishReason: relaycommon.StopFinishReason, } if len(candidate.Content.Parts) > 0 { - content, _ = json.Marshal(candidate.Content.Parts[0].Text) - choice.Message.Content = content + if candidate.Content.Parts[0].FunctionCall != nil { + choice.Message.ToolCalls = getToolCalls(&candidate) + } else { + choice.Message.SetStringContent(candidate.Content.Parts[0].Text) + } } fullTextResponse.Choices = append(fullTextResponse.Choices, choice) } @@ -154,7 +187,17 @@ func responseGeminiChat2OpenAI(response *GeminiChatResponse) *dto.OpenAITextResp func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.ChatCompletionsStreamResponse { var choice dto.ChatCompletionsStreamResponseChoice - choice.Delta.SetContentString(geminiResponse.GetResponseText()) + //choice.Delta.SetContentString(geminiResponse.GetResponseText()) + if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 { + respFirst := geminiResponse.Candidates[0].Content.Parts[0] + if respFirst.FunctionCall != nil { + // function response + choice.Delta.ToolCalls = getToolCalls(&geminiResponse.Candidates[0]) + } else { + // text response + choice.Delta.SetContentString(respFirst.Text) + } + } choice.FinishReason = &relaycommon.StopFinishReason var response dto.ChatCompletionsStreamResponse response.Object = "chat.completion.chunk" @@ -165,92 +208,47 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { responseText := "" - responseJson := "" id := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) createAt := common.GetTimestamp() var usage = &dto.Usage{} - dataChan := make(chan string, 5) - stopChan := make(chan bool, 2) scanner := bufio.NewScanner(resp.Body) - scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - if atEOF && len(data) == 0 { - return 0, nil, nil - } - if i := strings.Index(string(data), "\n"); i >= 0 { - return i + 1, data[0:i], nil - } - if atEOF { - return len(data), data, nil - } - return 0, nil, nil - }) - go func() { - for scanner.Scan() { - data := scanner.Text() - responseJson += data - data = strings.TrimSpace(data) - if !strings.HasPrefix(data, "\"text\": \"") { - continue - } - data = strings.TrimPrefix(data, "\"text\": \"") - data = strings.TrimSuffix(data, "\"") - if !common.SafeSendStringTimeout(dataChan, data, constant.StreamingTimeout) { - // send data timeout, stop the stream - common.LogError(c, "send data timeout, stop the stream") - break - } - } - stopChan <- true - }() - isFirst := true + scanner.Split(bufio.ScanLines) + service.SetEventStreamHeaders(c) - c.Stream(func(w io.Writer) bool { - select { - case data := <-dataChan: - if isFirst { - isFirst = false - info.FirstResponseTime = time.Now() - } - // this is used to prevent annoying \ related format bug - data = fmt.Sprintf("{\"content\": \"%s\"}", data) - type dummyStruct struct { - Content string `json:"content"` - } - var dummy dummyStruct - err := json.Unmarshal([]byte(data), &dummy) - responseText += dummy.Content - var choice dto.ChatCompletionsStreamResponseChoice - choice.Delta.SetContentString(dummy.Content) - response := dto.ChatCompletionsStreamResponse{ - Id: id, - Object: "chat.completion.chunk", - Created: createAt, - Model: info.UpstreamModelName, - Choices: []dto.ChatCompletionsStreamResponseChoice{choice}, - } - jsonResponse, err := json.Marshal(response) - if err != nil { - common.SysError("error marshalling stream response: " + err.Error()) - return true - } - c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - return true - case <-stopChan: - return false + for scanner.Scan() { + data := scanner.Text() + info.SetFirstResponseTime() + data = strings.TrimSpace(data) + if !strings.HasPrefix(data, "data: ") { + continue } - }) - var geminiChatResponses []GeminiChatResponse - err := json.Unmarshal([]byte(responseJson), &geminiChatResponses) - if err != nil { - log.Printf("cannot get gemini usage: %s", err.Error()) - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) - } else { - for _, response := range geminiChatResponses { - usage.PromptTokens = response.UsageMetadata.PromptTokenCount - usage.CompletionTokens = response.UsageMetadata.CandidatesTokenCount + data = strings.TrimPrefix(data, "data: ") + data = strings.TrimSuffix(data, "\"") + var geminiResponse GeminiChatResponse + err := json.Unmarshal([]byte(data), &geminiResponse) + if err != nil { + common.LogError(c, "error unmarshalling stream response: "+err.Error()) + continue + } + + response := streamResponseGeminiChat2OpenAI(&geminiResponse) + if response == nil { + continue + } + response.Id = id + response.Created = createAt + responseText += response.Choices[0].Delta.GetContentString() + if geminiResponse.UsageMetadata.TotalTokenCount != 0 { + usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount + usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + } + err = service.ObjectData(c, response) + if err != nil { + common.LogError(c, err.Error()) } - usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens } + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + if info.ShouldIncludeUsage { response := service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage) err := service.ObjectData(c, response) @@ -259,10 +257,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom } } service.Done(c) - err = resp.Body.Close() - if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), usage - } + resp.Body.Close() return nil, usage }