diff --git a/relay/channel/cohere/adaptor.go b/relay/channel/cohere/adaptor.go index 44b7f38..cd01634 100644 --- a/relay/channel/cohere/adaptor.go +++ b/relay/channel/cohere/adaptor.go @@ -36,7 +36,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { - err, usage = cohereStreamHandler(c, resp, info.UpstreamModelName, info.PromptTokens) + err, usage = cohereStreamHandler(c, resp, info) } else { err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens) } diff --git a/relay/channel/cohere/relay-cohere.go b/relay/channel/cohere/relay-cohere.go index 463e8b1..cc424b0 100644 --- a/relay/channel/cohere/relay-cohere.go +++ b/relay/channel/cohere/relay-cohere.go @@ -9,8 +9,10 @@ import ( "net/http" "one-api/common" "one-api/dto" + relaycommon "one-api/relay/common" "one-api/service" "strings" + "time" ) func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest { @@ -56,7 +58,7 @@ func stopReasonCohere2OpenAI(reason string) string { } } -func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) createdTime := common.GetTimestamp() usage := &dto.Usage{} @@ -84,9 +86,14 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string, stopChan <- true }() service.SetEventStreamHeaders(c) + isFirst := true c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: + if isFirst { + isFirst = false + info.FirstResponseTime = time.Now() + } data = strings.TrimSuffix(data, "\r") var cohereResp CohereResponse err := json.Unmarshal([]byte(data), &cohereResp) @@ -98,7 +105,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string, openaiResp.Id = responseId openaiResp.Created = createdTime openaiResp.Object = "chat.completion.chunk" - openaiResp.Model = modelName + openaiResp.Model = info.UpstreamModelName if cohereResp.IsFinished { finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason) openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{ @@ -137,7 +144,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string, } }) if usage.PromptTokens == 0 { - usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens) + usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } return nil, usage }