diff --git a/api/core/types/config.go b/api/core/types/config.go index 3a1755f7..7045fa22 100644 --- a/api/core/types/config.go +++ b/api/core/types/config.go @@ -113,11 +113,11 @@ type Platform string const OpenAI = Platform("OpenAI") const Azure = Platform("Azure") -const ChatGML = Platform("ChatGML") +const ChatGLM = Platform("ChatGLM") // UserChatConfig 用户的聊天配置 type UserChatConfig struct { - ApiKeys map[Platform]string + ApiKeys map[Platform]string `json:"api_keys"` } type ModelAPIConfig struct { diff --git a/api/handler/admin/user_handler.go b/api/handler/admin/user_handler.go index 6fb5d4b1..d848cb26 100644 --- a/api/handler/admin/user_handler.go +++ b/api/handler/admin/user_handler.go @@ -109,7 +109,7 @@ func (h *UserHandler) Save(c *gin.Context) { ApiKeys: map[types.Platform]string{ types.OpenAI: "", types.Azure: "", - types.ChatGML: "", + types.ChatGLM: "", }, }), Calls: h.App.SysConfig.UserInitCalls, diff --git a/api/handler/azure_handler.go b/api/handler/azure_handler.go index 1776d6a4..39fc8b3c 100644 --- a/api/handler/azure_handler.go +++ b/api/handler/azure_handler.go @@ -28,7 +28,7 @@ func (h *ChatHandler) sendAzureMessage( ws *types.WsClient) error { promptCreatedAt := time.Now() // 记录提问时间 start := time.Now() - var apiKey string + var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform] response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) if err != nil { @@ -174,7 +174,9 @@ func (h *ChatHandler) sendAzureMessage( // 消息发送成功 if len(contents) > 0 { // 更新用户的对话次数 - h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", 1)) + if userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" { + h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", 1)) + } if message.Role == "" { message.Role = "assistant" @@ -183,14 +185,14 @@ func (h *ChatHandler) sendAzureMessage( useMsg := types.Message{Role: "user", Content: prompt} // 更新上下文消息,如果是调用函数则不需要更新上下文 - if userVo.ChatConfig.EnableContext && functionCall == false { + if h.App.ChatConfig.EnableContext && functionCall == false { chatCtx = append(chatCtx, useMsg) // 提问消息 chatCtx = append(chatCtx, message) // 回复消息 h.App.ChatContexts.Put(session.ChatId, chatCtx) } // 追加聊天记录 - if userVo.ChatConfig.EnableHistory { + if h.App.ChatConfig.EnableHistory { useContext := true if functionCall { useContext = false @@ -254,8 +256,6 @@ func (h *ChatHandler) sendAzureMessage( } else { totalTokens = replyToken + getTotalTokens(req) } - h.db.Model(&model.User{}).Where("id = ?", userVo.Id). - UpdateColumn("tokens", gorm.Expr("tokens + ?", totalTokens)) h.db.Model(&model.User{}).Where("id = ?", userVo.Id). UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens)) } diff --git a/api/handler/chat_handler.go b/api/handler/chat_handler.go index 0dc932e7..0fae34f4 100644 --- a/api/handler/chat_handler.go +++ b/api/handler/chat_handler.go @@ -169,7 +169,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio return nil } - if userVo.Calls <= 0 { + if userVo.Calls <= 0 && userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" { utils.ReplyMessage(ws, "您的对话次数已经用尽,请联系管理员或者点击左下角菜单加入众筹获得100次对话!") utils.ReplyMessage(ws, "![](/images/wx.png)") return nil @@ -189,7 +189,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio req.Temperature = h.App.ChatConfig.Azure.Temperature req.MaxTokens = h.App.ChatConfig.Azure.MaxTokens break - case types.ChatGML: + case types.ChatGLM: req.Temperature = h.App.ChatConfig.ChatGML.Temperature req.MaxTokens = h.App.ChatConfig.ChatGML.MaxTokens break @@ -208,7 +208,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio // 加载聊天上下文 var chatCtx []interface{} - if userVo.ChatConfig.EnableContext { + if h.App.ChatConfig.EnableContext { if h.App.ChatContexts.Has(session.ChatId) { chatCtx = h.App.ChatContexts.Get(session.ChatId) } else { @@ -269,11 +269,10 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio return h.sendAzureMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) case types.OpenAI: return h.sendOpenAiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) - case types.ChatGML: + case types.ChatGLM: return h.sendChatGLMMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws) } - - return nil + return fmt.Errorf("not supported platform: %s", session.Model.Platform) } // Tokens 统计 token 数量 @@ -336,7 +335,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf md := strings.Replace(req.Model, ".", "", 1) apiURL = strings.Replace(h.App.ChatConfig.Azure.ApiURL, "{model}", md, 1) break - case types.ChatGML: + case types.ChatGLM: apiURL = strings.Replace(h.App.ChatConfig.ChatGML.ApiURL, "{model}", req.Model, 1) req.Prompt = req.Messages req.Messages = nil @@ -368,21 +367,24 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf } else { client = http.DefaultClient } - var key model.ApiKey - res := h.db.Where("platform = ?", platform).Order("last_used_at ASC").First(&key) - if res.Error != nil { - return nil, errors.New("no available key, please import key") + if *apiKey == "" { + var key model.ApiKey + res := h.db.Where("platform = ?", platform).Order("last_used_at ASC").First(&key) + if res.Error != nil { + return nil, errors.New("no available key, please import key") + } + // 更新 API KEY 的最后使用时间 + h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix()) + *apiKey = key.Value } - // 更新 API KEY 的最后使用时间 - h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix()) - logger.Infof("Sending %s request, KEY: %s, PROXY: %s, Model: %s", platform, key.Value, proxyURL, req.Model) + logger.Infof("Sending %s request, KEY: %s, PROXY: %s, Model: %s", platform, *apiKey, proxyURL, req.Model) switch platform { case types.Azure: - request.Header.Set("api-key", key.Value) + request.Header.Set("api-key", *apiKey) break - case types.ChatGML: - token, err := h.getChatGLMToken(key.Value) + case types.ChatGLM: + token, err := h.getChatGLMToken(*apiKey) if err != nil { return nil, err } @@ -390,8 +392,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) break default: - request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key.Value)) + request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey)) } - *apiKey = key.Value return client.Do(request) } diff --git a/api/handler/chatglm_handler.go b/api/handler/chatglm_handler.go index bedd92e5..c97c2ce5 100644 --- a/api/handler/chatglm_handler.go +++ b/api/handler/chatglm_handler.go @@ -29,7 +29,7 @@ func (h *ChatHandler) sendChatGLMMessage( ws *types.WsClient) error { promptCreatedAt := time.Now() // 记录提问时间 start := time.Now() - var apiKey string + var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform] response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) if err != nil { @@ -103,7 +103,9 @@ func (h *ChatHandler) sendChatGLMMessage( // 消息发送成功 if len(contents) > 0 { // 更新用户的对话次数 - h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", 1)) + if userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" { + h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", 1)) + } if message.Role == "" { message.Role = "assistant" @@ -112,14 +114,14 @@ func (h *ChatHandler) sendChatGLMMessage( useMsg := types.Message{Role: "user", Content: prompt} // 更新上下文消息,如果是调用函数则不需要更新上下文 - if userVo.ChatConfig.EnableContext { + if h.App.ChatConfig.EnableContext { chatCtx = append(chatCtx, useMsg) // 提问消息 chatCtx = append(chatCtx, message) // 回复消息 h.App.ChatContexts.Put(session.ChatId, chatCtx) } // 追加聊天记录 - if userVo.ChatConfig.EnableHistory { + if h.App.ChatConfig.EnableHistory { // for prompt promptToken, err := utils.CalcTokens(prompt, req.Model) if err != nil { @@ -167,8 +169,6 @@ func (h *ChatHandler) sendChatGLMMessage( // 计算本次对话消耗的总 token 数量 var totalTokens = 0 totalTokens = replyToken + getTotalTokens(req) - h.db.Model(&model.User{}).Where("id = ?", userVo.Id). - UpdateColumn("tokens", gorm.Expr("tokens + ?", totalTokens)) h.db.Model(&model.User{}).Where("id = ?", userVo.Id). UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens)) } @@ -205,7 +205,7 @@ func (h *ChatHandler) sendChatGLMMessage( return fmt.Errorf("error with decode response: %v", err) } if !res.Success { - utils.ReplyMessage(ws, "请求 ChatGML 失败:"+res.Msg) + utils.ReplyMessage(ws, "请求 ChatGLM 失败:"+res.Msg) } } diff --git a/api/handler/openai_handler.go b/api/handler/openai_handler.go index 68a6c6e4..bcafc6ce 100644 --- a/api/handler/openai_handler.go +++ b/api/handler/openai_handler.go @@ -28,7 +28,7 @@ func (h *ChatHandler) sendOpenAiMessage( ws *types.WsClient) error { promptCreatedAt := time.Now() // 记录提问时间 start := time.Now() - var apiKey string + var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform] response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) if err != nil { @@ -174,7 +174,9 @@ func (h *ChatHandler) sendOpenAiMessage( // 消息发送成功 if len(contents) > 0 { // 更新用户的对话次数 - h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", 1)) + if userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" { + h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", 1)) + } if message.Role == "" { message.Role = "assistant" @@ -183,14 +185,14 @@ func (h *ChatHandler) sendOpenAiMessage( useMsg := types.Message{Role: "user", Content: prompt} // 更新上下文消息,如果是调用函数则不需要更新上下文 - if userVo.ChatConfig.EnableContext && functionCall == false { + if h.App.ChatConfig.EnableContext && functionCall == false { chatCtx = append(chatCtx, useMsg) // 提问消息 chatCtx = append(chatCtx, message) // 回复消息 h.App.ChatContexts.Put(session.ChatId, chatCtx) } // 追加聊天记录 - if userVo.ChatConfig.EnableHistory { + if h.App.ChatConfig.EnableHistory { useContext := true if functionCall { useContext = false @@ -254,8 +256,6 @@ func (h *ChatHandler) sendOpenAiMessage( } else { totalTokens = replyToken + getTotalTokens(req) } - h.db.Model(&model.User{}).Where("id = ?", userVo.Id). - UpdateColumn("tokens", gorm.Expr("tokens + ?", totalTokens)) h.db.Model(&model.User{}).Where("id = ?", userVo.Id). UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", totalTokens)) } diff --git a/api/handler/user_handler.go b/api/handler/user_handler.go index 79016ff8..ff2f57e8 100644 --- a/api/handler/user_handler.go +++ b/api/handler/user_handler.go @@ -100,7 +100,7 @@ func (h *UserHandler) Register(c *gin.Context) { ApiKeys: map[types.Platform]string{ types.OpenAI: "", types.Azure: "", - types.ChatGML: "", + types.ChatGLM: "", }, }), Calls: h.App.SysConfig.UserInitCalls, diff --git a/api/store/vo/user.go b/api/store/vo/user.go index bb13eec2..7b00b873 100644 --- a/api/store/vo/user.go +++ b/api/store/vo/user.go @@ -4,16 +4,16 @@ import "chatplus/core/types" type User struct { BaseVo - Mobile string `json:"mobile"` - Avatar string `json:"avatar"` - Salt string `json:"salt"` // 密码盐 - TotalTokens int64 `json:"total_tokens"` // 总消耗tokens - Calls int `json:"calls"` // 剩余对话次数 - ImgCalls int `json:"img_calls"` - ChatConfig types.ChatConfig `json:"chat_config"` // 聊天配置 - ChatRoles []string `json:"chat_roles"` // 聊天角色集合 - ExpiredTime int64 `json:"expired_time"` // 账户到期时间 - Status bool `json:"status"` // 当前状态 - LastLoginAt int64 `json:"last_login_at"` // 最后登录时间 - LastLoginIp string `json:"last_login_ip"` // 最后登录 IP + Mobile string `json:"mobile"` + Avatar string `json:"avatar"` + Salt string `json:"salt"` // 密码盐 + TotalTokens int64 `json:"total_tokens"` // 总消耗tokens + Calls int `json:"calls"` // 剩余对话次数 + ImgCalls int `json:"img_calls"` + ChatConfig types.UserChatConfig `json:"chat_config"` // 聊天配置 + ChatRoles []string `json:"chat_roles"` // 聊天角色集合 + ExpiredTime int64 `json:"expired_time"` // 账户到期时间 + Status bool `json:"status"` // 当前状态 + LastLoginAt int64 `json:"last_login_at"` // 最后登录时间 + LastLoginIp string `json:"last_login_ip"` // 最后登录 IP } diff --git a/web/src/components/ConfigDialog.vue b/web/src/components/ConfigDialog.vue index e9fef50c..dc8cb332 100644 --- a/web/src/components/ConfigDialog.vue +++ b/web/src/components/ConfigDialog.vue @@ -8,7 +8,7 @@ title="用户设置" >
- + {{ form.mobile }} @@ -34,8 +34,14 @@ {{ form['total_tokens'] }} - - + + + + + + + +
@@ -77,15 +83,16 @@ const form = ref({ mobile: '', calls: 0, tokens: 0, - chat_configs: {} + chat_config: {api_keys: {OpenAI: "", Azure: "", ChatGLM: ""}} }) onMounted(() => { // 获取最新用户信息 httpGet('/api/user/profile').then(res => { form.value = res.data - }).catch(() => { - ElMessage.error("获取用户信息失败") + form.value.chat_config.api_keys = res.data.chat_config.api_keys ?? {OpenAI: "", Azure: "", ChatGLM: ""} + }).catch(e => { + ElMessage.error("获取用户信息失败:" + e.message) }); }) diff --git a/web/src/views/admin/ApiKey.vue b/web/src/views/admin/ApiKey.vue index f7816994..768beddb 100644 --- a/web/src/views/admin/ApiKey.vue +++ b/web/src/views/admin/ApiKey.vue @@ -82,7 +82,7 @@ const rules = reactive({ const loading = ref(true) const formRef = ref(null) const title = ref("") -const platforms = ref(["Azure", "OpenAI", "ChatGML"]) +const platforms = ref(["Azure", "OpenAI", "ChatGLM"]) // 获取数据 httpGet('/api/admin/apikey/list').then((res) => { diff --git a/web/src/views/admin/ChatModel.vue b/web/src/views/admin/ChatModel.vue index d9c2d32c..3c0540b4 100644 --- a/web/src/views/admin/ChatModel.vue +++ b/web/src/views/admin/ChatModel.vue @@ -9,7 +9,7 @@ @@ -47,7 +47,7 @@ - {{item}} + {{ item }} @@ -94,7 +94,7 @@ const rules = reactive({ }) const loading = ref(true) const formRef = ref(null) -const platforms = ref(["Azure","OpenAI","ChatGML"]) +const platforms = ref(["Azure", "OpenAI", "ChatGLM"]) // 获取数据 httpGet('/api/admin/model/list').then((res) => { @@ -127,13 +127,13 @@ onMounted(() => { const sortedData = Array.from(from.children).map(row => row.querySelector('.sort').getAttribute('data-id')); const ids = [] const sorts = [] - sortedData.forEach((id,index) => { + sortedData.forEach((id, index) => { ids.push(parseInt(id)) sorts.push(index) }) - httpPost("/api/admin/model/sort", {ids: ids, sorts:sorts}).catch(e => { - ElMessage.error("排序失败:"+e.message) + httpPost("/api/admin/model/sort", {ids: ids, sorts: sorts}).catch(e => { + ElMessage.error("排序失败:" + e.message) }) } }) @@ -174,7 +174,7 @@ const enable = (row) => { httpPost('/api/admin/model/enable', {id: row.id, enabled: row.enabled}).then(() => { ElMessage.success("操作成功!") }).catch(e => { - ElMessage.error("操作失败:"+e.message) + ElMessage.error("操作失败:" + e.message) }) }