diff --git a/common/constants.go b/common/constants.go index 66cc10d..97e8583 100644 --- a/common/constants.go +++ b/common/constants.go @@ -212,6 +212,7 @@ const ( ChannelTypeSunoAPI = 36 ChannelTypeDify = 37 ChannelTypeJina = 38 + ChannelCloudflare = 39 ChannelTypeDummy // this one is only for count, do not add any channel after this @@ -257,4 +258,5 @@ var ChannelBaseURLs = []string{ "", //36 "", //37 "https://api.jina.ai", //38 + "https://api.cloudflare.com", //39 } diff --git a/controller/channel-test.go b/controller/channel-test.go index 000d7f2..268dac2 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -12,6 +12,7 @@ import ( "net/url" "one-api/common" "one-api/dto" + "one-api/middleware" "one-api/model" "one-api/relay" relaycommon "one-api/relay/common" @@ -40,29 +41,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr Body: nil, Header: make(http.Header), } - c.Request.Header.Set("Authorization", "Bearer "+channel.Key) - c.Request.Header.Set("Content-Type", "application/json") - c.Set("channel", channel.Type) - c.Set("base_url", channel.GetBaseURL()) - switch channel.Type { - case common.ChannelTypeAzure: - c.Set("api_version", channel.Other) - case common.ChannelTypeXunfei: - c.Set("api_version", channel.Other) - //case common.ChannelTypeAIProxyLibrary: - // c.Set("library_id", channel.Other) - case common.ChannelTypeGemini: - c.Set("api_version", channel.Other) - case common.ChannelTypeAli: - c.Set("plugin", channel.Other) - } - meta := relaycommon.GenRelayInfo(c) - apiType, _ := constant.ChannelType2APIType(channel.Type) - adaptor := relay.GetAdaptor(apiType) - if adaptor == nil { - return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil - } if testModel == "" { if channel.TestModel != nil && *channel.TestModel != "" { testModel = *channel.TestModel @@ -88,6 +67,20 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr } } + c.Request.Header.Set("Authorization", "Bearer "+channel.Key) + c.Request.Header.Set("Content-Type", "application/json") + c.Set("channel", channel.Type) + c.Set("base_url", channel.GetBaseURL()) + + middleware.SetupContextForSelectedChannel(c, channel, testModel) + + meta := relaycommon.GenRelayInfo(c) + apiType, _ := constant.ChannelType2APIType(channel.Type) + adaptor := relay.GetAdaptor(apiType) + if adaptor == nil { + return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil + } + request := buildTestRequest() request.Model = testModel meta.UpstreamModelName = testModel diff --git a/dto/text_request.go b/dto/text_request.go index e12c9b4..ed36988 100644 --- a/dto/text_request.go +++ b/dto/text_request.go @@ -48,8 +48,8 @@ type StreamOptions struct { IncludeUsage bool `json:"include_usage,omitempty"` } -func (r GeneralOpenAIRequest) GetMaxTokens() int64 { - return int64(r.MaxTokens) +func (r GeneralOpenAIRequest) GetMaxTokens() int { + return int(r.MaxTokens) } func (r GeneralOpenAIRequest) ParseInput() []string { diff --git a/middleware/distributor.go b/middleware/distributor.go index 61361e6..9f75207 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -198,11 +198,11 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("api_version", channel.Other) case common.ChannelTypeXunfei: c.Set("api_version", channel.Other) - //case common.ChannelTypeAIProxyLibrary: - // c.Set("library_id", channel.Other) case common.ChannelTypeGemini: c.Set("api_version", channel.Other) case common.ChannelTypeAli: c.Set("plugin", channel.Other) + case common.ChannelCloudflare: + c.Set("api_version", channel.Other) } } diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go new file mode 100644 index 0000000..571f50c --- /dev/null +++ b/relay/channel/cloudflare/adaptor.go @@ -0,0 +1,76 @@ +package cloudflare + +import ( + "errors" + "fmt" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/dto" + "one-api/relay/channel" + relaycommon "one-api/relay/common" + "one-api/relay/constant" +) + +type Adaptor struct { +} + +func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) { +} + +func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) { +} + +func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + switch info.RelayMode { + case constant.RelayModeChatCompletions: + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.BaseUrl, info.ApiVersion), nil + case constant.RelayModeEmbeddings: + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.BaseUrl, info.ApiVersion), nil + default: + return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.BaseUrl, info.ApiVersion, info.UpstreamModelName), nil + } +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { + channel.SetupApiRequestHeader(info, c, req) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) + return nil +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + switch relayMode { + case constant.RelayModeCompletions: + return convertCf2CompletionsRequest(*request), nil + default: + return request, nil + } +} + +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { + return channel.DoApiRequest(a, c, info, requestBody) +} + +func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { + return request, nil +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { + if info.IsStream { + err, usage = cfStreamHandler(c, resp, info) + } else { + err, usage = cfHandler(c, resp, info) + } + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return ChannelName +} diff --git a/relay/channel/cloudflare/constant.go b/relay/channel/cloudflare/constant.go new file mode 100644 index 0000000..a874685 --- /dev/null +++ b/relay/channel/cloudflare/constant.go @@ -0,0 +1,38 @@ +package cloudflare + +var ModelList = []string{ + "@cf/meta/llama-2-7b-chat-fp16", + "@cf/meta/llama-2-7b-chat-int8", + "@cf/mistral/mistral-7b-instruct-v0.1", + "@hf/thebloke/deepseek-coder-6.7b-base-awq", + "@hf/thebloke/deepseek-coder-6.7b-instruct-awq", + "@cf/deepseek-ai/deepseek-math-7b-base", + "@cf/deepseek-ai/deepseek-math-7b-instruct", + "@cf/thebloke/discolm-german-7b-v1-awq", + "@cf/tiiuae/falcon-7b-instruct", + "@cf/google/gemma-2b-it-lora", + "@hf/google/gemma-7b-it", + "@cf/google/gemma-7b-it-lora", + "@hf/nousresearch/hermes-2-pro-mistral-7b", + "@hf/thebloke/llama-2-13b-chat-awq", + "@cf/meta-llama/llama-2-7b-chat-hf-lora", + "@cf/meta/llama-3-8b-instruct", + "@hf/thebloke/llamaguard-7b-awq", + "@hf/thebloke/mistral-7b-instruct-v0.1-awq", + "@hf/mistralai/mistral-7b-instruct-v0.2", + "@cf/mistral/mistral-7b-instruct-v0.2-lora", + "@hf/thebloke/neural-chat-7b-v3-1-awq", + "@cf/openchat/openchat-3.5-0106", + "@hf/thebloke/openhermes-2.5-mistral-7b-awq", + "@cf/microsoft/phi-2", + "@cf/qwen/qwen1.5-0.5b-chat", + "@cf/qwen/qwen1.5-1.8b-chat", + "@cf/qwen/qwen1.5-14b-chat-awq", + "@cf/qwen/qwen1.5-7b-chat-awq", + "@cf/defog/sqlcoder-7b-2", + "@hf/nexusflow/starling-lm-7b-beta", + "@cf/tinyllama/tinyllama-1.1b-chat-v1.0", + "@hf/thebloke/zephyr-7b-beta-awq", +} + +var ChannelName = "cloudflare" diff --git a/relay/channel/cloudflare/model.go b/relay/channel/cloudflare/model.go new file mode 100644 index 0000000..c870813 --- /dev/null +++ b/relay/channel/cloudflare/model.go @@ -0,0 +1,13 @@ +package cloudflare + +import "one-api/dto" + +type CfRequest struct { + Messages []dto.Message `json:"messages,omitempty"` + Lora string `json:"lora,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Prompt string `json:"prompt,omitempty"` + Raw bool `json:"raw,omitempty"` + Stream bool `json:"stream,omitempty"` + Temperature float64 `json:"temperature,omitempty"` +} diff --git a/relay/channel/cloudflare/relay_cloudflare.go b/relay/channel/cloudflare/relay_cloudflare.go new file mode 100644 index 0000000..94a7ea0 --- /dev/null +++ b/relay/channel/cloudflare/relay_cloudflare.go @@ -0,0 +1,115 @@ +package cloudflare + +import ( + "bufio" + "encoding/json" + "github.com/gin-gonic/gin" + "io" + "net/http" + "one-api/common" + "one-api/dto" + relaycommon "one-api/relay/common" + "one-api/service" + "strings" +) + +func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest { + p, _ := textRequest.Prompt.(string) + return &CfRequest{ + Prompt: p, + MaxTokens: textRequest.GetMaxTokens(), + Stream: textRequest.Stream, + Temperature: textRequest.Temperature, + } +} + +func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + scanner := bufio.NewScanner(resp.Body) + scanner.Split(bufio.ScanLines) + + service.SetEventStreamHeaders(c) + id := service.GetResponseID(c) + var responseText string + + for scanner.Scan() { + data := scanner.Text() + if len(data) < len("data: ") { + continue + } + data = strings.TrimPrefix(data, "data: ") + data = strings.TrimSuffix(data, "\r") + + if data == "[DONE]" { + break + } + + var response dto.ChatCompletionsStreamResponse + err := json.Unmarshal([]byte(data), &response) + if err != nil { + common.LogError(c, "error_unmarshalling_stream_response: "+err.Error()) + continue + } + for _, choice := range response.Choices { + choice.Delta.Role = "assistant" + responseText += choice.Delta.GetContentString() + } + response.Id = id + response.Model = info.UpstreamModelName + err = service.ObjectData(c, response) + if err != nil { + common.LogError(c, "error_rendering_stream_response: "+err.Error()) + } + } + + if err := scanner.Err(); err != nil { + common.LogError(c, "error_scanning_stream_response: "+err.Error()) + } + usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + if info.ShouldIncludeUsage { + response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage) + err := service.ObjectData(c, response) + if err != nil { + common.LogError(c, "error_rendering_final_usage_response: "+err.Error()) + } + } + service.Done(c) + + err := resp.Body.Close() + if err != nil { + common.LogError(c, "close_response_body_failed: "+err.Error()) + } + + return nil, usage +} + +func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var response dto.TextResponse + err = json.Unmarshal(responseBody, &response) + if err != nil { + return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + response.Model = info.UpstreamModelName + var responseText string + for _, choice := range response.Choices { + responseText += choice.Message.StringContent() + } + usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) + response.Usage = *usage + response.Id = service.GetResponseID(c) + jsonResponse, err := json.Marshal(response) + if err != nil { + return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, _ = c.Writer.Write(jsonResponse) + return nil, usage +} diff --git a/relay/channel/cohere/dto.go b/relay/channel/cohere/dto.go index fc6c445..b2c2739 100644 --- a/relay/channel/cohere/dto.go +++ b/relay/channel/cohere/dto.go @@ -7,7 +7,7 @@ type CohereRequest struct { ChatHistory []ChatHistory `json:"chat_history"` Message string `json:"message"` Stream bool `json:"stream"` - MaxTokens int64 `json:"max_tokens"` + MaxTokens int `json:"max_tokens"` } type ChatHistory struct { diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 42c8381..e07434a 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -68,7 +68,8 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { info.ApiVersion = GetAPIVersion(c) } if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic || - info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini { + info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini || + info.ChannelType == common.ChannelCloudflare { info.SupportStreamOptions = true } return info diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index 0ce2657..6bd93c4 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -22,6 +22,7 @@ const ( APITypeCohere APITypeDify APITypeJina + APITypeCloudflare APITypeDummy // this one is only for count, do not add any channel after this ) @@ -63,6 +64,8 @@ func ChannelType2APIType(channelType int) (int, bool) { apiType = APITypeDify case common.ChannelTypeJina: apiType = APITypeJina + case common.ChannelCloudflare: + apiType = APITypeCloudflare } if apiType == -1 { return APITypeOpenAI, false diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 8998540..4c0aef1 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -7,6 +7,7 @@ import ( "one-api/relay/channel/aws" "one-api/relay/channel/baidu" "one-api/relay/channel/claude" + "one-api/relay/channel/cloudflare" "one-api/relay/channel/cohere" "one-api/relay/channel/dify" "one-api/relay/channel/gemini" @@ -59,6 +60,8 @@ func GetAdaptor(apiType int) channel.Adaptor { return &dify.Adaptor{} case constant.APITypeJina: return &jina.Adaptor{} + case constant.APITypeCloudflare: + return &cloudflare.Adaptor{} } return nil } diff --git a/service/sse.go b/service/relay.go similarity index 87% rename from service/sse.go rename to service/relay.go index 2d531a4..22f9ce3 100644 --- a/service/sse.go +++ b/service/relay.go @@ -35,3 +35,8 @@ func ObjectData(c *gin.Context, object interface{}) error { func Done(c *gin.Context) { StringData(c, "[DONE]") } + +func GetResponseID(c *gin.Context) string { + logID := c.GetString("X-Oneapi-Request-Id") + return fmt.Sprintf("chatcmpl-%s", logID) +} diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index ff1d281..88614b0 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -99,6 +99,7 @@ export const CHANNEL_OPTIONS = [ color: 'orange', label: 'Google PaLM2', }, + { key: 39, text: 'Cloudflare', value: 39, color: 'grey', label: 'Cloudflare' }, { key: 25, text: 'Moonshot', value: 25, color: 'green', label: 'Moonshot' }, { key: 19, text: '360 智脑', value: 19, color: 'blue', label: '360 智脑' }, { key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' }, diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index aec3768..900fdf3 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -605,6 +605,24 @@ const EditChannel = (props) => { /> )} + {inputs.type === 39 && ( + <> +
+ Account ID: +
+ { + handleInputChange('other', value); + }} + value={inputs.other} + autoComplete='new-password' + /> + + )}
模型: