🐛 fix: 修复余额的问题

This commit is contained in:
MartialBE
2023-12-02 19:54:21 +08:00
parent 58fc40a744
commit c97c8a0f65
14 changed files with 158 additions and 37 deletions

View File

@@ -3,6 +3,7 @@ package controller
import (
"errors"
"net/http"
"net/http/httptest"
"one-api/common"
"one-api/model"
"one-api/providers"
@@ -46,7 +47,18 @@ type OpenAIUsageResponse struct {
}
func updateChannelBalance(channel *model.Channel) (float64, error) {
provider := providers.GetProvider(channel.Type, nil)
req, err := http.NewRequest("POST", "/balance", nil)
if err != nil {
return 0, err
}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
setChannelToContext(c, channel)
req.Header.Set("Content-Type", "application/json")
provider := providers.GetProvider(channel.Type, c)
if provider == nil {
return 0, errors.New("provider not found")
}
@@ -56,7 +68,7 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
return 0, errors.New("provider not implemented")
}
return balanceProvider.BalanceAction(channel)
return balanceProvider.Balance(channel)
}

View File

@@ -30,32 +30,26 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e
c.Request = req
setChannelToContext(c, channel)
switch channel.Type {
case common.ChannelTypePaLM:
request.Model = "PaLM-2"
case common.ChannelTypeAnthropic:
request.Model = "claude-2"
case common.ChannelTypeBaidu:
request.Model = "ERNIE-Bot"
case common.ChannelTypeZhipu:
request.Model = "chatglm_lite"
case common.ChannelTypeAli:
request.Model = "qwen-turbo"
case common.ChannelType360:
request.Model = "360GPT_S2_V9"
case common.ChannelTypeXunfei:
request.Model = "SparkDesk"
c.Set("api_version", channel.Other)
case common.ChannelTypeTencent:
request.Model = "hunyuan"
case common.ChannelTypeAzure:
request.Model = "gpt-3.5-turbo"
c.Set("api_version", channel.Other)
default:
request.Model = "gpt-3.5-turbo"
// 创建映射
channelTypeToModel := map[int]string{
common.ChannelTypePaLM: "PaLM-2",
common.ChannelTypeAnthropic: "claude-2",
common.ChannelTypeBaidu: "ERNIE-Bot",
common.ChannelTypeZhipu: "chatglm_lite",
common.ChannelTypeAli: "qwen-turbo",
common.ChannelType360: "360GPT_S2_V9",
common.ChannelTypeXunfei: "SparkDesk",
common.ChannelTypeTencent: "hunyuan",
common.ChannelTypeAzure: "gpt-3.5-turbo",
}
// 从映射中获取模型名称
model, ok := channelTypeToModel[channel.Type]
if !ok {
model = "gpt-3.5-turbo" // 默认值
}
request.Model = model
provider := providers.GetProvider(channel.Type, c)
if provider == nil {
return errors.New("channel not implemented"), nil
@@ -65,18 +59,16 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e
return errors.New("channel not implemented"), nil
}
isModelMapped := false
modelMap, err := parseModelMapping(channel.GetModelMapping())
if err != nil {
return err, nil
}
if modelMap != nil && modelMap[request.Model] != "" {
request.Model = modelMap[request.Model]
isModelMapped = true
}
promptTokens := common.CountTokenMessages(request.Messages, request.Model)
_, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, isModelMapped, promptTokens)
_, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, true, promptTokens)
if openAIErrorWithStatusCode != nil {
return nil, &openAIErrorWithStatusCode.OpenAIError
}