diff --git a/api/handler/admin/user_handler.go b/api/handler/admin/user_handler.go
index 430b66bb..6f534d7b 100644
--- a/api/handler/admin/user_handler.go
+++ b/api/handler/admin/user_handler.go
@@ -4,6 +4,7 @@ import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
+ "chatplus/service"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
@@ -17,10 +18,11 @@ import (
type UserHandler struct {
handler.BaseHandler
+ licenseService *service.LicenseService
}
-func NewUserHandler(app *core.AppServer, db *gorm.DB) *UserHandler {
- return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
+func NewUserHandler(app *core.AppServer, db *gorm.DB, licenseService *service.LicenseService) *UserHandler {
+ return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, licenseService: licenseService}
}
// List 用户列表
@@ -75,6 +77,13 @@ func (h *UserHandler) Save(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
+ // 检测最大注册人数
+ var totalUser int64
+ h.DB.Model(&model.User{}).Count(&totalUser)
+ if int(totalUser) >= h.licenseService.GetLicense().UserNum {
+ resp.ERROR(c, "当前注册用户数已达上限,请请升级 License")
+ return
+ }
var user = model.User{}
var res *gorm.DB
var userVo vo.User
diff --git a/api/handler/user_handler.go b/api/handler/user_handler.go
index db4d4066..f86e0ed6 100644
--- a/api/handler/user_handler.go
+++ b/api/handler/user_handler.go
@@ -3,6 +3,7 @@ package handler
import (
"chatplus/core"
"chatplus/core/types"
+ "chatplus/service"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
@@ -21,16 +22,23 @@ import (
type UserHandler struct {
BaseHandler
- searcher *xdb.Searcher
- redis *redis.Client
+ searcher *xdb.Searcher
+ redis *redis.Client
+ licenseService *service.LicenseService
}
func NewUserHandler(
app *core.AppServer,
db *gorm.DB,
searcher *xdb.Searcher,
- client *redis.Client) *UserHandler {
- return &UserHandler{BaseHandler: BaseHandler{DB: db, App: app}, searcher: searcher, redis: client}
+ client *redis.Client,
+ licenseService *service.LicenseService) *UserHandler {
+ return &UserHandler{
+ BaseHandler: BaseHandler{DB: db, App: app},
+ searcher: searcher,
+ redis: client,
+ licenseService: licenseService,
+ }
}
// Register user register
@@ -53,6 +61,14 @@ func (h *UserHandler) Register(c *gin.Context) {
return
}
+ // 检测最大注册人数
+ var totalUser int64
+ h.DB.Model(&model.User{}).Count(&totalUser)
+ if int(totalUser) >= h.licenseService.GetLicense().UserNum {
+ resp.ERROR(c, "当前注册用户数已达上限,请请升级 License")
+ return
+ }
+
// 检查验证码
var key string
if data.RegWay == "email" || data.RegWay == "mobile" || data.Code != "" {
diff --git a/api/main.go b/api/main.go
index e1c05ee5..b060d675 100644
--- a/api/main.go
+++ b/api/main.go
@@ -166,6 +166,9 @@ func main() {
fx.Provide(service.NewSmtpService),
// License 服务
fx.Provide(service.NewLicenseService),
+ fx.Invoke(func(licenseService *service.LicenseService) {
+ licenseService.SyncLicense()
+ }),
// 微信机器人服务
fx.Provide(wx.NewWeChatBot),
diff --git a/api/service/license_service.go b/api/service/license_service.go
index 73a0bbe9..f30ce098 100644
--- a/api/service/license_service.go
+++ b/api/service/license_service.go
@@ -13,10 +13,11 @@ import (
)
type LicenseService struct {
- config types.ApiConfig
- levelDB *store.LevelDB
- license types.License
- machineId string
+ config types.ApiConfig
+ levelDB *store.LevelDB
+ license types.License
+ urlWhiteList []string
+ machineId string
}
func NewLicenseService(server *core.AppServer, levelDB *store.LevelDB) * LicenseService {
@@ -28,25 +29,27 @@ func NewLicenseService(server *core.AppServer, levelDB *store.LevelDB) * License
machineId = info.HostID
}
return &LicenseService{
- config: server.Config.ApiConfig,
- levelDB: levelDB,
- license: license,
+ config: server.Config.ApiConfig,
+ levelDB: levelDB,
+ license: license,
machineId: machineId,
}
}
+type License struct {
+ Name string `json:"name"`
+ Value string `json:"license"`
+ Mid string `json:"mid"`
+ ExpiredAt int64 `json:"expired_at"`
+ UserNum int `json:"user_num"`
+}
+
// ActiveLicense 激活 License
func (s *LicenseService) ActiveLicense(license string, machineId string) error {
var res struct {
Code types.BizCode `json:"code"`
Message string `json:"message"`
- Data struct {
- Name string `json:"name"`
- License string `json:"license"`
- Mid string `json:"mid"`
- ExpiredAt int64 `json:"expired_at"`
- UserNum int `json:"user_num"`
- }
+ Data License `json:"data"`
}
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/active")
response, err := req.C().R().
@@ -75,10 +78,54 @@ func (s *LicenseService) ActiveLicense(license string, machineId string) error {
if err != nil {
return fmt.Errorf("保存许可证书失败:%v", err)
}
-
return nil
}
+// SyncLicense 定期同步 License
+func (s *LicenseService) SyncLicense() {
+ go func() {
+ for {
+ var res struct {
+ Code types.BizCode `json:"code"`
+ Message string `json:"message"`
+ Data struct {
+ License License `json:"license"`
+ Urls []string `json:"urls"`
+ }
+ }
+ apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/check")
+ response, err := req.C().R().
+ SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}).
+ SetSuccessResult(&res).Post(apiURL)
+ if err != nil {
+ logger.Errorf("发送激活请求失败: %v", err)
+ goto next
+ }
+ if response.IsErrorState() {
+ logger.Errorf("激活失败:%v", response.Status)
+ goto next
+ }
+ if res.Code != types.Success {
+ logger.Errorf("激活失败:%v", res.Message)
+ s.license.IsActive = false
+ goto next
+ }
+
+ s.license = types.License{
+ Key: res.Data.License.Value,
+ MachineId: res.Data.License.Mid,
+ UserNum: res.Data.License.UserNum,
+ ExpiredAt: res.Data.License.ExpiredAt,
+ IsActive: true,
+ }
+ s.urlWhiteList = res.Data.Urls
+ logger.Debugf("同步 License 成功:%v\n%v", s.license, s.urlWhiteList)
+ next:
+ time.Sleep(time.Second * 10)
+ }
+ }()
+}
+
// GetLicense 获取许可信息
func (s *LicenseService) GetLicense() types.License {
return s.license
@@ -98,12 +145,10 @@ func (s *LicenseService) IsValidApiURL(uri string) error {
return nil
}
- if !strings.HasPrefix(uri, "https://gpt.bemore.lol") &&
- !strings.HasPrefix(uri, "https://api.openai.com") &&
- !strings.HasPrefix(uri, "http://cdn.chat-plus.net") &&
- !strings.HasPrefix(uri, "https://api.chat-plus.net") {
- return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。",uri)
+ for _, v := range s.urlWhiteList {
+ if strings.HasPrefix(uri, v) {
+ return nil
+ }
}
-
- return nil
+ return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。", uri)
}
\ No newline at end of file
diff --git a/web/src/views/admin/Users.vue b/web/src/views/admin/Users.vue
index 4df4bdbe..f6fe6354 100644
--- a/web/src/views/admin/Users.vue
+++ b/web/src/views/admin/Users.vue
@@ -73,7 +73,7 @@
-
+
@@ -186,8 +186,17 @@ const models = ref([])
const showUserEditDialog = ref(false)
const showResetPassDialog = ref(false)
const rules = reactive({
- username: [{required: true, message: '请输入账号', trigger: 'change',}],
- password: [{required: true, message: '请输入密码', trigger: 'change',}],
+ username: [{required: true, message: '请输入账号', trigger: 'blur',}],
+ password: [
+ {
+ required: true,
+ validator: (rule, value) => {
+ return !(value.length > 16 || value.length < 8);
+
+ }, message: '密码必须为8-16',
+ trigger: 'blur'
+ }
+ ],
calls: [
{required: true, message: '请输入提问次数'},
{type: 'number', message: '请输入有效数字'},