fix: gemini usage (close #354)

This commit is contained in:
CalciumIon 2024-07-10 16:01:09 +08:00
parent 579fc8129e
commit 4e7e206290
4 changed files with 46 additions and 17 deletions

View File

@ -9,7 +9,6 @@ import (
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service"
) )
type Adaptor struct { type Adaptor struct {
@ -69,9 +68,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
var responseText string err, usage = geminiChatStreamHandler(c, resp, info)
err, responseText = geminiChatStreamHandler(c, resp, info)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else { } else {
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
} }

View File

@ -59,4 +59,11 @@ type GeminiChatPromptFeedback struct {
type GeminiChatResponse struct { type GeminiChatResponse struct {
Candidates []GeminiChatCandidate `json:"candidates"` Candidates []GeminiChatCandidate `json:"candidates"`
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"` PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
UsageMetadata GeminiUsageMetadata `json:"usageMetadata"`
}
type GeminiUsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
} }

View File

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
@ -162,8 +163,12 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
return &response return &response
} }
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string) { func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseText := "" responseText := ""
responseJson := ""
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createAt := common.GetTimestamp()
var usage = &dto.Usage{}
dataChan := make(chan string, 5) dataChan := make(chan string, 5)
stopChan := make(chan bool, 2) stopChan := make(chan bool, 2)
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
@ -182,6 +187,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
go func() { go func() {
for scanner.Scan() { for scanner.Scan() {
data := scanner.Text() data := scanner.Text()
responseJson += data
data = strings.TrimSpace(data) data = strings.TrimSpace(data)
if !strings.HasPrefix(data, "\"text\": \"") { if !strings.HasPrefix(data, "\"text\": \"") {
continue continue
@ -216,10 +222,10 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
var choice dto.ChatCompletionsStreamResponseChoice var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(dummy.Content) choice.Delta.SetContentString(dummy.Content)
response := dto.ChatCompletionsStreamResponse{ response := dto.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), Id: id,
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: common.GetTimestamp(), Created: createAt,
Model: "gemini-pro", Model: info.UpstreamModelName,
Choices: []dto.ChatCompletionsStreamResponseChoice{choice}, Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
} }
jsonResponse, err := json.Marshal(response) jsonResponse, err := json.Marshal(response)
@ -230,15 +236,34 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true return true
case <-stopChan: case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false return false
} }
}) })
err := resp.Body.Close() var geminiChatResponses []GeminiChatResponse
err := json.Unmarshal([]byte(responseJson), &geminiChatResponses)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" log.Printf("cannot get gemini usage: %s", err.Error())
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
for _, response := range geminiChatResponses {
usage.PromptTokens = response.UsageMetadata.PromptTokenCount
usage.CompletionTokens = response.UsageMetadata.CandidatesTokenCount
} }
return nil, responseText usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
}
if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
err := service.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
}
}
service.Done(c)
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), usage
}
return nil, usage
} }
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
@ -267,11 +292,10 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
}, nil }, nil
} }
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model)
usage := dto.Usage{ usage := dto.Usage{
PromptTokens: promptTokens, PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
CompletionTokens: completionTokens, CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
TotalTokens: promptTokens + completionTokens, TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
} }
fullTextResponse.Usage = usage fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse) jsonResponse, err := json.Marshal(fullTextResponse)

View File

@ -67,7 +67,8 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
if info.ChannelType == common.ChannelTypeAzure { if info.ChannelType == common.ChannelTypeAzure {
info.ApiVersion = GetAPIVersion(c) info.ApiVersion = GetAPIVersion(c)
} }
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic || info.ChannelType == common.ChannelTypeAws { if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini {
info.SupportStreamOptions = true info.SupportStreamOptions = true
} }
return info return info