Refactor codebase, introduce relaymode package, update constants and improve consistency

- Refactor constant definitions and organization
- Clean up package level variables and functions
- Introduce new `relaymode` and `apitype` packages for constant definitions
- Refactor and simplify code in several packages including `openai`, `relay/channel/baidu`, `relay/util`, `relay/controller`, `relay/channeltype`
- Add helper functions in `relay/channeltype` package to convert channel type constants to corresponding API type constants
- Remove deprecated functions such as `ResponseText2Usage` from `relay/channel/openai/helper.go`
- Modify code in `relay/util/validation.go` and related files to use new `validator.ValidateTextRequest` function
- Rename `util` package to `relaymode` and update related imports in several packages
This commit is contained in:
Laisky.Cai
2024-04-06 05:18:04 +00:00
parent 7c59a97a6b
commit 50bab08496
157 changed files with 3217 additions and 19160 deletions

View File

@@ -6,20 +6,23 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/adaptor/azure"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/billing"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/client"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
"strings"
)
func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
@@ -34,7 +37,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
tokenName := c.GetString("token_name")
var ttsRequest openai.TextToSpeechRequest
if relayMode == constant.RelayModeAudioSpeech {
if relayMode == relaymode.AudioSpeech {
// Read JSON
err := common.UnmarshalBodyReusable(c, &ttsRequest)
// Check if JSON is valid
@@ -48,14 +51,15 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
}
}
modelRatio := common.GetModelRatio(audioModel)
// groupRatio := common.GetGroupRatio(group)
groupRatio := c.GetFloat64("channel_ratio")
modelRatio := billingratio.GetModelRatio(audioModel)
// groupRatio := billingratio.GetGroupRatio(group)
groupRatio := c.GetFloat64("channel_ratio") // get minimal ratio from multiple groups
ratio := modelRatio * groupRatio
var quota int64
var preConsumedQuota int64
switch relayMode {
case constant.RelayModeAudioSpeech:
case relaymode.AudioSpeech:
preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio)
quota = preConsumedQuota
default:
@@ -117,19 +121,19 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
}
}
baseURL := common.ChannelBaseURLs[channelType]
baseURL := channeltype.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType)
if channelType == common.ChannelTypeAzure {
apiVersion := util.GetAzureAPIVersion(c)
if relayMode == constant.RelayModeAudioTranscription {
fullRequestURL := openai.GetFullRequestURL(baseURL, requestURL, channelType)
if channelType == channeltype.Azure {
apiVersion := azure.GetAPIVersion(c)
if relayMode == relaymode.AudioTranscription {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion)
} else if relayMode == constant.RelayModeAudioSpeech {
} else if relayMode == relaymode.AudioSpeech {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/text-to-speech-quickstart?tabs=command-line#rest-api
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", baseURL, audioModel, apiVersion)
}
@@ -148,7 +152,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
if (relayMode == constant.RelayModeAudioTranscription || relayMode == constant.RelayModeAudioSpeech) && channelType == common.ChannelTypeAzure {
if (relayMode == relaymode.AudioTranscription || relayMode == relaymode.AudioSpeech) && channelType == channeltype.Azure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
@@ -160,7 +164,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
resp, err := util.HTTPClient.Do(req)
resp, err := client.HTTPClient.Do(req)
if err != nil {
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
@@ -174,7 +178,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
if relayMode != constant.RelayModeAudioSpeech {
if relayMode != relaymode.AudioSpeech {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
@@ -213,12 +217,12 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
}
if resp.StatusCode != http.StatusOK {
return util.RelayErrorHandler(resp)
return RelayErrorHandler(resp)
}
succeed = true
quotaDelta := quota - preConsumedQuota
defer func(ctx context.Context) {
go util.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
go billing.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
}(c.Request.Context())
for k, v := range resp.Header {

91
relay/controller/error.go Normal file
View File

@@ -0,0 +1,91 @@
package controller
import (
"encoding/json"
"fmt"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strconv"
)
type GeneralErrorResponse struct {
Error model.Error `json:"error"`
Message string `json:"message"`
Msg string `json:"msg"`
Err string `json:"err"`
ErrorMsg string `json:"error_msg"`
Header struct {
Message string `json:"message"`
} `json:"header"`
Response struct {
Error struct {
Message string `json:"message"`
} `json:"error"`
} `json:"response"`
}
func (e GeneralErrorResponse) ToMessage() string {
if e.Error.Message != "" {
return e.Error.Message
}
if e.Message != "" {
return e.Message
}
if e.Msg != "" {
return e.Msg
}
if e.Err != "" {
return e.Err
}
if e.ErrorMsg != "" {
return e.ErrorMsg
}
if e.Header.Message != "" {
return e.Header.Message
}
if e.Response.Error.Message != "" {
return e.Response.Error.Message
}
return ""
}
func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *model.ErrorWithStatusCode) {
ErrorWithStatusCode = &model.ErrorWithStatusCode{
StatusCode: resp.StatusCode,
Error: model.Error{
Message: "",
Type: "upstream_error",
Code: "bad_response_status_code",
Param: strconv.Itoa(resp.StatusCode),
},
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return
}
if config.DebugEnabled {
logger.SysLog(fmt.Sprintf("error happened, status code: %d, response: \n%s", resp.StatusCode, string(responseBody)))
}
err = resp.Body.Close()
if err != nil {
return
}
var errResponse GeneralErrorResponse
err = json.Unmarshal(responseBody, &errResponse)
if err != nil {
return
}
if errResponse.Error.Message != "" {
// OpenAI format error, so we override the default one
ErrorWithStatusCode.Error = errResponse.Error
} else {
ErrorWithStatusCode.Error.Message = errResponse.ToMessage()
}
if ErrorWithStatusCode.Error.Message == "" {
ErrorWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
}
return
}

View File

@@ -9,10 +9,13 @@ import (
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/controller/validator"
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"github.com/songquanpeng/one-api/relay/relaymode"
"math"
"net/http"
)
@@ -23,21 +26,21 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener
if err != nil {
return nil, err
}
if relayMode == constant.RelayModeModerations && textRequest.Model == "" {
if relayMode == relaymode.Moderations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest"
}
if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" {
if relayMode == relaymode.Embeddings && textRequest.Model == "" {
textRequest.Model = c.Param("model")
}
err = util.ValidateTextRequest(textRequest, relayMode)
err = validator.ValidateTextRequest(textRequest, relayMode)
if err != nil {
return nil, err
}
return textRequest, nil
}
func getImageRequest(c *gin.Context, relayMode int) (*openai.ImageRequest, error) {
imageRequest := &openai.ImageRequest{}
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
imageRequest := &relaymodel.ImageRequest{}
err := common.UnmarshalBodyReusable(c, imageRequest)
if err != nil {
return nil, err
@@ -54,9 +57,25 @@ func getImageRequest(c *gin.Context, relayMode int) (*openai.ImageRequest, error
return imageRequest, nil
}
func validateImageRequest(imageRequest *openai.ImageRequest, meta *util.RelayMeta) *relaymodel.ErrorWithStatusCode {
func isValidImageSize(model string, size string) bool {
if model == "cogview-3" {
return true
}
_, ok := billingratio.ImageSizeRatios[model][size]
return ok
}
func getImageSizeRatio(model string, size string) float64 {
ratio, ok := billingratio.ImageSizeRatios[model][size]
if !ok {
return 1
}
return ratio
}
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
// model validation
_, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size]
hasValidSize := isValidImageSize(imageRequest.Model, imageRequest.Size)
if !hasValidSize {
return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
}
@@ -64,27 +83,24 @@ func validateImageRequest(imageRequest *openai.ImageRequest, meta *util.RelayMet
if imageRequest.Prompt == "" {
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
}
if len(imageRequest.Prompt) > constant.DalleImagePromptLengthLimitations[imageRequest.Model] {
if len(imageRequest.Prompt) > billingratio.ImagePromptLengthLimitations[imageRequest.Model] {
return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
}
// Number of generated images validation
if !isWithinRange(imageRequest.Model, imageRequest.N) {
// channel not azure
if meta.ChannelType != common.ChannelTypeAzure {
if meta.ChannelType != channeltype.Azure {
return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
}
}
return nil
}
func getImageCostRatio(imageRequest *openai.ImageRequest) (float64, error) {
func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
if imageRequest == nil {
return 0, errors.New("imageRequest is nil")
}
imageCostRatio, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size]
if !hasValidSize {
return 0, errors.Errorf("size not supported for this image model: %s", imageRequest.Size)
}
imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size)
if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
if imageRequest.Size == "1024x1024" {
imageCostRatio *= 2
@@ -97,11 +113,11 @@ func getImageCostRatio(imageRequest *openai.ImageRequest) (float64, error) {
func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
switch relayMode {
case constant.RelayModeChatCompletions:
case relaymode.ChatCompletions:
return openai.CountTokenMessages(textRequest.Messages, textRequest.Model)
case constant.RelayModeCompletions:
case relaymode.Completions:
return openai.CountTokenInput(textRequest.Prompt, textRequest.Model)
case constant.RelayModeModerations:
case relaymode.Moderations:
return openai.CountTokenInput(textRequest.Input, textRequest.Model)
}
return 0
@@ -115,7 +131,7 @@ func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTok
return int64(float64(preConsumedTokens) * ratio)
}
func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int64, *relaymodel.ErrorWithStatusCode) {
func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *meta.Meta) (int64, *relaymodel.ErrorWithStatusCode) {
preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio)
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
@@ -144,13 +160,13 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR
return preConsumedQuota, nil
}
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.RelayMeta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) {
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) {
if usage == nil {
logger.Error(ctx, "usage is nil, which is unexpected")
return
}
var quota int64
completionRatio := common.GetCompletionRatio(textRequest.Model)
completionRatio := billingratio.GetCompletionRatio(textRequest.Model)
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
@@ -178,3 +194,14 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.R
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
model.UpdateChannelUsedQuota(meta.ChannelId, quota)
}
func getMappedModelName(modelName string, mapping map[string]string) (string, bool) {
if mapping == nil {
return modelName, false
}
mappedModelName := mapping[modelName]
if mappedModelName != "" {
return mappedModelName, true
}
return modelName, false
}

View File

@@ -5,35 +5,33 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/Laisky/errors/v2"
"io"
"net/http"
"strings"
"github.com/songquanpeng/one-api/common"
"github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"github.com/gin-gonic/gin"
)
func isWithinRange(element string, value int) bool {
if _, ok := constant.DalleGenerationImageAmounts[element]; !ok {
if _, ok := billingratio.ImageGenerationAmounts[element]; !ok {
return false
}
min := constant.DalleGenerationImageAmounts[element][0]
max := constant.DalleGenerationImageAmounts[element][1]
min := billingratio.ImageGenerationAmounts[element][0]
max := billingratio.ImageGenerationAmounts[element][1]
return value >= min && value <= max
}
func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
ctx := c.Request.Context()
meta := util.GetRelayMeta(c)
meta := meta.GetByContext(c)
imageRequest, err := getImageRequest(c, meta.Mode)
if err != nil {
logger.Errorf(ctx, "getImageRequest failed: %s", err.Error())
@@ -43,7 +41,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
// map model name
var isModelMapped bool
meta.OriginModelName = imageRequest.Model
imageRequest.Model, isModelMapped = util.GetMappedModelName(imageRequest.Model, meta.ModelMapping)
imageRequest.Model, isModelMapped = getMappedModelName(imageRequest.Model, meta.ModelMapping)
meta.ActualModelName = imageRequest.Model
// model validation
@@ -57,17 +55,8 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError)
}
requestURL := c.Request.URL.String()
fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType)
if meta.ChannelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
apiVersion := util.GetAzureAPIVersion(c)
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, imageRequest.Model, apiVersion)
}
var requestBody io.Reader
if isModelMapped || meta.ChannelType == common.ChannelTypeAzure { // make Azure channel request body
if isModelMapped || meta.ChannelType == channeltype.Azure { // make Azure channel request body
jsonStr, err := json.Marshal(imageRequest)
if err != nil {
return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError)
@@ -77,9 +66,32 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
requestBody = c.Request.Body
}
modelRatio := common.GetModelRatio(imageRequest.Model)
// groupRatio := common.GetGroupRatio(meta.Group)
adaptor := relay.GetAdaptor(meta.APIType)
if adaptor == nil {
return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest)
}
switch meta.ChannelType {
case channeltype.Ali:
fallthrough
case channeltype.Baidu:
fallthrough
case channeltype.Zhipu:
finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
if err != nil {
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
}
jsonStr, err := json.Marshal(finalRequest)
if err != nil {
return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
}
modelRatio := billingratio.GetModelRatio(imageRequest.Model)
// groupRatio := billingratio.GetGroupRatio(meta.Group)
groupRatio := c.GetFloat64("channel_ratio") // pre-selected cheapest channel ratio
ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
@@ -89,36 +101,13 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
token := c.Request.Header.Get("Authorization")
if meta.ChannelType == common.ChannelTypeAzure { // Azure authentication
token = strings.TrimPrefix(token, "Bearer ")
req.Header.Set("api-key", token)
} else {
req.Header.Set("Authorization", token)
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
resp, err := util.HTTPClient.Do(req)
// do request
resp, err := adaptor.DoRequest(c, meta, requestBody)
if err != nil {
logger.Errorf(ctx, "DoRequest failed: %s", err.Error())
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
err = req.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
var imageResponse openai.ImageResponse
defer func(ctx context.Context) {
if resp.StatusCode != http.StatusOK {
return
@@ -141,34 +130,12 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
}
}(c.Request.Context())
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &imageResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
// do response
_, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
return respErr
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
return nil
}

View File

@@ -10,18 +10,20 @@ import (
"github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/helper"
"github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/apitype"
"github.com/songquanpeng/one-api/relay/billing"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
)
func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
ctx := c.Request.Context()
meta := util.GetRelayMeta(c)
meta := meta.GetByContext(c)
// get & validate textRequest
textRequest, err := getAndValidateTextRequest(c, meta.Mode)
if err != nil {
@@ -33,12 +35,13 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
// map model name
var isModelMapped bool
meta.OriginModelName = textRequest.Model
textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping)
textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping)
meta.ActualModelName = textRequest.Model
// get model ratio & group ratio
modelRatio := common.GetModelRatio(textRequest.Model)
// groupRatio := common.GetGroupRatio(meta.Group)
modelRatio := billingratio.GetModelRatio(textRequest.Model)
// groupRatio := billingratio.GetGroupRatio(meta.Group)
groupRatio := meta.ChannelRatio
ratio := modelRatio * groupRatio
// pre-consume quota
promptTokens := getPromptTokens(textRequest, meta.Mode)
@@ -49,16 +52,16 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
return bizErr
}
adaptor := helper.GetAdaptor(meta.APIType)
adaptor := relay.GetAdaptor(meta.APIType)
if adaptor == nil {
return openai.ErrorWrapper(errors.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest)
}
// get request body
var requestBody io.Reader
if meta.APIType == constant.APITypeOpenAI {
if meta.APIType == apitype.OpenAI {
// no need to convert request for openai
shouldResetRequestBody := isModelMapped || meta.ChannelType == common.ChannelTypeBaichuan // frequency_penalty 0 is not acceptable for baichuan
shouldResetRequestBody := isModelMapped || meta.ChannelType == channeltype.Baichuan // frequency_penalty 0 is not acceptable for baichuan
if shouldResetRequestBody {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
@@ -93,10 +96,10 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
}
errorHappened := (resp.StatusCode != http.StatusOK) || (meta.IsStream && resp.Header.Get("Content-Type") == "application/json")
if errorHappened {
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
logger.Error(ctx, fmt.Sprintf("relay text [%d] <- %q %q",
resp.StatusCode, resp.Request.URL.String(), string(requestBodyBytes)))
return util.RelayErrorHandler(resp)
return RelayErrorHandler(resp)
}
meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
@@ -104,7 +107,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
usage, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
return respErr
}
// post-consume quota

View File

@@ -0,0 +1,38 @@
package validator
import (
"math"
"github.com/Laisky/errors/v2"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
)
func ValidateTextRequest(textRequest *model.GeneralOpenAIRequest, relayMode int) error {
if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 {
return errors.New("max_tokens is invalid")
}
if textRequest.Model == "" {
return errors.New("model is required")
}
switch relayMode {
case relaymode.Completions:
if textRequest.Prompt == "" {
return errors.New("field prompt is required")
}
case relaymode.ChatCompletions:
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
return errors.New("field messages is required")
}
case relaymode.Embeddings:
case relaymode.Moderations:
if textRequest.Input == "" {
return errors.New("field input is required")
}
case relaymode.Edits:
if textRequest.Instruction == "" {
return errors.New("field instruction is required")
}
}
return nil
}