From 4e7e20629050522a2bc774ce0340b9ca1256d201 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Wed, 10 Jul 2024 16:01:09 +0800 Subject: [PATCH] fix: gemini usage (close #354) --- relay/channel/gemini/adaptor.go | 5 +-- relay/channel/gemini/dto.go | 7 ++++ relay/channel/gemini/relay-gemini.go | 48 +++++++++++++++++++++------- relay/common/relay_info.go | 3 +- 4 files changed, 46 insertions(+), 17 deletions(-) diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index f51ae3f..9755163 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -9,7 +9,6 @@ import ( "one-api/dto" "one-api/relay/channel" relaycommon "one-api/relay/common" - "one-api/service" ) type Adaptor struct { @@ -69,9 +68,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { - var responseText string - err, responseText = geminiChatStreamHandler(c, resp, info) - usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + err, usage = geminiChatStreamHandler(c, resp, info) } else { err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/gemini/dto.go b/relay/channel/gemini/dto.go index a581c68..99ab654 100644 --- a/relay/channel/gemini/dto.go +++ b/relay/channel/gemini/dto.go @@ -59,4 +59,11 @@ type GeminiChatPromptFeedback struct { type GeminiChatResponse struct { Candidates []GeminiChatCandidate `json:"candidates"` PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"` + UsageMetadata GeminiUsageMetadata `json:"usageMetadata"` +} + +type GeminiUsageMetadata struct { + PromptTokenCount int `json:"promptTokenCount"` + CandidatesTokenCount int `json:"candidatesTokenCount"` + TotalTokenCount int `json:"totalTokenCount"` } diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 8af08c5..b7080b9 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io" + "log" "net/http" "one-api/common" "one-api/constant" @@ -162,8 +163,12 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch return &response } -func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string) { +func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { responseText := "" + responseJson := "" + id := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) + createAt := common.GetTimestamp() + var usage = &dto.Usage{} dataChan := make(chan string, 5) stopChan := make(chan bool, 2) scanner := bufio.NewScanner(resp.Body) @@ -182,6 +187,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom go func() { for scanner.Scan() { data := scanner.Text() + responseJson += data data = strings.TrimSpace(data) if !strings.HasPrefix(data, "\"text\": \"") { continue @@ -216,10 +222,10 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom var choice dto.ChatCompletionsStreamResponseChoice choice.Delta.SetContentString(dummy.Content) response := dto.ChatCompletionsStreamResponse{ - Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Id: id, Object: "chat.completion.chunk", - Created: common.GetTimestamp(), - Model: "gemini-pro", + Created: createAt, + Model: info.UpstreamModelName, Choices: []dto.ChatCompletionsStreamResponseChoice{choice}, } jsonResponse, err := json.Marshal(response) @@ -230,15 +236,34 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) return true case <-stopChan: - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) return false } }) - err := resp.Body.Close() + var geminiChatResponses []GeminiChatResponse + err := json.Unmarshal([]byte(responseJson), &geminiChatResponses) if err != nil { - return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + log.Printf("cannot get gemini usage: %s", err.Error()) + usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + } else { + for _, response := range geminiChatResponses { + usage.PromptTokens = response.UsageMetadata.PromptTokenCount + usage.CompletionTokens = response.UsageMetadata.CandidatesTokenCount + } + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens } - return nil, responseText + if info.ShouldIncludeUsage { + response := service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage) + err := service.ObjectData(c, response) + if err != nil { + common.SysError("send final response failed: " + err.Error()) + } + } + service.Done(c) + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), usage + } + return nil, usage } func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { @@ -267,11 +292,10 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo }, nil } fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) - completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model) usage := dto.Usage{ - PromptTokens: promptTokens, - CompletionTokens: completionTokens, - TotalTokens: promptTokens + completionTokens, + PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount, + CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount, + TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount, } fullTextResponse.Usage = usage jsonResponse, err := json.Marshal(fullTextResponse) diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 8e7c3e6..42c8381 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -67,7 +67,8 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { if info.ChannelType == common.ChannelTypeAzure { info.ApiVersion = GetAPIVersion(c) } - if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic || info.ChannelType == common.ChannelTypeAws { + if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic || + info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini { info.SupportStreamOptions = true } return info