From f2654692e8a8c476de3682b84e2db1de9279eb00 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Thu, 27 Jun 2024 00:16:39 +0800 Subject: [PATCH] feat: first response time support gemini and claude --- relay/channel/claude/adaptor.go | 2 +- relay/channel/claude/relay-claude.go | 19 ++++++++++----- relay/channel/gemini/adaptor.go | 36 ++++++++++++++-------------- relay/channel/gemini/relay-gemini.go | 8 ++++++- relay/common/relay_info.go | 30 ++++++++++++----------- 5 files changed, 55 insertions(+), 40 deletions(-) diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index 9add208..d302265 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -65,7 +65,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { - err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp) + err, usage = claudeStreamHandler(c, resp, info, a.RequestMode) } else { err, usage = claudeHandler(a.RequestMode, c, resp, info.PromptTokens, info.UpstreamModelName) } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index ca53433..3818af3 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -9,8 +9,10 @@ import ( "net/http" "one-api/common" "one-api/dto" + relaycommon "one-api/relay/common" "one-api/service" "strings" + "time" ) func stopReasonClaude2OpenAI(reason string) string { @@ -246,7 +248,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *ClaudeResponse) *dto.Ope return &fullTextResponse } -func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func claudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) var usage *dto.Usage usage = &dto.Usage{} @@ -278,10 +280,15 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c } stopChan <- true }() + isFirst := true service.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: + if isFirst { + isFirst = false + info.FirstResponseTime = time.Now() + } // some implementations may add \r at the end of data data = strings.TrimSuffix(data, "\r") var claudeResponse ClaudeResponse @@ -302,7 +309,7 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c if claudeResponse.Type == "message_start" { // message_start, 获取usage responseId = claudeResponse.Message.Id - modelName = claudeResponse.Message.Model + info.UpstreamModelName = claudeResponse.Message.Model usage.PromptTokens = claudeUsage.InputTokens } else if claudeResponse.Type == "content_block_delta" { responseText += claudeResponse.Delta.Text @@ -316,7 +323,7 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c //response.Id = responseId response.Id = responseId response.Created = createdTime - response.Model = modelName + response.Model = info.UpstreamModelName jsonStr, err := json.Marshal(response) if err != nil { @@ -335,13 +342,13 @@ func claudeStreamHandler(requestMode int, modelName string, promptTokens int, c return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } if requestMode == RequestModeCompletion { - usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens) + usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { if usage.PromptTokens == 0 { - usage.PromptTokens = promptTokens + usage.PromptTokens = info.PromptTokens } if usage.CompletionTokens == 0 { - usage, _ = service.ResponseText2Usage(responseText, modelName, usage.PromptTokens) + usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, usage.PromptTokens) } } return nil, usage diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index d372d82..875361c 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -20,27 +20,27 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIReq // 定义一个映射,存储模型名称和对应的版本 var modelVersionMap = map[string]string{ - "gemini-1.5-pro-latest": "v1beta", - "gemini-1.5-flash-latest": "v1beta", - "gemini-ultra": "v1beta", + "gemini-1.5-pro-latest": "v1beta", + "gemini-1.5-flash-latest": "v1beta", + "gemini-ultra": "v1beta", } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - // 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1" - version, beta := modelVersionMap[info.UpstreamModelName] - if !beta { - if info.ApiVersion != "" { - version = info.ApiVersion - } else { - version = "v1" - } - } + // 从映射中获取模型名称对应的版本,如果找不到就使用 info.ApiVersion 或默认的版本 "v1" + version, beta := modelVersionMap[info.UpstreamModelName] + if !beta { + if info.ApiVersion != "" { + version = info.ApiVersion + } else { + version = "v1" + } + } - action := "generateContent" - if info.IsStream { - action = "streamGenerateContent" - } - return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil + action := "generateContent" + if info.IsStream { + action = "streamGenerateContent" + } + return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { @@ -63,7 +63,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { var responseText string - err, responseText = geminiChatStreamHandler(c, resp) + err, responseText = geminiChatStreamHandler(c, resp, info) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { err, usage = geminiChatHandler(c, resp, info.PromptTokens, info.UpstreamModelName) diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 6d45b57..2ba5cef 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -11,6 +11,7 @@ import ( relaycommon "one-api/relay/common" "one-api/service" "strings" + "time" "github.com/gin-gonic/gin" ) @@ -160,7 +161,7 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch return &response } -func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, string) { +func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string) { responseText := "" dataChan := make(chan string) stopChan := make(chan bool) @@ -190,10 +191,15 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIEr } stopChan <- true }() + isFirst := true service.SetEventStreamHeaders(c) c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: + if isFirst { + isFirst = false + info.FirstResponseTime = time.Now() + } // this is used to prevent annoying \ related format bug data = fmt.Sprintf("{\"content\": \"%s\"}", data) type dummyStruct struct { diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 2a6872e..fc01d17 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -38,24 +38,26 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { group := c.GetString("group") tokenUnlimited := c.GetBool("token_unlimited_quota") startTime := time.Now() + // firstResponseTime = time.Now() - 1 second apiType, _ := constant.ChannelType2APIType(channelType) info := &RelayInfo{ - RelayMode: constant.Path2RelayMode(c.Request.URL.Path), - BaseUrl: c.GetString("base_url"), - RequestURLPath: c.Request.URL.String(), - ChannelType: channelType, - ChannelId: channelId, - TokenId: tokenId, - UserId: userId, - Group: group, - TokenUnlimited: tokenUnlimited, - StartTime: startTime, - ApiType: apiType, - ApiVersion: c.GetString("api_version"), - ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), - Organization: c.GetString("channel_organization"), + RelayMode: constant.Path2RelayMode(c.Request.URL.Path), + BaseUrl: c.GetString("base_url"), + RequestURLPath: c.Request.URL.String(), + ChannelType: channelType, + ChannelId: channelId, + TokenId: tokenId, + UserId: userId, + Group: group, + TokenUnlimited: tokenUnlimited, + StartTime: startTime, + FirstResponseTime: startTime.Add(-time.Second), + ApiType: apiType, + ApiVersion: c.GetString("api_version"), + ApiKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), + Organization: c.GetString("channel_organization"), } if info.BaseUrl == "" { info.BaseUrl = common.ChannelBaseURLs[channelType]