mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-17 16:06:38 +08:00
250 lines
7.5 KiB
Go
250 lines
7.5 KiB
Go
package cohere
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/json"
|
|
"fmt"
|
|
"github.com/gin-gonic/gin"
|
|
"io"
|
|
"net/http"
|
|
"one-api/common"
|
|
"one-api/dto"
|
|
relaycommon "one-api/relay/common"
|
|
"one-api/service"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
|
|
cohereReq := CohereRequest{
|
|
Model: textRequest.Model,
|
|
ChatHistory: []ChatHistory{},
|
|
Message: "",
|
|
Stream: textRequest.Stream,
|
|
MaxTokens: textRequest.GetMaxTokens(),
|
|
}
|
|
if cohereReq.MaxTokens == 0 {
|
|
cohereReq.MaxTokens = 4000
|
|
}
|
|
for _, msg := range textRequest.Messages {
|
|
if msg.Role == "user" {
|
|
cohereReq.Message = msg.StringContent()
|
|
} else {
|
|
var role string
|
|
if msg.Role == "assistant" {
|
|
role = "CHATBOT"
|
|
} else if msg.Role == "system" {
|
|
role = "SYSTEM"
|
|
} else {
|
|
role = "USER"
|
|
}
|
|
cohereReq.ChatHistory = append(cohereReq.ChatHistory, ChatHistory{
|
|
Role: role,
|
|
Message: msg.StringContent(),
|
|
})
|
|
}
|
|
}
|
|
return &cohereReq
|
|
}
|
|
|
|
func requestConvertRerank2Cohere(rerankRequest dto.RerankRequest) *CohereRerankRequest {
|
|
if rerankRequest.TopN == 0 {
|
|
rerankRequest.TopN = 1
|
|
}
|
|
cohereReq := CohereRerankRequest{
|
|
Query: rerankRequest.Query,
|
|
Documents: rerankRequest.Documents,
|
|
Model: rerankRequest.Model,
|
|
TopN: rerankRequest.TopN,
|
|
ReturnDocuments: true,
|
|
}
|
|
return &cohereReq
|
|
}
|
|
|
|
func stopReasonCohere2OpenAI(reason string) string {
|
|
switch reason {
|
|
case "COMPLETE":
|
|
return "stop"
|
|
case "MAX_TOKENS":
|
|
return "max_tokens"
|
|
default:
|
|
return reason
|
|
}
|
|
}
|
|
|
|
func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
|
createdTime := common.GetTimestamp()
|
|
usage := &dto.Usage{}
|
|
responseText := ""
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
|
if atEOF && len(data) == 0 {
|
|
return 0, nil, nil
|
|
}
|
|
if i := strings.Index(string(data), "\n"); i >= 0 {
|
|
return i + 1, data[0:i], nil
|
|
}
|
|
if atEOF {
|
|
return len(data), data, nil
|
|
}
|
|
return 0, nil, nil
|
|
})
|
|
dataChan := make(chan string)
|
|
stopChan := make(chan bool)
|
|
go func() {
|
|
for scanner.Scan() {
|
|
data := scanner.Text()
|
|
dataChan <- data
|
|
}
|
|
stopChan <- true
|
|
}()
|
|
service.SetEventStreamHeaders(c)
|
|
isFirst := true
|
|
c.Stream(func(w io.Writer) bool {
|
|
select {
|
|
case data := <-dataChan:
|
|
if isFirst {
|
|
isFirst = false
|
|
info.FirstResponseTime = time.Now()
|
|
}
|
|
data = strings.TrimSuffix(data, "\r")
|
|
var cohereResp CohereResponse
|
|
err := json.Unmarshal([]byte(data), &cohereResp)
|
|
if err != nil {
|
|
common.SysError("error unmarshalling stream response: " + err.Error())
|
|
return true
|
|
}
|
|
var openaiResp dto.ChatCompletionsStreamResponse
|
|
openaiResp.Id = responseId
|
|
openaiResp.Created = createdTime
|
|
openaiResp.Object = "chat.completion.chunk"
|
|
openaiResp.Model = info.UpstreamModelName
|
|
if cohereResp.IsFinished {
|
|
finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
|
|
openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
|
|
{
|
|
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{},
|
|
Index: 0,
|
|
FinishReason: &finishReason,
|
|
},
|
|
}
|
|
if cohereResp.Response != nil {
|
|
usage.PromptTokens = cohereResp.Response.Meta.BilledUnits.InputTokens
|
|
usage.CompletionTokens = cohereResp.Response.Meta.BilledUnits.OutputTokens
|
|
}
|
|
} else {
|
|
openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
|
|
{
|
|
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
|
|
Role: "assistant",
|
|
Content: &cohereResp.Text,
|
|
},
|
|
Index: 0,
|
|
},
|
|
}
|
|
responseText += cohereResp.Text
|
|
}
|
|
jsonStr, err := json.Marshal(openaiResp)
|
|
if err != nil {
|
|
common.SysError("error marshalling stream response: " + err.Error())
|
|
return true
|
|
}
|
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
|
|
return true
|
|
case <-stopChan:
|
|
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
return false
|
|
}
|
|
})
|
|
if usage.PromptTokens == 0 {
|
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
|
}
|
|
return nil, usage
|
|
}
|
|
|
|
func cohereHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
createdTime := common.GetTimestamp()
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
err = resp.Body.Close()
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
var cohereResp CohereResponseResult
|
|
err = json.Unmarshal(responseBody, &cohereResp)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
usage := dto.Usage{}
|
|
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
|
|
usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
|
|
usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
|
|
|
|
var openaiResp dto.TextResponse
|
|
openaiResp.Id = cohereResp.ResponseId
|
|
openaiResp.Created = createdTime
|
|
openaiResp.Object = "chat.completion"
|
|
openaiResp.Model = modelName
|
|
openaiResp.Usage = usage
|
|
|
|
content, _ := json.Marshal(cohereResp.Text)
|
|
openaiResp.Choices = []dto.OpenAITextResponseChoice{
|
|
{
|
|
Index: 0,
|
|
Message: dto.Message{Content: content, Role: "assistant"},
|
|
FinishReason: stopReasonCohere2OpenAI(cohereResp.FinishReason),
|
|
},
|
|
}
|
|
|
|
jsonResponse, err := json.Marshal(openaiResp)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
c.Writer.Header().Set("Content-Type", "application/json")
|
|
c.Writer.WriteHeader(resp.StatusCode)
|
|
_, err = c.Writer.Write(jsonResponse)
|
|
return nil, &usage
|
|
}
|
|
|
|
func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
err = resp.Body.Close()
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
var cohereResp CohereRerankResponseResult
|
|
err = json.Unmarshal(responseBody, &cohereResp)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
usage := dto.Usage{}
|
|
if cohereResp.Meta.BilledUnits.InputTokens == 0 {
|
|
usage.PromptTokens = info.PromptTokens
|
|
usage.CompletionTokens = 0
|
|
usage.TotalTokens = info.PromptTokens
|
|
} else {
|
|
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
|
|
usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
|
|
usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
|
|
}
|
|
|
|
var rerankResp dto.RerankResponse
|
|
rerankResp.Results = cohereResp.Results
|
|
rerankResp.Usage = usage
|
|
|
|
jsonResponse, err := json.Marshal(rerankResp)
|
|
if err != nil {
|
|
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
|
}
|
|
c.Writer.Header().Set("Content-Type", "application/json")
|
|
c.Writer.WriteHeader(resp.StatusCode)
|
|
_, err = c.Writer.Write(jsonResponse)
|
|
return nil, &usage
|
|
}
|