package cohere import ( "bufio" "encoding/json" "fmt" "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "one-api/dto" relaycommon "one-api/relay/common" "one-api/service" "strings" "time" ) func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest { cohereReq := CohereRequest{ Model: textRequest.Model, ChatHistory: []ChatHistory{}, Message: "", Stream: textRequest.Stream, MaxTokens: textRequest.GetMaxTokens(), } if common.CohereSafetySetting != "NONE" { cohereReq.SafetyMode = common.CohereSafetySetting } if cohereReq.MaxTokens == 0 { cohereReq.MaxTokens = 4000 } for _, msg := range textRequest.Messages { if msg.Role == "user" { cohereReq.Message = msg.StringContent() } else { var role string if msg.Role == "assistant" { role = "CHATBOT" } else if msg.Role == "system" { role = "SYSTEM" } else { role = "USER" } cohereReq.ChatHistory = append(cohereReq.ChatHistory, ChatHistory{ Role: role, Message: msg.StringContent(), }) } } return &cohereReq } func requestConvertRerank2Cohere(rerankRequest dto.RerankRequest) *CohereRerankRequest { if rerankRequest.TopN == 0 { rerankRequest.TopN = 1 } cohereReq := CohereRerankRequest{ Query: rerankRequest.Query, Documents: rerankRequest.Documents, Model: rerankRequest.Model, TopN: rerankRequest.TopN, ReturnDocuments: true, } return &cohereReq } func stopReasonCohere2OpenAI(reason string) string { switch reason { case "COMPLETE": return "stop" case "MAX_TOKENS": return "max_tokens" default: return reason } } func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) createdTime := common.GetTimestamp() usage := &dto.Usage{} responseText := "" scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { return 0, nil, nil } if i := strings.Index(string(data), "\n"); i >= 0 { return i + 1, data[0:i], nil } if atEOF { return len(data), data, nil } return 0, nil, nil }) dataChan := make(chan string) stopChan := make(chan bool) go func() { for scanner.Scan() { data := scanner.Text() dataChan <- data } stopChan <- true }() service.SetEventStreamHeaders(c) isFirst := true c.Stream(func(w io.Writer) bool { select { case data := <-dataChan: if isFirst { isFirst = false info.FirstResponseTime = time.Now() } data = strings.TrimSuffix(data, "\r") var cohereResp CohereResponse err := json.Unmarshal([]byte(data), &cohereResp) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) return true } var openaiResp dto.ChatCompletionsStreamResponse openaiResp.Id = responseId openaiResp.Created = createdTime openaiResp.Object = "chat.completion.chunk" openaiResp.Model = info.UpstreamModelName if cohereResp.IsFinished { finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason) openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{ { Delta: dto.ChatCompletionsStreamResponseChoiceDelta{}, Index: 0, FinishReason: &finishReason, }, } if cohereResp.Response != nil { usage.PromptTokens = cohereResp.Response.Meta.BilledUnits.InputTokens usage.CompletionTokens = cohereResp.Response.Meta.BilledUnits.OutputTokens } } else { openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{ { Delta: dto.ChatCompletionsStreamResponseChoiceDelta{ Role: "assistant", Content: &cohereResp.Text, }, Index: 0, }, } responseText += cohereResp.Text } jsonStr, err := json.Marshal(openaiResp) if err != nil { common.SysError("error marshalling stream response: " + err.Error()) return true } c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) return true case <-stopChan: c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) return false } }) if usage.PromptTokens == 0 { usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } return nil, usage } func cohereHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { createdTime := common.GetTimestamp() responseBody, err := io.ReadAll(resp.Body) if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } var cohereResp CohereResponseResult err = json.Unmarshal(responseBody, &cohereResp) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } usage := dto.Usage{} usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens var openaiResp dto.TextResponse openaiResp.Id = cohereResp.ResponseId openaiResp.Created = createdTime openaiResp.Object = "chat.completion" openaiResp.Model = modelName openaiResp.Usage = usage content, _ := json.Marshal(cohereResp.Text) openaiResp.Choices = []dto.OpenAITextResponseChoice{ { Index: 0, Message: dto.Message{Content: content, Role: "assistant"}, FinishReason: stopReasonCohere2OpenAI(cohereResp.FinishReason), }, } jsonResponse, err := json.Marshal(openaiResp) if err != nil { return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) _, err = c.Writer.Write(jsonResponse) return nil, &usage } func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { responseBody, err := io.ReadAll(resp.Body) if err != nil { return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil } err = resp.Body.Close() if err != nil { return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil } var cohereResp CohereRerankResponseResult err = json.Unmarshal(responseBody, &cohereResp) if err != nil { return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil } usage := dto.Usage{} if cohereResp.Meta.BilledUnits.InputTokens == 0 { usage.PromptTokens = info.PromptTokens usage.CompletionTokens = 0 usage.TotalTokens = info.PromptTokens } else { usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens } var rerankResp dto.RerankResponse rerankResp.Results = cohereResp.Results rerankResp.Usage = usage jsonResponse, err := json.Marshal(rerankResp) if err != nil { return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil } c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(resp.StatusCode) _, err = c.Writer.Write(jsonResponse) return nil, &usage }