mirror of
https://github.com/songquanpeng/one-api.git
synced 2026-04-13 13:34:28 +08:00
Merge 47918f3143 into ea0721d525
This commit is contained in:
@@ -5,21 +5,24 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/client"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/ctxkey"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
"github.com/songquanpeng/one-api/relay/billing"
|
||||
"github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
@@ -27,6 +30,35 @@ import (
|
||||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||||
)
|
||||
|
||||
type commonAudioRequest struct {
|
||||
File *multipart.FileHeader `form:"file" binding:"required"`
|
||||
}
|
||||
|
||||
func countAudioTokens(c *gin.Context) (float64, error) {
|
||||
body, err := common.GetRequestBody(c)
|
||||
if err != nil {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
|
||||
reqBody := new(commonAudioRequest)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(body))
|
||||
if err = c.ShouldBind(reqBody); err != nil {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
|
||||
reqFp, err := reqBody.File.Open()
|
||||
if err != nil {
|
||||
return 0, errors.WithStack(err)
|
||||
}
|
||||
defer reqFp.Close()
|
||||
|
||||
ctxMeta := meta.GetByContext(c)
|
||||
|
||||
return helper.GetAudioTokens(c.Request.Context(),
|
||||
reqFp,
|
||||
ratio.GetAudioPromptTokensPerSecond(ctxMeta.ActualModelName))
|
||||
}
|
||||
|
||||
func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
|
||||
ctx := c.Request.Context()
|
||||
meta := meta.GetByContext(c)
|
||||
@@ -63,9 +95,19 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
case relaymode.AudioSpeech:
|
||||
preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio)
|
||||
quota = preConsumedQuota
|
||||
case relaymode.AudioTranscription,
|
||||
relaymode.AudioTranslation:
|
||||
audioTokens, err := countAudioTokens(c)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
preConsumedQuota = int64(math.Ceil(audioTokens * ratio))
|
||||
quota = preConsumedQuota
|
||||
default:
|
||||
preConsumedQuota = int64(float64(config.PreConsumedQuota) * ratio)
|
||||
return openai.ErrorWrapper(errors.New("unexpected_relay_mode"), "unexpected_relay_mode", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
userQuota, err := model.CacheGetUserQuota(ctx, userId)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
@@ -139,7 +181,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
return openai.ErrorWrapper(err, "new_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes()))
|
||||
responseFormat := c.DefaultPostForm("response_format", "json")
|
||||
// responseFormat := c.DefaultPostForm("response_format", "json")
|
||||
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
@@ -172,47 +214,53 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if relayMode != relaymode.AudioSpeech {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
// https://github.com/Laisky/one-api/pull/21
|
||||
// Commenting out the following code because Whisper's transcription
|
||||
// only charges for the length of the input audio, not for the output.
|
||||
// -------------------------------------
|
||||
// if relayMode != relaymode.AudioSpeech {
|
||||
// responseBody, err := io.ReadAll(resp.Body)
|
||||
// if err != nil {
|
||||
// return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
// }
|
||||
// err = resp.Body.Close()
|
||||
// if err != nil {
|
||||
// return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||
// }
|
||||
|
||||
var openAIErr openai.SlimTextResponse
|
||||
if err = json.Unmarshal(responseBody, &openAIErr); err == nil {
|
||||
if openAIErr.Error.Message != "" {
|
||||
return openai.ErrorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
// var openAIErr openai.SlimTextResponse
|
||||
// if err = json.Unmarshal(responseBody, &openAIErr); err == nil {
|
||||
// if openAIErr.Error.Message != "" {
|
||||
// return openai.ErrorWrapper(errors.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError)
|
||||
// }
|
||||
// }
|
||||
|
||||
// var text string
|
||||
// switch responseFormat {
|
||||
// case "json":
|
||||
// text, err = getTextFromJSON(responseBody)
|
||||
// case "text":
|
||||
// text, err = getTextFromText(responseBody)
|
||||
// case "srt":
|
||||
// text, err = getTextFromSRT(responseBody)
|
||||
// case "verbose_json":
|
||||
// text, err = getTextFromVerboseJSON(responseBody)
|
||||
// case "vtt":
|
||||
// text, err = getTextFromVTT(responseBody)
|
||||
// default:
|
||||
// return openai.ErrorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError)
|
||||
// }
|
||||
// if err != nil {
|
||||
// return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError)
|
||||
// }
|
||||
// quota = int64(openai.CountTokenText(text, audioModel))
|
||||
// resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
// }
|
||||
|
||||
var text string
|
||||
switch responseFormat {
|
||||
case "json":
|
||||
text, err = getTextFromJSON(responseBody)
|
||||
case "text":
|
||||
text, err = getTextFromText(responseBody)
|
||||
case "srt":
|
||||
text, err = getTextFromSRT(responseBody)
|
||||
case "verbose_json":
|
||||
text, err = getTextFromVerboseJSON(responseBody)
|
||||
case "vtt":
|
||||
text, err = getTextFromVTT(responseBody)
|
||||
default:
|
||||
return openai.ErrorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError)
|
||||
}
|
||||
if err != nil {
|
||||
return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError)
|
||||
}
|
||||
quota = int64(openai.CountTokenText(text, audioModel))
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return RelayErrorHandler(resp)
|
||||
}
|
||||
|
||||
succeed = true
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
defer func(ctx context.Context) {
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/constant/role"
|
||||
"github.com/songquanpeng/one-api/relay/controller/validator"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
@@ -44,10 +45,10 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener
|
||||
return textRequest, nil
|
||||
}
|
||||
|
||||
func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
|
||||
func getPromptTokens(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
|
||||
switch relayMode {
|
||||
case relaymode.ChatCompletions:
|
||||
return openai.CountTokenMessages(textRequest.Messages, textRequest.Model)
|
||||
return openai.CountTokenMessages(ctx, textRequest.Messages, textRequest.Model)
|
||||
case relaymode.Completions:
|
||||
return openai.CountTokenInput(textRequest.Prompt, textRequest.Model)
|
||||
case relaymode.Moderations:
|
||||
@@ -131,17 +132,6 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M
|
||||
model.UpdateChannelUsedQuota(meta.ChannelId, quota)
|
||||
}
|
||||
|
||||
func getMappedModelName(modelName string, mapping map[string]string) (string, bool) {
|
||||
if mapping == nil {
|
||||
return modelName, false
|
||||
}
|
||||
mappedModelName := mapping[modelName]
|
||||
if mappedModelName != "" {
|
||||
return mappedModelName, true
|
||||
}
|
||||
return modelName, false
|
||||
}
|
||||
|
||||
func isErrorHappened(meta *meta.Meta, resp *http.Response) bool {
|
||||
if resp == nil {
|
||||
if meta.ChannelType == channeltype.AwsClaude {
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
relaymeta "github.com/songquanpeng/one-api/relay/meta"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
@@ -66,7 +66,7 @@ func getImageSizeRatio(model string, size string) float64 {
|
||||
return 1
|
||||
}
|
||||
|
||||
func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *meta.Meta) *relaymodel.ErrorWithStatusCode {
|
||||
func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *relaymeta.Meta) *relaymodel.ErrorWithStatusCode {
|
||||
// check prompt length
|
||||
if imageRequest.Prompt == "" {
|
||||
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
|
||||
@@ -105,7 +105,7 @@ func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
|
||||
|
||||
func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
|
||||
ctx := c.Request.Context()
|
||||
meta := meta.GetByContext(c)
|
||||
meta := relaymeta.GetByContext(c)
|
||||
imageRequest, err := getImageRequest(c, meta.Mode)
|
||||
if err != nil {
|
||||
logger.Errorf(ctx, "getImageRequest failed: %s", err.Error())
|
||||
@@ -115,7 +115,8 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
// map model name
|
||||
var isModelMapped bool
|
||||
meta.OriginModelName = imageRequest.Model
|
||||
imageRequest.Model, isModelMapped = getMappedModelName(imageRequest.Model, meta.ModelMapping)
|
||||
imageRequest.Model = meta.ActualModelName
|
||||
isModelMapped = meta.OriginModelName != meta.ActualModelName
|
||||
meta.ActualModelName = imageRequest.Model
|
||||
|
||||
// model validation
|
||||
@@ -131,7 +132,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
|
||||
|
||||
imageModel := imageRequest.Model
|
||||
// Convert the original image model
|
||||
imageRequest.Model, _ = getMappedModelName(imageRequest.Model, billingratio.ImageOriginModelName)
|
||||
imageRequest.Model = relaymeta.GetMappedModelName(imageRequest.Model, billingratio.ImageOriginModelName)
|
||||
c.Set("response_format", imageRequest.ResponseFormat)
|
||||
|
||||
var requestBody io.Reader
|
||||
|
||||
@@ -4,11 +4,12 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/relay"
|
||||
"github.com/songquanpeng/one-api/relay/adaptor"
|
||||
@@ -17,13 +18,13 @@ import (
|
||||
"github.com/songquanpeng/one-api/relay/billing"
|
||||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
"github.com/songquanpeng/one-api/relay/meta"
|
||||
"github.com/songquanpeng/one-api/relay/model"
|
||||
relaymeta "github.com/songquanpeng/one-api/relay/meta"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
)
|
||||
|
||||
func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
func RelayTextHelper(c *gin.Context) *relaymodel.ErrorWithStatusCode {
|
||||
ctx := c.Request.Context()
|
||||
meta := meta.GetByContext(c)
|
||||
meta := relaymeta.GetByContext(c)
|
||||
// get & validate textRequest
|
||||
textRequest, err := getAndValidateTextRequest(c, meta.Mode)
|
||||
if err != nil {
|
||||
@@ -34,7 +35,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
|
||||
// map model name
|
||||
meta.OriginModelName = textRequest.Model
|
||||
textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping)
|
||||
textRequest.Model = meta.ActualModelName
|
||||
meta.ActualModelName = textRequest.Model
|
||||
// set system prompt if not empty
|
||||
systemPromptReset := setSystemPrompt(ctx, textRequest, meta.SystemPrompt)
|
||||
@@ -43,7 +44,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
groupRatio := billingratio.GetGroupRatio(meta.Group)
|
||||
ratio := modelRatio * groupRatio
|
||||
// pre-consume quota
|
||||
promptTokens := getPromptTokens(textRequest, meta.Mode)
|
||||
promptTokens := getPromptTokens(c.Request.Context(), textRequest, meta.Mode)
|
||||
meta.PromptTokens = promptTokens
|
||||
preConsumedQuota, bizErr := preConsumeQuota(ctx, textRequest, promptTokens, ratio, meta)
|
||||
if bizErr != nil {
|
||||
@@ -86,9 +87,12 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
|
||||
return nil
|
||||
}
|
||||
|
||||
func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) {
|
||||
if !config.EnforceIncludeUsage && meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan {
|
||||
// no need to convert request for openai
|
||||
func getRequestBody(c *gin.Context, meta *relaymeta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) {
|
||||
if !config.EnforceIncludeUsage &&
|
||||
meta.APIType == apitype.OpenAI &&
|
||||
meta.OriginModelName == meta.ActualModelName &&
|
||||
meta.ChannelType != channeltype.OpenAI && // openai also need to convert request
|
||||
meta.ChannelType != channeltype.Baichuan {
|
||||
return c.Request.Body, nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user