mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-17 16:06:38 +08:00
fix: gemini usage (close #354)
This commit is contained in:
parent
579fc8129e
commit
4e7e206290
@ -9,7 +9,6 @@ import (
|
|||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/service"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
@ -69,9 +68,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
err, usage = geminiChatStreamHandler(c, resp, info)
|
||||||
err, responseText = geminiChatStreamHandler(c, resp, info)
|
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
|
||||||
} else {
|
} else {
|
||||||
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
}
|
}
|
||||||
|
@ -59,4 +59,11 @@ type GeminiChatPromptFeedback struct {
|
|||||||
type GeminiChatResponse struct {
|
type GeminiChatResponse struct {
|
||||||
Candidates []GeminiChatCandidate `json:"candidates"`
|
Candidates []GeminiChatCandidate `json:"candidates"`
|
||||||
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
|
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
|
||||||
|
UsageMetadata GeminiUsageMetadata `json:"usageMetadata"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GeminiUsageMetadata struct {
|
||||||
|
PromptTokenCount int `json:"promptTokenCount"`
|
||||||
|
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||||
|
TotalTokenCount int `json:"totalTokenCount"`
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
@ -162,8 +163,12 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
|
|||||||
return &response
|
return &response
|
||||||
}
|
}
|
||||||
|
|
||||||
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string) {
|
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
responseText := ""
|
responseText := ""
|
||||||
|
responseJson := ""
|
||||||
|
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
||||||
|
createAt := common.GetTimestamp()
|
||||||
|
var usage = &dto.Usage{}
|
||||||
dataChan := make(chan string, 5)
|
dataChan := make(chan string, 5)
|
||||||
stopChan := make(chan bool, 2)
|
stopChan := make(chan bool, 2)
|
||||||
scanner := bufio.NewScanner(resp.Body)
|
scanner := bufio.NewScanner(resp.Body)
|
||||||
@ -182,6 +187,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
go func() {
|
go func() {
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
data := scanner.Text()
|
data := scanner.Text()
|
||||||
|
responseJson += data
|
||||||
data = strings.TrimSpace(data)
|
data = strings.TrimSpace(data)
|
||||||
if !strings.HasPrefix(data, "\"text\": \"") {
|
if !strings.HasPrefix(data, "\"text\": \"") {
|
||||||
continue
|
continue
|
||||||
@ -216,10 +222,10 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
var choice dto.ChatCompletionsStreamResponseChoice
|
var choice dto.ChatCompletionsStreamResponseChoice
|
||||||
choice.Delta.SetContentString(dummy.Content)
|
choice.Delta.SetContentString(dummy.Content)
|
||||||
response := dto.ChatCompletionsStreamResponse{
|
response := dto.ChatCompletionsStreamResponse{
|
||||||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
Id: id,
|
||||||
Object: "chat.completion.chunk",
|
Object: "chat.completion.chunk",
|
||||||
Created: common.GetTimestamp(),
|
Created: createAt,
|
||||||
Model: "gemini-pro",
|
Model: info.UpstreamModelName,
|
||||||
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
||||||
}
|
}
|
||||||
jsonResponse, err := json.Marshal(response)
|
jsonResponse, err := json.Marshal(response)
|
||||||
@ -230,15 +236,34 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
||||||
return true
|
return true
|
||||||
case <-stopChan:
|
case <-stopChan:
|
||||||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
err := resp.Body.Close()
|
var geminiChatResponses []GeminiChatResponse
|
||||||
|
err := json.Unmarshal([]byte(responseJson), &geminiChatResponses)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
log.Printf("cannot get gemini usage: %s", err.Error())
|
||||||
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
|
} else {
|
||||||
|
for _, response := range geminiChatResponses {
|
||||||
|
usage.PromptTokens = response.UsageMetadata.PromptTokenCount
|
||||||
|
usage.CompletionTokens = response.UsageMetadata.CandidatesTokenCount
|
||||||
}
|
}
|
||||||
return nil, responseText
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
|
}
|
||||||
|
if info.ShouldIncludeUsage {
|
||||||
|
response := service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
||||||
|
err := service.ObjectData(c, response)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("send final response failed: " + err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
service.Done(c)
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), usage
|
||||||
|
}
|
||||||
|
return nil, usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
@ -267,11 +292,10 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
||||||
completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model)
|
|
||||||
usage := dto.Usage{
|
usage := dto.Usage{
|
||||||
PromptTokens: promptTokens,
|
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||||||
CompletionTokens: completionTokens,
|
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
|
||||||
TotalTokens: promptTokens + completionTokens,
|
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
||||||
}
|
}
|
||||||
fullTextResponse.Usage = usage
|
fullTextResponse.Usage = usage
|
||||||
jsonResponse, err := json.Marshal(fullTextResponse)
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
||||||
|
@ -67,7 +67,8 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
|||||||
if info.ChannelType == common.ChannelTypeAzure {
|
if info.ChannelType == common.ChannelTypeAzure {
|
||||||
info.ApiVersion = GetAPIVersion(c)
|
info.ApiVersion = GetAPIVersion(c)
|
||||||
}
|
}
|
||||||
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic || info.ChannelType == common.ChannelTypeAws {
|
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
|
||||||
|
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini {
|
||||||
info.SupportStreamOptions = true
|
info.SupportStreamOptions = true
|
||||||
}
|
}
|
||||||
return info
|
return info
|
||||||
|
Loading…
Reference in New Issue
Block a user