Compare commits

...

3 Commits

Author SHA1 Message Date
Laisky.Cai
b83e400297 fix: change GetAudioTokens to return float64 and update related functions 2025-01-26 12:20:01 +00:00
Laisky.Cai
bcba9bf3a1 feat: only allow gpt-audio stream mode when EnforceIncludeUsage is true 2025-01-26 08:21:59 +00:00
Laisky.Cai
010bc72304 fix: whisper model billing
- Refactor model name handling across multiple controllers to improve clarity and maintainability.
- Enhance error logging and handling for better debugging and request processing robustness.
- Update pricing models in accordance with new calculations, ensuring accuracy in the billing logic.
2025-01-26 08:10:46 +00:00
12 changed files with 64 additions and 41 deletions

View File

@@ -161,4 +161,5 @@ var RelayProxy = env.String("RELAY_PROXY", "")
var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "")
var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30)
// EnforceIncludeUsage is used to determine whether to include usage in the response
var EnforceIncludeUsage = env.Bool("ENFORCE_INCLUDE_USAGE", false)

View File

@@ -21,4 +21,5 @@ const (
AvailableModels = "available_models"
KeyRequestBody = "key_request_body"
SystemPrompt = "system_prompt"
Meta = "meta"
)

View File

@@ -4,7 +4,6 @@ import (
"bytes"
"context"
"io"
"math"
"os"
"os/exec"
"strconv"
@@ -33,7 +32,7 @@ func SaveTmpFile(filename string, data io.Reader) (string, error) {
}
// GetAudioTokens returns the number of tokens in an audio file.
func GetAudioTokens(ctx context.Context, audio io.Reader, tokensPerSecond int) (int, error) {
func GetAudioTokens(ctx context.Context, audio io.Reader, tokensPerSecond float64) (float64, error) {
filename, err := SaveTmpFile("audio", audio)
if err != nil {
return 0, errors.Wrap(err, "failed to save audio to temporary file")
@@ -45,7 +44,7 @@ func GetAudioTokens(ctx context.Context, audio io.Reader, tokensPerSecond int) (
return 0, errors.Wrap(err, "failed to get audio tokens")
}
return int(math.Ceil(duration)) * tokensPerSecond, nil
return duration * tokensPerSecond, nil
}
// GetAudioDuration returns the duration of an audio file in seconds.

View File

@@ -8,6 +8,7 @@ import (
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/doubao"
"github.com/songquanpeng/one-api/relay/adaptor/minimax"
@@ -97,10 +98,10 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
}(request.Messages)
}
if request.Stream && strings.HasPrefix(request.Model, "gpt-4o-audio") {
if request.Stream && strings.HasPrefix(request.Model, "gpt-4o-audio") && !config.EnforceIncludeUsage {
// TODO: Since it is not clear how to implement billing in stream mode,
// it is temporarily not supported
return nil, errors.New("stream mode is not supported for gpt-4o-audio")
return nil, errors.New("set ENFORCE_INCLUDE_USAGE=true to enable stream mode for gpt-4o-audio")
}
return request, nil

View File

@@ -93,7 +93,9 @@ func CountTokenMessages(ctx context.Context,
tokensPerMessage = 3
tokensPerName = 1
}
tokenNum := 0
var totalAudioTokens float64
for _, message := range messages {
tokenNum += tokensPerMessage
contents := message.ParseContent()
@@ -117,17 +119,19 @@ func CountTokenMessages(ctx context.Context,
logger.SysError("error decoding audio data: " + err.Error())
}
tokens, err := helper.GetAudioTokens(ctx,
audioTokens, err := helper.GetAudioTokens(ctx,
bytes.NewReader(audioData),
ratio.GetAudioPromptTokensPerSecond(actualModel))
if err != nil {
logger.SysError("error counting audio tokens: " + err.Error())
} else {
tokenNum += tokens
totalAudioTokens += audioTokens
}
}
}
tokenNum += int(math.Ceil(totalAudioTokens))
tokenNum += getTokenNum(tokenEncoder, message.Role)
if message.Name != nil {
tokenNum += tokensPerName

View File

@@ -60,7 +60,6 @@ func (a *Adaptor) GetChannelName() string {
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
prefix := fmt.Sprintf("/v1/oneapi/proxy/%d", meta.ChannelId)
return meta.BaseURL + strings.TrimPrefix(meta.RequestURLPath, prefix), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {

View File

@@ -3,7 +3,6 @@ package ratio
import (
"encoding/json"
"fmt"
"math"
"strings"
"github.com/songquanpeng/one-api/common/logger"
@@ -71,7 +70,7 @@ var ModelRatio = map[string]float64{
"text-davinci-003": 10,
"text-davinci-edit-001": 10,
"code-davinci-edit-001": 10,
"whisper-1": 15, // $0.006 / minute -> $0.006 / 150 words -> $0.006 / 200 tokens -> $0.03 / 1k tokens
"whisper-1": 15,
"tts-1": 7.5, // $0.015 / 1K characters
"tts-1-1106": 7.5,
"tts-1-hd": 15, // $0.030 / 1K characters
@@ -380,8 +379,9 @@ func GetAudioCompletionRatio(actualModelName string) float64 {
// AudioTokensPerSecond is the number of audio tokens per second for each model.
var AudioPromptTokensPerSecond = map[string]float64{
// $0.006 / minute -> $0.002 / 20 seconds -> $0.002 / 1K tokens
"whisper-1": 1000 / 20,
// whisper 的 API 价格是 $0.0001/secone-api 的历史倍率为 15对应 $0.03/kilo_tokens
// 那么换算后可得,每秒的 tokens 应该为 0.0001/0.03*1000 = 3.3333
"whisper-1": 0.0001 / 0.03 * 1000,
// gpt-4o-audio series processes 10 tokens per second
"gpt-4o-audio-preview": 10,
"gpt-4o-audio-preview-2024-12-17": 10,
@@ -390,7 +390,7 @@ var AudioPromptTokensPerSecond = map[string]float64{
// GetAudioPromptTokensPerSecond returns the number of audio tokens per second
// for the given model.
func GetAudioPromptTokensPerSecond(actualModelName string) int {
func GetAudioPromptTokensPerSecond(actualModelName string) float64 {
var v float64
if tokensPerSecond, ok := AudioPromptTokensPerSecond[actualModelName]; ok {
v = tokensPerSecond
@@ -398,7 +398,7 @@ func GetAudioPromptTokensPerSecond(actualModelName string) int {
v = 10
}
return int(math.Ceil(v))
return v
}
var CompletionRatio = map[string]float64{

View File

@@ -7,6 +7,7 @@ import (
"encoding/json"
"fmt"
"io"
"math"
"mime/multipart"
"net/http"
"strings"
@@ -33,7 +34,7 @@ type commonAudioRequest struct {
File *multipart.FileHeader `form:"file" binding:"required"`
}
func countAudioTokens(c *gin.Context) (int, error) {
func countAudioTokens(c *gin.Context) (float64, error) {
body, err := common.GetRequestBody(c)
if err != nil {
return 0, errors.WithStack(err)
@@ -101,7 +102,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError)
}
preConsumedQuota = int64(float64(audioTokens) * ratio)
preConsumedQuota = int64(math.Ceil(audioTokens * ratio))
quota = preConsumedQuota
default:
return openai.ErrorWrapper(errors.New("unexpected_relay_mode"), "unexpected_relay_mode", http.StatusInternalServerError)

View File

@@ -129,17 +129,6 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.M
model.UpdateChannelUsedQuota(meta.ChannelId, quota)
}
func getMappedModelName(modelName string, mapping map[string]string) (string, bool) {
if mapping == nil {
return modelName, false
}
mappedModelName := mapping[modelName]
if mappedModelName != "" {
return mappedModelName, true
}
return modelName, false
}
func isErrorHappened(meta *meta.Meta, resp *http.Response) bool {
if resp == nil {
if meta.ChannelType == channeltype.AwsClaude {

View File

@@ -18,7 +18,7 @@ import (
"github.com/songquanpeng/one-api/relay/adaptor/openai"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta"
metalib "github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
@@ -65,7 +65,7 @@ func getImageSizeRatio(model string, size string) float64 {
return 1
}
func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *meta.Meta) *relaymodel.ErrorWithStatusCode {
func validateImageRequest(imageRequest *relaymodel.ImageRequest, _ *metalib.Meta) *relaymodel.ErrorWithStatusCode {
// check prompt length
if imageRequest.Prompt == "" {
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
@@ -104,7 +104,7 @@ func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
ctx := c.Request.Context()
meta := meta.GetByContext(c)
meta := metalib.GetByContext(c)
imageRequest, err := getImageRequest(c, meta.Mode)
if err != nil {
logger.Errorf(ctx, "getImageRequest failed: %s", err.Error())
@@ -114,7 +114,8 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
// map model name
var isModelMapped bool
meta.OriginModelName = imageRequest.Model
imageRequest.Model, isModelMapped = getMappedModelName(imageRequest.Model, meta.ModelMapping)
imageRequest.Model = meta.ActualModelName
isModelMapped = meta.OriginModelName != meta.ActualModelName
meta.ActualModelName = imageRequest.Model
// model validation
@@ -130,7 +131,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
imageModel := imageRequest.Model
// Convert the original image model
imageRequest.Model, _ = getMappedModelName(imageRequest.Model, billingratio.ImageOriginModelName)
imageRequest.Model = metalib.GetMappedModelName(imageRequest.Model, billingratio.ImageOriginModelName)
c.Set("response_format", imageRequest.ResponseFormat)
var requestBody io.Reader

View File

@@ -17,13 +17,13 @@ import (
"github.com/songquanpeng/one-api/relay/billing"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
metalib "github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
)
func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
func RelayTextHelper(c *gin.Context) *relaymodel.ErrorWithStatusCode {
ctx := c.Request.Context()
meta := meta.GetByContext(c)
meta := metalib.GetByContext(c)
// get & validate textRequest
textRequest, err := getAndValidateTextRequest(c, meta.Mode)
if err != nil {
@@ -34,7 +34,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
// map model name
meta.OriginModelName = textRequest.Model
textRequest.Model, _ = getMappedModelName(textRequest.Model, meta.ModelMapping)
textRequest.Model = meta.ActualModelName
meta.ActualModelName = textRequest.Model
// set system prompt if not empty
systemPromptReset := setSystemPrompt(ctx, textRequest, meta.SystemPrompt)
@@ -86,9 +86,12 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
return nil
}
func getRequestBody(c *gin.Context, meta *meta.Meta, textRequest *model.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) {
if !config.EnforceIncludeUsage && meta.APIType == apitype.OpenAI && meta.OriginModelName == meta.ActualModelName && meta.ChannelType != channeltype.Baichuan {
// no need to convert request for openai
func getRequestBody(c *gin.Context, meta *metalib.Meta, textRequest *relaymodel.GeneralOpenAIRequest, adaptor adaptor.Adaptor) (io.Reader, error) {
if !config.EnforceIncludeUsage &&
meta.APIType == apitype.OpenAI &&
meta.OriginModelName == meta.ActualModelName &&
meta.ChannelType != channeltype.OpenAI && // openai also need to convert request
meta.ChannelType != channeltype.Baichuan {
return c.Request.Body, nil
}

View File

@@ -1,12 +1,13 @@
package meta
import (
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/relaymode"
"strings"
)
type Meta struct {
@@ -33,6 +34,20 @@ type Meta struct {
SystemPrompt string
}
// GetMappedModelName returns the mapped model name and a bool indicating if the model name is mapped
func GetMappedModelName(modelName string, mapping map[string]string) string {
if mapping == nil {
return modelName
}
mappedModelName := mapping[modelName]
if mappedModelName != "" {
return mappedModelName
}
return modelName
}
func GetByContext(c *gin.Context) *Meta {
meta := Meta{
Mode: relaymode.GetByPath(c.Request.URL.Path),
@@ -44,6 +59,7 @@ func GetByContext(c *gin.Context) *Meta {
Group: c.GetString(ctxkey.Group),
ModelMapping: c.GetStringMapString(ctxkey.ModelMapping),
OriginModelName: c.GetString(ctxkey.RequestModel),
ActualModelName: c.GetString(ctxkey.RequestModel),
BaseURL: c.GetString(ctxkey.BaseURL),
APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
RequestURLPath: c.Request.URL.String(),
@@ -57,5 +73,13 @@ func GetByContext(c *gin.Context) *Meta {
meta.BaseURL = channeltype.ChannelBaseURLs[meta.ChannelType]
}
meta.APIType = channeltype.ToAPIType(meta.ChannelType)
meta.ActualModelName = GetMappedModelName(meta.OriginModelName, meta.ModelMapping)
Set2Context(c, &meta)
return &meta
}
func Set2Context(c *gin.Context, meta *Meta) {
c.Set(ctxkey.Meta, meta)
}