Merge remote-tracking branch 'remotes/origin_songquanpeng/main'

# Conflicts:
#	relay/adaptor/openai/adaptor.go
#	relay/controller/text.go
This commit is contained in:
mlkt
2024-08-07 06:24:20 +08:00
114 changed files with 3392 additions and 1700 deletions

View File

@@ -4,6 +4,10 @@ import (
"bytes"
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/relay/adaptor"
"io"
"net/http"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay"
@@ -14,15 +18,8 @@ import (
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"io"
"math"
"net/http"
"os"
"strings"
)
var fixAIProxyGpt4oMaxTokens = os.Getenv("FIX_AI_PROXY_GPT4O_MAX_TOKENS") == "1"
func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
ctx := c.Request.Context()
meta := meta.GetByContext(c)
@@ -35,12 +32,11 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
meta.IsStream = textRequest.Stream
// map model name
var isModelMapped bool
meta.OriginModelName = textRequest.Model
textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping)
textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping)
meta.ActualModelName = textRequest.Model
// get model ratio & group ratio
modelRatio := billingratio.GetModelRatio(textRequest.Model)
modelRatio := billingratio.GetModelRatio(textRequest.Model, meta.ChannelType)
groupRatio := billingratio.GetGroupRatio(meta.Group)
ratio := modelRatio * groupRatio
// pre-consume quota
@@ -58,41 +54,9 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
}
adaptor.Init(meta)
// get request body
var requestBody io.Reader
if meta.APIType == apitype.OpenAI {
// no need to convert request for openai
shouldResetRequestBody := isModelMapped ||
meta.ChannelType == channeltype.Baichuan /*frequency_penalty 0 is not acceptable for baichuan*/ ||
(meta.ChannelType == channeltype.AIProxy && fixAIProxyGpt4oMaxTokens && strings.HasPrefix(textRequest.Model, "gpt-4o"))
if shouldResetRequestBody {
if meta.ChannelType == channeltype.AIProxy {
maxTokens := textRequest.MaxTokens
maxTokens = int(math.Min(float64(maxTokens), 4096))
if maxTokens == 0 {
maxTokens = 4096
}
textRequest.MaxTokens = maxTokens
}
jsonStr, err := json.Marshal(textRequest)
if err != nil {
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
} else {
convertedRequest, err := adaptor.ConvertRequest(c, meta.Mode, textRequest)
if err != nil {
return openai.ErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError)
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
logger.Debugf(ctx, "converted request: \n%s", string(jsonData))
requestBody = bytes.NewBuffer(jsonData)
requestBody, err := getRequestBody(c, meta, textRequest, adaptor)
if err != nil {
return openai.ErrorWrapper(err, "convert_request_failed", http.StatusInternalServerError)
}
// do request
@@ -117,3 +81,26 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
go postConsumeQuota(ctx, usage, meta, textRequest, ratio, preConsumedQuota, modelRatio, groupRatio)
return nil
}
func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) {
if meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan {
// no need to convert request for openai
return c.Request.Body, nil
}
// get request body
var requestBody io.Reader
convertedRequest, err := adaptor.ConvertRequest(c, meta.Mode, textRequest)
if err != nil {
logger.Debugf(c.Request.Context(), "converted request failed: %s\n", err.Error())
return nil, err
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
logger.Debugf(c.Request.Context(), "converted request json_marshal_failed: %s\n", err.Error())
return nil, err
}
logger.Debugf(c.Request.Context(), "converted request: \n%s", string(jsonData))
requestBody = bytes.NewBuffer(jsonData)
return requestBody, nil
}