diff --git a/api/handler/admin/user_handler.go b/api/handler/admin/user_handler.go index 430b66bb..6f534d7b 100644 --- a/api/handler/admin/user_handler.go +++ b/api/handler/admin/user_handler.go @@ -4,6 +4,7 @@ import ( "chatplus/core" "chatplus/core/types" "chatplus/handler" + "chatplus/service" "chatplus/store/model" "chatplus/store/vo" "chatplus/utils" @@ -17,10 +18,11 @@ import ( type UserHandler struct { handler.BaseHandler + licenseService *service.LicenseService } -func NewUserHandler(app *core.AppServer, db *gorm.DB) *UserHandler { - return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}} +func NewUserHandler(app *core.AppServer, db *gorm.DB, licenseService *service.LicenseService) *UserHandler { + return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, licenseService: licenseService} } // List 用户列表 @@ -75,6 +77,13 @@ func (h *UserHandler) Save(c *gin.Context) { resp.ERROR(c, types.InvalidArgs) return } + // 检测最大注册人数 + var totalUser int64 + h.DB.Model(&model.User{}).Count(&totalUser) + if int(totalUser) >= h.licenseService.GetLicense().UserNum { + resp.ERROR(c, "当前注册用户数已达上限,请请升级 License") + return + } var user = model.User{} var res *gorm.DB var userVo vo.User diff --git a/api/handler/user_handler.go b/api/handler/user_handler.go index db4d4066..f86e0ed6 100644 --- a/api/handler/user_handler.go +++ b/api/handler/user_handler.go @@ -3,6 +3,7 @@ package handler import ( "chatplus/core" "chatplus/core/types" + "chatplus/service" "chatplus/store/model" "chatplus/store/vo" "chatplus/utils" @@ -21,16 +22,23 @@ import ( type UserHandler struct { BaseHandler - searcher *xdb.Searcher - redis *redis.Client + searcher *xdb.Searcher + redis *redis.Client + licenseService *service.LicenseService } func NewUserHandler( app *core.AppServer, db *gorm.DB, searcher *xdb.Searcher, - client *redis.Client) *UserHandler { - return &UserHandler{BaseHandler: BaseHandler{DB: db, App: app}, searcher: searcher, redis: client} + client *redis.Client, + licenseService *service.LicenseService) *UserHandler { + return &UserHandler{ + BaseHandler: BaseHandler{DB: db, App: app}, + searcher: searcher, + redis: client, + licenseService: licenseService, + } } // Register user register @@ -53,6 +61,14 @@ func (h *UserHandler) Register(c *gin.Context) { return } + // 检测最大注册人数 + var totalUser int64 + h.DB.Model(&model.User{}).Count(&totalUser) + if int(totalUser) >= h.licenseService.GetLicense().UserNum { + resp.ERROR(c, "当前注册用户数已达上限,请请升级 License") + return + } + // 检查验证码 var key string if data.RegWay == "email" || data.RegWay == "mobile" || data.Code != "" { diff --git a/api/main.go b/api/main.go index e1c05ee5..b060d675 100644 --- a/api/main.go +++ b/api/main.go @@ -166,6 +166,9 @@ func main() { fx.Provide(service.NewSmtpService), // License 服务 fx.Provide(service.NewLicenseService), + fx.Invoke(func(licenseService *service.LicenseService) { + licenseService.SyncLicense() + }), // 微信机器人服务 fx.Provide(wx.NewWeChatBot), diff --git a/api/service/license_service.go b/api/service/license_service.go index 73a0bbe9..f30ce098 100644 --- a/api/service/license_service.go +++ b/api/service/license_service.go @@ -13,10 +13,11 @@ import ( ) type LicenseService struct { - config types.ApiConfig - levelDB *store.LevelDB - license types.License - machineId string + config types.ApiConfig + levelDB *store.LevelDB + license types.License + urlWhiteList []string + machineId string } func NewLicenseService(server *core.AppServer, levelDB *store.LevelDB) * LicenseService { @@ -28,25 +29,27 @@ func NewLicenseService(server *core.AppServer, levelDB *store.LevelDB) * License machineId = info.HostID } return &LicenseService{ - config: server.Config.ApiConfig, - levelDB: levelDB, - license: license, + config: server.Config.ApiConfig, + levelDB: levelDB, + license: license, machineId: machineId, } } +type License struct { + Name string `json:"name"` + Value string `json:"license"` + Mid string `json:"mid"` + ExpiredAt int64 `json:"expired_at"` + UserNum int `json:"user_num"` +} + // ActiveLicense 激活 License func (s *LicenseService) ActiveLicense(license string, machineId string) error { var res struct { Code types.BizCode `json:"code"` Message string `json:"message"` - Data struct { - Name string `json:"name"` - License string `json:"license"` - Mid string `json:"mid"` - ExpiredAt int64 `json:"expired_at"` - UserNum int `json:"user_num"` - } + Data License `json:"data"` } apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/active") response, err := req.C().R(). @@ -75,10 +78,54 @@ func (s *LicenseService) ActiveLicense(license string, machineId string) error { if err != nil { return fmt.Errorf("保存许可证书失败:%v", err) } - return nil } +// SyncLicense 定期同步 License +func (s *LicenseService) SyncLicense() { + go func() { + for { + var res struct { + Code types.BizCode `json:"code"` + Message string `json:"message"` + Data struct { + License License `json:"license"` + Urls []string `json:"urls"` + } + } + apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/check") + response, err := req.C().R(). + SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}). + SetSuccessResult(&res).Post(apiURL) + if err != nil { + logger.Errorf("发送激活请求失败: %v", err) + goto next + } + if response.IsErrorState() { + logger.Errorf("激活失败:%v", response.Status) + goto next + } + if res.Code != types.Success { + logger.Errorf("激活失败:%v", res.Message) + s.license.IsActive = false + goto next + } + + s.license = types.License{ + Key: res.Data.License.Value, + MachineId: res.Data.License.Mid, + UserNum: res.Data.License.UserNum, + ExpiredAt: res.Data.License.ExpiredAt, + IsActive: true, + } + s.urlWhiteList = res.Data.Urls + logger.Debugf("同步 License 成功:%v\n%v", s.license, s.urlWhiteList) + next: + time.Sleep(time.Second * 10) + } + }() +} + // GetLicense 获取许可信息 func (s *LicenseService) GetLicense() types.License { return s.license @@ -98,12 +145,10 @@ func (s *LicenseService) IsValidApiURL(uri string) error { return nil } - if !strings.HasPrefix(uri, "https://gpt.bemore.lol") && - !strings.HasPrefix(uri, "https://api.openai.com") && - !strings.HasPrefix(uri, "http://cdn.chat-plus.net") && - !strings.HasPrefix(uri, "https://api.chat-plus.net") { - return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。",uri) + for _, v := range s.urlWhiteList { + if strings.HasPrefix(uri, v) { + return nil + } } - - return nil + return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。", uri) } \ No newline at end of file diff --git a/web/src/views/admin/Users.vue b/web/src/views/admin/Users.vue index 4df4bdbe..f6fe6354 100644 --- a/web/src/views/admin/Users.vue +++ b/web/src/views/admin/Users.vue @@ -73,7 +73,7 @@ - + @@ -186,8 +186,17 @@ const models = ref([]) const showUserEditDialog = ref(false) const showResetPassDialog = ref(false) const rules = reactive({ - username: [{required: true, message: '请输入账号', trigger: 'change',}], - password: [{required: true, message: '请输入密码', trigger: 'change',}], + username: [{required: true, message: '请输入账号', trigger: 'blur',}], + password: [ + { + required: true, + validator: (rule, value) => { + return !(value.length > 16 || value.length < 8); + + }, message: '密码必须为8-16', + trigger: 'blur' + } + ], calls: [ {required: true, message: '请输入提问次数'}, {type: 'number', message: '请输入有效数字'},