🔖 chore: update gemini model and token calculation.

This commit is contained in:
MartialBE
2024-05-15 03:49:22 +08:00
parent e9d3b4654f
commit 46aea4731b
3 changed files with 69 additions and 10 deletions

View File

@@ -173,10 +173,7 @@ func (p *GeminiProvider) convertToChatOpenai(response *GeminiChatResponse, reque
openaiResponse.Choices = append(openaiResponse.Choices, choice)
}
completionTokens := common.CountTokenText(response.GetResponseText(), response.Model)
p.Usage.CompletionTokens = completionTokens
p.Usage.TotalTokens = p.Usage.PromptTokens + completionTokens
*p.Usage = convertOpenAIUsage(request.Model, response.UsageMetadata)
openaiResponse.Usage = p.Usage
return
@@ -270,6 +267,60 @@ func (h *geminiStreamHandler) convertToOpenaiStream(geminiResponse *GeminiChatRe
dataChan <- string(responseBody)
}
h.Usage.CompletionTokens += common.CountTokenText(geminiResponse.GetResponseText(), h.Request.Model)
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
if geminiResponse.UsageMetadata != nil {
*h.Usage = convertOpenAIUsage(h.Request.Model, geminiResponse.UsageMetadata)
}
}
const tokenThreshold = 1000000
var modelAdjustRatios = map[string]int{
"gemini-1.5-pro": 2,
"gemini-1.5-flash": 2,
}
func adjustTokenCounts(modelName string, usage *GeminiUsageMetadata) {
if usage.PromptTokenCount <= tokenThreshold && usage.CandidatesTokenCount <= tokenThreshold {
return
}
currentRatio := 1
for model, r := range modelAdjustRatios {
if strings.HasPrefix(modelName, model) {
currentRatio = r
break
}
}
if currentRatio == 1 {
return
}
adjustTokenCount := func(count int) int {
if count > tokenThreshold {
return tokenThreshold + (count-tokenThreshold)*currentRatio
}
return count
}
if usage.PromptTokenCount > tokenThreshold {
usage.PromptTokenCount = adjustTokenCount(usage.PromptTokenCount)
}
if usage.CandidatesTokenCount > tokenThreshold {
usage.CandidatesTokenCount = adjustTokenCount(usage.CandidatesTokenCount)
}
usage.TotalTokenCount = usage.PromptTokenCount + usage.CandidatesTokenCount
}
func convertOpenAIUsage(modelName string, geminiUsage *GeminiUsageMetadata) types.Usage {
adjustTokenCounts(modelName, geminiUsage)
return types.Usage{
PromptTokens: geminiUsage.PromptTokenCount,
CompletionTokens: geminiUsage.CandidatesTokenCount,
TotalTokens: geminiUsage.TotalTokenCount,
}
}

View File

@@ -94,11 +94,17 @@ type GeminiErrorResponse struct {
type GeminiChatResponse struct {
Candidates []GeminiChatCandidate `json:"candidates"`
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
Usage *types.Usage `json:"usage,omitempty"`
UsageMetadata *GeminiUsageMetadata `json:"usageMetadata,omitempty"`
Model string `json:"model,omitempty"`
GeminiErrorResponse
}
type GeminiUsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
}
type GeminiChatCandidate struct {
Content GeminiChatContent `json:"content"`
FinishReason string `json:"finishReason"`