mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-16 13:13:41 +08:00
♻️ refactor: provider refactor (#41)
* ♻️ refactor: provider refactor
* 完善百度/讯飞的函数调用,现在可以在`lobe-chat`中正常调用函数了
This commit is contained in:
@@ -3,7 +3,6 @@ package openai
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"time"
|
||||
)
|
||||
@@ -16,15 +15,14 @@ func (p *OpenAIProvider) Balance(channel *model.Channel) (float64, error) {
|
||||
fullRequestURL := p.GetFullRequestURL("/v1/dashboard/billing/subscription", "")
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
|
||||
req, err := p.Requester.NewRequest("GET", fullRequestURL, p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
var subscription OpenAISubscriptionResponse
|
||||
_, errWithCode := common.SendRequest(req, &subscription, false, p.Channel.Proxy)
|
||||
_, errWithCode := p.Requester.SendRequest(req, &subscription, false)
|
||||
if errWithCode != nil {
|
||||
return 0, errors.New(errWithCode.OpenAIError.Message)
|
||||
}
|
||||
@@ -37,12 +35,15 @@ func (p *OpenAIProvider) Balance(channel *model.Channel) (float64, error) {
|
||||
}
|
||||
|
||||
fullRequestURL = p.GetFullRequestURL(fmt.Sprintf("/v1/dashboard/billing/usage?start_date=%s&end_date=%s", startDate, endDate), "")
|
||||
req, err = client.NewRequest("GET", fullRequestURL, common.WithHeader(headers))
|
||||
req, err = p.Requester.NewRequest("GET", fullRequestURL, p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
usage := OpenAIUsageResponse{}
|
||||
_, errWithCode = common.SendRequest(req, &usage, false, p.Channel.Proxy)
|
||||
_, errWithCode = p.Requester.SendRequest(req, &usage, false)
|
||||
if errWithCode != nil {
|
||||
return 0, errWithCode
|
||||
}
|
||||
|
||||
balance := subscription.HardLimitUSD - usage.TotalUsage/100
|
||||
channel.UpdateBalance(balance)
|
||||
|
||||
@@ -1,69 +1,100 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/model"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
|
||||
"one-api/providers/base"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type OpenAIProviderFactory struct{}
|
||||
|
||||
// 创建 OpenAIProvider
|
||||
func (f OpenAIProviderFactory) Create(c *gin.Context) base.ProviderInterface {
|
||||
openAIProvider := CreateOpenAIProvider(c, "")
|
||||
openAIProvider.BalanceAction = true
|
||||
return openAIProvider
|
||||
}
|
||||
|
||||
type OpenAIProvider struct {
|
||||
base.BaseProvider
|
||||
IsAzure bool
|
||||
BalanceAction bool
|
||||
}
|
||||
|
||||
// 创建 OpenAIProvider
|
||||
func (f OpenAIProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
|
||||
openAIProvider := CreateOpenAIProvider(channel, "https://api.openai.com")
|
||||
openAIProvider.BalanceAction = true
|
||||
return openAIProvider
|
||||
}
|
||||
|
||||
// 创建 OpenAIProvider
|
||||
// https://platform.openai.com/docs/api-reference/introduction
|
||||
func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider {
|
||||
if baseURL == "" {
|
||||
baseURL = "https://api.openai.com"
|
||||
}
|
||||
func CreateOpenAIProvider(channel *model.Channel, baseURL string) *OpenAIProvider {
|
||||
config := getOpenAIConfig(baseURL)
|
||||
|
||||
return &OpenAIProvider{
|
||||
BaseProvider: base.BaseProvider{
|
||||
BaseURL: baseURL,
|
||||
Completions: "/v1/completions",
|
||||
ChatCompletions: "/v1/chat/completions",
|
||||
Embeddings: "/v1/embeddings",
|
||||
Moderation: "/v1/moderations",
|
||||
AudioSpeech: "/v1/audio/speech",
|
||||
AudioTranscriptions: "/v1/audio/transcriptions",
|
||||
AudioTranslations: "/v1/audio/translations",
|
||||
ImagesGenerations: "/v1/images/generations",
|
||||
ImagesEdit: "/v1/images/edits",
|
||||
ImagesVariations: "/v1/images/variations",
|
||||
Context: c,
|
||||
Config: config,
|
||||
Channel: channel,
|
||||
Requester: requester.NewHTTPRequester(channel.Proxy, RequestErrorHandle),
|
||||
},
|
||||
IsAzure: false,
|
||||
BalanceAction: true,
|
||||
}
|
||||
}
|
||||
|
||||
func getOpenAIConfig(baseURL string) base.ProviderConfig {
|
||||
return base.ProviderConfig{
|
||||
BaseURL: baseURL,
|
||||
Completions: "/v1/completions",
|
||||
ChatCompletions: "/v1/chat/completions",
|
||||
Embeddings: "/v1/embeddings",
|
||||
Moderation: "/v1/moderations",
|
||||
AudioSpeech: "/v1/audio/speech",
|
||||
AudioTranscriptions: "/v1/audio/transcriptions",
|
||||
AudioTranslations: "/v1/audio/translations",
|
||||
ImagesGenerations: "/v1/images/generations",
|
||||
ImagesEdit: "/v1/images/edits",
|
||||
ImagesVariations: "/v1/images/variations",
|
||||
}
|
||||
}
|
||||
|
||||
// 请求错误处理
|
||||
func RequestErrorHandle(resp *http.Response) *types.OpenAIError {
|
||||
var errorResponse *types.OpenAIErrorResponse
|
||||
err := json.NewDecoder(resp.Body).Decode(errorResponse)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ErrorHandle(errorResponse)
|
||||
}
|
||||
|
||||
// 错误处理
|
||||
func ErrorHandle(openaiError *types.OpenAIErrorResponse) *types.OpenAIError {
|
||||
if openaiError.Error.Message == "" {
|
||||
return nil
|
||||
}
|
||||
return &openaiError.Error
|
||||
}
|
||||
|
||||
// 获取完整请求 URL
|
||||
func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||
|
||||
if p.IsAzure {
|
||||
apiVersion := p.Channel.Other
|
||||
// 以-分割,检测modelName 最后一个元素是否为4位数字,必须是数字,如果是则删除modelName最后一个元素
|
||||
modelNameSlice := strings.Split(modelName, "-")
|
||||
lastModelNameSlice := modelNameSlice[len(modelNameSlice)-1]
|
||||
modelNum := common.String2Int(lastModelNameSlice)
|
||||
if modelNum > 999 && modelNum < 10000 {
|
||||
modelName = strings.TrimSuffix(modelName, "-"+lastModelNameSlice)
|
||||
}
|
||||
// 检测模型是是否包含 . 如果有则直接去掉
|
||||
modelName = strings.Replace(modelName, ".", "", -1)
|
||||
|
||||
if modelName == "dall-e-2" {
|
||||
// 因为dall-e-3需要api-version=2023-12-01-preview,但是该版本
|
||||
// 已经没有dall-e-2了,所以暂时写死
|
||||
@@ -72,10 +103,6 @@ func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string)
|
||||
requestURL = fmt.Sprintf("/openai/deployments/%s%s?api-version=%s", modelName, requestURL, apiVersion)
|
||||
}
|
||||
|
||||
// 检测模型是是否包含 . 如果有则直接去掉
|
||||
if strings.Contains(requestURL, ".") {
|
||||
requestURL = strings.Replace(requestURL, ".", "", -1)
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
|
||||
@@ -102,89 +129,21 @@ func (p *OpenAIProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
return headers
|
||||
}
|
||||
|
||||
// 获取请求体
|
||||
func (p *OpenAIProvider) GetRequestBody(request any, isModelMapped bool) (requestBody io.Reader, err error) {
|
||||
if isModelMapped {
|
||||
jsonStr, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
} else {
|
||||
requestBody = p.Context.Request.Body
|
||||
func (p *OpenAIProvider) GetRequestTextBody(relayMode int, ModelName string, request any) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||
url, errWithCode := p.GetSupportedAPIUri(relayMode)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
return
|
||||
}
|
||||
// 获取请求地址
|
||||
fullRequestURL := p.GetFullRequestURL(url, ModelName)
|
||||
|
||||
// 发送流式请求
|
||||
func (p *OpenAIProvider) SendStreamRequest(req *http.Request, response OpenAIProviderStreamResponseHandler) (openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode, responseText string) {
|
||||
defer req.Body.Close()
|
||||
|
||||
client := common.GetHttpClient(p.Channel.Proxy)
|
||||
resp, err := client.Do(req)
|
||||
// 获取请求头
|
||||
headers := p.GetRequestHeaders()
|
||||
// 创建请求
|
||||
req, err := p.Requester.NewRequest(http.MethodPost, fullRequestURL, p.Requester.WithBody(request), p.Requester.WithHeader(headers))
|
||||
if err != nil {
|
||||
return common.ErrorWrapper(err, "http_request_failed", http.StatusInternalServerError), ""
|
||||
}
|
||||
common.PutHttpClient(client)
|
||||
|
||||
if common.IsFailureStatusCode(resp) {
|
||||
return common.HandleErrorResp(resp), ""
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
|
||||
if atEOF && len(data) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
if i := strings.Index(string(data), "\n"); i >= 0 {
|
||||
return i + 1, data[0:i], nil
|
||||
}
|
||||
if atEOF {
|
||||
return len(data), data, nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
})
|
||||
dataChan := make(chan string)
|
||||
stopChan := make(chan bool)
|
||||
go func() {
|
||||
for scanner.Scan() {
|
||||
data := scanner.Text()
|
||||
if len(data) < 6 { // ignore blank line or wrong format
|
||||
continue
|
||||
}
|
||||
if data[:6] != "data: " && data[:6] != "[DONE]" {
|
||||
continue
|
||||
}
|
||||
dataChan <- data
|
||||
data = data[6:]
|
||||
if !strings.HasPrefix(data, "[DONE]") {
|
||||
err := json.Unmarshal([]byte(data), response)
|
||||
if err != nil {
|
||||
common.SysError("error unmarshalling stream response: " + err.Error())
|
||||
continue // just ignore the error
|
||||
}
|
||||
responseText += response.responseStreamHandler()
|
||||
}
|
||||
}
|
||||
stopChan <- true
|
||||
}()
|
||||
common.SetEventStreamHeaders(p.Context)
|
||||
p.Context.Stream(func(w io.Writer) bool {
|
||||
select {
|
||||
case data := <-dataChan:
|
||||
if strings.HasPrefix(data, "data: [DONE]") {
|
||||
data = data[:12]
|
||||
}
|
||||
// some implementations may add \r at the end of data
|
||||
data = strings.TrimSuffix(data, "\r")
|
||||
p.Context.Render(-1, common.CustomEvent{Data: data})
|
||||
return true
|
||||
case <-stopChan:
|
||||
return false
|
||||
}
|
||||
})
|
||||
|
||||
return nil, responseText
|
||||
return req, nil
|
||||
}
|
||||
|
||||
@@ -1,82 +1,101 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (c *OpenAIProviderChatResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if c.Error.Type != "" {
|
||||
type OpenAIStreamHandler struct {
|
||||
Usage *types.Usage
|
||||
ModelName string
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) CreateChatCompletion(request *types.ChatCompletionRequest) (openaiResponse *types.ChatCompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
response := &OpenAIProviderChatResponse{}
|
||||
// 发送请求
|
||||
_, errWithCode = p.Requester.SendRequest(req, response, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
// 检测是否错误
|
||||
openaiErr := ErrorHandle(&response.OpenAIErrorResponse)
|
||||
if openaiErr != nil {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: c.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
OpenAIError: *openaiErr,
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}
|
||||
return
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *OpenAIProviderChatStreamResponse) responseStreamHandler() (responseText string) {
|
||||
for _, choice := range c.Choices {
|
||||
responseText += choice.Delta.Content
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
return
|
||||
*p.Usage = *response.Usage
|
||||
|
||||
return &response.ChatCompletionResponse, nil
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) ChatAction(request *types.ChatCompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
requestBody, err := p.GetRequestBody(&request, isModelMapped)
|
||||
func (p *OpenAIProvider) CreateChatCompletionStream(request *types.ChatCompletionRequest) (requester.StreamReaderInterface[types.ChatCompletionStreamResponse], *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
// 发送请求
|
||||
resp, errWithCode := p.Requester.SendRequestRaw(req)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
chatHandler := OpenAIStreamHandler{
|
||||
Usage: p.Usage,
|
||||
ModelName: request.Model,
|
||||
}
|
||||
|
||||
return requester.RequestStream[types.ChatCompletionStreamResponse](p.Requester, resp, chatHandler.HandlerChatStream)
|
||||
}
|
||||
|
||||
func (h *OpenAIStreamHandler) HandlerChatStream(rawLine *[]byte, isFinished *bool, response *[]types.ChatCompletionStreamResponse) error {
|
||||
// 如果rawLine 前缀不为data:,则直接返回
|
||||
if !strings.HasPrefix(string(*rawLine), "data: ") {
|
||||
*rawLine = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// 去除前缀
|
||||
*rawLine = (*rawLine)[6:]
|
||||
|
||||
// 如果等于 DONE 则结束
|
||||
if string(*rawLine) == "[DONE]" {
|
||||
*isFinished = true
|
||||
return nil
|
||||
}
|
||||
|
||||
var openaiResponse OpenAIProviderChatStreamResponse
|
||||
err := json.Unmarshal(*rawLine, &openaiResponse)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
return common.ErrorToOpenAIError(err)
|
||||
}
|
||||
|
||||
fullRequestURL := p.GetFullRequestURL(p.ChatCompletions, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream && headers["Accept"] == "" {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
error := ErrorHandle(&openaiResponse.OpenAIErrorResponse)
|
||||
if error != nil {
|
||||
return error
|
||||
}
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
|
||||
h.Usage.CompletionTokens += countTokenText
|
||||
h.Usage.TotalTokens += countTokenText
|
||||
|
||||
if request.Stream {
|
||||
openAIProviderChatStreamResponse := &OpenAIProviderChatStreamResponse{}
|
||||
var textResponse string
|
||||
errWithCode, textResponse = p.SendStreamRequest(req, openAIProviderChatStreamResponse)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
*response = append(*response, openaiResponse.ChatCompletionStreamResponse)
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: common.CountTokenText(textResponse, request.Model),
|
||||
TotalTokens: promptTokens + common.CountTokenText(textResponse, request.Model),
|
||||
}
|
||||
|
||||
} else {
|
||||
openAIProviderChatResponse := &OpenAIProviderChatResponse{}
|
||||
errWithCode = p.SendRequest(req, openAIProviderChatResponse, true)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = openAIProviderChatResponse.Usage
|
||||
|
||||
if usage.TotalTokens == 0 {
|
||||
completionTokens := 0
|
||||
for _, choice := range openAIProviderChatResponse.Choices {
|
||||
completionTokens += common.CountTokenText(choice.Message.StringContent(), openAIProviderChatResponse.Model)
|
||||
}
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,82 +1,96 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (c *OpenAIProviderCompletionResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if c.Error.Type != "" {
|
||||
func (p *OpenAIProvider) CreateCompletion(request *types.CompletionRequest) (openaiResponse *types.CompletionResponse, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeCompletions, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
response := &OpenAIProviderCompletionResponse{}
|
||||
// 发送请求
|
||||
_, errWithCode = p.Requester.SendRequest(req, response, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
// 检测是否错误
|
||||
openaiErr := ErrorHandle(&response.OpenAIErrorResponse)
|
||||
if openaiErr != nil {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: c.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
OpenAIError: *openaiErr,
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}
|
||||
return
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *OpenAIProviderCompletionResponse) responseStreamHandler() (responseText string) {
|
||||
for _, choice := range c.Choices {
|
||||
responseText += choice.Text
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
return
|
||||
*p.Usage = *response.Usage
|
||||
|
||||
return &response.CompletionResponse, nil
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) CompleteAction(request *types.CompletionRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
requestBody, err := p.GetRequestBody(&request, isModelMapped)
|
||||
func (p *OpenAIProvider) CreateCompletionStream(request *types.CompletionRequest) (stream requester.StreamReaderInterface[types.CompletionResponse], errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeChatCompletions, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
// 发送请求
|
||||
resp, errWithCode := p.Requester.SendRequestRaw(req)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
chatHandler := OpenAIStreamHandler{
|
||||
Usage: p.Usage,
|
||||
ModelName: request.Model,
|
||||
}
|
||||
|
||||
return requester.RequestStream[types.CompletionResponse](p.Requester, resp, chatHandler.handlerCompletionStream)
|
||||
}
|
||||
|
||||
func (h *OpenAIStreamHandler) handlerCompletionStream(rawLine *[]byte, isFinished *bool, response *[]types.CompletionResponse) error {
|
||||
// 如果rawLine 前缀不为data:,则直接返回
|
||||
if !strings.HasPrefix(string(*rawLine), "data: ") {
|
||||
*rawLine = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// 去除前缀
|
||||
*rawLine = (*rawLine)[6:]
|
||||
|
||||
// 如果等于 DONE 则结束
|
||||
if string(*rawLine) == "[DONE]" {
|
||||
*isFinished = true
|
||||
return nil
|
||||
}
|
||||
|
||||
var openaiResponse OpenAIProviderCompletionResponse
|
||||
err := json.Unmarshal(*rawLine, &openaiResponse)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
return common.ErrorToOpenAIError(err)
|
||||
}
|
||||
|
||||
fullRequestURL := p.GetFullRequestURL(p.Completions, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
if request.Stream && headers["Accept"] == "" {
|
||||
headers["Accept"] = "text/event-stream"
|
||||
error := ErrorHandle(&openaiResponse.OpenAIErrorResponse)
|
||||
if error != nil {
|
||||
return error
|
||||
}
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
countTokenText := common.CountTokenText(openaiResponse.getResponseText(), h.ModelName)
|
||||
h.Usage.CompletionTokens += countTokenText
|
||||
h.Usage.TotalTokens += countTokenText
|
||||
|
||||
openAIProviderCompletionResponse := &OpenAIProviderCompletionResponse{}
|
||||
if request.Stream {
|
||||
// TODO
|
||||
var textResponse string
|
||||
errWithCode, textResponse = p.SendStreamRequest(req, openAIProviderCompletionResponse)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
*response = append(*response, openaiResponse.CompletionResponse)
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: common.CountTokenText(textResponse, request.Model),
|
||||
TotalTokens: promptTokens + common.CountTokenText(textResponse, request.Model),
|
||||
}
|
||||
|
||||
} else {
|
||||
errWithCode = p.SendRequest(req, openAIProviderCompletionResponse, true)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = openAIProviderCompletionResponse.Usage
|
||||
|
||||
if usage.TotalTokens == 0 {
|
||||
completionTokens := 0
|
||||
for _, choice := range openAIProviderCompletionResponse.Choices {
|
||||
completionTokens += common.CountTokenText(choice.Text, openAIProviderCompletionResponse.Model)
|
||||
}
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -6,40 +6,30 @@ import (
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
func (c *OpenAIProviderEmbeddingsResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if c.Error.Type != "" {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: c.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
return
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
requestBody, err := p.GetRequestBody(&request, isModelMapped)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
fullRequestURL := p.GetFullRequestURL(p.Embeddings, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
openAIProviderEmbeddingsResponse := &OpenAIProviderEmbeddingsResponse{}
|
||||
errWithCode = p.SendRequest(req, openAIProviderEmbeddingsResponse, true)
|
||||
func (p *OpenAIProvider) CreateEmbeddings(request *types.EmbeddingRequest) (*types.EmbeddingResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeEmbeddings, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
response := &OpenAIProviderEmbeddingsResponse{}
|
||||
// 发送请求
|
||||
_, errWithCode = p.Requester.SendRequest(req, response, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
usage = openAIProviderEmbeddingsResponse.Usage
|
||||
openaiErr := ErrorHandle(&response.OpenAIErrorResponse)
|
||||
if openaiErr != nil {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: *openaiErr,
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
return
|
||||
*p.Usage = *response.Usage
|
||||
|
||||
return &response.EmbeddingResponse, nil
|
||||
}
|
||||
|
||||
@@ -5,28 +5,71 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
func (p *OpenAIProvider) ImageEditsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
fullRequestURL := p.GetFullRequestURL(p.ImagesEdit, request.Model)
|
||||
func (p *OpenAIProvider) CreateImageEdits(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getRequestImageBody(common.RelayModeEdits, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
response := &OpenAIProviderImageResponse{}
|
||||
// 发送请求
|
||||
_, errWithCode = p.Requester.SendRequest(req, response, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
openaiErr := ErrorHandle(&response.OpenAIErrorResponse)
|
||||
if openaiErr != nil {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: *openaiErr,
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
p.Usage.TotalTokens = p.Usage.PromptTokens
|
||||
|
||||
return &response.ImageResponse, nil
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) getRequestImageBody(relayMode int, ModelName string, request *types.ImageEditRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||
url, errWithCode := p.GetSupportedAPIUri(relayMode)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
// 获取请求地址
|
||||
fullRequestURL := p.GetFullRequestURL(url, ModelName)
|
||||
|
||||
// 获取请求头
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
|
||||
var formBody bytes.Buffer
|
||||
// 创建请求
|
||||
var req *http.Request
|
||||
var err error
|
||||
if isModelMapped {
|
||||
builder := client.CreateFormBuilder(&formBody)
|
||||
if p.OriginalModel != request.Model {
|
||||
var formBody bytes.Buffer
|
||||
builder := p.Requester.CreateFormBuilder(&formBody)
|
||||
if err := imagesEditsMultipartForm(request, builder); err != nil {
|
||||
return nil, common.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError)
|
||||
}
|
||||
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType()))
|
||||
req, err = p.Requester.NewRequest(
|
||||
http.MethodPost,
|
||||
fullRequestURL,
|
||||
p.Requester.WithBody(&formBody),
|
||||
p.Requester.WithHeader(headers),
|
||||
p.Requester.WithContentType(builder.FormDataContentType()))
|
||||
req.ContentLength = int64(formBody.Len())
|
||||
|
||||
} else {
|
||||
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type")))
|
||||
req, err = p.Requester.NewRequest(
|
||||
http.MethodPost,
|
||||
fullRequestURL,
|
||||
p.Requester.WithBody(p.Context.Request.Body),
|
||||
p.Requester.WithHeader(headers),
|
||||
p.Requester.WithContentType(p.Context.Request.Header.Get("Content-Type")))
|
||||
req.ContentLength = p.Context.Request.ContentLength
|
||||
}
|
||||
|
||||
@@ -34,22 +77,10 @@ func (p *OpenAIProvider) ImageEditsAction(request *types.ImageEditRequest, isMod
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
openAIProviderImageResponseResponse := &OpenAIProviderImageResponseResponse{}
|
||||
errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: promptTokens,
|
||||
}
|
||||
|
||||
return
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func imagesEditsMultipartForm(request *types.ImageEditRequest, b common.FormBuilder) error {
|
||||
func imagesEditsMultipartForm(request *types.ImageEditRequest, b requester.FormBuilder) error {
|
||||
err := b.CreateFormFile("image", request.Image)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating form image: %w", err)
|
||||
|
||||
@@ -6,53 +6,41 @@ import (
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
func (c *OpenAIProviderImageResponseResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if c.Error.Type != "" {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: c.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
return
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) ImageGenerationsAction(request *types.ImageRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
if !isWithinRange(request.Model, request.N) {
|
||||
func (p *OpenAIProvider) CreateImageGenerations(request *types.ImageRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
if !IsWithinRange(request.Model, request.N) {
|
||||
return nil, common.StringErrorWrapper("n_not_within_range", "n_not_within_range", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
requestBody, err := p.GetRequestBody(&request, isModelMapped)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
fullRequestURL := p.GetFullRequestURL(p.ImagesGenerations, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
openAIProviderImageResponseResponse := &OpenAIProviderImageResponseResponse{}
|
||||
errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true)
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeImagesGenerations, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
response := &OpenAIProviderImageResponse{}
|
||||
// 发送请求
|
||||
_, errWithCode = p.Requester.SendRequest(req, response, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: promptTokens,
|
||||
// 检测是否错误
|
||||
openaiErr := ErrorHandle(&response.OpenAIErrorResponse)
|
||||
if openaiErr != nil {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: *openaiErr,
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
return
|
||||
p.Usage.TotalTokens = p.Usage.PromptTokens
|
||||
|
||||
return &response.ImageResponse, nil
|
||||
|
||||
}
|
||||
|
||||
func isWithinRange(element string, value int) bool {
|
||||
func IsWithinRange(element string, value int) bool {
|
||||
if _, ok := common.DalleGenerationImageAmounts[element]; !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,49 +1,35 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
func (p *OpenAIProvider) ImageVariationsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
fullRequestURL := p.GetFullRequestURL(p.ImagesVariations, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
|
||||
var formBody bytes.Buffer
|
||||
var req *http.Request
|
||||
var err error
|
||||
if isModelMapped {
|
||||
builder := client.CreateFormBuilder(&formBody)
|
||||
if err := imagesEditsMultipartForm(request, builder); err != nil {
|
||||
return nil, common.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError)
|
||||
}
|
||||
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType()))
|
||||
req.ContentLength = int64(formBody.Len())
|
||||
|
||||
} else {
|
||||
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type")))
|
||||
req.ContentLength = p.Context.Request.ContentLength
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
openAIProviderImageResponseResponse := &OpenAIProviderImageResponseResponse{}
|
||||
errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true)
|
||||
func (p *OpenAIProvider) CreateImageVariations(request *types.ImageEditRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getRequestImageBody(common.RelayModeImagesVariations, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
response := &OpenAIProviderImageResponse{}
|
||||
// 发送请求
|
||||
_, errWithCode = p.Requester.SendRequest(req, response, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: promptTokens,
|
||||
openaiErr := ErrorHandle(&response.OpenAIErrorResponse)
|
||||
if openaiErr != nil {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: *openaiErr,
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
return
|
||||
p.Usage.TotalTokens = p.Usage.PromptTokens
|
||||
|
||||
return &response.ImageResponse, nil
|
||||
}
|
||||
|
||||
@@ -6,44 +6,31 @@ import (
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
func (c *OpenAIProviderModerationResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if c.Error.Type != "" {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: c.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
return
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
func (p *OpenAIProvider) CreateModeration(request *types.ModerationRequest) (*types.ModerationResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
func (p *OpenAIProvider) ModerationAction(request *types.ModerationRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
requestBody, err := p.GetRequestBody(&request, isModelMapped)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
fullRequestURL := p.GetFullRequestURL(p.Moderation, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
openAIProviderModerationResponse := &OpenAIProviderModerationResponse{}
|
||||
errWithCode = p.SendRequest(req, openAIProviderModerationResponse, true)
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeModerations, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
response := &OpenAIProviderModerationResponse{}
|
||||
// 发送请求
|
||||
_, errWithCode = p.Requester.SendRequest(req, response, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: promptTokens,
|
||||
openaiErr := ErrorHandle(&response.OpenAIErrorResponse)
|
||||
if openaiErr != nil {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: *openaiErr,
|
||||
StatusCode: http.StatusBadRequest,
|
||||
}
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
return
|
||||
p.Usage.TotalTokens = p.Usage.PromptTokens
|
||||
|
||||
return &response.ModerationResponse, nil
|
||||
}
|
||||
|
||||
@@ -3,35 +3,29 @@ package openai
|
||||
import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
func (p *OpenAIProvider) SpeechAction(request *types.SpeechAudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
|
||||
requestBody, err := p.GetRequestBody(&request, isModelMapped)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
fullRequestURL := p.GetFullRequestURL(p.AudioSpeech, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
errWithCode = p.SendRequestRaw(req)
|
||||
func (p *OpenAIProvider) CreateSpeech(request *types.SpeechAudioRequest) (*http.Response, *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.GetRequestTextBody(common.RelayModeAudioSpeech, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
// 发送请求
|
||||
var resp *http.Response
|
||||
resp, errWithCode = p.Requester.SendRequestRaw(req)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: promptTokens,
|
||||
if resp.Header.Get("Content-Type") == "application/json" {
|
||||
return nil, requester.HandleErrorResp(resp, p.Requester.ErrorHandler)
|
||||
}
|
||||
|
||||
return
|
||||
p.Usage.TotalTokens = p.Usage.PromptTokens
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -4,48 +4,99 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/requester"
|
||||
"one-api/types"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (c *OpenAIProviderTranscriptionsResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
if c.Error.Type != "" {
|
||||
errWithCode = &types.OpenAIErrorWithStatusCode{
|
||||
OpenAIError: c.Error,
|
||||
StatusCode: resp.StatusCode,
|
||||
}
|
||||
return
|
||||
func (p *OpenAIProvider) CreateTranscriptions(request *types.AudioRequest) (*types.AudioResponseWrapper, *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getRequestAudioBody(common.RelayModeAudioTranscription, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
return nil, nil
|
||||
defer req.Body.Close()
|
||||
|
||||
var textResponse string
|
||||
var resp *http.Response
|
||||
var err error
|
||||
audioResponseWrapper := &types.AudioResponseWrapper{}
|
||||
if hasJSONResponse(request) {
|
||||
openAIProviderTranscriptionsResponse := &OpenAIProviderTranscriptionsResponse{}
|
||||
resp, errWithCode = p.Requester.SendRequest(req, openAIProviderTranscriptionsResponse, true)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
textResponse = openAIProviderTranscriptionsResponse.Text
|
||||
} else {
|
||||
openAIProviderTranscriptionsTextResponse := new(OpenAIProviderTranscriptionsTextResponse)
|
||||
resp, errWithCode = p.Requester.SendRequest(req, openAIProviderTranscriptionsTextResponse, true)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
textResponse = getTextContent(*openAIProviderTranscriptionsTextResponse.GetString(), request.ResponseFormat)
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
audioResponseWrapper.Headers = map[string]string{
|
||||
"Content-Type": resp.Header.Get("Content-Type"),
|
||||
}
|
||||
|
||||
audioResponseWrapper.Body, err = io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
completionTokens := common.CountTokenText(textResponse, request.Model)
|
||||
|
||||
p.Usage.CompletionTokens = completionTokens
|
||||
p.Usage.TotalTokens = p.Usage.PromptTokens + p.Usage.CompletionTokens
|
||||
|
||||
return audioResponseWrapper, nil
|
||||
}
|
||||
|
||||
func (c *OpenAIProviderTranscriptionsTextResponse) ResponseHandler(resp *http.Response) (OpenAIResponse any, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
return nil, nil
|
||||
func hasJSONResponse(request *types.AudioRequest) bool {
|
||||
return request.ResponseFormat == "" || request.ResponseFormat == "json" || request.ResponseFormat == "verbose_json"
|
||||
}
|
||||
|
||||
func (p *OpenAIProvider) TranscriptionsAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
fullRequestURL := p.GetFullRequestURL(p.AudioTranscriptions, request.Model)
|
||||
func (p *OpenAIProvider) getRequestAudioBody(relayMode int, ModelName string, request *types.AudioRequest) (*http.Request, *types.OpenAIErrorWithStatusCode) {
|
||||
url, errWithCode := p.GetSupportedAPIUri(relayMode)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
// 获取请求地址
|
||||
fullRequestURL := p.GetFullRequestURL(url, ModelName)
|
||||
|
||||
// 获取请求头
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
|
||||
var formBody bytes.Buffer
|
||||
// 创建请求
|
||||
var req *http.Request
|
||||
var err error
|
||||
if isModelMapped {
|
||||
builder := client.CreateFormBuilder(&formBody)
|
||||
if p.OriginalModel != request.Model {
|
||||
var formBody bytes.Buffer
|
||||
builder := p.Requester.CreateFormBuilder(&formBody)
|
||||
if err := audioMultipartForm(request, builder); err != nil {
|
||||
return nil, common.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError)
|
||||
}
|
||||
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType()))
|
||||
req, err = p.Requester.NewRequest(
|
||||
http.MethodPost,
|
||||
fullRequestURL,
|
||||
p.Requester.WithBody(&formBody),
|
||||
p.Requester.WithHeader(headers),
|
||||
p.Requester.WithContentType(builder.FormDataContentType()))
|
||||
req.ContentLength = int64(formBody.Len())
|
||||
|
||||
} else {
|
||||
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type")))
|
||||
req, err = p.Requester.NewRequest(
|
||||
http.MethodPost,
|
||||
fullRequestURL,
|
||||
p.Requester.WithBody(p.Context.Request.Body),
|
||||
p.Requester.WithHeader(headers),
|
||||
p.Requester.WithContentType(p.Context.Request.Header.Get("Content-Type")))
|
||||
req.ContentLength = p.Context.Request.ContentLength
|
||||
}
|
||||
|
||||
@@ -53,37 +104,10 @@ func (p *OpenAIProvider) TranscriptionsAction(request *types.AudioRequest, isMod
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
var textResponse string
|
||||
if hasJSONResponse(request) {
|
||||
openAIProviderTranscriptionsResponse := &OpenAIProviderTranscriptionsResponse{}
|
||||
errWithCode = p.SendRequest(req, openAIProviderTranscriptionsResponse, true)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
textResponse = openAIProviderTranscriptionsResponse.Text
|
||||
} else {
|
||||
openAIProviderTranscriptionsTextResponse := new(OpenAIProviderTranscriptionsTextResponse)
|
||||
errWithCode = p.SendRequest(req, openAIProviderTranscriptionsTextResponse, true)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
textResponse = getTextContent(*openAIProviderTranscriptionsTextResponse.GetString(), request.ResponseFormat)
|
||||
}
|
||||
|
||||
completionTokens := common.CountTokenText(textResponse, request.Model)
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
return
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func hasJSONResponse(request *types.AudioRequest) bool {
|
||||
return request.ResponseFormat == "" || request.ResponseFormat == "json" || request.ResponseFormat == "verbose_json"
|
||||
}
|
||||
|
||||
func audioMultipartForm(request *types.AudioRequest, b common.FormBuilder) error {
|
||||
func audioMultipartForm(request *types.AudioRequest, b requester.FormBuilder) error {
|
||||
|
||||
err := b.CreateFormFile("file", request.File)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,60 +1,53 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
func (p *OpenAIProvider) TranslationAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
fullRequestURL := p.GetFullRequestURL(p.AudioTranslations, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
|
||||
var formBody bytes.Buffer
|
||||
var req *http.Request
|
||||
var err error
|
||||
if isModelMapped {
|
||||
builder := client.CreateFormBuilder(&formBody)
|
||||
if err := audioMultipartForm(request, builder); err != nil {
|
||||
return nil, common.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError)
|
||||
}
|
||||
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType()))
|
||||
req.ContentLength = int64(formBody.Len())
|
||||
|
||||
} else {
|
||||
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type")))
|
||||
req.ContentLength = p.Context.Request.ContentLength
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
func (p *OpenAIProvider) CreateTranslation(request *types.AudioRequest) (*types.AudioResponseWrapper, *types.OpenAIErrorWithStatusCode) {
|
||||
req, errWithCode := p.getRequestAudioBody(common.RelayModeAudioTranslation, request.Model, request)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
defer req.Body.Close()
|
||||
|
||||
var textResponse string
|
||||
var resp *http.Response
|
||||
var err error
|
||||
audioResponseWrapper := &types.AudioResponseWrapper{}
|
||||
if hasJSONResponse(request) {
|
||||
openAIProviderTranscriptionsResponse := &OpenAIProviderTranscriptionsResponse{}
|
||||
errWithCode = p.SendRequest(req, openAIProviderTranscriptionsResponse, true)
|
||||
resp, errWithCode = p.Requester.SendRequest(req, openAIProviderTranscriptionsResponse, true)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
return nil, errWithCode
|
||||
}
|
||||
textResponse = openAIProviderTranscriptionsResponse.Text
|
||||
} else {
|
||||
openAIProviderTranscriptionsTextResponse := new(OpenAIProviderTranscriptionsTextResponse)
|
||||
errWithCode = p.SendRequest(req, openAIProviderTranscriptionsTextResponse, true)
|
||||
resp, errWithCode = p.Requester.SendRequest(req, openAIProviderTranscriptionsTextResponse, true)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
return nil, errWithCode
|
||||
}
|
||||
textResponse = getTextContent(*openAIProviderTranscriptionsTextResponse.GetString(), request.ResponseFormat)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
audioResponseWrapper.Headers = map[string]string{
|
||||
"Content-Type": resp.Header.Get("Content-Type"),
|
||||
}
|
||||
|
||||
audioResponseWrapper.Body, err = io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
completionTokens := common.CountTokenText(textResponse, request.Model)
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: completionTokens,
|
||||
TotalTokens: promptTokens + completionTokens,
|
||||
}
|
||||
return
|
||||
|
||||
p.Usage.CompletionTokens = completionTokens
|
||||
p.Usage.TotalTokens = p.Usage.PromptTokens + p.Usage.CompletionTokens
|
||||
|
||||
return audioResponseWrapper, nil
|
||||
}
|
||||
|
||||
@@ -12,11 +12,27 @@ type OpenAIProviderChatStreamResponse struct {
|
||||
types.OpenAIErrorResponse
|
||||
}
|
||||
|
||||
func (c *OpenAIProviderChatStreamResponse) getResponseText() (responseText string) {
|
||||
for _, choice := range c.Choices {
|
||||
responseText += choice.Delta.Content
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type OpenAIProviderCompletionResponse struct {
|
||||
types.CompletionResponse
|
||||
types.OpenAIErrorResponse
|
||||
}
|
||||
|
||||
func (c *OpenAIProviderCompletionResponse) getResponseText() (responseText string) {
|
||||
for _, choice := range c.Choices {
|
||||
responseText += choice.Text
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type OpenAIProviderEmbeddingsResponse struct {
|
||||
types.EmbeddingResponse
|
||||
types.OpenAIErrorResponse
|
||||
@@ -38,7 +54,7 @@ func (a *OpenAIProviderTranscriptionsTextResponse) GetString() *string {
|
||||
return (*string)(a)
|
||||
}
|
||||
|
||||
type OpenAIProviderImageResponseResponse struct {
|
||||
type OpenAIProviderImageResponse struct {
|
||||
types.ImageResponse
|
||||
types.OpenAIErrorResponse
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user