fix: gemini stream finish reason (close #378)

This commit is contained in:
CalciumIon
2024-07-19 17:16:20 +08:00
parent 12da7f64cd
commit 3875b141c6
3 changed files with 33 additions and 14 deletions

View File

@@ -198,7 +198,6 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
choice.Delta.SetContentString(respFirst.Text) choice.Delta.SetContentString(respFirst.Text)
} }
} }
choice.FinishReason = &relaycommon.StopFinishReason
var response dto.ChatCompletionsStreamResponse var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk" response.Object = "chat.completion.chunk"
response.Model = "gemini" response.Model = "gemini"
@@ -247,10 +246,14 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
common.LogError(c, err.Error()) common.LogError(c, err.Error())
} }
} }
response := service.GenerateStopResponse(id, createAt, info.UpstreamModelName, relaycommon.StopFinishReason)
service.ObjectData(c, response)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
if info.ShouldIncludeUsage { if info.ShouldIncludeUsage {
response := service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage) response = service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
err := service.ObjectData(c, response) err := service.ObjectData(c, response)
if err != nil { if err != nil {
common.SysError("send final response failed: " + err.Error()) common.SysError("send final response failed: " + err.Error())

View File

@@ -7,6 +7,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto"
) )
func SetEventStreamHeaders(c *gin.Context) { func SetEventStreamHeaders(c *gin.Context) {
@@ -45,3 +46,30 @@ func GetResponseID(c *gin.Context) string {
logID := c.GetString("X-Oneapi-Request-Id") logID := c.GetString("X-Oneapi-Request-Id")
return fmt.Sprintf("chatcmpl-%s", logID) return fmt.Sprintf("chatcmpl-%s", logID)
} }
func GenerateStopResponse(id string, createAt int64, model string, finishReason string) *dto.ChatCompletionsStreamResponse {
return &dto.ChatCompletionsStreamResponse{
Id: id,
Object: "chat.completion.chunk",
Created: createAt,
Model: model,
SystemFingerprint: nil,
Choices: []dto.ChatCompletionsStreamResponseChoice{
{
FinishReason: &finishReason,
},
},
}
}
func GenerateFinalUsageResponse(id string, createAt int64, model string, usage dto.Usage) *dto.ChatCompletionsStreamResponse {
return &dto.ChatCompletionsStreamResponse{
Id: id,
Object: "chat.completion.chunk",
Created: createAt,
Model: model,
SystemFingerprint: nil,
Choices: make([]dto.ChatCompletionsStreamResponseChoice, 0),
Usage: &usage,
}
}

View File

@@ -25,18 +25,6 @@ func ResponseText2Usage(responseText string, modeName string, promptTokens int)
return usage, err return usage, err
} }
func GenerateFinalUsageResponse(id string, createAt int64, model string, usage dto.Usage) *dto.ChatCompletionsStreamResponse {
return &dto.ChatCompletionsStreamResponse{
Id: id,
Object: "chat.completion.chunk",
Created: createAt,
Model: model,
SystemFingerprint: nil,
Choices: make([]dto.ChatCompletionsStreamResponseChoice, 0),
Usage: &usage,
}
}
func ValidUsage(usage *dto.Usage) bool { func ValidUsage(usage *dto.Usage) bool {
return usage != nil && (usage.PromptTokens != 0 || usage.CompletionTokens != 0) return usage != nil && (usage.PromptTokens != 0 || usage.CompletionTokens != 0)
} }