diff --git a/api/go/core/types/chat.go b/api/go/core/types/chat.go index a4466d2e..315031c7 100644 --- a/api/go/core/types/chat.go +++ b/api/go/core/types/chat.go @@ -43,5 +43,5 @@ type ApiError struct { } } -const PROMPT_MSG = "prompt" // prompt message -const REPLY_MSG = "reply" // reply message +const PromptMsg = "prompt" // prompt message +const ReplyMsg = "reply" // reply message diff --git a/api/go/core/types/config.go b/api/go/core/types/config.go index e98a6ca6..4b55f1b7 100644 --- a/api/go/core/types/config.go +++ b/api/go/core/types/config.go @@ -33,7 +33,7 @@ type Session struct { // ChatConfig 系统默认的聊天配置 type ChatConfig struct { - ApiURL string `json:"api_url"` + ApiURL string `json:"api_url,omitempty"` Model string `json:"model"` // 默认模型 Temperature float32 `json:"temperature"` MaxTokens int `json:"max_tokens"` @@ -43,9 +43,12 @@ type ChatConfig struct { } type SystemConfig struct { - Title string `json:"title"` - AdminTitle string `json:"admin_title"` - Models []string `json:"models"` + Title string `json:"title"` + AdminTitle string `json:"admin_title"` + Models []string `json:"models"` + UserInitCalls int `json:"user_init_calls"` // 新用户注册默认总送多少次调用 } var GptModels = []string{"gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613", "gpt-4", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0613"} + +const UserInitCalls = 1000 diff --git a/api/go/handler/chat_handler.go b/api/go/handler/chat_handler.go index 9c3f3908..698a9954 100644 --- a/api/go/handler/chat_handler.go +++ b/api/go/handler/chat_handler.go @@ -167,7 +167,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession if res.Error == nil { for _, msg := range historyMessages { ms := types.Message{Role: "user", Content: msg.Content} - if msg.Type == types.REPLY_MSG { + if msg.Type == types.ReplyMsg { ms.Role = "assistant" } chatCtx = append(chatCtx, ms) @@ -276,7 +276,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession UserId: userVo.Id, ChatId: session.ChatId, RoleId: role.Id, - Type: types.PROMPT_MSG, + Type: types.PromptMsg, Icon: user.Avatar, Content: prompt, Tokens: token, @@ -297,7 +297,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession UserId: userVo.Id, ChatId: session.ChatId, RoleId: role.Id, - Type: types.REPLY_MSG, + Type: types.ReplyMsg, Icon: role.Icon, Content: message.Content, Tokens: token, diff --git a/api/go/handler/config_handler.go b/api/go/handler/config_handler.go index da5b455d..7549c492 100644 --- a/api/go/handler/config_handler.go +++ b/api/go/handler/config_handler.go @@ -52,6 +52,7 @@ func (h *ConfigHandler) Update(c *gin.Context) { resp.SUCCESS(c, config) } +// Get 获取指定的系统配置 func (h *ConfigHandler) Get(c *gin.Context) { key := c.Query("key") var config model.Config diff --git a/api/go/handler/user_handler.go b/api/go/handler/user_handler.go index 18afffd7..6a20c8f0 100644 --- a/api/go/handler/user_handler.go +++ b/api/go/handler/user_handler.go @@ -85,6 +85,16 @@ func (h *UserHandler) Register(c *gin.Context) { ApiKey: "", }), } + // 初始化调用次数 + var cfg model.Config + h.db.Where("marker = ?", "system").First(&cfg) + var config types.SystemConfig + err := utils.JsonDecode(cfg.Config, &config) + if err != nil || config.UserInitCalls <= 0 { + user.Calls = types.UserInitCalls + } else { + user.Calls = config.UserInitCalls + } res := h.db.Create(&user) if res.Error != nil { resp.ERROR(c, "保存数据失败")