feat: Support Cloudflare AI

This commit is contained in:
Martial BE
2024-04-16 18:08:56 +08:00
parent 5606a104f6
commit 344555418e
14 changed files with 606 additions and 50 deletions

View File

@@ -0,0 +1,86 @@
package cloudflareAI
import (
"encoding/json"
"fmt"
"net/http"
"one-api/common/requester"
"one-api/model"
"one-api/providers/base"
"one-api/types"
"strings"
)
type CloudflareAIProviderFactory struct{}
// 创建 CloudflareAIProvider
func (f CloudflareAIProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
cf := &CloudflareAIProvider{
BaseProvider: base.BaseProvider{
Config: getConfig(),
Channel: channel,
Requester: requester.NewHTTPRequester(*channel.Proxy, requestErrorHandle),
},
}
tokens := strings.Split(channel.Key, "|")
if len(tokens) == 2 {
cf.AccountID = tokens[0]
cf.CFToken = tokens[1]
}
return cf
}
type CloudflareAIProvider struct {
base.BaseProvider
AccountID string
CFToken string
}
func getConfig() base.ProviderConfig {
return base.ProviderConfig{
BaseURL: "https://api.cloudflare.com/client/v4/accounts/%s/ai/run/%s",
ImagesGenerations: "true",
ChatCompletions: "true",
AudioTranscriptions: "true",
}
}
// 请求错误处理
func requestErrorHandle(resp *http.Response) *types.OpenAIError {
CloudflareAIError := &CloudflareAIError{}
err := json.NewDecoder(resp.Body).Decode(CloudflareAIError)
if err != nil {
return nil
}
return errorHandle(CloudflareAIError)
}
// 错误处理
func errorHandle(CloudflareAIError *CloudflareAIError) *types.OpenAIError {
if CloudflareAIError.Success || len(CloudflareAIError.Error) == 0 {
return nil
}
return &types.OpenAIError{
Message: CloudflareAIError.Error[0].Message,
Type: "CloudflareAI error",
Code: CloudflareAIError.Error[0].Code,
}
}
// 获取请求头
func (p *CloudflareAIProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
p.CommonRequestHeaders(headers)
headers["Authorization"] = fmt.Sprintf("Bearer %s", p.CFToken)
return headers
}
func (p *CloudflareAIProvider) GetFullRequestURL(modelName string) string {
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
return fmt.Sprintf(baseURL, p.AccountID, modelName)
}

View File

@@ -0,0 +1,184 @@
package cloudflareAI
import (
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/types"
"strings"
)
type CloudflareAIStreamHandler struct {
Usage *types.Usage
Request *types.ChatCompletionRequest
}
func (p *CloudflareAIProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (*types.ChatCompletionResponse, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
chatResponse := &ChatRespone{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, chatResponse, false)
if errWithCode != nil {
return nil, errWithCode
}
return p.convertToChatOpenai(chatResponse, request)
}
func (p *CloudflareAIProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[string], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getChatRequest(request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
// 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil {
return nil, errWithCode
}
chatHandler := &CloudflareAIStreamHandler{
Usage: p.Usage,
Request: request,
}
return requester.RequestStream[string](p.Requester, resp, chatHandler.handlerStream)
}
func (p *CloudflareAIProvider) getChatRequest(request *types.ChatCompletionRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
// 获取请求地址
fullRequestURL := p.GetFullRequestURL(request.Model)
if fullRequestURL == "" {
return nil, common.ErrorWrapper(nil, "invalid_cloudflare_ai_config", http.StatusInternalServerError)
}
// 获取请求头
headers := p.GetRequestHeaders()
chatRequest := p.convertFromChatOpenai(request)
// 创建请求
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(chatRequest), p.Requester.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
return req, nil
}
func (p *CloudflareAIProvider) convertToChatOpenai(response *ChatRespone, request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
err := errorHandle(&response.CloudflareAIError)
if err != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: *err,
StatusCode: http.StatusBadRequest,
}
return
}
openaiResponse = &types.ChatCompletionResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
Model: request.Model,
Choices: []types.ChatCompletionChoice{{
Index: 0,
Message: types.ChatCompletionMessage{
Role: types.ChatMessageRoleAssistant,
Content: response.Result.Response,
},
FinishReason: types.FinishReasonStop,
}},
}
completionTokens := common.CountTokenText(response.Result.Response, request.Model)
p.Usage.CompletionTokens = completionTokens
p.Usage.TotalTokens = p.Usage.PromptTokens + completionTokens
openaiResponse.Usage = p.Usage
return
}
func (p *CloudflareAIProvider) convertFromChatOpenai(request *types.ChatCompletionRequest) *ChatRequest {
chatRequest := &ChatRequest{
Stream: request.Stream,
MaxTokens: request.MaxTokens,
Messages: make([]Message, 0, len(request.Messages)),
}
for _, message := range request.Messages {
chatRequest.Messages = append(chatRequest.Messages, Message{
Role: message.Role,
Content: message.StringContent(),
})
}
return chatRequest
}
// 转换为OpenAI聊天流式请求体
func (h *CloudflareAIStreamHandler) handlerStream(rawLine *[]byte, dataChan chan string, errChan chan error) {
// 如果rawLine 前缀不为data: 或者 meta:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data: ") {
*rawLine = nil
return
}
*rawLine = (*rawLine)[6:]
if strings.HasPrefix(string(*rawLine), "[DONE]") {
h.convertToOpenaiStream(nil, dataChan, true)
errChan <- io.EOF
*rawLine = requester.StreamClosed
return
}
chatResponse := &ChatResult{}
err := json.Unmarshal(*rawLine, chatResponse)
if err != nil {
errChan <- common.ErrorToOpenAIError(err)
return
}
h.convertToOpenaiStream(chatResponse, dataChan, false)
}
func (h *CloudflareAIStreamHandler) convertToOpenaiStream(chatResponse *ChatResult, dataChan chan string, isStop bool) {
streamResponse := types.ChatCompletionStreamResponse{
ID: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: h.Request.Model,
}
choice := types.ChatCompletionStreamChoice{
Index: 0,
Delta: types.ChatCompletionStreamChoiceDelta{
Role: types.ChatMessageRoleAssistant,
Content: "",
},
}
if isStop {
choice.FinishReason = types.FinishReasonStop
} else {
choice.Delta.Content = chatResponse.Response
h.Usage.CompletionTokens += common.CountTokenText(chatResponse.Response, h.Request.Model)
h.Usage.TotalTokens = h.Usage.PromptTokens + h.Usage.CompletionTokens
}
streamResponse.Choices = []types.ChatCompletionStreamChoice{choice}
responseBody, _ := json.Marshal(streamResponse)
dataChan <- string(responseBody)
}

View File

@@ -0,0 +1,62 @@
package cloudflareAI
import (
"encoding/base64"
"io"
"net/http"
"one-api/common"
"one-api/types"
"time"
)
func (p *CloudflareAIProvider) CreateImageGenerations(request *types.ImageRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) {
// 获取请求地址
fullRequestURL := p.GetFullRequestURL(request.Model)
if fullRequestURL == "" {
return nil, common.ErrorWrapper(nil, "invalid_cloudflare_ai_config", http.StatusInternalServerError)
}
// 获取请求头
headers := p.GetRequestHeaders()
cfRequest := convertFromIamgeOpenai(request)
// 创建请求
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(cfRequest), p.Requester.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
defer req.Body.Close()
resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil {
return nil, errWithCode
}
defer resp.Body.Close()
if resp.Header.Get("Content-Type") != "image/png" {
return nil, common.StringErrorWrapper("invalid_image_response", "invalid_image_response", http.StatusInternalServerError)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, common.ErrorWrapper(err, "read_response_failed", http.StatusInternalServerError)
}
base64Image := base64.StdEncoding.EncodeToString(body)
openaiResponse := &types.ImageResponse{
Created: time.Now().Unix(),
Data: []types.ImageResponseDataInner{{B64JSON: base64Image}},
}
p.Usage.PromptTokens = 1000
return openaiResponse, nil
}
func convertFromIamgeOpenai(request *types.ImageRequest) *ImageRequest {
return &ImageRequest{
Prompt: request.Prompt,
}
}

View File

@@ -0,0 +1,94 @@
package cloudflareAI
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/types"
)
func (p *CloudflareAIProvider) CreateTranscriptions(request *types.AudioRequest) (*types.AudioResponseWrapper, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getRequestAudioBody(request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
var resp *http.Response
var err error
audioResponse := &AudioResponse{}
resp, errWithCode = p.Requester.SendRequest(req, audioResponse, false)
if errWithCode != nil {
return nil, errWithCode
}
errWithOP := errorHandle(&audioResponse.CloudflareAIError)
if errWithOP != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: *errWithOP,
StatusCode: http.StatusBadRequest,
}
return nil, errWithCode
}
chatResult := audioResponse.Result
audioResponseWrapper := &types.AudioResponseWrapper{}
audioResponseWrapper.Headers = map[string]string{
"Content-Type": resp.Header.Get("Content-Type"),
}
audioResponseWrapper.Body, err = json.Marshal(&chatResult)
if err != nil {
return nil, common.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
completionTokens := common.CountTokenText(chatResult.Text, request.Model)
p.Usage.CompletionTokens = completionTokens
p.Usage.TotalTokens = p.Usage.PromptTokens + p.Usage.CompletionTokens
return audioResponseWrapper, nil
}
func (p *CloudflareAIProvider) getRequestAudioBody(ModelName string, request *types.AudioRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
// 获取请求地址
fullRequestURL := p.GetFullRequestURL(ModelName)
// 获取请求头
headers := p.GetRequestHeaders()
// 创建请求
var req *http.Request
var err error
var formBody bytes.Buffer
builder := p.Requester.CreateFormBuilder(&formBody)
if err := audioMultipartForm(request, builder); err != nil {
return nil, common.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError)
}
req, err = p.Requester.NewRequest(
http.MethodPost,
fullRequestURL,
p.Requester.WithBody(&formBody),
p.Requester.WithHeader(headers),
p.Requester.WithContentType(builder.FormDataContentType()))
req.ContentLength = int64(formBody.Len())
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
return req, nil
}
func audioMultipartForm(request *types.AudioRequest, b requester.FormBuilder) error {
err := b.CreateFormFile("file", request.File)
if err != nil {
return fmt.Errorf("creating form file: %w", err)
}
return b.Close()
}

View File

@@ -0,0 +1,60 @@
package cloudflareAI
import "one-api/types"
type CloudflareAIError struct {
Error []struct {
Code int `json:"code"`
Message string `json:"message"`
} `json:"errors,omitempty"`
Success bool `json:"success"`
}
type ImageRequest struct {
Prompt string `json:"prompt"`
Image interface{} `json:"image,omitempty"` // 可以是 string 或者 ImageObject
Mask interface{} `json:"mask,omitempty"` // 可以是 string 或者 MaskObject
NumSteps int `json:"num_steps,omitempty"`
Strength float64 `json:"strength,omitempty"`
Guidance float64 `json:"guidance,omitempty"`
}
type ImageObject struct {
Image []float64 `json:"image"`
}
type MaskObject struct {
Mask []float64 `json:"mask"`
}
type ChatRequest struct {
Messages []Message `json:"messages"`
Stream bool `json:"stream,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ChatRespone struct {
Result ChatResult `json:"result,omitempty"`
CloudflareAIError
}
type ChatResult struct {
Response string `json:"response"`
}
type AudioResponse struct {
Result AudioResult `json:"result,omitempty"`
CloudflareAIError
}
type AudioResult struct {
Text string `json:"text,omitempty"`
WordCount int `json:"word_count,omitempty"`
Words []types.AudioWordsList `json:"words,omitempty"`
Vtt string `json:"vtt,omitempty"`
}

View File

@@ -11,6 +11,7 @@ import (
"one-api/providers/base"
"one-api/providers/bedrock"
"one-api/providers/claude"
"one-api/providers/cloudflareAI"
"one-api/providers/deepseek"
"one-api/providers/gemini"
"one-api/providers/groq"
@@ -54,6 +55,7 @@ func init() {
providerFactories[common.ChannelTypeGroq] = groq.GroqProviderFactory{}
providerFactories[common.ChannelTypeBedrock] = bedrock.BedrockProviderFactory{}
providerFactories[common.ChannelTypeMidjourney] = midjourney.MidjourneyProviderFactory{}
providerFactories[common.ChannelTypeCloudflareAI] = cloudflareAI.CloudflareAIProviderFactory{}
}