mirror of
https://github.com/songquanpeng/one-api.git
synced 2026-02-16 02:44:24 +08:00
🐛 fix: 修复余额的问题
This commit is contained in:
@@ -3,6 +3,7 @@ package controller
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"one-api/providers"
|
||||
@@ -46,7 +47,18 @@ type OpenAIUsageResponse struct {
|
||||
}
|
||||
|
||||
func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||
provider := providers.GetProvider(channel.Type, nil)
|
||||
req, err := http.NewRequest("POST", "/balance", nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = req
|
||||
|
||||
setChannelToContext(c, channel)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
provider := providers.GetProvider(channel.Type, c)
|
||||
if provider == nil {
|
||||
return 0, errors.New("provider not found")
|
||||
}
|
||||
@@ -56,7 +68,7 @@ func updateChannelBalance(channel *model.Channel) (float64, error) {
|
||||
return 0, errors.New("provider not implemented")
|
||||
}
|
||||
|
||||
return balanceProvider.BalanceAction(channel)
|
||||
return balanceProvider.Balance(channel)
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -30,32 +30,26 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e
|
||||
c.Request = req
|
||||
|
||||
setChannelToContext(c, channel)
|
||||
|
||||
switch channel.Type {
|
||||
case common.ChannelTypePaLM:
|
||||
request.Model = "PaLM-2"
|
||||
case common.ChannelTypeAnthropic:
|
||||
request.Model = "claude-2"
|
||||
case common.ChannelTypeBaidu:
|
||||
request.Model = "ERNIE-Bot"
|
||||
case common.ChannelTypeZhipu:
|
||||
request.Model = "chatglm_lite"
|
||||
case common.ChannelTypeAli:
|
||||
request.Model = "qwen-turbo"
|
||||
case common.ChannelType360:
|
||||
request.Model = "360GPT_S2_V9"
|
||||
case common.ChannelTypeXunfei:
|
||||
request.Model = "SparkDesk"
|
||||
c.Set("api_version", channel.Other)
|
||||
case common.ChannelTypeTencent:
|
||||
request.Model = "hunyuan"
|
||||
case common.ChannelTypeAzure:
|
||||
request.Model = "gpt-3.5-turbo"
|
||||
c.Set("api_version", channel.Other)
|
||||
default:
|
||||
request.Model = "gpt-3.5-turbo"
|
||||
// 创建映射
|
||||
channelTypeToModel := map[int]string{
|
||||
common.ChannelTypePaLM: "PaLM-2",
|
||||
common.ChannelTypeAnthropic: "claude-2",
|
||||
common.ChannelTypeBaidu: "ERNIE-Bot",
|
||||
common.ChannelTypeZhipu: "chatglm_lite",
|
||||
common.ChannelTypeAli: "qwen-turbo",
|
||||
common.ChannelType360: "360GPT_S2_V9",
|
||||
common.ChannelTypeXunfei: "SparkDesk",
|
||||
common.ChannelTypeTencent: "hunyuan",
|
||||
common.ChannelTypeAzure: "gpt-3.5-turbo",
|
||||
}
|
||||
|
||||
// 从映射中获取模型名称
|
||||
model, ok := channelTypeToModel[channel.Type]
|
||||
if !ok {
|
||||
model = "gpt-3.5-turbo" // 默认值
|
||||
}
|
||||
request.Model = model
|
||||
|
||||
provider := providers.GetProvider(channel.Type, c)
|
||||
if provider == nil {
|
||||
return errors.New("channel not implemented"), nil
|
||||
@@ -65,18 +59,16 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e
|
||||
return errors.New("channel not implemented"), nil
|
||||
}
|
||||
|
||||
isModelMapped := false
|
||||
modelMap, err := parseModelMapping(channel.GetModelMapping())
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
if modelMap != nil && modelMap[request.Model] != "" {
|
||||
request.Model = modelMap[request.Model]
|
||||
isModelMapped = true
|
||||
}
|
||||
|
||||
promptTokens := common.CountTokenMessages(request.Messages, request.Model)
|
||||
_, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, isModelMapped, promptTokens)
|
||||
_, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, true, promptTokens)
|
||||
if openAIErrorWithStatusCode != nil {
|
||||
return nil, &openAIErrorWithStatusCode.OpenAIError
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user