mirror of
https://github.com/linux-do/new-api.git
synced 2025-09-17 16:06:38 +08:00
222 lines
7.1 KiB
Go
222 lines
7.1 KiB
Go
package openai
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/gin-gonic/gin"
|
|
"io"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"one-api/common"
|
|
"one-api/dto"
|
|
"one-api/relay/channel"
|
|
"one-api/relay/channel/ai360"
|
|
"one-api/relay/channel/lingyiwanwu"
|
|
"one-api/relay/channel/minimax"
|
|
"one-api/relay/channel/moonshot"
|
|
relaycommon "one-api/relay/common"
|
|
"one-api/relay/constant"
|
|
"strings"
|
|
)
|
|
|
|
type Adaptor struct {
|
|
ChannelType int
|
|
ResponseFormat string
|
|
}
|
|
|
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|
a.ChannelType = info.ChannelType
|
|
}
|
|
|
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|
if info.RelayMode == constant.RelayModeRealtime {
|
|
// trim https
|
|
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
|
|
baseUrl = strings.TrimPrefix(baseUrl, "http://")
|
|
baseUrl = "wss://" + baseUrl
|
|
info.BaseUrl = baseUrl
|
|
}
|
|
switch info.ChannelType {
|
|
case common.ChannelTypeAzure:
|
|
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
|
|
requestURL := strings.Split(info.RequestURLPath, "?")[0]
|
|
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, info.ApiVersion)
|
|
task := strings.TrimPrefix(requestURL, "/v1/")
|
|
model_ := info.UpstreamModelName
|
|
model_ = strings.Replace(model_, ".", "", -1)
|
|
// https://github.com/songquanpeng/one-api/issues/67
|
|
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
|
if info.RelayMode == constant.RelayModeRealtime {
|
|
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, info.ApiVersion)
|
|
}
|
|
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
|
|
case common.ChannelTypeMiniMax:
|
|
return minimax.GetRequestURL(info)
|
|
case common.ChannelTypeCustom:
|
|
url := info.BaseUrl
|
|
url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
|
|
return url, nil
|
|
default:
|
|
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
|
|
}
|
|
}
|
|
|
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
|
|
channel.SetupApiRequestHeader(info, c, header)
|
|
if info.ChannelType == common.ChannelTypeAzure {
|
|
header.Set("api-key", info.ApiKey)
|
|
return nil
|
|
}
|
|
if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization {
|
|
header.Set("OpenAI-Organization", info.Organization)
|
|
}
|
|
if info.RelayMode == constant.RelayModeRealtime {
|
|
swp := c.Request.Header.Get("Sec-WebSocket-Protocol")
|
|
if swp != "" {
|
|
items := []string{
|
|
"realtime",
|
|
"openai-insecure-api-key." + info.ApiKey,
|
|
"openai-beta.realtime-v1",
|
|
}
|
|
header.Set("Sec-WebSocket-Protocol", strings.Join(items, ","))
|
|
//req.Header.Set("Sec-WebSocket-Key", c.Request.Header.Get("Sec-WebSocket-Key"))
|
|
//req.Header.Set("Sec-Websocket-Extensions", c.Request.Header.Get("Sec-Websocket-Extensions"))
|
|
//req.Header.Set("Sec-Websocket-Version", c.Request.Header.Get("Sec-Websocket-Version"))
|
|
} else {
|
|
header.Set("openai-beta", "realtime=v1")
|
|
header.Set("Authorization", "Bearer "+info.ApiKey)
|
|
}
|
|
} else {
|
|
header.Set("Authorization", "Bearer "+info.ApiKey)
|
|
}
|
|
//if info.ChannelType == common.ChannelTypeOpenRouter {
|
|
// req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
|
|
// req.Header.Set("X-Title", "One API")
|
|
//}
|
|
return nil
|
|
}
|
|
|
|
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
|
if request == nil {
|
|
return nil, errors.New("request is nil")
|
|
}
|
|
if info.ChannelType != common.ChannelTypeOpenAI {
|
|
request.StreamOptions = nil
|
|
}
|
|
if strings.HasPrefix(request.Model, "o1-") {
|
|
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
|
|
request.MaxCompletionTokens = request.MaxTokens
|
|
request.MaxTokens = 0
|
|
}
|
|
}
|
|
return request, nil
|
|
}
|
|
|
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
|
return nil, errors.New("not implemented")
|
|
}
|
|
|
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
|
a.ResponseFormat = request.ResponseFormat
|
|
if info.RelayMode == constant.RelayModeAudioSpeech {
|
|
jsonData, err := json.Marshal(request)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error marshalling object: %w", err)
|
|
}
|
|
return bytes.NewReader(jsonData), nil
|
|
} else {
|
|
var requestBody bytes.Buffer
|
|
writer := multipart.NewWriter(&requestBody)
|
|
|
|
writer.WriteField("model", request.Model)
|
|
|
|
// 添加文件字段
|
|
file, header, err := c.Request.FormFile("file")
|
|
if err != nil {
|
|
return nil, errors.New("file is required")
|
|
}
|
|
defer file.Close()
|
|
|
|
part, err := writer.CreateFormFile("file", header.Filename)
|
|
if err != nil {
|
|
return nil, errors.New("create form file failed")
|
|
}
|
|
if _, err := io.Copy(part, file); err != nil {
|
|
return nil, errors.New("copy file failed")
|
|
}
|
|
|
|
// 关闭 multipart 编写器以设置分界线
|
|
writer.Close()
|
|
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
|
return &requestBody, nil
|
|
}
|
|
}
|
|
|
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
|
return request, nil
|
|
}
|
|
|
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
|
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
|
|
return channel.DoFormRequest(a, c, info, requestBody)
|
|
} else if info.RelayMode == constant.RelayModeRealtime {
|
|
return channel.DoWssRequest(a, c, info, requestBody)
|
|
} else {
|
|
return channel.DoApiRequest(a, c, info, requestBody)
|
|
}
|
|
}
|
|
|
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
|
switch info.RelayMode {
|
|
case constant.RelayModeRealtime:
|
|
err, usage = OpenaiRealtimeHandler(c, info)
|
|
case constant.RelayModeAudioSpeech:
|
|
err, usage = OpenaiTTSHandler(c, resp, info)
|
|
case constant.RelayModeAudioTranslation:
|
|
fallthrough
|
|
case constant.RelayModeAudioTranscription:
|
|
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
|
|
case constant.RelayModeImagesGenerations:
|
|
err, usage = OpenaiTTSHandler(c, resp, info)
|
|
default:
|
|
if info.IsStream {
|
|
err, usage = OaiStreamHandler(c, resp, info)
|
|
} else {
|
|
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func (a *Adaptor) GetModelList() []string {
|
|
switch a.ChannelType {
|
|
case common.ChannelType360:
|
|
return ai360.ModelList
|
|
case common.ChannelTypeMoonshot:
|
|
return moonshot.ModelList
|
|
case common.ChannelTypeLingYiWanWu:
|
|
return lingyiwanwu.ModelList
|
|
case common.ChannelTypeMiniMax:
|
|
return minimax.ModelList
|
|
default:
|
|
return ModelList
|
|
}
|
|
}
|
|
|
|
func (a *Adaptor) GetChannelName() string {
|
|
switch a.ChannelType {
|
|
case common.ChannelType360:
|
|
return ai360.ChannelName
|
|
case common.ChannelTypeMoonshot:
|
|
return moonshot.ChannelName
|
|
case common.ChannelTypeLingYiWanWu:
|
|
return lingyiwanwu.ChannelName
|
|
case common.ChannelTypeMiniMax:
|
|
return minimax.ChannelName
|
|
default:
|
|
return ChannelName
|
|
}
|
|
}
|