From ed972eef06d494efb66cf9c7f1310f018421cda3 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Sun, 22 Sep 2024 17:44:57 +0800 Subject: [PATCH] feat: pricing page support multi groups #487 --- controller/pricing.go | 11 +----- model/ability.go | 6 +++ model/pricing.go | 48 +++++++++++++----------- web/src/components/HeaderBar.js | 2 +- web/src/components/ModelPricing.js | 60 +++++++++++++++++++++++++----- 5 files changed, 87 insertions(+), 40 deletions(-) diff --git a/controller/pricing.go b/controller/pricing.go index 498cbe2..c298ae5 100644 --- a/controller/pricing.go +++ b/controller/pricing.go @@ -7,18 +7,11 @@ import ( ) func GetPricing(c *gin.Context) { - userId := c.GetInt("id") - // if no login, get default group ratio - groupRatio := common.GetGroupRatio("default") - group, err := model.CacheGetUserGroup(userId) - if err == nil { - groupRatio = common.GetGroupRatio(group) - } - pricing := model.GetPricing(group) + pricing := model.GetPricing() c.JSON(200, gin.H{ "success": true, "data": pricing, - "group_ratio": groupRatio, + "group_ratio": common.GroupRatio, }) } diff --git a/model/ability.go b/model/ability.go index 2733f6c..115ceb1 100644 --- a/model/ability.go +++ b/model/ability.go @@ -36,6 +36,12 @@ func GetEnabledModels() []string { return models } +func GetAllEnableAbilities() []Ability { + var abilities []Ability + DB.Find(&abilities, "enabled = ?", true) + return abilities +} + func getPriority(group string, model string, retry int) (int, error) { groupCol := "`group`" trueVal := "1" diff --git a/model/pricing.go b/model/pricing.go index 7384a2f..8ae5e32 100644 --- a/model/pricing.go +++ b/model/pricing.go @@ -7,14 +7,13 @@ import ( ) type Pricing struct { - Available bool `json:"available"` ModelName string `json:"model_name"` QuotaType int `json:"quota_type"` ModelRatio float64 `json:"model_ratio"` ModelPrice float64 `json:"model_price"` OwnerBy string `json:"owner_by"` CompletionRatio float64 `json:"completion_ratio"` - EnableGroup []string `json:"enable_group,omitempty"` + EnableGroup []string `json:"enable_groups,omitempty"` } var ( @@ -23,40 +22,47 @@ var ( updatePricingLock sync.Mutex ) -func GetPricing(group string) []Pricing { +func GetPricing() []Pricing { updatePricingLock.Lock() defer updatePricingLock.Unlock() if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 { updatePricing() } - if group != "" { - userPricingMap := make([]Pricing, 0) - models := GetGroupModels(group) - for _, pricing := range pricingMap { - if !common.StringsContains(models, pricing.ModelName) { - pricing.Available = false - } - userPricingMap = append(userPricingMap, pricing) - } - return userPricingMap - } + //if group != "" { + // userPricingMap := make([]Pricing, 0) + // models := GetGroupModels(group) + // for _, pricing := range pricingMap { + // if !common.StringsContains(models, pricing.ModelName) { + // pricing.Available = false + // } + // userPricingMap = append(userPricingMap, pricing) + // } + // return userPricingMap + //} return pricingMap } func updatePricing() { //modelRatios := common.GetModelRatios() - enabledModels := GetEnabledModels() - allModels := make(map[string]int) - for i, model := range enabledModels { - allModels[model] = i + enableAbilities := GetAllEnableAbilities() + modelGroupsMap := make(map[string][]string) + for _, ability := range enableAbilities { + groups := modelGroupsMap[ability.Model] + if groups == nil { + groups = make([]string, 0) + } + if !common.StringsContains(groups, ability.Group) { + groups = append(groups, ability.Group) + } + modelGroupsMap[ability.Model] = groups } pricingMap = make([]Pricing, 0) - for model, _ := range allModels { + for model, groups := range modelGroupsMap { pricing := Pricing{ - Available: true, - ModelName: model, + ModelName: model, + EnableGroup: groups, } modelPrice, findPrice := common.GetModelPrice(model, false) if findPrice { diff --git a/web/src/components/HeaderBar.js b/web/src/components/HeaderBar.js index b73bb0e..366cbce 100644 --- a/web/src/components/HeaderBar.js +++ b/web/src/components/HeaderBar.js @@ -36,7 +36,7 @@ let buttons = [ text: '首页', itemKey: 'home', to: '/', - icon: , + // icon: , }, // { // text: '模型价格', diff --git a/web/src/components/ModelPricing.js b/web/src/components/ModelPricing.js index 57de598..c289368 100644 --- a/web/src/components/ModelPricing.js +++ b/web/src/components/ModelPricing.js @@ -1,5 +1,5 @@ import React, { useContext, useEffect, useRef, useMemo, useState } from 'react'; -import { API, copy, showError, showSuccess } from '../helpers'; +import { API, copy, showError, showInfo, showSuccess } from '../helpers'; import { Banner, @@ -87,6 +87,7 @@ const ModelPricing = () => { const [selectedRowKeys, setSelectedRowKeys] = useState([]); const [modalImageUrl, setModalImageUrl] = useState(''); const [isModalOpenurl, setIsModalOpenurl] = useState(false); + const [selectedGroup, setSelectedGroup] = useState('default'); const rowSelection = useMemo( () => ({ @@ -120,7 +121,8 @@ const ModelPricing = () => { title: '可用性', dataIndex: 'available', render: (text, record, index) => { - return renderAvailable(text); + // if record.enable_groups contains selectedGroup, then available is true + return renderAvailable(record.enable_groups.includes(selectedGroup)); }, sorter: (a, b) => a.available - b.available, }, @@ -166,6 +168,43 @@ const ModelPricing = () => { }, sorter: (a, b) => a.quota_type - b.quota_type, }, + { + title: '可用分组', + dataIndex: 'enable_groups', + render: (text, record, index) => { + // enable_groups is a string array + return ( + + {text.map((group) => { + if (group === selectedGroup) { + return ( + } + > + {group} + + ); + } else { + return ( + { + setSelectedGroup(group); + showInfo('当前查看的分组为:' + group + ',倍率为:' + groupRatio[group]); + }} + > + {group} + + ); + } + })} + + ); + }, + }, { title: () => ( @@ -201,6 +240,8 @@ const ModelPricing = () => { 模型:{record.quota_type === 0 ? text : '无'}
补全:{record.quota_type === 0 ? completionRatio : '无'} +
+ 分组:{groupRatio[selectedGroup]} ); return
{content}
; @@ -213,11 +254,11 @@ const ModelPricing = () => { let content = text; if (record.quota_type === 0) { // 这里的 *2 是因为 1倍率=0.002刀,请勿删除 - let inputRatioPrice = record.model_ratio * 2 * record.group_ratio; + let inputRatioPrice = record.model_ratio * 2 * groupRatio[selectedGroup]; let completionRatioPrice = record.model_ratio * record.completion_ratio * 2 * - record.group_ratio; + groupRatio[selectedGroup]; content = ( <> 提示 ${inputRatioPrice} / 1M tokens @@ -226,7 +267,7 @@ const ModelPricing = () => { ); } else { - let price = parseFloat(text) * record.group_ratio; + let price = parseFloat(text) * groupRatio[selectedGroup]; content = <>模型价格:${price}; } return
{content}
; @@ -237,12 +278,12 @@ const ModelPricing = () => { const [models, setModels] = useState([]); const [loading, setLoading] = useState(true); const [userState, userDispatch] = useContext(UserContext); - const [groupRatio, setGroupRatio] = useState(1); + const [groupRatio, setGroupRatio] = useState({}); const setModelsFormat = (models, groupRatio) => { for (let i = 0; i < models.length; i++) { models[i].key = models[i].model_name; - models[i].group_ratio = groupRatio; + models[i].group_ratio = groupRatio[models[i].model_name]; } // sort by quota_type models.sort((a, b) => { @@ -275,6 +316,7 @@ const ModelPricing = () => { const { success, message, data, group_ratio } = res.data; if (success) { setGroupRatio(group_ratio); + setSelectedGroup(userState.user ? userState.user.group : 'default') setModelsFormat(data, group_ratio); } else { showError(message); @@ -307,14 +349,14 @@ const ModelPricing = () => { type="success" fullMode={false} closeIcon="null" - description={`您的分组为:${userState.user.group},分组倍率为:${groupRatio}`} + description={`您的默认分组为:${userState.user.group},分组倍率为:${groupRatio[userState.user.group]}`} /> ) : ( )}