From 5715fcf8fb24ba7b18fa0ad3a3a6d971cb269102 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Mon, 13 May 2024 23:02:35 +0800 Subject: [PATCH] feat: add pricing page --- common/model-ratio.go | 21 +++ common/utils.go | 8 + controller/model.go | 68 ++++----- dto/pricing.go | 37 +++++ middleware/auth.go | 11 ++ model/ability.go | 7 + model/pricing.go | 72 +++++++++ model/usedata.go | 1 + router/api-router.go | 1 + web/src/App.js | 9 ++ web/src/components/LoginForm.js | 3 +- web/src/components/ModelPricing.js | 229 +++++++++++++++++++++++++++++ web/src/components/SiderBar.js | 37 ++--- web/src/helpers/data.js | 33 +++++ web/src/helpers/render.js | 2 +- web/src/pages/Pricing/index.js | 10 ++ 16 files changed, 481 insertions(+), 68 deletions(-) create mode 100644 dto/pricing.go create mode 100644 model/pricing.go create mode 100644 web/src/components/ModelPricing.js create mode 100644 web/src/helpers/data.js create mode 100644 web/src/pages/Pricing/index.js diff --git a/common/model-ratio.go b/common/model-ratio.go index a8db3b3..4510551 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -178,6 +178,13 @@ func GetModelPrice(name string, printErr bool) (float64, bool) { return price, true } +func GetModelPrices() map[string]float64 { + if modelPrice == nil { + modelPrice = DefaultModelPrice + } + return modelPrice +} + func ModelRatio2JSONString() string { if modelRatio == nil { modelRatio = DefaultModelRatio @@ -209,6 +216,13 @@ func GetModelRatio(name string) float64 { return ratio } +func GetModelRatios() map[string]float64 { + if modelRatio == nil { + modelRatio = DefaultModelRatio + } + return modelRatio +} + func CompletionRatio2JSONString() string { if CompletionRatio == nil { CompletionRatio = DefaultCompletionRatio @@ -282,3 +296,10 @@ func GetCompletionRatio(name string) float64 { } return 1 } + +func GetCompletionRatios() map[string]float64 { + if CompletionRatio == nil { + CompletionRatio = DefaultCompletionRatio + } + return CompletionRatio +} diff --git a/common/utils.go b/common/utils.go index 657ffd4..3130020 100644 --- a/common/utils.go +++ b/common/utils.go @@ -250,3 +250,11 @@ func MapToJsonStr(m map[string]interface{}) string { } return string(bytes) } + +func MapToJsonStrFloat(m map[string]float64) string { + bytes, err := json.Marshal(m) + if err != nil { + return "" + } + return string(bytes) +} diff --git a/controller/model.go b/controller/model.go index c9c50db..de86ca3 100644 --- a/controller/model.go +++ b/controller/model.go @@ -18,38 +18,13 @@ import ( // https://platform.openai.com/docs/api-reference/models/list -type OpenAIModelPermission struct { - Id string `json:"id"` - Object string `json:"object"` - Created int `json:"created"` - AllowCreateEngine bool `json:"allow_create_engine"` - AllowSampling bool `json:"allow_sampling"` - AllowLogprobs bool `json:"allow_logprobs"` - AllowSearchIndices bool `json:"allow_search_indices"` - AllowView bool `json:"allow_view"` - AllowFineTuning bool `json:"allow_fine_tuning"` - Organization string `json:"organization"` - Group *string `json:"group"` - IsBlocking bool `json:"is_blocking"` -} - -type OpenAIModels struct { - Id string `json:"id"` - Object string `json:"object"` - Created int `json:"created"` - OwnedBy string `json:"owned_by"` - Permission []OpenAIModelPermission `json:"permission"` - Root string `json:"root"` - Parent *string `json:"parent"` -} - -var openAIModels []OpenAIModels -var openAIModelsMap map[string]OpenAIModels +var openAIModels []dto.OpenAIModels +var openAIModelsMap map[string]dto.OpenAIModels var channelId2Models map[int][]string -func getPermission() []OpenAIModelPermission { - var permission []OpenAIModelPermission - permission = append(permission, OpenAIModelPermission{ +func getPermission() []dto.OpenAIModelPermission { + var permission []dto.OpenAIModelPermission + permission = append(permission, dto.OpenAIModelPermission{ Id: "modelperm-LwHkVFn8AcMItP432fKKDIKJ", Object: "model_permission", Created: 1626777600, @@ -77,7 +52,7 @@ func init() { channelName := adaptor.GetChannelName() modelNames := adaptor.GetModelList() for _, modelName := range modelNames { - openAIModels = append(openAIModels, OpenAIModels{ + openAIModels = append(openAIModels, dto.OpenAIModels{ Id: modelName, Object: "model", Created: 1626777600, @@ -89,7 +64,7 @@ func init() { } } for _, modelName := range ai360.ModelList { - openAIModels = append(openAIModels, OpenAIModels{ + openAIModels = append(openAIModels, dto.OpenAIModels{ Id: modelName, Object: "model", Created: 1626777600, @@ -100,7 +75,7 @@ func init() { }) } for _, modelName := range moonshot.ModelList { - openAIModels = append(openAIModels, OpenAIModels{ + openAIModels = append(openAIModels, dto.OpenAIModels{ Id: modelName, Object: "model", Created: 1626777600, @@ -111,7 +86,7 @@ func init() { }) } for _, modelName := range lingyiwanwu.ModelList { - openAIModels = append(openAIModels, OpenAIModels{ + openAIModels = append(openAIModels, dto.OpenAIModels{ Id: modelName, Object: "model", Created: 1626777600, @@ -122,7 +97,7 @@ func init() { }) } for modelName, _ := range constant.MidjourneyModel2Action { - openAIModels = append(openAIModels, OpenAIModels{ + openAIModels = append(openAIModels, dto.OpenAIModels{ Id: modelName, Object: "model", Created: 1626777600, @@ -132,7 +107,7 @@ func init() { Parent: nil, }) } - openAIModelsMap = make(map[string]OpenAIModels) + openAIModelsMap = make(map[string]dto.OpenAIModels) for _, model := range openAIModels { openAIModelsMap[model.Id] = model } @@ -160,17 +135,17 @@ func ListModels(c *gin.Context) { return } models := model.GetGroupModels(user.Group) - userOpenAiModels := make([]OpenAIModels, 0) + userOpenAiModels := make([]dto.OpenAIModels, 0) permission := getPermission() for _, s := range models { if _, ok := openAIModelsMap[s]; ok { userOpenAiModels = append(userOpenAiModels, openAIModelsMap[s]) } else { - userOpenAiModels = append(userOpenAiModels, OpenAIModels{ + userOpenAiModels = append(userOpenAiModels, dto.OpenAIModels{ Id: s, Object: "model", Created: 1626777600, - OwnedBy: "openai", + OwnedBy: "custom", Permission: permission, Root: s, Parent: nil, @@ -213,3 +188,18 @@ func RetrieveModel(c *gin.Context) { }) } } + +func GetPricing(c *gin.Context) { + userId := c.GetInt("id") + user, _ := model.GetUserById(userId, true) + groupRatio := common.GetGroupRatio("default") + if user != nil { + groupRatio = common.GetGroupRatio(user.Group) + } + pricing := model.GetPricing(user, openAIModels) + c.JSON(200, gin.H{ + "success": true, + "data": pricing, + "group_ratio": groupRatio, + }) +} diff --git a/dto/pricing.go b/dto/pricing.go new file mode 100644 index 0000000..b049749 --- /dev/null +++ b/dto/pricing.go @@ -0,0 +1,37 @@ +package dto + +type OpenAIModelPermission struct { + Id string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + AllowCreateEngine bool `json:"allow_create_engine"` + AllowSampling bool `json:"allow_sampling"` + AllowLogprobs bool `json:"allow_logprobs"` + AllowSearchIndices bool `json:"allow_search_indices"` + AllowView bool `json:"allow_view"` + AllowFineTuning bool `json:"allow_fine_tuning"` + Organization string `json:"organization"` + Group *string `json:"group"` + IsBlocking bool `json:"is_blocking"` +} + +type OpenAIModels struct { + Id string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + OwnedBy string `json:"owned_by"` + Permission []OpenAIModelPermission `json:"permission"` + Root string `json:"root"` + Parent *string `json:"parent"` +} + +type ModelPricing struct { + Available bool `json:"available"` + ModelName string `json:"model_name"` + QuotaType int `json:"quota_type"` + ModelRatio float64 `json:"model_ratio"` + ModelPrice float64 `json:"model_price"` + OwnerBy string `json:"owner_by"` + CompletionRatio float64 `json:"completion_ratio"` + EnableGroup []string `json:"enable_group,omitempty"` +} diff --git a/middleware/auth.go b/middleware/auth.go index 686f2d9..d9df9c8 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -64,6 +64,17 @@ func authHelper(c *gin.Context, minRole int) { c.Next() } +func TryUserAuth() func(c *gin.Context) { + return func(c *gin.Context) { + session := sessions.Default(c) + id := session.Get("id") + if id != nil { + c.Set("id", id) + } + c.Next() + } +} + func UserAuth() func(c *gin.Context) { return func(c *gin.Context) { authHelper(c, common.RoleCommonUser) diff --git a/model/ability.go b/model/ability.go index 7fd52bc..8d2d4f8 100644 --- a/model/ability.go +++ b/model/ability.go @@ -29,6 +29,13 @@ func GetGroupModels(group string) []string { return models } +func GetEnabledModels() []string { + var models []string + // Find distinct models + DB.Table("abilities").Where("enabled = ?", true).Distinct("model").Pluck("model", &models) + return models +} + func getPriority(group string, model string, retry int) (int, error) { groupCol := "`group`" trueVal := "1" diff --git a/model/pricing.go b/model/pricing.go new file mode 100644 index 0000000..c9685f3 --- /dev/null +++ b/model/pricing.go @@ -0,0 +1,72 @@ +package model + +import ( + "one-api/common" + "one-api/dto" + "sync" + "time" +) + +var ( + pricingMap []dto.ModelPricing + lastGetPricingTime time.Time + updatePricingLock sync.Mutex +) + +func GetPricing(user *User, openAIModels []dto.OpenAIModels) []dto.ModelPricing { + updatePricingLock.Lock() + defer updatePricingLock.Unlock() + + if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 { + updatePricing(openAIModels) + } + if user != nil { + userPricingMap := make([]dto.ModelPricing, 0) + models := GetGroupModels(user.Group) + for _, pricing := range pricingMap { + if !common.StringsContains(models, pricing.ModelName) { + pricing.Available = false + } + userPricingMap = append(userPricingMap, pricing) + } + return userPricingMap + } + return pricingMap +} + +func updatePricing(openAIModels []dto.OpenAIModels) { + modelRatios := common.GetModelRatios() + enabledModels := GetEnabledModels() + allModels := make(map[string]string) + for _, openAIModel := range openAIModels { + if common.StringsContains(enabledModels, openAIModel.Id) { + allModels[openAIModel.Id] = openAIModel.OwnedBy + } + } + for model, _ := range modelRatios { + if common.StringsContains(enabledModels, model) { + if _, ok := allModels[model]; !ok { + allModels[model] = "custom" + } + } + } + pricingMap = make([]dto.ModelPricing, 0) + for model, ownerBy := range allModels { + pricing := dto.ModelPricing{ + Available: true, + ModelName: model, + OwnerBy: ownerBy, + } + modelPrice, findPrice := common.GetModelPrice(model, false) + if findPrice { + pricing.ModelPrice = modelPrice + pricing.QuotaType = 1 + } else { + pricing.ModelRatio = common.GetModelRatio(model) + pricing.CompletionRatio = common.GetCompletionRatio(model) + pricing.QuotaType = 0 + } + pricingMap = append(pricingMap, pricing) + } + lastGetPricingTime = time.Now() +} diff --git a/model/usedata.go b/model/usedata.go index b2f3025..4735333 100644 --- a/model/usedata.go +++ b/model/usedata.go @@ -45,6 +45,7 @@ func logQuotaDataCache(userId int, username string, modelName string, quota int, if ok { quotaData.Count += 1 quotaData.Quota += quota + quotaData.TokenUsed += tokenUsed } else { quotaData = &QuotaData{ UserID: userId, diff --git a/router/api-router.go b/router/api-router.go index 8c0ae30..add5c5f 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -20,6 +20,7 @@ func SetApiRouter(router *gin.Engine) { apiRouter.GET("/about", controller.GetAbout) //apiRouter.GET("/midjourney", controller.GetMidjourney) apiRouter.GET("/home_page_content", controller.GetHomePageContent) + apiRouter.GET("/pricing", middleware.CriticalRateLimit(), middleware.TryUserAuth(), controller.GetPricing) apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification) apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) diff --git a/web/src/App.js b/web/src/App.js index a3b0660..1b63def 100644 --- a/web/src/App.js +++ b/web/src/App.js @@ -22,6 +22,7 @@ import Log from './pages/Log'; import Chat from './pages/Chat'; import { Layout } from '@douyinfe/semi-ui'; import Midjourney from './pages/Midjourney'; +import Pricing from './pages/Pricing/index.js'; // import Detail from './pages/Detail'; const Home = lazy(() => import('./pages/Home')); @@ -219,6 +220,14 @@ function App() { } /> + }> + + + } + /> { const [inputs, setInputs] = useState({ @@ -99,7 +100,7 @@ const LoginForm = () => { const { success, message, data } = res.data; if (success) { userDispatch({ type: 'login', payload: data }); - localStorage.setItem('user', JSON.stringify(data)); + setUserData(data); showSuccess('登录成功!'); if (username === 'root' && password === '123456') { Modal.error({ diff --git a/web/src/components/ModelPricing.js b/web/src/components/ModelPricing.js new file mode 100644 index 0000000..708d79e --- /dev/null +++ b/web/src/components/ModelPricing.js @@ -0,0 +1,229 @@ +import React, { useContext, useEffect, useState } from 'react'; +import { API, copy, showError, showSuccess } from '../helpers'; + +import { Banner, Layout, Modal, Table, Tag, Tooltip } from '@douyinfe/semi-ui'; +import { stringToColor } from '../helpers/render.js'; +import { UserContext } from '../context/User/index.js'; +import Text from '@douyinfe/semi-ui/lib/es/typography/text'; + +function renderQuotaType(type) { + // Ensure all cases are string literals by adding quotes. + switch (type) { + case 1: + return ( + + 按次计费 + + ); + case 0: + return ( + + 按量计费 + + ); + default: + return ( + + 未知 + + ); + } +} + +function renderAvailable(available) { + return available ? ( + + 可用 + + ) : ( + + + 不可用 + + + ); +} + +const ModelPricing = () => { + const columns = [ + { + title: '可用性', + dataIndex: 'available', + render: (text, record, index) => { + return renderAvailable(text); + }, + }, + { + title: '提供者', + dataIndex: 'owner_by', + render: (text, record, index) => { + return ( + <> + + {text} + + + ); + }, + }, + { + title: '模型名称', + dataIndex: 'model_name', // 以finish_time作为dataIndex + render: (text, record, index) => { + return ( + <> + { + copyText(text); + }} + > + {text} + + + ); + }, + }, + { + title: '计费类型', + dataIndex: 'quota_type', + render: (text, record, index) => { + return renderQuotaType(parseInt(text)); + }, + }, + { + title: '模型倍率', + dataIndex: 'model_ratio', + render: (text, record, index) => { + return
{record.quota_type === 0 ? text : 'N/A'}
; + }, + }, + { + title: '补全倍率', + dataIndex: 'completion_ratio', + render: (text, record, index) => { + let ratio = parseFloat(text.toFixed(3)); + return
{record.quota_type === 0 ? ratio : 'N/A'}
; + }, + }, + { + title: '模型价格', + dataIndex: 'model_price', + render: (text, record, index) => { + let content = text; + if (record.quota_type === 0) { + let inputRatioPrice = record.model_ratio * 2.0 * record.group_ratio; + let completionRatioPrice = + record.model_ratio * + record.completion_ratio * + 2.0 * + record.group_ratio; + content = ( + <> + 提示 ${inputRatioPrice} / 1M tokens +
+ 补全 ${completionRatioPrice} / 1M tokens + + ); + } else { + let price = parseFloat(text) * record.group_ratio; + content = <>模型价格:${price}; + } + return
{content}
; + }, + }, + ]; + + const [models, setModels] = useState([]); + const [loading, setLoading] = useState(true); + const [userState, userDispatch] = useContext(UserContext); + const [groupRatio, setGroupRatio] = useState(1); + + const setModelsFormat = (models, groupRatio) => { + for (let i = 0; i < models.length; i++) { + models[i].key = i; + models[i].group_ratio = groupRatio; + } + // sort by quota_type + models.sort((a, b) => { + return a.quota_type - b.quota_type; + }); + + // sort by owner_by, openai is max, other use localeCompare + models.sort((a, b) => { + if (a.owner_by === 'openai') { + return -1; + } else if (b.owner_by === 'openai') { + return 1; + } else { + return a.owner_by.localeCompare(b.owner_by); + } + }); + + setModels(models); + }; + + const loadPricing = async () => { + setLoading(true); + + let url = ''; + url = `/api/pricing`; + const res = await API.get(url); + const { success, message, data, group_ratio } = res.data; + if (success) { + setGroupRatio(group_ratio); + setModelsFormat(data, group_ratio); + } else { + showError(message); + } + setLoading(false); + }; + + const refresh = async () => { + await loadPricing(); + }; + + const copyText = async (text) => { + if (await copy(text)) { + showSuccess('已复制:' + text); + } else { + // setSearchKeyword(text); + Modal.error({ title: '无法复制到剪贴板,请手动复制', content: text }); + } + }; + + useEffect(() => { + refresh().then(); + }, []); + + return ( + <> + + {userState.user ? ( + + ) : ( + + )} + + + + ); +}; + +export default ModelPricing; diff --git a/web/src/components/SiderBar.js b/web/src/components/SiderBar.js index 96d7e7e..d542a43 100644 --- a/web/src/components/SiderBar.js +++ b/web/src/components/SiderBar.js @@ -23,10 +23,12 @@ import { IconImage, IconKey, IconLayers, + IconPriceTag, IconSetting, IconUser, } from '@douyinfe/semi-icons'; import { Layout, Nav } from '@douyinfe/semi-ui'; +import { setStatusData } from '../helpers/data.js'; // HeaderBar Buttons @@ -55,6 +57,7 @@ const SiderBar = () => { about: '/about', chat: '/chat', detail: '/detail', + pricing: '/pricing', }; const headerButtons = useMemo( @@ -100,6 +103,12 @@ const SiderBar = () => { to: '/topup', icon: , }, + { + text: '模型价格', + itemKey: 'pricing', + to: '/pricing', + icon: , + }, { text: '用户管理', itemKey: 'user', @@ -161,34 +170,8 @@ const SiderBar = () => { } const { success, data } = res.data; if (success) { - localStorage.setItem('status', JSON.stringify(data)); statusDispatch({ type: 'set', payload: data }); - localStorage.setItem('system_name', data.system_name); - localStorage.setItem('logo', data.logo); - localStorage.setItem('footer_html', data.footer_html); - localStorage.setItem('quota_per_unit', data.quota_per_unit); - localStorage.setItem('display_in_currency', data.display_in_currency); - localStorage.setItem('enable_drawing', data.enable_drawing); - localStorage.setItem('enable_data_export', data.enable_data_export); - localStorage.setItem( - 'data_export_default_time', - data.data_export_default_time, - ); - localStorage.setItem( - 'default_collapse_sidebar', - data.default_collapse_sidebar, - ); - localStorage.setItem('mj_notify_enabled', data.mj_notify_enabled); - if (data.chat_link) { - localStorage.setItem('chat_link', data.chat_link); - } else { - localStorage.removeItem('chat_link'); - } - if (data.chat_link2) { - localStorage.setItem('chat_link2', data.chat_link2); - } else { - localStorage.removeItem('chat_link2'); - } + setStatusData(data); } else { showError('无法正常连接至服务器!'); } diff --git a/web/src/helpers/data.js b/web/src/helpers/data.js new file mode 100644 index 0000000..750b670 --- /dev/null +++ b/web/src/helpers/data.js @@ -0,0 +1,33 @@ +export function setStatusData(data) { + localStorage.setItem('status', JSON.stringify(data)); + localStorage.setItem('system_name', data.system_name); + localStorage.setItem('logo', data.logo); + localStorage.setItem('footer_html', data.footer_html); + localStorage.setItem('quota_per_unit', data.quota_per_unit); + localStorage.setItem('display_in_currency', data.display_in_currency); + localStorage.setItem('enable_drawing', data.enable_drawing); + localStorage.setItem('enable_data_export', data.enable_data_export); + localStorage.setItem( + 'data_export_default_time', + data.data_export_default_time, + ); + localStorage.setItem( + 'default_collapse_sidebar', + data.default_collapse_sidebar, + ); + localStorage.setItem('mj_notify_enabled', data.mj_notify_enabled); + if (data.chat_link) { + localStorage.setItem('chat_link', data.chat_link); + } else { + localStorage.removeItem('chat_link'); + } + if (data.chat_link2) { + localStorage.setItem('chat_link2', data.chat_link2); + } else { + localStorage.removeItem('chat_link2'); + } +} + +export function setUserData(data) { + localStorage.setItem('user', JSON.stringify(data)); +} diff --git a/web/src/helpers/render.js b/web/src/helpers/render.js index 3113fed..d84b2eb 100644 --- a/web/src/helpers/render.js +++ b/web/src/helpers/render.js @@ -159,7 +159,7 @@ export function renderModelPrice(

提示 ${inputRatioPrice} / 1M tokens

补全 ${completionRatioPrice} / 1M tokens

-

+

提示 {inputTokens} tokens / 1M tokens * ${inputRatioPrice} + 补全{' '} {completionTokens} tokens / 1M tokens * ${completionRatioPrice} = $ diff --git a/web/src/pages/Pricing/index.js b/web/src/pages/Pricing/index.js new file mode 100644 index 0000000..cb56a47 --- /dev/null +++ b/web/src/pages/Pricing/index.js @@ -0,0 +1,10 @@ +import React from 'react'; +import ModelPricing from '../../components/ModelPricing.js'; + +const Pricing = () => ( + <> + + +); + +export default Pricing;