mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-07 09:13:42 +08:00
refactor: use adaptor to do relay & test
This commit is contained in:
@@ -14,13 +14,14 @@ import (
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func RelayAudioHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode {
|
||||
func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
|
||||
audioModel := "whisper-1"
|
||||
|
||||
tokenId := c.GetInt("token_id")
|
||||
|
||||
@@ -11,14 +11,14 @@ import (
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
func getAndValidateTextRequest(c *gin.Context, relayMode int) (*openai.GeneralOpenAIRequest, error) {
|
||||
textRequest := &openai.GeneralOpenAIRequest{}
|
||||
func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.GeneralOpenAIRequest, error) {
|
||||
textRequest := &relaymodel.GeneralOpenAIRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, textRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -36,7 +36,7 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*openai.GeneralOp
|
||||
return textRequest, nil
|
||||
}
|
||||
|
||||
func getPromptTokens(textRequest *openai.GeneralOpenAIRequest, relayMode int) int {
|
||||
func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
|
||||
switch relayMode {
|
||||
case constant.RelayModeChatCompletions:
|
||||
return openai.CountTokenMessages(textRequest.Messages, textRequest.Model)
|
||||
@@ -48,7 +48,7 @@ func getPromptTokens(textRequest *openai.GeneralOpenAIRequest, relayMode int) in
|
||||
return 0
|
||||
}
|
||||
|
||||
func getPreConsumedQuota(textRequest *openai.GeneralOpenAIRequest, promptTokens int, ratio float64) int {
|
||||
func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64) int {
|
||||
preConsumedTokens := config.PreConsumedQuota
|
||||
if textRequest.MaxTokens != 0 {
|
||||
preConsumedTokens = promptTokens + textRequest.MaxTokens
|
||||
@@ -56,7 +56,7 @@ func getPreConsumedQuota(textRequest *openai.GeneralOpenAIRequest, promptTokens
|
||||
return int(float64(preConsumedTokens) * ratio)
|
||||
}
|
||||
|
||||
func preConsumeQuota(ctx context.Context, textRequest *openai.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int, *openai.ErrorWithStatusCode) {
|
||||
func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int, *relaymodel.ErrorWithStatusCode) {
|
||||
preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio)
|
||||
|
||||
userQuota, err := model.CacheGetUserQuota(meta.UserId)
|
||||
@@ -85,7 +85,7 @@ func preConsumeQuota(ctx context.Context, textRequest *openai.GeneralOpenAIReque
|
||||
return preConsumedQuota, nil
|
||||
}
|
||||
|
||||
func postConsumeQuota(ctx context.Context, usage *openai.Usage, meta *util.RelayMeta, textRequest *openai.GeneralOpenAIRequest, ratio float64, preConsumedQuota int, modelRatio float64, groupRatio float64) {
|
||||
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.RelayMeta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int, modelRatio float64, groupRatio float64) {
|
||||
if usage == nil {
|
||||
logger.Error(ctx, "usage is nil, which is unexpected")
|
||||
return
|
||||
@@ -120,27 +120,3 @@ func postConsumeQuota(ctx context.Context, usage *openai.Usage, meta *util.Relay
|
||||
model.UpdateChannelUsedQuota(meta.ChannelId, quota)
|
||||
}
|
||||
}
|
||||
|
||||
func doRequest(ctx context.Context, c *gin.Context, meta *util.RelayMeta, isStream bool, fullRequestURL string, requestBody io.Reader) (*http.Response, error) {
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
SetupRequestHeaders(c, req, meta, isStream)
|
||||
resp, err := util.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp == nil {
|
||||
return nil, errors.New("resp is nil")
|
||||
}
|
||||
err = req.Body.Close()
|
||||
if err != nil {
|
||||
logger.Warnf(ctx, "close req.Body failed: %+v", err)
|
||||
}
|
||||
err = c.Request.Body.Close()
|
||||
if err != nil {
|
||||
logger.Warnf(ctx, "close c.Request.Body failed: %+v", err)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -28,7 +29,7 @@ func isWithinRange(element string, value int) bool {
|
||||
return value >= min && value <= max
|
||||
}
|
||||
|
||||
func RelayImageHelper(c *gin.Context, relayMode int) *openai.ErrorWithStatusCode {
|
||||
func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
|
||||
imageModel := "dall-e-2"
|
||||
imageSize := "1024x1024"
|
||||
|
||||
|
||||
@@ -12,19 +12,21 @@ import (
|
||||
"github.com/songquanpeng/one-api/relay/channel/ali"
|
||||
"github.com/songquanpeng/one-api/relay/channel/anthropic"
|
||||
"github.com/songquanpeng/one-api/relay/channel/baidu"
|
||||
"github.com/songquanpeng/one-api/relay/channel/google"
|
||||
"github.com/songquanpeng/one-api/relay/channel/gemini"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/channel/palm"
|
||||
"github.com/songquanpeng/one-api/relay/channel/tencent"
|
||||
"github.com/songquanpeng/one-api/relay/channel/xunfei"
|
||||
"github.com/songquanpeng/one-api/relay/channel/zhipu"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func GetRequestURL(requestURL string, meta *util.RelayMeta, textRequest *openai.GeneralOpenAIRequest) (string, error) {
|
||||
func GetRequestURL(requestURL string, meta *util.RelayMeta, textRequest *model.GeneralOpenAIRequest) (string, error) {
|
||||
fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType)
|
||||
switch meta.APIType {
|
||||
case constant.APITypeOpenAI:
|
||||
@@ -43,7 +45,7 @@ func GetRequestURL(requestURL string, meta *util.RelayMeta, textRequest *openai.
|
||||
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
|
||||
fullRequestURL = util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType)
|
||||
}
|
||||
case constant.APITypeClaude:
|
||||
case constant.APITypeAnthropic:
|
||||
fullRequestURL = fmt.Sprintf("%s/v1/complete", meta.BaseURL)
|
||||
case constant.APITypeBaidu:
|
||||
switch textRequest.Model {
|
||||
@@ -92,19 +94,10 @@ func GetRequestURL(requestURL string, meta *util.RelayMeta, textRequest *openai.
|
||||
return fullRequestURL, nil
|
||||
}
|
||||
|
||||
func GetRequestBody(c *gin.Context, textRequest openai.GeneralOpenAIRequest, isModelMapped bool, apiType int, relayMode int) (io.Reader, error) {
|
||||
func GetRequestBody(c *gin.Context, textRequest model.GeneralOpenAIRequest, isModelMapped bool, apiType int, relayMode int) (io.Reader, error) {
|
||||
var requestBody io.Reader
|
||||
if isModelMapped {
|
||||
jsonStr, err := json.Marshal(textRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
} else {
|
||||
requestBody = c.Request.Body
|
||||
}
|
||||
switch apiType {
|
||||
case constant.APITypeClaude:
|
||||
case constant.APITypeAnthropic:
|
||||
claudeRequest := anthropic.ConvertRequest(textRequest)
|
||||
jsonStr, err := json.Marshal(claudeRequest)
|
||||
if err != nil {
|
||||
@@ -127,14 +120,14 @@ func GetRequestBody(c *gin.Context, textRequest openai.GeneralOpenAIRequest, isM
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
case constant.APITypePaLM:
|
||||
palmRequest := google.ConvertPaLMRequest(textRequest)
|
||||
palmRequest := palm.ConvertRequest(textRequest)
|
||||
jsonStr, err := json.Marshal(palmRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
case constant.APITypeGemini:
|
||||
geminiChatRequest := google.ConvertGeminiRequest(textRequest)
|
||||
geminiChatRequest := gemini.ConvertRequest(textRequest)
|
||||
jsonStr, err := json.Marshal(geminiChatRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -187,19 +180,20 @@ func GetRequestBody(c *gin.Context, textRequest openai.GeneralOpenAIRequest, isM
|
||||
return nil, err
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
default:
|
||||
if isModelMapped {
|
||||
jsonStr, err := json.Marshal(textRequest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
} else {
|
||||
requestBody = c.Request.Body
|
||||
}
|
||||
}
|
||||
return requestBody, nil
|
||||
}
|
||||
|
||||
func SetupRequestHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, isStream bool) {
|
||||
SetupAuthHeaders(c, req, meta, isStream)
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
if isStream && c.Request.Header.Get("Accept") == "" {
|
||||
req.Header.Set("Accept", "text/event-stream")
|
||||
}
|
||||
}
|
||||
|
||||
func SetupAuthHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, isStream bool) {
|
||||
apiKey := meta.APIKey
|
||||
switch meta.APIType {
|
||||
@@ -213,7 +207,7 @@ func SetupAuthHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, i
|
||||
req.Header.Set("X-Title", "One API")
|
||||
}
|
||||
}
|
||||
case constant.APITypeClaude:
|
||||
case constant.APITypeAnthropic:
|
||||
req.Header.Set("x-api-key", apiKey)
|
||||
anthropicVersion := c.Request.Header.Get("anthropic-version")
|
||||
if anthropicVersion == "" {
|
||||
@@ -242,7 +236,7 @@ func SetupAuthHeaders(c *gin.Context, req *http.Request, meta *util.RelayMeta, i
|
||||
}
|
||||
}
|
||||
|
||||
func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp *http.Response, relayMode int, apiType int, isStream bool, promptTokens int) (usage *openai.Usage, err *openai.ErrorWithStatusCode) {
|
||||
func DoResponse(c *gin.Context, textRequest *model.GeneralOpenAIRequest, resp *http.Response, relayMode int, apiType int, isStream bool, promptTokens int) (usage *model.Usage, err *model.ErrorWithStatusCode) {
|
||||
var responseText string
|
||||
switch apiType {
|
||||
case constant.APITypeOpenAI:
|
||||
@@ -251,7 +245,7 @@ func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp *
|
||||
} else {
|
||||
err, usage = openai.Handler(c, resp, promptTokens, textRequest.Model)
|
||||
}
|
||||
case constant.APITypeClaude:
|
||||
case constant.APITypeAnthropic:
|
||||
if isStream {
|
||||
err, responseText = anthropic.StreamHandler(c, resp)
|
||||
} else {
|
||||
@@ -270,15 +264,15 @@ func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp *
|
||||
}
|
||||
case constant.APITypePaLM:
|
||||
if isStream { // PaLM2 API does not support stream
|
||||
err, responseText = google.PaLMStreamHandler(c, resp)
|
||||
err, responseText = palm.StreamHandler(c, resp)
|
||||
} else {
|
||||
err, usage = google.PaLMHandler(c, resp, promptTokens, textRequest.Model)
|
||||
err, usage = palm.Handler(c, resp, promptTokens, textRequest.Model)
|
||||
}
|
||||
case constant.APITypeGemini:
|
||||
if isStream {
|
||||
err, responseText = google.StreamHandler(c, resp)
|
||||
err, responseText = gemini.StreamHandler(c, resp)
|
||||
} else {
|
||||
err, usage = google.GeminiHandler(c, resp, promptTokens, textRequest.Model)
|
||||
err, usage = gemini.Handler(c, resp, promptTokens, textRequest.Model)
|
||||
}
|
||||
case constant.APITypeZhipu:
|
||||
if isStream {
|
||||
@@ -328,7 +322,7 @@ func DoResponse(c *gin.Context, textRequest *openai.GeneralOpenAIRequest, resp *
|
||||
return nil, err
|
||||
}
|
||||
if usage == nil && responseText != "" {
|
||||
usage = &openai.Usage{}
|
||||
usage = &model.Usage{}
|
||||
usage.PromptTokens = promptTokens
|
||||
usage.CompletionTokens = openai.CountTokenText(responseText, textRequest.Model)
|
||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||
|
||||
@@ -1,18 +1,23 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/helper"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func RelayTextHelper(c *gin.Context) *openai.ErrorWithStatusCode {
|
||||
func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
ctx := c.Request.Context()
|
||||
meta := util.GetRelayMeta(c)
|
||||
// get & validate textRequest
|
||||
@@ -21,9 +26,13 @@ func RelayTextHelper(c *gin.Context) *openai.ErrorWithStatusCode {
|
||||
logger.Errorf(ctx, "getAndValidateTextRequest failed: %s", err.Error())
|
||||
return openai.ErrorWrapper(err, "invalid_text_request", http.StatusBadRequest)
|
||||
}
|
||||
meta.IsStream = textRequest.Stream
|
||||
|
||||
// map model name
|
||||
var isModelMapped bool
|
||||
meta.OriginModelName = textRequest.Model
|
||||
textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping)
|
||||
meta.ActualModelName = textRequest.Model
|
||||
// get model ratio & group ratio
|
||||
modelRatio := common.GetModelRatio(textRequest.Model)
|
||||
groupRatio := common.GetGroupRatio(meta.Group)
|
||||
@@ -36,35 +45,50 @@ func RelayTextHelper(c *gin.Context) *openai.ErrorWithStatusCode {
|
||||
return bizErr
|
||||
}
|
||||
|
||||
adaptor := helper.GetAdaptor(meta.APIType)
|
||||
if adaptor == nil {
|
||||
return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
// get request body
|
||||
requestBody, err := GetRequestBody(c, *textRequest, isModelMapped, meta.APIType, meta.Mode)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "get_request_body_failed", http.StatusInternalServerError)
|
||||
var requestBody io.Reader
|
||||
if meta.APIType == constant.APITypeOpenAI {
|
||||
// no need to convert request for openai
|
||||
if isModelMapped {
|
||||
jsonStr, err := json.Marshal(textRequest)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
} else {
|
||||
requestBody = c.Request.Body
|
||||
}
|
||||
} else {
|
||||
convertedRequest, err := adaptor.ConvertRequest(c, meta.Mode, textRequest)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonData)
|
||||
}
|
||||
|
||||
// do request
|
||||
var resp *http.Response
|
||||
isStream := textRequest.Stream
|
||||
if meta.APIType != constant.APITypeXunfei { // cause xunfei use websocket
|
||||
fullRequestURL, err := GetRequestURL(c.Request.URL.String(), meta, textRequest)
|
||||
if err != nil {
|
||||
logger.Error(ctx, fmt.Sprintf("util.GetRequestURL failed: %s", err.Error()))
|
||||
return openai.ErrorWrapper(fmt.Errorf("util.GetRequestURL failed"), "get_request_url_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
resp, err = doRequest(ctx, c, meta, isStream, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "doRequest failed: %s", err.Error())
|
||||
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
|
||||
return util.RelayErrorHandler(resp)
|
||||
}
|
||||
resp, err := adaptor.DoRequest(c, meta, requestBody)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "DoRequest failed: %s", err.Error())
|
||||
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
|
||||
return util.RelayErrorHandler(resp)
|
||||
}
|
||||
|
||||
// do response
|
||||
usage, respErr := DoResponse(c, textRequest, resp, meta.Mode, meta.APIType, isStream, promptTokens)
|
||||
usage, respErr := adaptor.DoResponse(c, resp, meta)
|
||||
if respErr != nil {
|
||||
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
|
||||
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
|
||||
|
||||
Reference in New Issue
Block a user