mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-17 17:16:38 +08:00
216 lines
7.3 KiB
Go
216 lines
7.3 KiB
Go
package controller
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"math"
|
||
"net/http"
|
||
"strings"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/songquanpeng/one-api/common"
|
||
"github.com/songquanpeng/one-api/common/config"
|
||
"github.com/songquanpeng/one-api/common/helper"
|
||
"github.com/songquanpeng/one-api/common/logger"
|
||
"github.com/songquanpeng/one-api/model"
|
||
"github.com/songquanpeng/one-api/relay/adaptor/openai"
|
||
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
|
||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||
"github.com/songquanpeng/one-api/relay/constant/role"
|
||
"github.com/songquanpeng/one-api/relay/controller/validator"
|
||
"github.com/songquanpeng/one-api/relay/meta"
|
||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||
"github.com/songquanpeng/one-api/relay/relaymode"
|
||
)
|
||
|
||
func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.GeneralOpenAIRequest, error) {
|
||
textRequest := &relaymodel.GeneralOpenAIRequest{}
|
||
err := common.UnmarshalBodyReusable(c, textRequest)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if relayMode == relaymode.Moderations && textRequest.Model == "" {
|
||
textRequest.Model = "text-moderation-latest"
|
||
}
|
||
if relayMode == relaymode.Embeddings && textRequest.Model == "" {
|
||
textRequest.Model = c.Param("model")
|
||
}
|
||
err = validator.ValidateTextRequest(textRequest, relayMode)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return textRequest, nil
|
||
}
|
||
|
||
func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
|
||
switch relayMode {
|
||
case relaymode.ChatCompletions:
|
||
return openai.CountTokenMessages(textRequest.Messages, textRequest.Model)
|
||
case relaymode.Completions:
|
||
return openai.CountTokenInput(textRequest.Prompt, textRequest.Model)
|
||
case relaymode.Moderations:
|
||
return openai.CountTokenInput(textRequest.Input, textRequest.Model)
|
||
}
|
||
return 0
|
||
}
|
||
|
||
func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64) int64 {
|
||
preConsumedTokens := config.PreConsumedQuota + int64(promptTokens)
|
||
if textRequest.MaxTokens != 0 {
|
||
preConsumedTokens += int64(textRequest.MaxTokens)
|
||
}
|
||
return int64(float64(preConsumedTokens) * ratio)
|
||
}
|
||
|
||
func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *meta.Meta) (int64, *relaymodel.ErrorWithStatusCode) {
|
||
preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio)
|
||
|
||
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
|
||
if err != nil {
|
||
return preConsumedQuota, openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||
}
|
||
if userQuota-preConsumedQuota < 0 {
|
||
return preConsumedQuota, openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
|
||
}
|
||
err = model.CacheDecreaseUserQuota(meta.UserId, preConsumedQuota)
|
||
if err != nil {
|
||
return preConsumedQuota, openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
|
||
}
|
||
if userQuota > 100*preConsumedQuota {
|
||
// in this case, we do not pre-consume quota
|
||
// because the user has enough quota
|
||
preConsumedQuota = 0
|
||
logger.Info(ctx, fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", meta.UserId, userQuota))
|
||
}
|
||
if preConsumedQuota > 0 {
|
||
err := model.PreConsumeTokenQuota(meta.TokenId, preConsumedQuota)
|
||
if err != nil {
|
||
return preConsumedQuota, openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
|
||
}
|
||
}
|
||
return preConsumedQuota, nil
|
||
}
|
||
|
||
func postConsumeQuota(ctx context.Context,
|
||
usage *relaymodel.Usage,
|
||
meta *meta.Meta,
|
||
textRequest *relaymodel.GeneralOpenAIRequest,
|
||
ratio float64,
|
||
preConsumedQuota int64,
|
||
modelRatio float64,
|
||
groupRatio float64,
|
||
systemPromptReset bool) (quota int64) {
|
||
if usage == nil {
|
||
logger.Error(ctx, "usage is nil, which is unexpected")
|
||
return
|
||
}
|
||
completionRatio := billingratio.GetCompletionRatio(textRequest.Model, meta.ChannelType)
|
||
promptTokens := usage.PromptTokens
|
||
// It appears that DeepSeek's official service automatically merges ReasoningTokens into CompletionTokens,
|
||
// but the behavior of third-party providers may differ, so for now we do not add them manually.
|
||
// completionTokens := usage.CompletionTokens + usage.CompletionTokensDetails.ReasoningTokens
|
||
completionTokens := usage.CompletionTokens
|
||
quota = int64(math.Ceil((float64(promptTokens)+float64(completionTokens)*completionRatio)*ratio)) + usage.ToolsCost
|
||
if ratio != 0 && quota <= 0 {
|
||
quota = 1
|
||
}
|
||
|
||
totalTokens := promptTokens + completionTokens
|
||
if totalTokens == 0 {
|
||
// in this case, must be some error happened
|
||
// we cannot just return, because we may have to return the pre-consumed quota
|
||
quota = 0
|
||
}
|
||
quotaDelta := quota - preConsumedQuota
|
||
err := model.PostConsumeTokenQuota(meta.TokenId, quotaDelta)
|
||
if err != nil {
|
||
logger.Error(ctx, "error consuming token remain quota: "+err.Error())
|
||
}
|
||
err = model.CacheUpdateUserQuota(ctx, meta.UserId)
|
||
if err != nil {
|
||
logger.Error(ctx, "error update user quota cache: "+err.Error())
|
||
}
|
||
|
||
var logContent string
|
||
if usage.ToolsCost == 0 {
|
||
logContent = fmt.Sprintf("倍率:%.2f × %.2f × %.2f", modelRatio, groupRatio, completionRatio)
|
||
} else {
|
||
logContent = fmt.Sprintf("倍率:%.2f × %.2f × %.2f, tools cost %d", modelRatio, groupRatio, completionRatio, usage.ToolsCost)
|
||
}
|
||
model.RecordConsumeLog(ctx, &model.Log{
|
||
UserId: meta.UserId,
|
||
ChannelId: meta.ChannelId,
|
||
PromptTokens: promptTokens,
|
||
CompletionTokens: completionTokens,
|
||
ModelName: textRequest.Model,
|
||
TokenName: meta.TokenName,
|
||
Quota: int(quota),
|
||
Content: logContent,
|
||
IsStream: meta.IsStream,
|
||
ElapsedTime: helper.CalcElapsedTime(meta.StartTime),
|
||
SystemPromptReset: systemPromptReset,
|
||
})
|
||
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
|
||
model.UpdateChannelUsedQuota(meta.ChannelId, quota)
|
||
|
||
return quota
|
||
}
|
||
|
||
func getMappedModelName(modelName string, mapping map[string]string) (string, bool) {
|
||
if mapping == nil {
|
||
return modelName, false
|
||
}
|
||
mappedModelName := mapping[modelName]
|
||
if mappedModelName != "" {
|
||
return mappedModelName, true
|
||
}
|
||
return modelName, false
|
||
}
|
||
|
||
func isErrorHappened(meta *meta.Meta, resp *http.Response) bool {
|
||
if resp == nil {
|
||
if meta.ChannelType == channeltype.AwsClaude {
|
||
return false
|
||
}
|
||
return true
|
||
}
|
||
if resp.StatusCode != http.StatusOK &&
|
||
// replicate return 201 to create a task
|
||
resp.StatusCode != http.StatusCreated {
|
||
return true
|
||
}
|
||
if meta.ChannelType == channeltype.DeepL {
|
||
// skip stream check for deepl
|
||
return false
|
||
}
|
||
|
||
if meta.IsStream && strings.HasPrefix(resp.Header.Get("Content-Type"), "application/json") &&
|
||
// Even if stream mode is enabled, replicate will first return a task info in JSON format,
|
||
// requiring the client to request the stream endpoint in the task info
|
||
meta.ChannelType != channeltype.Replicate {
|
||
return true
|
||
}
|
||
return false
|
||
}
|
||
|
||
func setSystemPrompt(ctx context.Context, request *relaymodel.GeneralOpenAIRequest, prompt string) (reset bool) {
|
||
if prompt == "" {
|
||
return false
|
||
}
|
||
if len(request.Messages) == 0 {
|
||
return false
|
||
}
|
||
if request.Messages[0].Role == role.System {
|
||
request.Messages[0].Content = prompt
|
||
logger.Infof(ctx, "rewrite system prompt")
|
||
return true
|
||
}
|
||
request.Messages = append([]relaymodel.Message{{
|
||
Role: role.System,
|
||
Content: prompt,
|
||
}}, request.Messages...)
|
||
logger.Infof(ctx, "add system prompt")
|
||
return true
|
||
}
|