From bfcaccc2e315f8f667d7d13705e63366287bb419 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Wed, 24 Apr 2024 18:49:56 +0800 Subject: [PATCH] feat: support cohere (close #195) --- common/constants.go | 3 +- dto/text_request.go | 4 + relay/channel/claude/relay-claude.go | 2 +- relay/channel/cohere/adaptor.go | 52 +++++++ relay/channel/cohere/constant.go | 7 + relay/channel/cohere/dto.go | 44 ++++++ relay/channel/cohere/relay-cohere.go | 189 +++++++++++++++++++++++++ relay/constant/api_type.go | 3 + relay/relay_adaptor.go | 3 + web/src/constants/channel.constants.js | 7 + web/src/pages/Channel/EditChannel.js | 10 ++ 11 files changed, 322 insertions(+), 2 deletions(-) create mode 100644 relay/channel/cohere/adaptor.go create mode 100644 relay/channel/cohere/constant.go create mode 100644 relay/channel/cohere/dto.go create mode 100644 relay/channel/cohere/relay-cohere.go diff --git a/common/constants.go b/common/constants.go index f0fb1d5..99c78ac 100644 --- a/common/constants.go +++ b/common/constants.go @@ -207,6 +207,7 @@ const ( ChannelTypePerplexity = 27 ChannelTypeLingYiWanWu = 31 ChannelTypeAws = 33 + ChannelTypeCohere = 34 ) var ChannelBaseURLs = []string{ @@ -244,5 +245,5 @@ var ChannelBaseURLs = []string{ "https://api.lingyiwanwu.com", //31 "", //32 "", //33 - + "https://api.cohere.ai", //34 } diff --git a/dto/text_request.go b/dto/text_request.go index cc2d92e..0f696fc 100644 --- a/dto/text_request.go +++ b/dto/text_request.go @@ -43,6 +43,10 @@ type OpenAIFunction struct { Parameters any `json:"parameters,omitempty"` } +func (r GeneralOpenAIRequest) GetMaxTokens() int64 { + return int64(r.MaxTokens) +} + func (r GeneralOpenAIRequest) ParseInput() []string { if r.Input == nil { return nil diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index bd5b7ef..33e742a 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -20,7 +20,7 @@ func stopReasonClaude2OpenAI(reason string) string { case "end_turn": return "stop" case "max_tokens": - return "length" + return "max_tokens" default: return reason } diff --git a/relay/channel/cohere/adaptor.go b/relay/channel/cohere/adaptor.go new file mode 100644 index 0000000..44b7f38 --- /dev/null +++ b/relay/channel/cohere/adaptor.go @@ -0,0 +1,52 @@ +package cohere + +import ( + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + "one-api/relay/channel" + relaycommon "one-api/relay/common" +) + +type Adaptor struct { +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { + channel.SetupApiRequestHeader(info, c, req) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { + return requestOpenAI2Cohere(*request), nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { + return channel.DoApiRequest(a, c, info, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { + if info.IsStream { + err, usage = cohereStreamHandler(c, resp, info.UpstreamModelName, info.PromptTokens) + } else { + err, usage = cohereHandler(c, resp, info.UpstreamModelName, info.PromptTokens) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/cohere/constant.go b/relay/channel/cohere/constant.go new file mode 100644 index 0000000..c8e173e --- /dev/null +++ b/relay/channel/cohere/constant.go @@ -0,0 +1,7 @@ +package cohere + +var ModelList = []string{ + "command-r", "command-r-plus", "command-light", "command-light-nightly", "command", "command-nightly", +} + +var ChannelName = "choere" diff --git a/relay/channel/cohere/dto.go b/relay/channel/cohere/dto.go new file mode 100644 index 0000000..958343c --- /dev/null +++ b/relay/channel/cohere/dto.go @@ -0,0 +1,44 @@ +package cohere + +type CohereRequest struct { + Model string `json:"model"` + ChatHistory []ChatHistory `json:"chat_history"` + Message string `json:"message"` + Stream bool `json:"stream"` + MaxTokens int64 `json:"max_tokens"` +} + +type ChatHistory struct { + Role string `json:"role"` + Message string `json:"message"` +} + +type CohereResponse struct { + IsFinished bool `json:"is_finished"` + EventType string `json:"event_type"` + Text string `json:"text,omitempty"` + FinishReason string `json:"finish_reason,omitempty"` + Response *CohereResponseResult `json:"response"` +} + +type CohereResponseResult struct { + ResponseId string `json:"response_id"` + FinishReason string `json:"finish_reason,omitempty"` + Text string `json:"text"` + Meta CohereMeta `json:"meta"` +} + +type CohereMeta struct { + //Tokens CohereTokens `json:"tokens"` + BilledUnits CohereBilledUnits `json:"billed_units"` +} + +type CohereBilledUnits struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +type CohereTokens struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} diff --git a/relay/channel/cohere/relay-cohere.go b/relay/channel/cohere/relay-cohere.go new file mode 100644 index 0000000..a21d4a9 --- /dev/null +++ b/relay/channel/cohere/relay-cohere.go @@ -0,0 +1,189 @@ +package cohere + +import ( + "bufio" + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "one-api/dto" + "one-api/service" + "strings" +) + +func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest { + cohereReq := CohereRequest{ + Model: textRequest.Model, + ChatHistory: []ChatHistory{}, + Message: "", + Stream: textRequest.Stream, + MaxTokens: textRequest.GetMaxTokens(), + } + if cohereReq.MaxTokens == 0 { + cohereReq.MaxTokens = 4000 + } + for _, msg := range textRequest.Messages { + if msg.Role == "user" { + cohereReq.Message = msg.StringContent() + } else { + var role string + if msg.Role == "assistant" { + role = "CHATBOT" + } else if msg.Role == "system" { + role = "SYSTEM" + } else { + role = "USER" + } + cohereReq.ChatHistory = append(cohereReq.ChatHistory, ChatHistory{ + Role: role, + Message: msg.StringContent(), + }) + } + } + return &cohereReq +} + +func stopReasonCohere2OpenAI(reason string) string { + switch reason { + case "COMPLETE": + return "stop" + case "MAX_TOKENS": + return "max_tokens" + default: + return reason + } +} + +func cohereStreamHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) + createdTime := common.GetTimestamp() + usage := &dto.Usage{} + responseText := "" + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil + }) + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + for scanner.Scan() { + data := scanner.Text() + dataChan <- data + } + stopChan <- true + }() + service.SetEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + data = strings.TrimSuffix(data, "\r") + var cohereResp CohereResponse + err := json.Unmarshal([]byte(data), &cohereResp) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + return true + } + var openaiResp dto.ChatCompletionsStreamResponse + openaiResp.Id = responseId + openaiResp.Created = createdTime + openaiResp.Object = "chat.completion.chunk" + openaiResp.Model = modelName + if cohereResp.IsFinished { + finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason) + openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{ + { + Delta: dto.ChatCompletionsStreamResponseChoiceDelta{}, + Index: 0, + FinishReason: &finishReason, + }, + } + if cohereResp.Response != nil { + usage.PromptTokens = cohereResp.Response.Meta.BilledUnits.InputTokens + usage.CompletionTokens = cohereResp.Response.Meta.BilledUnits.OutputTokens + } + } else { + openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{ + { + Delta: dto.ChatCompletionsStreamResponseChoiceDelta{ + Role: "assistant", + Content: cohereResp.Text, + }, + Index: 0, + }, + } + responseText += cohereResp.Text + } + jsonStr, err := json.Marshal(openaiResp) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + if usage.PromptTokens == 0 { + usage, _ = service.ResponseText2Usage(responseText, modelName, promptTokens) + } + return nil, usage +} + +func cohereHandler(c *gin.Context, resp *http.Response, modelName string, promptTokens int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + createdTime := common.GetTimestamp() + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var cohereResp CohereResponseResult + err = json.Unmarshal(responseBody, &cohereResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + usage := dto.Usage{} + usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens + usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens + usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens + + var openaiResp dto.TextResponse + openaiResp.Id = cohereResp.ResponseId + openaiResp.Created = createdTime + openaiResp.Object = "chat.completion" + openaiResp.Model = modelName + openaiResp.Usage = usage + + content, _ := json.Marshal(cohereResp.Text) + openaiResp.Choices = []dto.OpenAITextResponseChoice{ + { + Index: 0, + Message: dto.Message{Content: content, Role: "assistant"}, + FinishReason: stopReasonCohere2OpenAI(cohereResp.FinishReason), + }, + } + + jsonResponse, err := json.Marshal(openaiResp) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &usage +} diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index 8ee6a99..7f11ae2 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -19,6 +19,7 @@ const ( APITypeOllama APITypePerplexity APITypeAws + APITypeCohere APITypeDummy // this one is only for count, do not add any channel after this ) @@ -52,6 +53,8 @@ func ChannelType2APIType(channelType int) int { apiType = APITypePerplexity case common.ChannelTypeAws: apiType = APITypeAws + case common.ChannelTypeCohere: + apiType = APITypeCohere } return apiType } diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 867fd53..01e9cec 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -6,6 +6,7 @@ import ( "one-api/relay/channel/aws" "one-api/relay/channel/baidu" "one-api/relay/channel/claude" + "one-api/relay/channel/cohere" "one-api/relay/channel/gemini" "one-api/relay/channel/ollama" "one-api/relay/channel/openai" @@ -48,6 +49,8 @@ func GetAdaptor(apiType int) channel.Adaptor { return &perplexity.Adaptor{} case constant.APITypeAws: return &aws.Adaptor{} + case constant.APITypeCohere: + return &cohere.Adaptor{} } return nil } diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index dd9561e..c6f0899 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -50,6 +50,13 @@ export const CHANNEL_OPTIONS = [ color: 'orange', label: 'Google Gemini', }, + { + key: 34, + text: 'Cohere', + value: 34, + color: 'purple', + label: 'Cohere', + }, { key: 15, text: '百度文心千帆', diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 8da2e30..cc8707d 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -155,6 +155,16 @@ const EditChannel = (props) => { 'gemini-pro-vision', ]; break; + case 34: + localModels = [ + 'command-r', + 'command-r-plus', + 'command-light', + 'command-light-nightly', + 'command', + 'command-nightly', + ]; + break; case 25: localModels = [ 'moonshot-v1-8k',