diff --git a/common/constants.go b/common/constants.go index 9ce2003..c89280d 100644 --- a/common/constants.go +++ b/common/constants.go @@ -9,14 +9,19 @@ import ( "github.com/google/uuid" ) +// Pay Settings + +var PayAddress = "" +var CustomCallbackAddress = "" +var EpayId = "" +var EpayKey = "" +var Price = 7.3 +var MinTopUp = 1 + var StartTime = time.Now().Unix() // unit: second var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change var SystemName = "New API" var ServerAddress = "http://localhost:3000" -var PayAddress = "" -var EpayId = "" -var EpayKey = "" -var Price = 7.3 var Footer = "" var Logo = "" var TopUpLink = "" @@ -29,6 +34,7 @@ var DrawingEnabled = true var DataExportEnabled = true var DataExportInterval = 5 // unit: minute var DataExportDefaultTime = "hour" // unit: minute +var DefaultCollapseSidebar = false // default value of collapse sidebar // Any options with "Secret", "Token" in its key won't be return by GetOptions diff --git a/common/model-ratio.go b/common/model-ratio.go index 30b87ee..648f9fc 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -80,7 +80,10 @@ var ModelRatio = map[string]float64{ "qwen-turbo": 0.8572, // ¥0.012 / 1k tokens "qwen-plus": 10, // ¥0.14 / 1k tokens "text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens - "SparkDesk": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v3.1": 1.2858, // ¥0.018 / 1k tokens + "SparkDesk-v3.5": 1.2858, // ¥0.018 / 1k tokens "360GPT_S2_V9": 0.8572, // ¥0.012 / 1k tokens "embedding-bert-512-v1": 0.0715, // ¥0.001 / 1k tokens "embedding_s1_v1": 0.0715, // ¥0.001 / 1k tokens diff --git a/controller/channel.go b/controller/channel.go index dd71259..b98af41 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -54,8 +54,9 @@ func FixChannelsAbilities(c *gin.Context) { func SearchChannels(c *gin.Context) { keyword := c.Query("keyword") group := c.Query("group") + modelKeyword := c.Query("model") //idSort, _ := strconv.ParseBool(c.Query("id_sort")) - channels, err := model.SearchChannels(keyword, group) + channels, err := model.SearchChannels(keyword, group, modelKeyword) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, diff --git a/controller/misc.go b/controller/misc.go index 1aabb02..0a4f1d8 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -29,6 +29,7 @@ func GetStatus(c *gin.Context) { "wechat_login": common.WeChatAuthEnabled, "server_address": common.ServerAddress, "price": common.Price, + "min_topup": common.MinTopUp, "turnstile_check": common.TurnstileCheckEnabled, "turnstile_site_key": common.TurnstileSiteKey, "top_up_link": common.TopUpLink, @@ -40,6 +41,7 @@ func GetStatus(c *gin.Context) { "enable_drawing": common.DrawingEnabled, "enable_data_export": common.DataExportEnabled, "data_export_default_time": common.DataExportDefaultTime, + "default_collapse_sidebar": common.DefaultCollapseSidebar, "enable_online_topup": common.PayAddress != "" && common.EpayId != "" && common.EpayKey != "", }, }) diff --git a/controller/model.go b/controller/model.go index 8909de4..38c6c46 100644 --- a/controller/model.go +++ b/controller/model.go @@ -129,6 +129,13 @@ func ListModels(c *gin.Context) { }) } +func ChannelListModels(c *gin.Context) { + c.JSON(200, gin.H{ + "object": "list", + "data": openAIModels, + }) +} + func RetrieveModel(c *gin.Context) { modelId := c.Param("model") if model, ok := openAIModelsMap[modelId]; ok { diff --git a/controller/topup.go b/controller/topup.go index 961ffa2..3203eb8 100644 --- a/controller/topup.go +++ b/controller/topup.go @@ -9,6 +9,7 @@ import ( "net/url" "one-api/common" "one-api/model" + "one-api/service" "strconv" "time" ) @@ -55,14 +56,14 @@ func RequestEpay(c *gin.Context) { c.JSON(200, gin.H{"message": err.Error(), "data": 10}) return } - if req.Amount < 1 { - c.JSON(200, gin.H{"message": "充值金额不能小于1", "data": 10}) + if req.Amount < common.MinTopUp { + c.JSON(200, gin.H{"message": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp), "data": 10}) return } id := c.GetInt("id") user, _ := model.GetUserById(id, false) - amount := GetAmount(float64(req.Amount), *user) + payMoney := GetAmount(float64(req.Amount), *user) var payType epay.PurchaseType if req.PaymentMethod == "zfb" { @@ -72,11 +73,10 @@ func RequestEpay(c *gin.Context) { req.PaymentMethod = "wxpay" payType = epay.WechatPay } - + callBackAddress := service.GetCallbackAddress() returnUrl, _ := url.Parse(common.ServerAddress + "/log") - notifyUrl, _ := url.Parse(common.ServerAddress + "/api/user/epay/notify") + notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify") tradeNo := strconv.FormatInt(time.Now().Unix(), 10) - payMoney := amount client := GetEpayClient() if client == nil { c.JSON(200, gin.H{"message": "error", "data": "当前管理员未配置支付信息"}) @@ -169,8 +169,8 @@ func RequestAmount(c *gin.Context) { c.JSON(200, gin.H{"message": "error", "data": "参数错误"}) return } - if req.Amount < 1 { - c.JSON(200, gin.H{"message": "error", "data": "充值金额不能小于1"}) + if req.Amount < common.MinTopUp { + c.JSON(200, gin.H{"message": "error", "data": fmt.Sprintf("充值数量不能小于 %d", common.MinTopUp)}) return } id := c.GetInt("id") diff --git a/model/cache.go b/model/cache.go index b1199e2..8294e73 100644 --- a/model/cache.go +++ b/model/cache.go @@ -291,24 +291,27 @@ func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error } } } + + // 平滑系数 + smoothingFactor := 10 // Calculate the total weight of all channels up to endIdx totalWeight := 0 for _, channel := range channels[:endIdx] { - totalWeight += channel.GetWeight() + totalWeight += channel.GetWeight() + smoothingFactor } - if totalWeight == 0 { - // If all weights are 0, select a channel randomly - return channels[rand.Intn(endIdx)], nil - } + //if totalWeight == 0 { + // // If all weights are 0, select a channel randomly + // return channels[rand.Intn(endIdx)], nil + //} // Generate a random value in the range [0, totalWeight) randomWeight := rand.Intn(totalWeight) // Find a channel based on its weight for _, channel := range channels[:endIdx] { - randomWeight -= channel.GetWeight() - if randomWeight <= 0 { + randomWeight -= channel.GetWeight() + smoothingFactor + if randomWeight < 0 { return channel, nil } } diff --git a/model/channel.go b/model/channel.go index 96b3635..7c7b0d9 100644 --- a/model/channel.go +++ b/model/channel.go @@ -43,21 +43,39 @@ func GetAllChannels(startIdx int, num int, selectAll bool, idSort bool) ([]*Chan return channels, err } -func SearchChannels(keyword string, group string) (channels []*Channel, err error) { +func SearchChannels(keyword string, group string, model string) ([]*Channel, error) { + var channels []*Channel keyCol := "`key`" + groupCol := "`group`" + modelsCol := "`models`" + + // 如果是 PostgreSQL,使用双引号 if common.UsingPostgreSQL { keyCol = `"key"` + groupCol = `"group"` + modelsCol = `"models"` } + + // 构造基础查询 + baseQuery := DB.Model(&Channel{}).Omit(keyCol) + + // 构造WHERE子句 + var whereClause string + var args []interface{} if group != "" { - groupCol := "`group`" - if common.UsingPostgreSQL { - groupCol = `"group"` - } - err = DB.Omit("key").Where("(id = ? or name LIKE ? or "+keyCol+" = ?) and "+groupCol+" LIKE ?", common.String2Int(keyword), keyword+"%", keyword, "%"+group+"%").Find(&channels).Error + whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + groupCol + " LIKE ? AND " + modelsCol + " LIKE ?" + args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+group+"%", "%"+model+"%") } else { - err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error + whereClause = "(id = ? OR name LIKE ? OR " + keyCol + " = ?) AND " + modelsCol + " LIKE ?" + args = append(args, common.String2Int(keyword), "%"+keyword+"%", keyword, "%"+model+"%") } - return channels, err + + // 执行查询 + err := baseQuery.Where(whereClause, args...).Find(&channels).Error + if err != nil { + return nil, err + } + return channels, nil } func GetChannelById(id int, selectAll bool) (*Channel, error) { diff --git a/model/option.go b/model/option.go index 2feb711..9a7ad60 100644 --- a/model/option.go +++ b/model/option.go @@ -57,9 +57,11 @@ func InitOptionMap() { common.OptionMap["Logo"] = common.Logo common.OptionMap["ServerAddress"] = "" common.OptionMap["PayAddress"] = "" + common.OptionMap["CustomCallbackAddress"] = "" common.OptionMap["EpayId"] = "" common.OptionMap["EpayKey"] = "" common.OptionMap["Price"] = strconv.FormatFloat(common.Price, 'f', -1, 64) + common.OptionMap["MinTopUp"] = strconv.Itoa(common.MinTopUp) common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString() common.OptionMap["GitHubClientId"] = "" common.OptionMap["GitHubClientSecret"] = "" @@ -85,6 +87,7 @@ func InitOptionMap() { common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes) common.OptionMap["DataExportInterval"] = strconv.Itoa(common.DataExportInterval) common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime + common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar) common.OptionMapRWMutex.Unlock() loadOptionsFromDatabase() @@ -141,7 +144,7 @@ func updateOptionMap(key string, value string) (err error) { common.ImageDownloadPermission = intValue } } - if strings.HasSuffix(key, "Enabled") { + if strings.HasSuffix(key, "Enabled") || key == "DefaultCollapseSidebar" { boolValue := value == "true" switch key { case "PasswordRegisterEnabled": @@ -176,6 +179,8 @@ func updateOptionMap(key string, value string) (err error) { common.DrawingEnabled = boolValue case "DataExportEnabled": common.DataExportEnabled = boolValue + case "DefaultCollapseSidebar": + common.DefaultCollapseSidebar = boolValue } } switch key { @@ -196,12 +201,16 @@ func updateOptionMap(key string, value string) (err error) { common.ServerAddress = value case "PayAddress": common.PayAddress = value + case "CustomCallbackAddress": + common.CustomCallbackAddress = value case "EpayId": common.EpayId = value case "EpayKey": common.EpayKey = value case "Price": common.Price, _ = strconv.ParseFloat(value, 64) + case "MinTopUp": + common.MinTopUp, _ = strconv.Atoi(value) case "TopupGroupRatio": err = common.UpdateTopupGroupRatioByJSONString(value) case "GitHubClientId": diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index f64b8a4..cf14aee 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -71,10 +71,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { var responseText string - err, responseText = openaiStreamHandler(c, resp, info.RelayMode) + err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode) usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { - err, usage = openaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) + err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } return } diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index c0b6353..b0f3aa5 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -16,7 +16,7 @@ import ( "time" ) -func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) { +func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) { var responseTextBuilder strings.Builder scanner := bufio.NewScanner(resp.Body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { @@ -111,7 +111,7 @@ func openaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d return nil, responseTextBuilder.String() } -func openaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { +func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { var textResponse dto.TextResponse responseBody, err := io.ReadAll(resp.Body) if err != nil { diff --git a/relay/channel/zhipu_v4/adaptor.go b/relay/channel/zhipu_4v/adaptor.go similarity index 80% rename from relay/channel/zhipu_v4/adaptor.go rename to relay/channel/zhipu_4v/adaptor.go index 076c024..546b048 100644 --- a/relay/channel/zhipu_v4/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -1,4 +1,4 @@ -package zhipu_v4 +package zhipu_4v import ( "errors" @@ -8,7 +8,9 @@ import ( "net/http" "one-api/dto" "one-api/relay/channel" + "one-api/relay/channel/openai" relaycommon "one-api/relay/common" + "one-api/service" ) type Adaptor struct { @@ -41,9 +43,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { - err, usage = zhipuStreamHandler(c, resp) + var responseText string + err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) + usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) } else { - err, usage = zhipuHandler(c, resp) + err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) } return } diff --git a/relay/channel/zhipu_v4/constants.go b/relay/channel/zhipu_4v/constants.go similarity index 58% rename from relay/channel/zhipu_v4/constants.go rename to relay/channel/zhipu_4v/constants.go index 756629a..1b0b0cc 100644 --- a/relay/channel/zhipu_v4/constants.go +++ b/relay/channel/zhipu_4v/constants.go @@ -1,7 +1,7 @@ -package zhipu_v4 +package zhipu_4v var ModelList = []string{ "glm-4", "glm-4v", "glm-3-turbo", } -var ChannelName = "zhipu_v4" +var ChannelName = "zhipu_4v" diff --git a/relay/channel/zhipu_v4/dto.go b/relay/channel/zhipu_4v/dto.go similarity index 99% rename from relay/channel/zhipu_v4/dto.go rename to relay/channel/zhipu_4v/dto.go index b17a89f..4d86767 100644 --- a/relay/channel/zhipu_v4/dto.go +++ b/relay/channel/zhipu_4v/dto.go @@ -1,4 +1,4 @@ -package zhipu_v4 +package zhipu_4v import ( "one-api/dto" diff --git a/relay/channel/zhipu_v4/relay-zhipu_v4.go b/relay/channel/zhipu_4v/relay-zhipu_v4.go similarity index 99% rename from relay/channel/zhipu_v4/relay-zhipu_v4.go rename to relay/channel/zhipu_4v/relay-zhipu_v4.go index dde1b5f..af9b1d8 100644 --- a/relay/channel/zhipu_v4/relay-zhipu_v4.go +++ b/relay/channel/zhipu_4v/relay-zhipu_v4.go @@ -1,4 +1,4 @@ -package zhipu_v4 +package zhipu_4v import ( "bufio" diff --git a/relay/relay-text.go b/relay/relay-text.go index 63b1ff6..95c29ec 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -59,6 +59,7 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) } } relayInfo.IsStream = textRequest.Stream + relayInfo.UpstreamModelName = textRequest.Model return textRequest, nil } diff --git a/relay/relay_adaptor.go b/relay/relay_adaptor.go index 328b1e6..cc76270 100644 --- a/relay/relay_adaptor.go +++ b/relay/relay_adaptor.go @@ -11,7 +11,7 @@ import ( "one-api/relay/channel/tencent" "one-api/relay/channel/xunfei" "one-api/relay/channel/zhipu" - "one-api/relay/channel/zhipu_v4" + "one-api/relay/channel/zhipu_4v" "one-api/relay/constant" ) @@ -38,7 +38,7 @@ func GetAdaptor(apiType int) channel.Adaptor { case constant.APITypeZhipu: return &zhipu.Adaptor{} case constant.APITypeZhipu_v4: - return &zhipu_v4.Adaptor{} + return &zhipu_4v.Adaptor{} } return nil } diff --git a/router/api-router.go b/router/api-router.go index f72b9b3..1683a4f 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -75,7 +75,7 @@ func SetApiRouter(router *gin.Engine) { { channelRoute.GET("/", controller.GetAllChannels) channelRoute.GET("/search", controller.SearchChannels) - channelRoute.GET("/models", controller.ListModels) + channelRoute.GET("/models", controller.ChannelListModels) channelRoute.GET("/:id", controller.GetChannel) channelRoute.GET("/test", controller.TestAllChannels) channelRoute.GET("/test/:id", controller.TestChannel) diff --git a/service/epay.go b/service/epay.go new file mode 100644 index 0000000..7ce4aad --- /dev/null +++ b/service/epay.go @@ -0,0 +1,10 @@ +package service + +import "one-api/common" + +func GetCallbackAddress() string { + if common.CustomCallbackAddress == "" { + return common.ServerAddress + } + return common.CustomCallbackAddress +} diff --git a/web/src/App.js b/web/src/App.js index e980397..86458b0 100644 --- a/web/src/App.js +++ b/web/src/App.js @@ -29,7 +29,7 @@ const Home = lazy(() => import('./pages/Home')); const About = lazy(() => import('./pages/About')); function App() { const [userState, userDispatch] = useContext(UserContext); - const [statusState, statusDispatch] = useContext(StatusContext); + // const [statusState, statusDispatch] = useContext(StatusContext); const loadUser = () => { let user = localStorage.getItem('user'); @@ -38,47 +38,9 @@ function App() { userDispatch({ type: 'login', payload: data }); } }; - const loadStatus = async () => { - const res = await API.get('/api/status'); - const { success, data } = res.data; - if (success) { - localStorage.setItem('status', JSON.stringify(data)); - statusDispatch({ type: 'set', payload: data }); - localStorage.setItem('system_name', data.system_name); - localStorage.setItem('logo', data.logo); - localStorage.setItem('footer_html', data.footer_html); - localStorage.setItem('quota_per_unit', data.quota_per_unit); - localStorage.setItem('display_in_currency', data.display_in_currency); - localStorage.setItem('enable_drawing', data.enable_drawing); - localStorage.setItem('enable_data_export', data.enable_data_export); - localStorage.setItem('data_export_default_time', data.data_export_default_time); - if (data.chat_link) { - localStorage.setItem('chat_link', data.chat_link); - } else { - localStorage.removeItem('chat_link'); - } - if (data.chat_link2) { - localStorage.setItem('chat_link2', data.chat_link2); - } else { - localStorage.removeItem('chat_link2'); - } - // if ( - // data.version !== process.env.REACT_APP_VERSION && - // data.version !== 'v0.0.0' && - // process.env.REACT_APP_VERSION !== '' - // ) { - // showNotice( - // `新版本可用:${data.version},请使用快捷键 Shift + F5 刷新页面` - // ); - // } - } else { - showError('无法正常连接至服务器!'); - } - }; useEffect(() => { loadUser(); - loadStatus().then(); let systemName = getSystemName(); if (systemName) { document.title = systemName; diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index 9fdd597..26da9a3 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -257,6 +257,7 @@ const ChannelsTable = () => { const [idSort, setIdSort] = useState(false); const [searchKeyword, setSearchKeyword] = useState(''); const [searchGroup, setSearchGroup] = useState(''); + const [searchModel, setSearchModel] = useState(''); const [searching, setSearching] = useState(false); const [updatingBalance, setUpdatingBalance] = useState(false); const [pageSize, setPageSize] = useState(ITEMS_PER_PAGE); @@ -440,15 +441,15 @@ const ChannelsTable = () => { } }; - const searchChannels = async (searchKeyword, searchGroup) => { - if (searchKeyword === '' && searchGroup === '') { + const searchChannels = async (searchKeyword, searchGroup, searchModel) => { + if (searchKeyword === '' && searchGroup === '' && searchModel === '') { // if keyword is blank, load files instead. await loadChannels(0, pageSize, idSort); setActivePage(1); return; } setSearching(true); - const res = await API.get(`/api/channel/search?keyword=${searchKeyword}&group=${searchGroup}`); + const res = await API.get(`/api/channel/search?keyword=${searchKeyword}&group=${searchGroup}&model=${searchModel}`); const {success, message, data} = res.data; if (success) { setChannels(data); @@ -625,13 +626,12 @@ const ChannelsTable = () => { return ( <> -
{searchChannels(searchKeyword, searchGroup)}} labelPosition='left'> - + {searchChannels(searchKeyword, searchGroup, searchModel)}} labelPosition='left'>
{ setSearchKeyword(v.trim()) }} /> + { + setSearchModel(v.trim()) + }} + /> { setSearchGroup(v) - searchChannels(searchKeyword, v) + searchChannels(searchKeyword, v, searchModel) }}/> +
diff --git a/web/src/components/OperationSetting.js b/web/src/components/OperationSetting.js index a90855f..1a3a4be 100644 --- a/web/src/components/OperationSetting.js +++ b/web/src/components/OperationSetting.js @@ -27,6 +27,7 @@ const OperationSetting = () => { DataExportEnabled: '', DataExportDefaultTime: 'hour', DataExportInterval: 5, + DefaultCollapseSidebar: '', // 默认折叠侧边栏 RetryTimes: 0 }); const [originInputs, setOriginInputs] = useState({}); @@ -65,6 +66,10 @@ const OperationSetting = () => { if (key.endsWith('Enabled')) { value = inputs[key] === 'true' ? 'false' : 'true'; } + if (key === 'DefaultCollapseSidebar') { + value = inputs[key] === 'true' ? 'false' : 'true'; + } + console.log(key, value) const res = await API.put('/api/option/', { key, value @@ -79,7 +84,7 @@ const OperationSetting = () => { }; const handleInputChange = async (e, {name, value}) => { - if (name.endsWith('Enabled') || name === 'DataExportInterval' || name === 'DataExportDefaultTime') { + if (name.endsWith('Enabled') || name === 'DataExportInterval' || name === 'DataExportDefaultTime' || name === 'DefaultCollapseSidebar') { if (name === 'DataExportDefaultTime') { localStorage.setItem('data_export_default_time', value); } @@ -243,6 +248,12 @@ const OperationSetting = () => { name='DrawingEnabled' onChange={handleInputChange} /> + { submitConfig('general').then(); diff --git a/web/src/components/SiderBar.js b/web/src/components/SiderBar.js index e36599b..0d72b97 100644 --- a/web/src/components/SiderBar.js +++ b/web/src/components/SiderBar.js @@ -1,8 +1,8 @@ -import React, {useContext, useMemo, useState} from 'react'; +import React, { useContext, useEffect, useLayoutEffect, useMemo, useState } from 'react'; import {Link, useNavigate} from 'react-router-dom'; import {UserContext} from '../context/User'; -import {API, getLogo, getSystemName, isAdmin, isMobile, showSuccess} from '../helpers'; +import { API, getLogo, getSystemName, isAdmin, isMobile, showError, showSuccess } from '../helpers'; import '../index.css'; import { @@ -24,11 +24,14 @@ import {Nav, Avatar, Dropdown, Layout} from '@douyinfe/semi-ui'; const SiderBar = () => { const [userState, userDispatch] = useContext(UserContext); + const defaultIsCollapsed = isMobile() || localStorage.getItem('default_collapse_sidebar') === 'true'; + let navigate = useNavigate(); const [selectedKeys, setSelectedKeys] = useState(['home']); - const [showSidebar, setShowSidebar] = useState(false); const systemName = getSystemName(); const logo = getLogo(); + const [isCollapsed, setIsCollapsed] = useState(defaultIsCollapsed); + const headerButtons = useMemo(() => [ { text: '首页', @@ -110,15 +113,41 @@ const SiderBar = () => { // } ], [localStorage.getItem('enable_data_export'), localStorage.getItem('enable_drawing'), localStorage.getItem('chat_link'), isAdmin()]); + const loadStatus = async () => { + const res = await API.get('/api/status'); + const { success, data } = res.data; + if (success) { + localStorage.setItem('status', JSON.stringify(data)); + // statusDispatch({ type: 'set', payload: data }); + localStorage.setItem('system_name', data.system_name); + localStorage.setItem('logo', data.logo); + localStorage.setItem('footer_html', data.footer_html); + localStorage.setItem('quota_per_unit', data.quota_per_unit); + localStorage.setItem('display_in_currency', data.display_in_currency); + localStorage.setItem('enable_drawing', data.enable_drawing); + localStorage.setItem('enable_data_export', data.enable_data_export); + localStorage.setItem('data_export_default_time', data.data_export_default_time); + localStorage.setItem('default_collapse_sidebar', data.default_collapse_sidebar); + if (data.chat_link) { + localStorage.setItem('chat_link', data.chat_link); + } else { + localStorage.removeItem('chat_link'); + } + if (data.chat_link2) { + localStorage.setItem('chat_link2', data.chat_link2); + } else { + localStorage.removeItem('chat_link2'); + } + } else { + showError('无法正常连接至服务器!'); + } + }; - async function logout() { - setShowSidebar(false); - await API.get('/api/user/logout'); - showSuccess('注销成功!'); - userDispatch({type: 'logout'}); - localStorage.removeItem('user'); - navigate('/login'); - } + useEffect(() => { + loadStatus().then(() => { + setIsCollapsed(isMobile() || localStorage.getItem('default_collapse_sidebar') === 'true'); + }); + },[]) return ( <> @@ -127,7 +156,12 @@ const SiderBar = () => {