♻️ refactor: provider refactor (#41)

* ♻️ refactor: provider refactor
* 完善百度/讯飞的函数调用,现在可以在`lobe-chat`中正常调用函数了
This commit is contained in:
Buer
2024-01-19 02:47:10 +08:00
committed by GitHub
parent 0bfe1f5779
commit ef041e28a1
96 changed files with 4339 additions and 3276 deletions

View File

@@ -67,7 +67,7 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
return 0, errors.New("provider not implemented")
}
return balanceProvider.Balance(channel)
return balanceProvider.Balance()
}

View File

@@ -1,6 +1,7 @@
package controller
import (
"encoding/json"
"errors"
"fmt"
"net/http"
@@ -38,30 +39,36 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e
if provider == nil {
return errors.New("channel not implemented"), nil
}
newModelName, err := provider.ModelMappingHandler(request.Model)
if err != nil {
return err, nil
}
request.Model = newModelName
chatProvider, ok := provider.(providers_base.ChatInterface)
if !ok {
return errors.New("channel not implemented"), nil
}
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
return err, nil
}
if modelMap != nil && modelMap[request.Model] != "" {
request.Model = modelMap[request.Model]
}
chatProvider.SetUsage(&types.Usage{})
response, openAIErrorWithStatusCode := chatProvider.CreateChatCompletion(&request)
promptTokens := common.CountTokenMessages(request.Messages, request.Model)
Usage, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, true, promptTokens)
if openAIErrorWithStatusCode != nil {
return errors.New(openAIErrorWithStatusCode.Message), &openAIErrorWithStatusCode.OpenAIError
}
if Usage.CompletionTokens == 0 {
usage := chatProvider.GetUsage()
if usage.CompletionTokens == 0 {
return fmt.Errorf("channel %s, message 补全 tokens 非预期返回 0", channel.Name), nil
}
common.SysLog(fmt.Sprintf("测试模型 %s 返回内容为:%s", channel.Name, w.Body.String()))
// 转换为JSON字符串
jsonBytes, _ := json.Marshal(response)
common.SysLog(fmt.Sprintf("测试模型 %s 返回内容为:%s", channel.Name, string(jsonBytes)))
return nil, nil
}
@@ -74,9 +81,9 @@ func buildTestRequest() *types.ChatCompletionRequest {
Content: "You just need to output 'hi' next.",
},
},
Model: "",
MaxTokens: 1,
Stream: false,
Model: "",
// MaxTokens: 1,
Stream: false,
}
return testRequest
}

164
controller/quota.go Normal file
View File

@@ -0,0 +1,164 @@
package controller
import (
"context"
"errors"
"fmt"
"math"
"net/http"
"one-api/common"
"one-api/model"
"one-api/types"
"time"
"github.com/gin-gonic/gin"
)
type QuotaInfo struct {
modelName string
promptTokens int
preConsumedTokens int
modelRatio float64
groupRatio float64
ratio float64
preConsumedQuota int
userId int
channelId int
tokenId int
HandelStatus bool
}
func generateQuotaInfo(c *gin.Context, modelName string, promptTokens int) (*QuotaInfo, *types.OpenAIErrorWithStatusCode) {
quotaInfo := &QuotaInfo{
modelName: modelName,
promptTokens: promptTokens,
userId: c.GetInt("id"),
channelId: c.GetInt("channel_id"),
tokenId: c.GetInt("token_id"),
HandelStatus: false,
}
quotaInfo.initQuotaInfo(c.GetString("group"))
errWithCode := quotaInfo.preQuotaConsumption()
if errWithCode != nil {
return nil, errWithCode
}
return quotaInfo, nil
}
func (q *QuotaInfo) initQuotaInfo(groupName string) {
modelRatio := common.GetModelRatio(q.modelName)
groupRatio := common.GetGroupRatio(groupName)
preConsumedTokens := common.PreConsumedQuota
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(q.promptTokens+preConsumedTokens) * ratio)
q.preConsumedTokens = preConsumedTokens
q.modelRatio = modelRatio
q.groupRatio = groupRatio
q.ratio = ratio
q.preConsumedQuota = preConsumedQuota
}
func (q *QuotaInfo) preQuotaConsumption() *types.OpenAIErrorWithStatusCode {
userQuota, err := model.CacheGetUserQuota(q.userId)
if err != nil {
return common.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota < q.preConsumedQuota {
return common.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
err = model.CacheDecreaseUserQuota(q.userId, q.preConsumedQuota)
if err != nil {
return common.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*q.preConsumedQuota {
// in this case, we do not pre-consume quota
// because the user has enough quota
q.preConsumedQuota = 0
// common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
}
if q.preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(q.tokenId, q.preConsumedQuota)
if err != nil {
return common.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
q.HandelStatus = true
}
return nil
}
func (q *QuotaInfo) completedQuotaConsumption(usage *types.Usage, tokenName string, ctx context.Context) error {
quota := 0
completionRatio := common.GetCompletionRatio(q.modelName)
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * q.ratio))
if q.ratio != 0 && quota <= 0 {
quota = 1
}
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
quotaDelta := quota - q.preConsumedQuota
err := model.PostConsumeTokenQuota(q.tokenId, quotaDelta)
if err != nil {
return errors.New("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(q.userId)
if err != nil {
return errors.New("error consuming token remain quota: " + err.Error())
}
if quota != 0 {
requestTime := 0
requestStartTimeValue := ctx.Value("requestStartTime")
if requestStartTimeValue != nil {
requestStartTime, ok := requestStartTimeValue.(time.Time)
if ok {
requestTime = int(time.Since(requestStartTime).Milliseconds())
}
}
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", q.modelRatio, q.groupRatio)
model.RecordConsumeLog(ctx, q.userId, q.channelId, promptTokens, completionTokens, q.modelName, tokenName, quota, logContent, requestTime)
model.UpdateUserUsedQuotaAndRequestCount(q.userId, quota)
model.UpdateChannelUsedQuota(q.channelId, quota)
}
return nil
}
func (q *QuotaInfo) undo(c *gin.Context, errWithCode *types.OpenAIErrorWithStatusCode) {
tokenId := c.GetInt("token_id")
if q.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -q.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
}
func (q *QuotaInfo) consume(c *gin.Context, usage *types.Usage) {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err := q.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}

View File

@@ -1,11 +1,11 @@
package controller
import (
"context"
"fmt"
"math"
"net/http"
"one-api/common"
"one-api/model"
"one-api/common/requester"
providersBase "one-api/providers/base"
"one-api/types"
@@ -20,33 +20,18 @@ func RelayChat(c *gin.Context) {
return
}
channel, pass := fetchChannel(c, chatRequest.Model)
if pass {
return
}
if chatRequest.MaxTokens < 0 || chatRequest.MaxTokens > math.MaxInt32/2 {
common.AbortWithMessage(c, http.StatusBadRequest, "max_tokens is invalid")
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[chatRequest.Model] != "" {
chatRequest.Model = modelMap[chatRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel, common.RelayModeChatCompletions)
if pass {
provider, modelName, fail := getProvider(c, chatRequest.Model)
if fail {
return
}
chatRequest.Model = modelName
chatProvider, ok := provider.(providersBase.ChatInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
@@ -56,39 +41,42 @@ func RelayChat(c *gin.Context) {
// 获取Input Tokens
promptTokens := common.CountTokenMessages(chatRequest.Messages, chatRequest.Model)
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, chatRequest.Model, promptTokens)
usage := &types.Usage{
PromptTokens: promptTokens,
}
provider.SetUsage(usage)
quotaInfo, errWithCode := generateQuotaInfo(c, chatRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = chatProvider.ChatAction(&chatRequest, isModelMapped, promptTokens)
if chatRequest.Stream {
var response requester.StreamReaderInterface[types.ChatCompletionStreamResponse]
response, errWithCode = chatProvider.CreateChatCompletionStream(&chatRequest)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
errWithCode = responseStreamClient[types.ChatCompletionStreamResponse](c, response)
} else {
var response *types.ChatCompletionResponse
response, errWithCode = chatProvider.CreateChatCompletion(&chatRequest)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
errWithCode = responseJsonClient(c, response)
}
fmt.Println(usage)
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
quotaInfo.undo(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
quotaInfo.consume(c, usage)
}

View File

@@ -1,11 +1,9 @@
package controller
import (
"context"
"math"
"net/http"
"one-api/common"
"one-api/model"
providersBase "one-api/providers/base"
"one-api/types"
@@ -20,33 +18,18 @@ func RelayCompletions(c *gin.Context) {
return
}
channel, pass := fetchChannel(c, completionRequest.Model)
if pass {
return
}
if completionRequest.MaxTokens < 0 || completionRequest.MaxTokens > math.MaxInt32/2 {
common.AbortWithMessage(c, http.StatusBadRequest, "max_tokens is invalid")
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[completionRequest.Model] != "" {
completionRequest.Model = modelMap[completionRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel, common.RelayModeCompletions)
if pass {
provider, modelName, fail := getProvider(c, completionRequest.Model)
if fail {
return
}
completionRequest.Model = modelName
completionProvider, ok := provider.(providersBase.CompletionInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
@@ -56,39 +39,38 @@ func RelayCompletions(c *gin.Context) {
// 获取Input Tokens
promptTokens := common.CountTokenInput(completionRequest.Prompt, completionRequest.Model)
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, completionRequest.Model, promptTokens)
usage := &types.Usage{
PromptTokens: promptTokens,
}
provider.SetUsage(usage)
quotaInfo, errWithCode := generateQuotaInfo(c, completionRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = completionProvider.CompleteAction(&completionRequest, isModelMapped, promptTokens)
if completionRequest.Stream {
response, errWithCode := completionProvider.CreateCompletionStream(&completionRequest)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
errWithCode = responseStreamClient[types.CompletionResponse](c, response)
} else {
response, errWithCode := completionProvider.CreateCompletion(&completionRequest)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
errWithCode = responseJsonClient(c, response)
}
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
quotaInfo.undo(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
quotaInfo.consume(c, usage)
}

View File

@@ -1,10 +1,8 @@
package controller
import (
"context"
"net/http"
"one-api/common"
"one-api/model"
providersBase "one-api/providers/base"
"one-api/types"
"strings"
@@ -24,28 +22,13 @@ func RelayEmbeddings(c *gin.Context) {
return
}
channel, pass := fetchChannel(c, embeddingsRequest.Model)
if pass {
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[embeddingsRequest.Model] != "" {
embeddingsRequest.Model = modelMap[embeddingsRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel, common.RelayModeEmbeddings)
if pass {
provider, modelName, fail := getProvider(c, embeddingsRequest.Model)
if fail {
return
}
embeddingsRequest.Model = modelName
embeddingsProvider, ok := provider.(providersBase.EmbeddingsInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
@@ -55,39 +38,29 @@ func RelayEmbeddings(c *gin.Context) {
// 获取Input Tokens
promptTokens := common.CountTokenInput(embeddingsRequest.Input, embeddingsRequest.Model)
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, embeddingsRequest.Model, promptTokens)
usage := &types.Usage{
PromptTokens: promptTokens,
}
provider.SetUsage(usage)
quotaInfo, errWithCode := generateQuotaInfo(c, embeddingsRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = embeddingsProvider.EmbeddingsAction(&embeddingsRequest, isModelMapped, promptTokens)
response, errWithCode := embeddingsProvider.CreateEmbeddings(&embeddingsRequest)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
errWithCode = responseJsonClient(c, response)
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
quotaInfo.undo(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
quotaInfo.consume(c, usage)
}

View File

@@ -1,10 +1,8 @@
package controller
import (
"context"
"net/http"
"one-api/common"
"one-api/model"
providersBase "one-api/providers/base"
"one-api/types"
@@ -33,28 +31,13 @@ func RelayImageEdits(c *gin.Context) {
imageEditRequest.Size = "1024x1024"
}
channel, pass := fetchChannel(c, imageEditRequest.Model)
if pass {
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[imageEditRequest.Model] != "" {
imageEditRequest.Model = modelMap[imageEditRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel, common.RelayModeImagesEdits)
if pass {
provider, modelName, fail := getProvider(c, imageEditRequest.Model)
if fail {
return
}
imageEditRequest.Model = modelName
imageEditsProvider, ok := provider.(providersBase.ImageEditsInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
@@ -68,39 +51,29 @@ func RelayImageEdits(c *gin.Context) {
return
}
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, imageEditRequest.Model, promptTokens)
usage := &types.Usage{
PromptTokens: promptTokens,
}
provider.SetUsage(usage)
quotaInfo, errWithCode := generateQuotaInfo(c, imageEditRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = imageEditsProvider.ImageEditsAction(&imageEditRequest, isModelMapped, promptTokens)
response, errWithCode := imageEditsProvider.CreateImageEdits(&imageEditRequest)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
errWithCode = responseJsonClient(c, response)
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
quotaInfo.undo(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
quotaInfo.consume(c, usage)
}

View File

@@ -1,10 +1,8 @@
package controller
import (
"context"
"net/http"
"one-api/common"
"one-api/model"
providersBase "one-api/providers/base"
"one-api/types"
@@ -36,28 +34,13 @@ func RelayImageGenerations(c *gin.Context) {
imageRequest.Quality = "standard"
}
channel, pass := fetchChannel(c, imageRequest.Model)
if pass {
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[imageRequest.Model] != "" {
imageRequest.Model = modelMap[imageRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel, common.RelayModeImagesGenerations)
if pass {
provider, modelName, fail := getProvider(c, imageRequest.Model)
if fail {
return
}
imageRequest.Model = modelName
imageGenerationsProvider, ok := provider.(providersBase.ImageGenerationsInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
@@ -71,39 +54,29 @@ func RelayImageGenerations(c *gin.Context) {
return
}
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, imageRequest.Model, promptTokens)
usage := &types.Usage{
PromptTokens: promptTokens,
}
provider.SetUsage(usage)
quotaInfo, errWithCode := generateQuotaInfo(c, imageRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = imageGenerationsProvider.ImageGenerationsAction(&imageRequest, isModelMapped, promptTokens)
response, errWithCode := imageGenerationsProvider.CreateImageGenerations(&imageRequest)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
errWithCode = responseJsonClient(c, response)
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
quotaInfo.undo(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
quotaInfo.consume(c, usage)
}

View File

@@ -1,10 +1,8 @@
package controller
import (
"context"
"net/http"
"one-api/common"
"one-api/model"
providersBase "one-api/providers/base"
"one-api/types"
@@ -28,28 +26,13 @@ func RelayImageVariations(c *gin.Context) {
imageEditRequest.Size = "1024x1024"
}
channel, pass := fetchChannel(c, imageEditRequest.Model)
if pass {
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[imageEditRequest.Model] != "" {
imageEditRequest.Model = modelMap[imageEditRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel, common.RelayModeImagesVariations)
if pass {
provider, modelName, fail := getProvider(c, imageEditRequest.Model)
if fail {
return
}
imageEditRequest.Model = modelName
imageVariations, ok := provider.(providersBase.ImageVariationsInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
@@ -63,39 +46,29 @@ func RelayImageVariations(c *gin.Context) {
return
}
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, imageEditRequest.Model, promptTokens)
usage := &types.Usage{
PromptTokens: promptTokens,
}
provider.SetUsage(usage)
quotaInfo, errWithCode := generateQuotaInfo(c, imageEditRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = imageVariations.ImageVariationsAction(&imageEditRequest, isModelMapped, promptTokens)
response, errWithCode := imageVariations.CreateImageVariations(&imageEditRequest)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
errWithCode = responseJsonClient(c, response)
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
quotaInfo.undo(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
quotaInfo.consume(c, usage)
}

View File

@@ -1,10 +1,8 @@
package controller
import (
"context"
"net/http"
"one-api/common"
"one-api/model"
providersBase "one-api/providers/base"
"one-api/types"
@@ -24,28 +22,13 @@ func RelayModerations(c *gin.Context) {
moderationRequest.Model = "text-moderation-stable"
}
channel, pass := fetchChannel(c, moderationRequest.Model)
if pass {
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[moderationRequest.Model] != "" {
moderationRequest.Model = modelMap[moderationRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel, common.RelayModeModerations)
if pass {
provider, modelName, fail := getProvider(c, moderationRequest.Model)
if fail {
return
}
moderationRequest.Model = modelName
moderationProvider, ok := provider.(providersBase.ModerationInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
@@ -55,39 +38,29 @@ func RelayModerations(c *gin.Context) {
// 获取Input Tokens
promptTokens := common.CountTokenInput(moderationRequest.Input, moderationRequest.Model)
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, moderationRequest.Model, promptTokens)
usage := &types.Usage{
PromptTokens: promptTokens,
}
provider.SetUsage(usage)
quotaInfo, errWithCode := generateQuotaInfo(c, moderationRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = moderationProvider.ModerationAction(&moderationRequest, isModelMapped, promptTokens)
response, errWithCode := moderationProvider.CreateModeration(&moderationRequest)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
errWithCode = responseJsonClient(c, response)
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
quotaInfo.undo(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
quotaInfo.consume(c, usage)
}

View File

@@ -1,10 +1,8 @@
package controller
import (
"context"
"net/http"
"one-api/common"
"one-api/model"
providersBase "one-api/providers/base"
"one-api/types"
@@ -20,28 +18,13 @@ func RelaySpeech(c *gin.Context) {
return
}
channel, pass := fetchChannel(c, speechRequest.Model)
if pass {
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[speechRequest.Model] != "" {
speechRequest.Model = modelMap[speechRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel, common.RelayModeAudioSpeech)
if pass {
provider, modelName, fail := getProvider(c, speechRequest.Model)
if fail {
return
}
speechRequest.Model = modelName
speechProvider, ok := provider.(providersBase.SpeechInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
@@ -51,39 +34,29 @@ func RelaySpeech(c *gin.Context) {
// 获取Input Tokens
promptTokens := len(speechRequest.Input)
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, speechRequest.Model, promptTokens)
usage := &types.Usage{
PromptTokens: promptTokens,
}
provider.SetUsage(usage)
quotaInfo, errWithCode := generateQuotaInfo(c, speechRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = speechProvider.SpeechAction(&speechRequest, isModelMapped, promptTokens)
response, errWithCode := speechProvider.CreateSpeech(&speechRequest)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
errWithCode = responseMultipart(c, response)
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
quotaInfo.undo(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
quotaInfo.consume(c, usage)
}

View File

@@ -1,10 +1,8 @@
package controller
import (
"context"
"net/http"
"one-api/common"
"one-api/model"
providersBase "one-api/providers/base"
"one-api/types"
@@ -20,28 +18,13 @@ func RelayTranscriptions(c *gin.Context) {
return
}
channel, pass := fetchChannel(c, audioRequest.Model)
if pass {
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[audioRequest.Model] != "" {
audioRequest.Model = modelMap[audioRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel, common.RelayModeAudioTranscription)
if pass {
provider, modelName, fail := getProvider(c, audioRequest.Model)
if fail {
return
}
audioRequest.Model = modelName
transcriptionsProvider, ok := provider.(providersBase.TranscriptionsInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
@@ -51,39 +34,29 @@ func RelayTranscriptions(c *gin.Context) {
// 获取Input Tokens
promptTokens := 0
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, audioRequest.Model, promptTokens)
usage := &types.Usage{
PromptTokens: promptTokens,
}
provider.SetUsage(usage)
quotaInfo, errWithCode := generateQuotaInfo(c, audioRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = transcriptionsProvider.TranscriptionsAction(&audioRequest, isModelMapped, promptTokens)
response, errWithCode := transcriptionsProvider.CreateTranscriptions(&audioRequest)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
errWithCode = responseCustom(c, response)
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
quotaInfo.undo(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
quotaInfo.consume(c, usage)
}

View File

@@ -1,10 +1,8 @@
package controller
import (
"context"
"net/http"
"one-api/common"
"one-api/model"
providersBase "one-api/providers/base"
"one-api/types"
@@ -20,28 +18,13 @@ func RelayTranslations(c *gin.Context) {
return
}
channel, pass := fetchChannel(c, audioRequest.Model)
if pass {
return
}
// 解析模型映射
var isModelMapped bool
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
return
}
if modelMap != nil && modelMap[audioRequest.Model] != "" {
audioRequest.Model = modelMap[audioRequest.Model]
isModelMapped = true
}
// 获取供应商
provider, pass := getProvider(c, channel, common.RelayModeAudioTranslation)
if pass {
provider, modelName, fail := getProvider(c, audioRequest.Model)
if fail {
return
}
audioRequest.Model = modelName
translationProvider, ok := provider.(providersBase.TranslationInterface)
if !ok {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not implemented")
@@ -51,39 +34,29 @@ func RelayTranslations(c *gin.Context) {
// 获取Input Tokens
promptTokens := 0
var quotaInfo *QuotaInfo
var errWithCode *types.OpenAIErrorWithStatusCode
var usage *types.Usage
quotaInfo, errWithCode = generateQuotaInfo(c, audioRequest.Model, promptTokens)
usage := &types.Usage{
PromptTokens: promptTokens,
}
provider.SetUsage(usage)
quotaInfo, errWithCode := generateQuotaInfo(c, audioRequest.Model, promptTokens)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
usage, errWithCode = translationProvider.TranslationAction(&audioRequest, isModelMapped, promptTokens)
response, errWithCode := translationProvider.CreateTranslation(&audioRequest)
if errWithCode != nil {
errorHelper(c, errWithCode)
return
}
errWithCode = responseCustom(c, response)
// 如果报错,则退还配额
if errWithCode != nil {
tokenId := c.GetInt("token_id")
if quotaInfo.HandelStatus {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -quotaInfo.preConsumedQuota)
if err != nil {
common.LogError(ctx, "error return pre-consumed quota: "+err.Error())
}
}(c.Request.Context())
}
errorHelper(c, errWithCode)
quotaInfo.undo(c, errWithCode)
return
} else {
tokenName := c.GetString("token_name")
// 如果没有报错,则消费配额
go func(ctx context.Context) {
err = quotaInfo.completedQuotaConsumption(usage, tokenName, ctx)
if err != nil {
common.LogError(ctx, err.Error())
}
}(c.Request.Context())
}
quotaInfo.consume(c, usage)
}

View File

@@ -1,25 +1,47 @@
package controller
import (
"context"
"encoding/json"
"errors"
"fmt"
"math"
"io"
"net/http"
"one-api/common"
"one-api/common/requester"
"one-api/model"
"one-api/providers"
providersBase "one-api/providers/base"
"one-api/types"
"reflect"
"strconv"
"time"
"github.com/gin-gonic/gin"
"github.com/go-playground/validator/v10"
)
func getProvider(c *gin.Context, modeName string) (provider providersBase.ProviderInterface, newModelName string, fail bool) {
channel, fail := fetchChannel(c, modeName)
if fail {
return
}
provider = providers.GetProvider(channel, c)
if provider == nil {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not found")
fail = true
return
}
newModelName, err := provider.ModelMappingHandler(modeName)
if err != nil {
common.AbortWithMessage(c, http.StatusInternalServerError, err.Error())
fail = true
return
}
return
}
func GetValidFieldName(err error, obj interface{}) string {
getObj := reflect.TypeOf(obj)
if errs, ok := err.(validator.ValidationErrors); ok {
@@ -32,17 +54,17 @@ func GetValidFieldName(err error, obj interface{}) string {
return err.Error()
}
func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, pass bool) {
func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, fail bool) {
channelId, ok := c.Get("channelId")
if ok {
channel, pass = fetchChannelById(c, channelId.(int))
if pass {
channel, fail = fetchChannelById(c, channelId.(int))
if fail {
return
}
}
channel, pass = fetchChannelByModel(c, modelName)
if pass {
channel, fail = fetchChannelByModel(c, modelName)
if fail {
return
}
@@ -86,21 +108,6 @@ func fetchChannelByModel(c *gin.Context, modelName string) (*model.Channel, bool
return channel, false
}
func getProvider(c *gin.Context, channel *model.Channel, relayMode int) (providersBase.ProviderInterface, bool) {
provider := providers.GetProvider(channel, c)
if provider == nil {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel not found")
return nil, true
}
if !provider.SupportAPI(relayMode) {
common.AbortWithMessage(c, http.StatusNotImplemented, "channel does not support this API")
return nil, true
}
return provider, false
}
func shouldDisableChannel(err *types.OpenAIError, statusCode int) bool {
if !common.AutomaticDisableChannelEnabled {
return false
@@ -130,138 +137,81 @@ func shouldEnableChannel(err error, openAIErr *types.OpenAIError) bool {
return true
}
func parseModelMapping(modelMapping string) (map[string]string, error) {
if modelMapping == "" || modelMapping == "{}" {
return nil, nil
}
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
func responseJsonClient(c *gin.Context, data interface{}) *types.OpenAIErrorWithStatusCode {
// 将data转换为 JSON
responseBody, err := json.Marshal(data)
if err != nil {
return nil, err
}
return modelMap, nil
}
type QuotaInfo struct {
modelName string
promptTokens int
preConsumedTokens int
modelRatio float64
groupRatio float64
ratio float64
preConsumedQuota int
userId int
channelId int
tokenId int
HandelStatus bool
}
func generateQuotaInfo(c *gin.Context, modelName string, promptTokens int) (*QuotaInfo, *types.OpenAIErrorWithStatusCode) {
quotaInfo := &QuotaInfo{
modelName: modelName,
promptTokens: promptTokens,
userId: c.GetInt("id"),
channelId: c.GetInt("channel_id"),
tokenId: c.GetInt("token_id"),
HandelStatus: false,
}
quotaInfo.initQuotaInfo(c.GetString("group"))
errWithCode := quotaInfo.preQuotaConsumption()
if errWithCode != nil {
return nil, errWithCode
return common.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
}
return quotaInfo, nil
}
func (q *QuotaInfo) initQuotaInfo(groupName string) {
modelRatio := common.GetModelRatio(q.modelName)
groupRatio := common.GetGroupRatio(groupName)
preConsumedTokens := common.PreConsumedQuota
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(q.promptTokens+preConsumedTokens) * ratio)
q.preConsumedTokens = preConsumedTokens
q.modelRatio = modelRatio
q.groupRatio = groupRatio
q.ratio = ratio
q.preConsumedQuota = preConsumedQuota
return
}
func (q *QuotaInfo) preQuotaConsumption() *types.OpenAIErrorWithStatusCode {
userQuota, err := model.CacheGetUserQuota(q.userId)
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(http.StatusOK)
_, err = c.Writer.Write(responseBody)
if err != nil {
return common.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
if userQuota < q.preConsumedQuota {
return common.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
err = model.CacheDecreaseUserQuota(q.userId, q.preConsumedQuota)
if err != nil {
return common.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError)
}
if userQuota > 100*q.preConsumedQuota {
// in this case, we do not pre-consume quota
// because the user has enough quota
q.preConsumedQuota = 0
// common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota))
}
if q.preConsumedQuota > 0 {
err := model.PreConsumeTokenQuota(q.tokenId, q.preConsumedQuota)
if err != nil {
return common.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden)
}
q.HandelStatus = true
return common.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError)
}
return nil
}
func (q *QuotaInfo) completedQuotaConsumption(usage *types.Usage, tokenName string, ctx context.Context) error {
quota := 0
completionRatio := common.GetCompletionRatio(q.modelName)
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * q.ratio))
if q.ratio != 0 && quota <= 0 {
quota = 1
}
totalTokens := promptTokens + completionTokens
if totalTokens == 0 {
// in this case, must be some error happened
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
quotaDelta := quota - q.preConsumedQuota
err := model.PostConsumeTokenQuota(q.tokenId, quotaDelta)
if err != nil {
return errors.New("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(q.userId)
if err != nil {
return errors.New("error consuming token remain quota: " + err.Error())
}
if quota != 0 {
requestTime := 0
requestStartTimeValue := ctx.Value("requestStartTime")
if requestStartTimeValue != nil {
requestStartTime, ok := requestStartTimeValue.(time.Time)
if ok {
requestTime = int(time.Since(requestStartTime).Milliseconds())
func responseStreamClient[T any](c *gin.Context, stream requester.StreamReaderInterface[T]) *types.OpenAIErrorWithStatusCode {
requester.SetEventStreamHeaders(c)
defer stream.Close()
for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
if response != nil && len(*response) > 0 {
for _, streamResponse := range *response {
responseBody, _ := json.Marshal(streamResponse)
c.Render(-1, common.CustomEvent{Data: "data: " + string(responseBody)})
}
}
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
break
}
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", q.modelRatio, q.groupRatio)
model.RecordConsumeLog(ctx, q.userId, q.channelId, promptTokens, completionTokens, q.modelName, tokenName, quota, logContent, requestTime)
model.UpdateUserUsedQuotaAndRequestCount(q.userId, quota)
model.UpdateChannelUsedQuota(q.channelId, quota)
if err != nil {
c.Render(-1, common.CustomEvent{Data: "data: " + err.Error()})
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
break
}
for _, streamResponse := range *response {
responseBody, _ := json.Marshal(streamResponse)
c.Render(-1, common.CustomEvent{Data: "data: " + string(responseBody)})
}
}
return nil
}
func responseMultipart(c *gin.Context, resp *http.Response) *types.OpenAIErrorWithStatusCode {
defer resp.Body.Close()
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err := io.Copy(c.Writer, resp.Body)
if err != nil {
return common.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError)
}
return nil
}
func responseCustom(c *gin.Context, response *types.AudioResponseWrapper) *types.OpenAIErrorWithStatusCode {
for k, v := range response.Headers {
c.Writer.Header().Set(k, v)
}
c.Writer.WriteHeader(http.StatusOK)
_, err := c.Writer.Write(response.Body)
if err != nil {
return common.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError)
}
return nil