♻️ refactor: provider refactor (#41)

* ♻️ refactor: provider refactor
* 完善百度/讯飞的函数调用,现在可以在`lobe-chat`中正常调用函数了
This commit is contained in:
Buer
2024-01-19 02:47:10 +08:00
committed by GitHub
parent 0bfe1f5779
commit ef041e28a1
96 changed files with 4339 additions and 3276 deletions

View File

@@ -3,7 +3,6 @@ package openai
import (
"errors"
"fmt"
"one-api/common"
"one-api/model"
"time"
)
@@ -16,15 +15,14 @@ func (p *OpenAIProvider) Balance(channel *model.Channel) (float64, error) {
fullRequestURL := p.GetFullRequestURL("/v1/dashboard/billing/subscription", "")
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
req, err := p.Requester.NewRequest("GET", fullRequestURL, p.Requester.WithHeader(headers))
if err != nil {
return 0, err
}
// 发送请求
var subscription OpenAISubscriptionResponse
_, errWithCode := common.SendRequest(req, &subscription, false, p.Channel.Proxy)
_, errWithCode := p.Requester.SendRequest(req, &subscription, false)
if errWithCode != nil {
return 0, errors.New(errWithCode.OpenAIError.Message)
}
@@ -37,12 +35,15 @@ func (p *OpenAIProvider) Balance(channel *model.Channel) (float64, error) {
}
fullRequestURL = p.GetFullRequestURL(fmt.Sprintf("/v1/dashboard/billing/usage?start_date=%s&end_date=%s", startDate, endDate), "")
req, err = client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
req, err = p.Requester.NewRequest("GET", fullRequestURL, p.Requester.WithHeader(headers))
if err != nil {
return 0, err
}
usage := OpenAIUsageResponse{}
_, errWithCode = common.SendRequest(req, &usage, false, p.Channel.Proxy)
_, errWithCode = p.Requester.SendRequest(req, &usage, false)
if errWithCode != nil {
return 0, errWithCode
}
balance := subscription.HardLimitUSD - usage.TotalUsage/100
channel.UpdateBalance(balance)

View File

@@ -1,69 +1,100 @@
package openai
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/model"
"one-api/types"
"strings"
"one-api/providers/base"
"github.com/gin-gonic/gin"
)
type OpenAIProviderFactory struct{}
// 创建 OpenAIProvider
func (f OpenAIProviderFactory) Create(c *gin.Context) base.ProviderInterface {
openAIProvider := CreateOpenAIProvider(c, "")
openAIProvider.BalanceAction = true
return openAIProvider
}
type OpenAIProvider struct {
base.BaseProvider
IsAzure bool
BalanceAction bool
}
// 创建 OpenAIProvider
func (f OpenAIProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
openAIProvider := CreateOpenAIProvider(channel, "https://api.openai.com")
openAIProvider.BalanceAction = true
return openAIProvider
}
// 创建 OpenAIProvider
// https://platform.openai.com/docs/api-reference/introduction
func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider {
if baseURL == "" {
baseURL = "https://api.openai.com"
}
func CreateOpenAIProvider(channel *model.Channel, baseURL string) *OpenAIProvider {
config := getOpenAIConfig(baseURL)
return &OpenAIProvider{
BaseProvider: base.BaseProvider{
BaseURL: baseURL,
Completions: "/v1/completions",
ChatCompletions: "/v1/chat/completions",
Embeddings: "/v1/embeddings",
Moderation: "/v1/moderations",
AudioSpeech: "/v1/audio/speech",
AudioTranscriptions: "/v1/audio/transcriptions",
AudioTranslations: "/v1/audio/translations",
ImagesGenerations: "/v1/images/generations",
ImagesEdit: "/v1/images/edits",
ImagesVariations: "/v1/images/variations",
Context: c,
Config: config,
Channel: channel,
Requester: requester.NewHTTPRequester(channel.Proxy, RequestErrorHandle),
},
IsAzure: false,
BalanceAction: true,
}
}
func getOpenAIConfig(baseURL string) base.ProviderConfig {
return base.ProviderConfig{
BaseURL: baseURL,
Completions: "/v1/completions",
ChatCompletions: "/v1/chat/completions",
Embeddings: "/v1/embeddings",
Moderation: "/v1/moderations",
AudioSpeech: "/v1/audio/speech",
AudioTranscriptions: "/v1/audio/transcriptions",
AudioTranslations: "/v1/audio/translations",
ImagesGenerations: "/v1/images/generations",
ImagesEdit: "/v1/images/edits",
ImagesVariations: "/v1/images/variations",
}
}
// 请求错误处理
func RequestErrorHandle(resp *http.Response) *types.OpenAIError {
var errorResponse *types.OpenAIErrorResponse
err := json.NewDecoder(resp.Body).Decode(errorResponse)
if err != nil {
return nil
}
return ErrorHandle(errorResponse)
}
// 错误处理
func ErrorHandle(openaiError *types.OpenAIErrorResponse) *types.OpenAIError {
if openaiError.Error.Message == "" {
return nil
}
return &openaiError.Error
}
// 获取完整请求 URL
func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string) string {
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
if p.IsAzure {
apiVersion := p.Channel.Other
// 以-分割检测modelName 最后一个元素是否为4位数字,必须是数字如果是则删除modelName最后一个元素
modelNameSlice := strings.Split(modelName, "-")
lastModelNameSlice := modelNameSlice[len(modelNameSlice)-1]
modelNum := common.String2Int(lastModelNameSlice)
if modelNum > 999 && modelNum < 10000 {
modelName = strings.TrimSuffix(modelName, "-"+lastModelNameSlice)
}
// 检测模型是是否包含 . 如果有则直接去掉
modelName = strings.Replace(modelName, ".", "", -1)
if modelName == "dall-e-2" {
// 因为dall-e-3需要api-version=2023-12-01-preview但是该版本
// 已经没有dall-e-2了所以暂时写死
@@ -72,10 +103,6 @@ func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string)
requestURL = fmt.Sprintf("/openai/deployments/%s%s?api-version=%s", modelName, requestURL, apiVersion)
}
// 检测模型是是否包含 . 如果有则直接去掉
if strings.Contains(requestURL, ".") {
requestURL = strings.Replace(requestURL, ".", "", -1)
}
}
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
@@ -102,89 +129,21 @@ func (p *OpenAIProvider) GetRequestHeaders() (headers map[string]string) {
return headers
}
// 获取请求体
func (p *OpenAIProvider) GetRequestBody(request any, isModelMapped bool) (requestBody io.Reader, err error) {
if isModelMapped {
jsonStr, err := json.Marshal(request)
if err != nil {
return nil, err
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = p.Context.Request.Body
func (p *OpenAIProvider) GetRequestTextBody(relayMode int, ModelName string, request any) (*http.Request, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(relayMode)
if errWithCode != nil {
return nil, errWithCode
}
return
}
// 获取请求地址
fullRequestURL := p.GetFullRequestURL(url, ModelName)
// 发送流式请求
func (p *OpenAIProvider) SendStreamRequest(req *http.Request, response OpenAIProviderStreamResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, responseText string) {
defer req.Body.Close()
client := common.GetHttpClient(p.Channel.Proxy)
resp, err := client.Do(req)
// 获取请求
headers := p.GetRequestHeaders()
// 创建请求
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(request), p.Requester.WithHeader(headers))
if err != nil {
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
}
common.PutHttpClient(client)
if common.IsFailureStatusCode(resp) {
return common.HandleErrorResp(resp), ""
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format
continue
}
if data[:6] != "data: " && data[:6] != "[DONE]" {
continue
}
dataChan <- data
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
err := json.Unmarshal([]byte(data), response)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
continue // just ignore the error
}
responseText += response.responseStreamHandler()
}
}
stopChan <- true
}()
common.SetEventStreamHeaders(p.Context)
p.Context.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if strings.HasPrefix(data, "data: [DONE]") {
data = data[:12]
}
// some implementations may add \r at the end of data
data = strings.TrimSuffix(data, "\r")
p.Context.Render(-1, common.CustomEvent{Data: data})
return true
case <-stopChan:
return false
}
})
return nil, responseText
return req, nil
}

View File

@@ -1,82 +1,101 @@
package openai
import (
"encoding/json"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/types"
"strings"
)
func (c *OpenAIProviderChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if c.Error.Type != "" {
type OpenAIStreamHandler struct {
Usage *types.Usage
ModelName string
}
func (p *OpenAIProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
response := &OpenAIProviderChatResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, response, false)
if errWithCode != nil {
return nil, errWithCode
}
// 检测是否错误
openaiErr := ErrorHandle(&response.OpenAIErrorResponse)
if openaiErr != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: c.Error,
StatusCode: resp.StatusCode,
OpenAIError: *openaiErr,
StatusCode: http.StatusBadRequest,
}
return
}
return nil, nil
}
func (c *OpenAIProviderChatStreamResponse) responseStreamHandler() (responseText string) {
for _, choice := range c.Choices {
responseText += choice.Delta.Content
return nil, errWithCode
}
return
*p.Usage = *response.Usage
return &response.ChatCompletionResponse, nil
}
func (p *OpenAIProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody, err := p.GetRequestBody(&request, isModelMapped)
func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
// 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil {
return nil, errWithCode
}
chatHandler := OpenAIStreamHandler{
Usage: p.Usage,
ModelName: request.Model,
}
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.HandlerChatStream)
}
func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
// 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data: ") {
*rawLine = nil
return nil
}
// 去除前缀
*rawLine = (*rawLine)[6:]
// 如果等于 DONE 则结束
if string(*rawLine) == "[DONE]" {
*isFinished = true
return nil
}
var openaiResponse OpenAIProviderChatStreamResponse
err := json.Unmarshal(*rawLine, &openaiResponse)
if err != nil {
return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
return common.ErrorToOpenAIError(err)
}
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
headers := p.GetRequestHeaders()
if request.Stream && headers["Accept"] == "" {
headers["Accept"] = "text/event-stream"
error := ErrorHandle(&openaiResponse.OpenAIErrorResponse)
if error != nil {
return error
}
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
h.Usage.CompletionTokens += countTokenText
h.Usage.TotalTokens += countTokenText
if request.Stream {
openAIProviderChatStreamResponse := &OpenAIProviderChatStreamResponse{}
var textResponse string
errWithCode, textResponse = p.SendStreamRequest(req, openAIProviderChatStreamResponse)
if errWithCode != nil {
return
}
*response = append(*response, openaiResponse.ChatCompletionStreamResponse)
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: common.CountTokenText(textResponse, request.Model),
TotalTokens: promptTokens + common.CountTokenText(textResponse, request.Model),
}
} else {
openAIProviderChatResponse := &OpenAIProviderChatResponse{}
errWithCode = p.SendRequest(req, openAIProviderChatResponse, true)
if errWithCode != nil {
return
}
usage = openAIProviderChatResponse.Usage
if usage.TotalTokens == 0 {
completionTokens := 0
for _, choice := range openAIProviderChatResponse.Choices {
completionTokens += common.CountTokenText(choice.Message.StringContent(), openAIProviderChatResponse.Model)
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
}
}
return
return nil
}

View File

@@ -1,82 +1,96 @@
package openai
import (
"encoding/json"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/types"
"strings"
)
func (c *OpenAIProviderCompletionResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if c.Error.Type != "" {
func (p *OpenAIProvider) CreateCompletion(request *types.CompletionRequest) (openaiResponse *types.CompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.GetRequestTextBody(common.RelayModeCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
response := &OpenAIProviderCompletionResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, response, false)
if errWithCode != nil {
return nil, errWithCode
}
// 检测是否错误
openaiErr := ErrorHandle(&response.OpenAIErrorResponse)
if openaiErr != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: c.Error,
StatusCode: resp.StatusCode,
OpenAIError: *openaiErr,
StatusCode: http.StatusBadRequest,
}
return
}
return nil, nil
}
func (c *OpenAIProviderCompletionResponse) responseStreamHandler() (responseText string) {
for _, choice := range c.Choices {
responseText += choice.Text
return nil, errWithCode
}
return
*p.Usage = *response.Usage
return &response.CompletionResponse, nil
}
func (p *OpenAIProvider) CompleteAction(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody, err := p.GetRequestBody(&request, isModelMapped)
func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest) (stream requester.StreamReaderInterface[types.CompletionResponse], errWithCode *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
// 发送请求
resp, errWithCode := p.Requester.SendRequestRaw(req)
if errWithCode != nil {
return nil, errWithCode
}
chatHandler := OpenAIStreamHandler{
Usage: p.Usage,
ModelName: request.Model,
}
return requester.RequestStream[types.CompletionResponse](p.Requester, resp, chatHandler.handlerCompletionStream)
}
func (h *OpenAIStreamHandler) handlerCompletionStream(rawLine *[]byte, isFinished *bool, response *[]types.CompletionResponse) error {
// 如果rawLine 前缀不为data:,则直接返回
if !strings.HasPrefix(string(*rawLine), "data: ") {
*rawLine = nil
return nil
}
// 去除前缀
*rawLine = (*rawLine)[6:]
// 如果等于 DONE 则结束
if string(*rawLine) == "[DONE]" {
*isFinished = true
return nil
}
var openaiResponse OpenAIProviderCompletionResponse
err := json.Unmarshal(*rawLine, &openaiResponse)
if err != nil {
return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
return common.ErrorToOpenAIError(err)
}
fullRequestURL := p.GetFullRequestURL(p.Completions, request.Model)
headers := p.GetRequestHeaders()
if request.Stream && headers["Accept"] == "" {
headers["Accept"] = "text/event-stream"
error := ErrorHandle(&openaiResponse.OpenAIErrorResponse)
if error != nil {
return error
}
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
h.Usage.CompletionTokens += countTokenText
h.Usage.TotalTokens += countTokenText
openAIProviderCompletionResponse := &OpenAIProviderCompletionResponse{}
if request.Stream {
// TODO
var textResponse string
errWithCode, textResponse = p.SendStreamRequest(req, openAIProviderCompletionResponse)
if errWithCode != nil {
return
}
*response = append(*response, openaiResponse.CompletionResponse)
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: common.CountTokenText(textResponse, request.Model),
TotalTokens: promptTokens + common.CountTokenText(textResponse, request.Model),
}
} else {
errWithCode = p.SendRequest(req, openAIProviderCompletionResponse, true)
if errWithCode != nil {
return
}
usage = openAIProviderCompletionResponse.Usage
if usage.TotalTokens == 0 {
completionTokens := 0
for _, choice := range openAIProviderCompletionResponse.Choices {
completionTokens += common.CountTokenText(choice.Text, openAIProviderCompletionResponse.Model)
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
}
}
return
return nil
}

View File

@@ -6,40 +6,30 @@ import (
"one-api/types"
)
func (c *OpenAIProviderEmbeddingsResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if c.Error.Type != "" {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: c.Error,
StatusCode: resp.StatusCode,
}
return
}
return nil, nil
}
func (p *OpenAIProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody, err := p.GetRequestBody(&request, isModelMapped)
if err != nil {
return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model)
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
openAIProviderEmbeddingsResponse := &OpenAIProviderEmbeddingsResponse{}
errWithCode = p.SendRequest(req, openAIProviderEmbeddingsResponse, true)
func (p *OpenAIProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.GetRequestTextBody(common.RelayModeEmbeddings, request.Model, request)
if errWithCode != nil {
return
return nil, errWithCode
}
defer req.Body.Close()
response := &OpenAIProviderEmbeddingsResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, response, false)
if errWithCode != nil {
return nil, errWithCode
}
usage = openAIProviderEmbeddingsResponse.Usage
openaiErr := ErrorHandle(&response.OpenAIErrorResponse)
if openaiErr != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: *openaiErr,
StatusCode: http.StatusBadRequest,
}
return nil, errWithCode
}
return
*p.Usage = *response.Usage
return &response.EmbeddingResponse, nil
}

View File

@@ -5,28 +5,71 @@ import (
"fmt"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/types"
)
func (p *OpenAIProvider) ImageEditsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
fullRequestURL := p.GetFullRequestURL(p.ImagesEdit, request.Model)
func (p *OpenAIProvider) CreateImageEdits(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getRequestImageBody(common.RelayModeEdits, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
response := &OpenAIProviderImageResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, response, false)
if errWithCode != nil {
return nil, errWithCode
}
openaiErr := ErrorHandle(&response.OpenAIErrorResponse)
if openaiErr != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: *openaiErr,
StatusCode: http.StatusBadRequest,
}
return nil, errWithCode
}
p.Usage.TotalTokens = p.Usage.PromptTokens
return &response.ImageResponse, nil
}
func (p *OpenAIProvider) getRequestImageBody(relayMode int, ModelName string, request *types.ImageEditRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(relayMode)
if errWithCode != nil {
return nil, errWithCode
}
// 获取请求地址
fullRequestURL := p.GetFullRequestURL(url, ModelName)
// 获取请求头
headers := p.GetRequestHeaders()
client := common.NewClient()
var formBody bytes.Buffer
// 创建请求
var req *http.Request
var err error
if isModelMapped {
builder := client.CreateFormBuilder(&formBody)
if p.OriginalModel != request.Model {
var formBody bytes.Buffer
builder := p.Requester.CreateFormBuilder(&formBody)
if err := imagesEditsMultipartForm(request, builder); err != nil {
return nil, common.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError)
}
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType()))
req, err = p.Requester.NewRequest(
http.MethodPost,
fullRequestURL,
p.Requester.WithBody(&formBody),
p.Requester.WithHeader(headers),
p.Requester.WithContentType(builder.FormDataContentType()))
req.ContentLength = int64(formBody.Len())
} else {
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type")))
req, err = p.Requester.NewRequest(
http.MethodPost,
fullRequestURL,
p.Requester.WithBody(p.Context.Request.Body),
p.Requester.WithHeader(headers),
p.Requester.WithContentType(p.Context.Request.Header.Get("Content-Type")))
req.ContentLength = p.Context.Request.ContentLength
}
@@ -34,22 +77,10 @@ func (p *OpenAIProvider) ImageEditsAction(request *types.ImageEditRequest, isMod
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
openAIProviderImageResponseResponse := &OpenAIProviderImageResponseResponse{}
errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true)
if errWithCode != nil {
return
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: 0,
TotalTokens: promptTokens,
}
return
return req, nil
}
func imagesEditsMultipartForm(request *types.ImageEditRequest, b common.FormBuilder) error {
func imagesEditsMultipartForm(request *types.ImageEditRequest, b requester.FormBuilder) error {
err := b.CreateFormFile("image", request.Image)
if err != nil {
return fmt.Errorf("creating form image: %w", err)

View File

@@ -6,53 +6,41 @@ import (
"one-api/types"
)
func (c *OpenAIProviderImageResponseResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if c.Error.Type != "" {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: c.Error,
StatusCode: resp.StatusCode,
}
return
}
return nil, nil
}
func (p *OpenAIProvider) ImageGenerationsAction(request *types.ImageRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
if !isWithinRange(request.Model, request.N) {
func (p *OpenAIProvider) CreateImageGenerations(request *types.ImageRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) {
if !IsWithinRange(request.Model, request.N) {
return nil, common.StringErrorWrapper("n_not_within_range", "n_not_within_range", http.StatusBadRequest)
}
requestBody, err := p.GetRequestBody(&request, isModelMapped)
if err != nil {
return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
fullRequestURL := p.GetFullRequestURL(p.ImagesGenerations, request.Model)
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
openAIProviderImageResponseResponse := &OpenAIProviderImageResponseResponse{}
errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true)
req, errWithCode := p.GetRequestTextBody(common.RelayModeImagesGenerations, request.Model, request)
if errWithCode != nil {
return
return nil, errWithCode
}
defer req.Body.Close()
response := &OpenAIProviderImageResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, response, false)
if errWithCode != nil {
return nil, errWithCode
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: 0,
TotalTokens: promptTokens,
// 检测是否错误
openaiErr := ErrorHandle(&response.OpenAIErrorResponse)
if openaiErr != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: *openaiErr,
StatusCode: http.StatusBadRequest,
}
return nil, errWithCode
}
return
p.Usage.TotalTokens = p.Usage.PromptTokens
return &response.ImageResponse, nil
}
func isWithinRange(element string, value int) bool {
func IsWithinRange(element string, value int) bool {
if _, ok := common.DalleGenerationImageAmounts[element]; !ok {
return false
}

View File

@@ -1,49 +1,35 @@
package openai
import (
"bytes"
"net/http"
"one-api/common"
"one-api/types"
)
func (p *OpenAIProvider) ImageVariationsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
fullRequestURL := p.GetFullRequestURL(p.ImagesVariations, request.Model)
headers := p.GetRequestHeaders()
client := common.NewClient()
var formBody bytes.Buffer
var req *http.Request
var err error
if isModelMapped {
builder := client.CreateFormBuilder(&formBody)
if err := imagesEditsMultipartForm(request, builder); err != nil {
return nil, common.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError)
}
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType()))
req.ContentLength = int64(formBody.Len())
} else {
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type")))
req.ContentLength = p.Context.Request.ContentLength
}
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
openAIProviderImageResponseResponse := &OpenAIProviderImageResponseResponse{}
errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true)
func (p *OpenAIProvider) CreateImageVariations(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getRequestImageBody(common.RelayModeImagesVariations, request.Model, request)
if errWithCode != nil {
return
return nil, errWithCode
}
defer req.Body.Close()
response := &OpenAIProviderImageResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, response, false)
if errWithCode != nil {
return nil, errWithCode
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: 0,
TotalTokens: promptTokens,
openaiErr := ErrorHandle(&response.OpenAIErrorResponse)
if openaiErr != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: *openaiErr,
StatusCode: http.StatusBadRequest,
}
return nil, errWithCode
}
return
p.Usage.TotalTokens = p.Usage.PromptTokens
return &response.ImageResponse, nil
}

View File

@@ -6,44 +6,31 @@ import (
"one-api/types"
)
func (c *OpenAIProviderModerationResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if c.Error.Type != "" {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: c.Error,
StatusCode: resp.StatusCode,
}
return
}
return nil, nil
}
func (p *OpenAIProvider) CreateModeration(request *types.ModerationRequest) (*types.ModerationResponse, *types.OpenAIErrorWithStatusCode) {
func (p *OpenAIProvider) ModerationAction(request *types.ModerationRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody, err := p.GetRequestBody(&request, isModelMapped)
if err != nil {
return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
fullRequestURL := p.GetFullRequestURL(p.Moderation, request.Model)
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
openAIProviderModerationResponse := &OpenAIProviderModerationResponse{}
errWithCode = p.SendRequest(req, openAIProviderModerationResponse, true)
req, errWithCode := p.GetRequestTextBody(common.RelayModeModerations, request.Model, request)
if errWithCode != nil {
return
return nil, errWithCode
}
defer req.Body.Close()
response := &OpenAIProviderModerationResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, response, false)
if errWithCode != nil {
return nil, errWithCode
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: 0,
TotalTokens: promptTokens,
openaiErr := ErrorHandle(&response.OpenAIErrorResponse)
if openaiErr != nil {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: *openaiErr,
StatusCode: http.StatusBadRequest,
}
return nil, errWithCode
}
return
p.Usage.TotalTokens = p.Usage.PromptTokens
return &response.ModerationResponse, nil
}

View File

@@ -3,35 +3,29 @@ package openai
import (
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/types"
)
func (p *OpenAIProvider) SpeechAction(request *types.SpeechAudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody, err := p.GetRequestBody(&request, isModelMapped)
if err != nil {
return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
fullRequestURL := p.GetFullRequestURL(p.AudioSpeech, request.Model)
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
errWithCode = p.SendRequestRaw(req)
func (p *OpenAIProvider) CreateSpeech(request *types.SpeechAudioRequest) (*http.Response, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.GetRequestTextBody(common.RelayModeAudioSpeech, request.Model, request)
if errWithCode != nil {
return
return nil, errWithCode
}
defer req.Body.Close()
// 发送请求
var resp *http.Response
resp, errWithCode = p.Requester.SendRequestRaw(req)
if errWithCode != nil {
return nil, errWithCode
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: 0,
TotalTokens: promptTokens,
if resp.Header.Get("Content-Type") == "application/json" {
return nil, requester.HandleErrorResp(resp, p.Requester.ErrorHandler)
}
return
p.Usage.TotalTokens = p.Usage.PromptTokens
return resp, nil
}

View File

@@ -4,48 +4,99 @@ import (
"bufio"
"bytes"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/types"
"regexp"
"strconv"
"strings"
)
func (c *OpenAIProviderTranscriptionsResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
if c.Error.Type != "" {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: c.Error,
StatusCode: resp.StatusCode,
}
return
func (p *OpenAIProvider) CreateTranscriptions(request *types.AudioRequest) (*types.AudioResponseWrapper, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getRequestAudioBody(common.RelayModeAudioTranscription, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
return nil, nil
defer req.Body.Close()
var textResponse string
var resp *http.Response
var err error
audioResponseWrapper := &types.AudioResponseWrapper{}
if hasJSONResponse(request) {
openAIProviderTranscriptionsResponse := &OpenAIProviderTranscriptionsResponse{}
resp, errWithCode = p.Requester.SendRequest(req, openAIProviderTranscriptionsResponse, true)
if errWithCode != nil {
return nil, errWithCode
}
textResponse = openAIProviderTranscriptionsResponse.Text
} else {
openAIProviderTranscriptionsTextResponse := new(OpenAIProviderTranscriptionsTextResponse)
resp, errWithCode = p.Requester.SendRequest(req, openAIProviderTranscriptionsTextResponse, true)
if errWithCode != nil {
return nil, errWithCode
}
textResponse = getTextContent(*openAIProviderTranscriptionsTextResponse.GetString(), request.ResponseFormat)
}
defer resp.Body.Close()
audioResponseWrapper.Headers = map[string]string{
"Content-Type": resp.Header.Get("Content-Type"),
}
audioResponseWrapper.Body, err = io.ReadAll(resp.Body)
if err != nil {
return nil, common.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
completionTokens := common.CountTokenText(textResponse, request.Model)
p.Usage.CompletionTokens = completionTokens
p.Usage.TotalTokens = p.Usage.PromptTokens + p.Usage.CompletionTokens
return audioResponseWrapper, nil
}
func (c *OpenAIProviderTranscriptionsTextResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
return nil, nil
func hasJSONResponse(request *types.AudioRequest) bool {
return request.ResponseFormat == "" || request.ResponseFormat == "json" || request.ResponseFormat == "verbose_json"
}
func (p *OpenAIProvider) TranscriptionsAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
fullRequestURL := p.GetFullRequestURL(p.AudioTranscriptions, request.Model)
func (p *OpenAIProvider) getRequestAudioBody(relayMode int, ModelName string, request *types.AudioRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(relayMode)
if errWithCode != nil {
return nil, errWithCode
}
// 获取请求地址
fullRequestURL := p.GetFullRequestURL(url, ModelName)
// 获取请求头
headers := p.GetRequestHeaders()
client := common.NewClient()
var formBody bytes.Buffer
// 创建请求
var req *http.Request
var err error
if isModelMapped {
builder := client.CreateFormBuilder(&formBody)
if p.OriginalModel != request.Model {
var formBody bytes.Buffer
builder := p.Requester.CreateFormBuilder(&formBody)
if err := audioMultipartForm(request, builder); err != nil {
return nil, common.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError)
}
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType()))
req, err = p.Requester.NewRequest(
http.MethodPost,
fullRequestURL,
p.Requester.WithBody(&formBody),
p.Requester.WithHeader(headers),
p.Requester.WithContentType(builder.FormDataContentType()))
req.ContentLength = int64(formBody.Len())
} else {
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type")))
req, err = p.Requester.NewRequest(
http.MethodPost,
fullRequestURL,
p.Requester.WithBody(p.Context.Request.Body),
p.Requester.WithHeader(headers),
p.Requester.WithContentType(p.Context.Request.Header.Get("Content-Type")))
req.ContentLength = p.Context.Request.ContentLength
}
@@ -53,37 +104,10 @@ func (p *OpenAIProvider) TranscriptionsAction(request *types.AudioRequest, isMod
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
var textResponse string
if hasJSONResponse(request) {
openAIProviderTranscriptionsResponse := &OpenAIProviderTranscriptionsResponse{}
errWithCode = p.SendRequest(req, openAIProviderTranscriptionsResponse, true)
if errWithCode != nil {
return
}
textResponse = openAIProviderTranscriptionsResponse.Text
} else {
openAIProviderTranscriptionsTextResponse := new(OpenAIProviderTranscriptionsTextResponse)
errWithCode = p.SendRequest(req, openAIProviderTranscriptionsTextResponse, true)
if errWithCode != nil {
return
}
textResponse = getTextContent(*openAIProviderTranscriptionsTextResponse.GetString(), request.ResponseFormat)
}
completionTokens := common.CountTokenText(textResponse, request.Model)
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
return
return req, nil
}
func hasJSONResponse(request *types.AudioRequest) bool {
return request.ResponseFormat == "" || request.ResponseFormat == "json" || request.ResponseFormat == "verbose_json"
}
func audioMultipartForm(request *types.AudioRequest, b common.FormBuilder) error {
func audioMultipartForm(request *types.AudioRequest, b requester.FormBuilder) error {
err := b.CreateFormFile("file", request.File)
if err != nil {

View File

@@ -1,60 +1,53 @@
package openai
import (
"bytes"
"io"
"net/http"
"one-api/common"
"one-api/types"
)
func (p *OpenAIProvider) TranslationAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
fullRequestURL := p.GetFullRequestURL(p.AudioTranslations, request.Model)
headers := p.GetRequestHeaders()
client := common.NewClient()
var formBody bytes.Buffer
var req *http.Request
var err error
if isModelMapped {
builder := client.CreateFormBuilder(&formBody)
if err := audioMultipartForm(request, builder); err != nil {
return nil, common.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError)
}
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType()))
req.ContentLength = int64(formBody.Len())
} else {
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type")))
req.ContentLength = p.Context.Request.ContentLength
}
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
func (p *OpenAIProvider) CreateTranslation(request *types.AudioRequest) (*types.AudioResponseWrapper, *types.OpenAIErrorWithStatusCode) {
req, errWithCode := p.getRequestAudioBody(common.RelayModeAudioTranslation, request.Model, request)
if errWithCode != nil {
return nil, errWithCode
}
defer req.Body.Close()
var textResponse string
var resp *http.Response
var err error
audioResponseWrapper := &types.AudioResponseWrapper{}
if hasJSONResponse(request) {
openAIProviderTranscriptionsResponse := &OpenAIProviderTranscriptionsResponse{}
errWithCode = p.SendRequest(req, openAIProviderTranscriptionsResponse, true)
resp, errWithCode = p.Requester.SendRequest(req, openAIProviderTranscriptionsResponse, true)
if errWithCode != nil {
return
return nil, errWithCode
}
textResponse = openAIProviderTranscriptionsResponse.Text
} else {
openAIProviderTranscriptionsTextResponse := new(OpenAIProviderTranscriptionsTextResponse)
errWithCode = p.SendRequest(req, openAIProviderTranscriptionsTextResponse, true)
resp, errWithCode = p.Requester.SendRequest(req, openAIProviderTranscriptionsTextResponse, true)
if errWithCode != nil {
return
return nil, errWithCode
}
textResponse = getTextContent(*openAIProviderTranscriptionsTextResponse.GetString(), request.ResponseFormat)
}
defer resp.Body.Close()
audioResponseWrapper.Headers = map[string]string{
"Content-Type": resp.Header.Get("Content-Type"),
}
audioResponseWrapper.Body, err = io.ReadAll(resp.Body)
if err != nil {
return nil, common.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
completionTokens := common.CountTokenText(textResponse, request.Model)
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
return
p.Usage.CompletionTokens = completionTokens
p.Usage.TotalTokens = p.Usage.PromptTokens + p.Usage.CompletionTokens
return audioResponseWrapper, nil
}

View File

@@ -12,11 +12,27 @@ type OpenAIProviderChatStreamResponse struct {
types.OpenAIErrorResponse
}
func (c *OpenAIProviderChatStreamResponse) getResponseText() (responseText string) {
for _, choice := range c.Choices {
responseText += choice.Delta.Content
}
return
}
type OpenAIProviderCompletionResponse struct {
types.CompletionResponse
types.OpenAIErrorResponse
}
func (c *OpenAIProviderCompletionResponse) getResponseText() (responseText string) {
for _, choice := range c.Choices {
responseText += choice.Text
}
return
}
type OpenAIProviderEmbeddingsResponse struct {
types.EmbeddingResponse
types.OpenAIErrorResponse
@@ -38,7 +54,7 @@ func (a *OpenAIProviderTranscriptionsTextResponse) GetString() *string {
return (*string)(a)
}
type OpenAIProviderImageResponseResponse struct {
type OpenAIProviderImageResponse struct {
types.ImageResponse
types.OpenAIErrorResponse
}