diff --git a/common/model-ratio.go b/common/model-ratio.go index 79a2589..7650755 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -70,9 +70,17 @@ var DefaultModelRatio = map[string]float64{ "claude-3-haiku-20240307": 0.125, // $0.25 / 1M tokens "claude-3-sonnet-20240229": 1.5, // $3 / 1M tokens "claude-3-opus-20240229": 7.5, // $15 / 1M tokens - "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens - "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens - "ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens + "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens //renamed to ERNIE-3.5-8K + "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens //renamed to ERNIE-Lite-8K + "ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens //renamed to ERNIE-4.0-8K + "ERNIE-4.0-8K": 8.572, // ¥0.12 / 1k tokens + "ERNIE-3.5-8K": 0.8572, // ¥0.012 / 1k tokens + "ERNIE-Speed-8K": 0.2858, // ¥0.004 / 1k tokens + "ERNIE-Speed-128K": 0.2858, // ¥0.004 / 1k tokens + "ERNIE-Lite-8K": 0.2143, // ¥0.003 / 1k tokens + "ERNIE-Tiny-8K": 0.0715, // ¥0.001 / 1k tokens + "ERNIE-Character-8K": 0.2858, // ¥0.004 / 1k tokens + "ERNIE-Functions-8K": 0.2858, // ¥0.004 / 1k tokens "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens "PaLM-2": 1, "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens @@ -80,6 +88,7 @@ var DefaultModelRatio = map[string]float64{ "gemini-1.0-pro-vision-001": 1, "gemini-1.0-pro-001": 1, "gemini-1.5-pro-latest": 1, + "gemini-1.5-flash-latest": 1, "gemini-1.0-pro-latest": 1, "gemini-1.0-pro-vision-latest": 1, "gemini-ultra": 1, @@ -98,6 +107,9 @@ var DefaultModelRatio = map[string]float64{ "SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens + "360gpt-turbo": 0.0858, // ¥0.0012 / 1k tokens + "360gpt-turbo-responsibility-8k": 0.8572, // ¥0.012 / 1k tokens + "360gpt-pro": 0.8572, // ¥0.012 / 1k tokens "embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens "embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens "semantic_similarity_s1_v1": 0.0715, // ¥0.001 / 1k tokens @@ -299,6 +311,15 @@ func GetCompletionRatio(name string) float64 { if strings.HasPrefix(name, "deepseek") { return 2 } + if strings.HasPrefix(name, "ERNIE-Speed-") { + return 2 + } else if strings.HasPrefix(name, "ERNIE-Lite-") { + return 2 + } else if strings.HasPrefix(name, "ERNIE-Character") { + return 2 + } else if strings.HasPrefix(name, "ERNIE-Functions") { + return 2 + } switch name { case "llama2-70b-4096": return 0.8 / 0.64 diff --git a/common/utils.go b/common/utils.go index 3130020..6c89d41 100644 --- a/common/utils.go +++ b/common/utils.go @@ -258,3 +258,12 @@ func MapToJsonStrFloat(m map[string]float64) string { } return string(bytes) } + +func StrToMap(str string) map[string]interface{} { + m := make(map[string]interface{}) + err := json.Unmarshal([]byte(str), &m) + if err != nil { + return nil + } + return m +} diff --git a/controller/channel-test.go b/controller/channel-test.go index 7474cb4..db03e75 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -64,7 +64,21 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr } else { testModel = adaptor.GetModelList()[0] } + } else { + modelMapping := *channel.ModelMapping + if modelMapping != "" && modelMapping != "{}" { + modelMap := make(map[string]string) + err := json.Unmarshal([]byte(modelMapping), &modelMap) + if err != nil { + openaiErr := service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError).Error + return err, &openaiErr + } + if modelMap[testModel] != "" { + testModel = modelMap[testModel] + } + } } + request := buildTestRequest() request.Model = testModel meta.UpstreamModelName = testModel diff --git a/controller/model.go b/controller/model.go index de86ca3..5e1aa7d 100644 --- a/controller/model.go +++ b/controller/model.go @@ -108,8 +108,8 @@ func init() { }) } openAIModelsMap = make(map[string]dto.OpenAIModels) - for _, model := range openAIModels { - openAIModelsMap[model.Id] = model + for _, aiModel := range openAIModels { + openAIModelsMap[aiModel.Id] = aiModel } channelId2Models = make(map[int][]string) for i := 1; i <= common.ChannelTypeDummy; i++ { @@ -174,8 +174,8 @@ func DashboardListModels(c *gin.Context) { func RetrieveModel(c *gin.Context) { modelId := c.Param("model") - if model, ok := openAIModelsMap[modelId]; ok { - c.JSON(200, model) + if aiModel, ok := openAIModelsMap[modelId]; ok { + c.JSON(200, aiModel) } else { openAIError := dto.OpenAIError{ Message: fmt.Sprintf("The model '%s' does not exist", modelId), @@ -191,12 +191,12 @@ func RetrieveModel(c *gin.Context) { func GetPricing(c *gin.Context) { userId := c.GetInt("id") - user, _ := model.GetUserById(userId, true) + group, err := model.CacheGetUserGroup(userId) groupRatio := common.GetGroupRatio("default") - if user != nil { - groupRatio = common.GetGroupRatio(user.Group) + if err != nil { + groupRatio = common.GetGroupRatio(group) } - pricing := model.GetPricing(user, openAIModels) + pricing := model.GetPricing(group) c.JSON(200, gin.H{ "success": true, "data": pricing, diff --git a/controller/relay.go b/controller/relay.go index 0bbd409..a066e5d 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -43,7 +43,7 @@ func Relay(c *gin.Context) { group := c.GetString("group") originalModel := c.GetString("original_model") openaiErr := relayHandler(c, relayMode) - useChannel := []int{channelId} + c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)}) if openaiErr != nil { go processChannelError(c, channelId, openaiErr) } else { @@ -56,7 +56,9 @@ func Relay(c *gin.Context) { break } channelId = channel.Id - useChannel = append(useChannel, channelId) + useChannel := c.GetStringSlice("use_channel") + useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) + c.Set("use_channel", useChannel) common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i)) middleware.SetupContextForSelectedChannel(c, channel, originalModel) @@ -67,6 +69,7 @@ func Relay(c *gin.Context) { go processChannelError(c, channelId, openaiErr) } } + useChannel := c.GetStringSlice("use_channel") if len(useChannel) > 1 { retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) common.LogInfo(c.Request.Context(), retryLogStr) diff --git a/model/log.go b/model/log.go index c7d6e13..01165ee 100644 --- a/model/log.go +++ b/model/log.go @@ -155,6 +155,16 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int err = tx.Order("id desc").Limit(num).Offset(startIdx).Omit("id").Find(&logs).Error return logs, total, err + for i := range logs { + var otherMap map[string]interface{} + otherMap = common.StrToMap(logs[i].Other) + if otherMap != nil { + // delete admin + delete(otherMap, "admin_info") + } + logs[i].Other = common.MapToJsonStr(otherMap) + } + return logs, total, err } func SearchAllLogs(keyword string) (logs []*Log, err error) { diff --git a/model/main.go b/model/main.go index 7b1cd3d..b6ad2cb 100644 --- a/model/main.go +++ b/model/main.go @@ -93,12 +93,12 @@ func InitDB() (err error) { if !common.IsMasterNode { return nil } - if common.UsingMySQL { - _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded - _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY action VARCHAR(40);") // TODO: delete this line when most users have upgraded - _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY progress VARCHAR(30);") // TODO: delete this line when most users have upgraded - _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded - } + //if common.UsingMySQL { + // _, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded + // _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY action VARCHAR(40);") // TODO: delete this line when most users have upgraded + // _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY progress VARCHAR(30);") // TODO: delete this line when most users have upgraded + // _, _ = sqlDB.Exec("ALTER TABLE midjourneys MODIFY status VARCHAR(20);") // TODO: delete this line when most users have upgraded + //} common.SysLog("database migration started") err = db.AutoMigrate(&Channel{}) if err != nil { diff --git a/model/pricing.go b/model/pricing.go index c9685f3..90d8bc7 100644 --- a/model/pricing.go +++ b/model/pricing.go @@ -13,16 +13,16 @@ var ( updatePricingLock sync.Mutex ) -func GetPricing(user *User, openAIModels []dto.OpenAIModels) []dto.ModelPricing { +func GetPricing(group string) []dto.ModelPricing { updatePricingLock.Lock() defer updatePricingLock.Unlock() if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 { - updatePricing(openAIModels) + updatePricing() } - if user != nil { + if group != "" { userPricingMap := make([]dto.ModelPricing, 0) - models := GetGroupModels(user.Group) + models := GetGroupModels(group) for _, pricing := range pricingMap { if !common.StringsContains(models, pricing.ModelName) { pricing.Available = false @@ -34,28 +34,19 @@ func GetPricing(user *User, openAIModels []dto.OpenAIModels) []dto.ModelPricing return pricingMap } -func updatePricing(openAIModels []dto.OpenAIModels) { - modelRatios := common.GetModelRatios() +func updatePricing() { + //modelRatios := common.GetModelRatios() enabledModels := GetEnabledModels() - allModels := make(map[string]string) - for _, openAIModel := range openAIModels { - if common.StringsContains(enabledModels, openAIModel.Id) { - allModels[openAIModel.Id] = openAIModel.OwnedBy - } - } - for model, _ := range modelRatios { - if common.StringsContains(enabledModels, model) { - if _, ok := allModels[model]; !ok { - allModels[model] = "custom" - } - } + allModels := make(map[string]int) + for i, model := range enabledModels { + allModels[model] = i } + pricingMap = make([]dto.ModelPricing, 0) - for model, ownerBy := range allModels { + for model, _ := range allModels { pricing := dto.ModelPricing{ Available: true, ModelName: model, - OwnerBy: ownerBy, } modelPrice, findPrice := common.GetModelPrice(model, false) if findPrice { diff --git a/model/token.go b/model/token.go index 1bbf6c4..056156e 100644 --- a/model/token.go +++ b/model/token.go @@ -11,7 +11,7 @@ import ( type Token struct { Id int `json:"id"` - UserId int `json:"user_id"` + UserId int `json:"user_id" gorm:"index"` Key string `json:"key" gorm:"type:char(48);uniqueIndex"` Status int `json:"status" gorm:"default:1"` Name string `json:"name" gorm:"index" ` diff --git a/relay/channel/ai360/constants.go b/relay/channel/ai360/constants.go index 82698fa..824231d 100644 --- a/relay/channel/ai360/constants.go +++ b/relay/channel/ai360/constants.go @@ -1,6 +1,9 @@ package ai360 var ModelList = []string{ + "360gpt-turbo", + "360gpt-turbo-responsibility-8k", + "360gpt-pro", "360GPT_S2_V9", "embedding-bert-512-v1", "embedding_s1_v1", diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index d2571dc..44c5e3f 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -9,6 +9,7 @@ import ( "one-api/relay/channel" relaycommon "one-api/relay/common" "one-api/relay/constant" + "strings" ) type Adaptor struct { @@ -33,8 +34,24 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" case "BLOOMZ-7B": fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" + case "ERNIE-4.0-8K": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro" + case "ERNIE-3.5-8K": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" + case "ERNIE-Speed-8K": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_speed" + case "ERNIE-Character-8K": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-char-8k" + case "ERNIE-Functions-8K": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-func-8k" + case "ERNIE-Lite-8K-0922": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" + case "Yi-34B-Chat": + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat" case "Embedding-V1": fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" + default: + fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + strings.ToLower(info.UpstreamModelName) } var accessToken string var err error diff --git a/relay/channel/baidu/constants.go b/relay/channel/baidu/constants.go index a0162bb..3cb96fc 100644 --- a/relay/channel/baidu/constants.go +++ b/relay/channel/baidu/constants.go @@ -1,11 +1,19 @@ package baidu var ModelList = []string{ - "ERNIE-Bot-4", - "ERNIE-Bot-8K", - "ERNIE-Bot", - "ERNIE-Speed", - "ERNIE-Bot-turbo", + "ERNIE-3.5-8K", + "ERNIE-4.0-8K", + "ERNIE-Speed-8K", + "ERNIE-Speed-128K", + "ERNIE-Lite-8K", + "ERNIE-Tiny-8K", + "ERNIE-Character-8K", + "ERNIE-Functions-8K", + //"ERNIE-Bot-4", + //"ERNIE-Bot-8K", + //"ERNIE-Bot", + //"ERNIE-Speed", + //"ERNIE-Bot-turbo", "Embedding-V1", } diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index daaadc5..d372d82 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -21,6 +21,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo, request dto.GeneralOpenAIReq // 定义一个映射,存储模型名称和对应的版本 var modelVersionMap = map[string]string{ "gemini-1.5-pro-latest": "v1beta", + "gemini-1.5-flash-latest": "v1beta", "gemini-ultra": "v1beta", } diff --git a/relay/channel/gemini/constant.go b/relay/channel/gemini/constant.go index 5e833bc..621336b 100644 --- a/relay/channel/gemini/constant.go +++ b/relay/channel/gemini/constant.go @@ -5,7 +5,7 @@ const ( ) var ModelList = []string{ - "gemini-1.0-pro-latest", "gemini-1.0-pro-001", "gemini-1.5-pro-latest", "gemini-ultra", + "gemini-1.0-pro-latest", "gemini-1.0-pro-001", "gemini-1.5-pro-latest", "gemini-1.5-flash-latest", "gemini-ultra", "gemini-1.0-pro-vision-latest", "gemini-1.0-pro-vision-001", } diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go index 8a1dbd6..943c407 100644 --- a/relay/constant/api_type.go +++ b/relay/constant/api_type.go @@ -15,7 +15,7 @@ const ( APITypeAIProxyLibrary APITypeTencent APITypeGemini - APITypeZhipu_v4 + APITypeZhipuV4 APITypeOllama APITypePerplexity APITypeAws @@ -48,7 +48,7 @@ func ChannelType2APIType(channelType int) (int, bool) { case common.ChannelTypeGemini: apiType = APITypeGemini case common.ChannelTypeZhipu_v4: - apiType = APITypeZhipu_v4 + apiType = APITypeZhipuV4 case common.ChannelTypeOllama: apiType = APITypeOllama case common.ChannelTypePerplexity: diff --git a/relay/relay-text.go b/relay/relay-text.go index 12b610f..8e2ab5f 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -323,6 +323,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe other["group_ratio"] = groupRatio other["completion_ratio"] = completionRatio other["model_price"] = modelPrice + adminInfo := make(map[string]interface{}) + adminInfo["use_channel"] = ctx.GetStringSlice("use_channel") + other["admin_info"] = adminInfo model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, promptTokens, completionTokens, logModel, tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other) //if quota != 0 { diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 01e9cec..cf63054 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -41,7 +41,7 @@ func GetAdaptor(apiType int) channel.Adaptor { return &xunfei.Adaptor{} case constant.APITypeZhipu: return &zhipu.Adaptor{} - case constant.APITypeZhipu_v4: + case constant.APITypeZhipuV4: return &zhipu_4v.Adaptor{} case constant.APITypeOllama: return &ollama.Adaptor{} diff --git a/router/api-router.go b/router/api-router.go index c700216..26d5248 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -20,7 +20,7 @@ func SetApiRouter(router *gin.Engine) { apiRouter.GET("/about", controller.GetAbout) apiRouter.GET("/midjourney", controller.GetMidjourney) apiRouter.GET("/home_page_content", controller.GetHomePageContent) - apiRouter.GET("/pricing", middleware.CriticalRateLimit(), middleware.TryUserAuth(), controller.GetPricing) + apiRouter.GET("/pricing", middleware.TryUserAuth(), controller.GetPricing) apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification) apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index 452309c..ad53999 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -6,6 +6,7 @@ import { showError, showInfo, showSuccess, + showWarning, timestamp2string, } from '../helpers'; @@ -309,6 +310,12 @@ const ChannelsTable = () => { const setChannelFormat = (channels) => { for (let i = 0; i < channels.length; i++) { + if (channels[i].type === 8) { + showWarning( + '检测到您使用了“自定义渠道”类型,请更换为“OpenAI”渠道类型!', + ); + showWarning('下个版本将不再支持“自定义渠道”类型!'); + } channels[i].key = '' + channels[i].id; let test_models = []; channels[i].models.split(',').forEach((item, index) => { diff --git a/web/src/components/LogsTable.js b/web/src/components/LogsTable.js index 0cc55e4..fdc72c4 100644 --- a/web/src/components/LogsTable.js +++ b/web/src/components/LogsTable.js @@ -294,6 +294,30 @@ const LogsTable = () => { ); }, }, + { + title: '重试', + dataIndex: 'retry', + className: isAdmin() ? 'tableShow' : 'tableHiddle', + render: (text, record, index) => { + let content = '渠道:' + record.channel; + if (record.other !== '') { + let other = JSON.parse(record.other); + if (other.admin_info !== undefined) { + if ( + other.admin_info.use_channel !== null && + other.admin_info.use_channel !== undefined && + other.admin_info.use_channel !== '' + ) { + // channel id array + let useChannel = other.admin_info.use_channel; + let useChannelStr = useChannel.join('->'); + content = `渠道:${useChannelStr}`; + } + } + } + return isAdminUser ?