mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-19 00:46:37 +08:00
feat: support cloudflare worker ai
This commit is contained in:
parent
c88f3741e6
commit
7b36a2b885
@ -212,6 +212,7 @@ const (
|
||||
ChannelTypeSunoAPI = 36
|
||||
ChannelTypeDify = 37
|
||||
ChannelTypeJina = 38
|
||||
ChannelCloudflare = 39
|
||||
|
||||
ChannelTypeDummy // this one is only for count, do not add any channel after this
|
||||
|
||||
@ -257,4 +258,5 @@ var ChannelBaseURLs = []string{
|
||||
"", //36
|
||||
"", //37
|
||||
"https://api.jina.ai", //38
|
||||
"https://api.cloudflare.com", //39
|
||||
}
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/middleware"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
relaycommon "one-api/relay/common"
|
||||
@ -40,29 +41,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
||||
Body: nil,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
switch channel.Type {
|
||||
case common.ChannelTypeAzure:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeXunfei:
|
||||
c.Set("api_version", channel.Other)
|
||||
//case common.ChannelTypeAIProxyLibrary:
|
||||
// c.Set("library_id", channel.Other)
|
||||
case common.ChannelTypeGemini:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeAli:
|
||||
c.Set("plugin", channel.Other)
|
||||
}
|
||||
|
||||
meta := relaycommon.GenRelayInfo(c)
|
||||
apiType, _ := constant.ChannelType2APIType(channel.Type)
|
||||
adaptor := relay.GetAdaptor(apiType)
|
||||
if adaptor == nil {
|
||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||
}
|
||||
if testModel == "" {
|
||||
if channel.TestModel != nil && *channel.TestModel != "" {
|
||||
testModel = *channel.TestModel
|
||||
@ -88,6 +67,20 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
|
||||
}
|
||||
}
|
||||
|
||||
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
|
||||
middleware.SetupContextForSelectedChannel(c, channel, testModel)
|
||||
|
||||
meta := relaycommon.GenRelayInfo(c)
|
||||
apiType, _ := constant.ChannelType2APIType(channel.Type)
|
||||
adaptor := relay.GetAdaptor(apiType)
|
||||
if adaptor == nil {
|
||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||
}
|
||||
|
||||
request := buildTestRequest()
|
||||
request.Model = testModel
|
||||
meta.UpstreamModelName = testModel
|
||||
|
@ -48,8 +48,8 @@ type StreamOptions struct {
|
||||
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||
}
|
||||
|
||||
func (r GeneralOpenAIRequest) GetMaxTokens() int64 {
|
||||
return int64(r.MaxTokens)
|
||||
func (r GeneralOpenAIRequest) GetMaxTokens() int {
|
||||
return int(r.MaxTokens)
|
||||
}
|
||||
|
||||
func (r GeneralOpenAIRequest) ParseInput() []string {
|
||||
|
@ -198,11 +198,11 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeXunfei:
|
||||
c.Set("api_version", channel.Other)
|
||||
//case common.ChannelTypeAIProxyLibrary:
|
||||
// c.Set("library_id", channel.Other)
|
||||
case common.ChannelTypeGemini:
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeAli:
|
||||
c.Set("plugin", channel.Other)
|
||||
case common.ChannelCloudflare:
|
||||
c.Set("api_version", channel.Other)
|
||||
}
|
||||
}
|
||||
|
76
relay/channel/cloudflare/adaptor.go
Normal file
76
relay/channel/cloudflare/adaptor.go
Normal file
@ -0,0 +1,76 @@
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
)
|
||||
|
||||
type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) InitRerank(info *relaycommon.RelayInfo, request dto.RerankRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) {
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
switch info.RelayMode {
|
||||
case constant.RelayModeChatCompletions:
|
||||
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.BaseUrl, info.ApiVersion), nil
|
||||
case constant.RelayModeEmbeddings:
|
||||
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.BaseUrl, info.ApiVersion), nil
|
||||
default:
|
||||
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.BaseUrl, info.ApiVersion, info.UpstreamModelName), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) {
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
switch relayMode {
|
||||
case constant.RelayModeCompletions:
|
||||
return convertCf2CompletionsRequest(*request), nil
|
||||
default:
|
||||
return request, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.IsStream {
|
||||
err, usage = cfStreamHandler(c, resp, info)
|
||||
} else {
|
||||
err, usage = cfHandler(c, resp, info)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetModelList() []string {
|
||||
return ModelList
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
38
relay/channel/cloudflare/constant.go
Normal file
38
relay/channel/cloudflare/constant.go
Normal file
@ -0,0 +1,38 @@
|
||||
package cloudflare
|
||||
|
||||
var ModelList = []string{
|
||||
"@cf/meta/llama-2-7b-chat-fp16",
|
||||
"@cf/meta/llama-2-7b-chat-int8",
|
||||
"@cf/mistral/mistral-7b-instruct-v0.1",
|
||||
"@hf/thebloke/deepseek-coder-6.7b-base-awq",
|
||||
"@hf/thebloke/deepseek-coder-6.7b-instruct-awq",
|
||||
"@cf/deepseek-ai/deepseek-math-7b-base",
|
||||
"@cf/deepseek-ai/deepseek-math-7b-instruct",
|
||||
"@cf/thebloke/discolm-german-7b-v1-awq",
|
||||
"@cf/tiiuae/falcon-7b-instruct",
|
||||
"@cf/google/gemma-2b-it-lora",
|
||||
"@hf/google/gemma-7b-it",
|
||||
"@cf/google/gemma-7b-it-lora",
|
||||
"@hf/nousresearch/hermes-2-pro-mistral-7b",
|
||||
"@hf/thebloke/llama-2-13b-chat-awq",
|
||||
"@cf/meta-llama/llama-2-7b-chat-hf-lora",
|
||||
"@cf/meta/llama-3-8b-instruct",
|
||||
"@hf/thebloke/llamaguard-7b-awq",
|
||||
"@hf/thebloke/mistral-7b-instruct-v0.1-awq",
|
||||
"@hf/mistralai/mistral-7b-instruct-v0.2",
|
||||
"@cf/mistral/mistral-7b-instruct-v0.2-lora",
|
||||
"@hf/thebloke/neural-chat-7b-v3-1-awq",
|
||||
"@cf/openchat/openchat-3.5-0106",
|
||||
"@hf/thebloke/openhermes-2.5-mistral-7b-awq",
|
||||
"@cf/microsoft/phi-2",
|
||||
"@cf/qwen/qwen1.5-0.5b-chat",
|
||||
"@cf/qwen/qwen1.5-1.8b-chat",
|
||||
"@cf/qwen/qwen1.5-14b-chat-awq",
|
||||
"@cf/qwen/qwen1.5-7b-chat-awq",
|
||||
"@cf/defog/sqlcoder-7b-2",
|
||||
"@hf/nexusflow/starling-lm-7b-beta",
|
||||
"@cf/tinyllama/tinyllama-1.1b-chat-v1.0",
|
||||
"@hf/thebloke/zephyr-7b-beta-awq",
|
||||
}
|
||||
|
||||
var ChannelName = "cloudflare"
|
13
relay/channel/cloudflare/model.go
Normal file
13
relay/channel/cloudflare/model.go
Normal file
@ -0,0 +1,13 @@
|
||||
package cloudflare
|
||||
|
||||
import "one-api/dto"
|
||||
|
||||
type CfRequest struct {
|
||||
Messages []dto.Message `json:"messages,omitempty"`
|
||||
Lora string `json:"lora,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Prompt string `json:"prompt,omitempty"`
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
}
|
115
relay/channel/cloudflare/relay_cloudflare.go
Normal file
115
relay/channel/cloudflare/relay_cloudflare.go
Normal file
@ -0,0 +1,115 @@
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/json"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest {
|
||||
p, _ := textRequest.Prompt.(string)
|
||||
return &CfRequest{
|
||||
Prompt: p,
|
||||
MaxTokens: textRequest.GetMaxTokens(),
|
||||
Stream: textRequest.Stream,
|
||||
Temperature: textRequest.Temperature,
|
||||
}
|
||||
}
|
||||
|
||||
func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
||||
service.SetEventStreamHeaders(c)
|
||||
id := service.GetResponseID(c)
|
||||
var responseText string
|
||||
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < len("data: ") {
|
||||
continue
|
||||
}
|
||||
data = strings.TrimPrefix(data, "data: ")
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
|
||||
if data == "[DONE]" {
|
||||
break
|
||||
}
|
||||
|
||||
var response dto.ChatCompletionsStreamResponse
|
||||
err := json.Unmarshal([]byte(data), &response)
|
||||
if err != nil {
|
||||
common.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
|
||||
continue
|
||||
}
|
||||
for _, choice := range response.Choices {
|
||||
choice.Delta.Role = "assistant"
|
||||
responseText += choice.Delta.GetContentString()
|
||||
}
|
||||
response.Id = id
|
||||
response.Model = info.UpstreamModelName
|
||||
err = service.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.LogError(c, "error_rendering_stream_response: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
common.LogError(c, "error_scanning_stream_response: "+err.Error())
|
||||
}
|
||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
if info.ShouldIncludeUsage {
|
||||
response := service.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
||||
err := service.ObjectData(c, response)
|
||||
if err != nil {
|
||||
common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
|
||||
}
|
||||
}
|
||||
service.Done(c)
|
||||
|
||||
err := resp.Body.Close()
|
||||
if err != nil {
|
||||
common.LogError(c, "close_response_body_failed: "+err.Error())
|
||||
}
|
||||
|
||||
return nil, usage
|
||||
}
|
||||
|
||||
func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
var response dto.TextResponse
|
||||
err = json.Unmarshal(responseBody, &response)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
response.Model = info.UpstreamModelName
|
||||
var responseText string
|
||||
for _, choice := range response.Choices {
|
||||
responseText += choice.Message.StringContent()
|
||||
}
|
||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||
response.Usage = *usage
|
||||
response.Id = service.GetResponseID(c)
|
||||
jsonResponse, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
}
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, _ = c.Writer.Write(jsonResponse)
|
||||
return nil, usage
|
||||
}
|
@ -7,7 +7,7 @@ type CohereRequest struct {
|
||||
ChatHistory []ChatHistory `json:"chat_history"`
|
||||
Message string `json:"message"`
|
||||
Stream bool `json:"stream"`
|
||||
MaxTokens int64 `json:"max_tokens"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
}
|
||||
|
||||
type ChatHistory struct {
|
||||
|
@ -68,7 +68,8 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
||||
info.ApiVersion = GetAPIVersion(c)
|
||||
}
|
||||
if info.ChannelType == common.ChannelTypeOpenAI || info.ChannelType == common.ChannelTypeAnthropic ||
|
||||
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini {
|
||||
info.ChannelType == common.ChannelTypeAws || info.ChannelType == common.ChannelTypeGemini ||
|
||||
info.ChannelType == common.ChannelCloudflare {
|
||||
info.SupportStreamOptions = true
|
||||
}
|
||||
return info
|
||||
|
@ -22,6 +22,7 @@ const (
|
||||
APITypeCohere
|
||||
APITypeDify
|
||||
APITypeJina
|
||||
APITypeCloudflare
|
||||
|
||||
APITypeDummy // this one is only for count, do not add any channel after this
|
||||
)
|
||||
@ -63,6 +64,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
|
||||
apiType = APITypeDify
|
||||
case common.ChannelTypeJina:
|
||||
apiType = APITypeJina
|
||||
case common.ChannelCloudflare:
|
||||
apiType = APITypeCloudflare
|
||||
}
|
||||
if apiType == -1 {
|
||||
return APITypeOpenAI, false
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"one-api/relay/channel/aws"
|
||||
"one-api/relay/channel/baidu"
|
||||
"one-api/relay/channel/claude"
|
||||
"one-api/relay/channel/cloudflare"
|
||||
"one-api/relay/channel/cohere"
|
||||
"one-api/relay/channel/dify"
|
||||
"one-api/relay/channel/gemini"
|
||||
@ -59,6 +60,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
|
||||
return &dify.Adaptor{}
|
||||
case constant.APITypeJina:
|
||||
return &jina.Adaptor{}
|
||||
case constant.APITypeCloudflare:
|
||||
return &cloudflare.Adaptor{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -35,3 +35,8 @@ func ObjectData(c *gin.Context, object interface{}) error {
|
||||
func Done(c *gin.Context) {
|
||||
StringData(c, "[DONE]")
|
||||
}
|
||||
|
||||
func GetResponseID(c *gin.Context) string {
|
||||
logID := c.GetString("X-Oneapi-Request-Id")
|
||||
return fmt.Sprintf("chatcmpl-%s", logID)
|
||||
}
|
@ -99,6 +99,7 @@ export const CHANNEL_OPTIONS = [
|
||||
color: 'orange',
|
||||
label: 'Google PaLM2',
|
||||
},
|
||||
{ key: 39, text: 'Cloudflare', value: 39, color: 'grey', label: 'Cloudflare' },
|
||||
{ key: 25, text: 'Moonshot', value: 25, color: 'green', label: 'Moonshot' },
|
||||
{ key: 19, text: '360 智脑', value: 19, color: 'blue', label: '360 智脑' },
|
||||
{ key: 23, text: '腾讯混元', value: 23, color: 'teal', label: '腾讯混元' },
|
||||
|
@ -605,6 +605,24 @@ const EditChannel = (props) => {
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
{inputs.type === 39 && (
|
||||
<>
|
||||
<div style={{ marginTop: 10 }}>
|
||||
<Typography.Text strong>Account ID:</Typography.Text>
|
||||
</div>
|
||||
<Input
|
||||
name='other'
|
||||
placeholder={
|
||||
'请输入Account ID,例如:d6b5da8hk1awo8nap34ube6gh'
|
||||
}
|
||||
onChange={(value) => {
|
||||
handleInputChange('other', value);
|
||||
}}
|
||||
value={inputs.other}
|
||||
autoComplete='new-password'
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
<div style={{ marginTop: 10 }}>
|
||||
<Typography.Text strong>模型:</Typography.Text>
|
||||
</div>
|
||||
|
Loading…
Reference in New Issue
Block a user