feat: support cohere first response time

This commit is contained in:
CalciumIon 2024-06-28 23:32:02 +08:00
parent d767ae04ff
commit a7e3168c17
2 changed files with 11 additions and 4 deletions

View File

@ -36,7 +36,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
if info.IsStream { if info.IsStream {
err, usage = cohereStreamHandler(c, resp, info.UpstreamModelName, info.PromptTokens) err, usage = cohereStreamHandler(c, resp, info)
} else { } else {
err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens) err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens)
} }

View File

@ -9,8 +9,10 @@ import (
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service" "one-api/service"
"strings" "strings"
"time"
) )
func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest { func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
@ -56,7 +58,7 @@ func stopReasonCohere2OpenAI(reason string) string {
} }
} }
func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
createdTime := common.GetTimestamp() createdTime := common.GetTimestamp()
usage := &dto.Usage{} usage := &dto.Usage{}
@ -84,9 +86,14 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string,
stopChan <- true stopChan <- true
}() }()
service.SetEventStreamHeaders(c) service.SetEventStreamHeaders(c)
isFirst := true
c.Stream(func(w io.Writer) bool { c.Stream(func(w io.Writer) bool {
select { select {
case data := <-dataChan: case data := <-dataChan:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
data = strings.TrimSuffix(data, "\r") data = strings.TrimSuffix(data, "\r")
var cohereResp CohereResponse var cohereResp CohereResponse
err := json.Unmarshal([]byte(data), &cohereResp) err := json.Unmarshal([]byte(data), &cohereResp)
@ -98,7 +105,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string,
openaiResp.Id = responseId openaiResp.Id = responseId
openaiResp.Created = createdTime openaiResp.Created = createdTime
openaiResp.Object = "chat.completion.chunk" openaiResp.Object = "chat.completion.chunk"
openaiResp.Model = modelName openaiResp.Model = info.UpstreamModelName
if cohereResp.IsFinished { if cohereResp.IsFinished {
finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason) finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{ openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
@ -137,7 +144,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string,
} }
}) })
if usage.PromptTokens == 0 { if usage.PromptTokens == 0 {
usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} }
return nil, usage return nil, usage
} }