diff --git a/api/core/types/chat.go b/api/core/types/chat.go index 54917f24..b6b63aa2 100644 --- a/api/core/types/chat.go +++ b/api/core/types/chat.go @@ -62,6 +62,7 @@ type ChatModel struct { MaxTokens int `json:"max_tokens"` // 最大响应长度 MaxContext int `json:"max_context"` // 最大上下文长度 Temperature float32 `json:"temperature"` // 模型温度 + KeyId int `json:"key_id"` // 绑定 API KEY } type ApiError struct { diff --git a/api/handler/admin/api_key_handler.go b/api/handler/admin/api_key_handler.go index 5566b0c0..7935d0ba 100644 --- a/api/handler/admin/api_key_handler.go +++ b/api/handler/admin/api_key_handler.go @@ -66,9 +66,20 @@ func (h *ApiKeyHandler) Save(c *gin.Context) { } func (h *ApiKeyHandler) List(c *gin.Context) { + status := h.GetBool(c, "status") + t := h.GetTrim(c, "type") + + session := h.DB.Session(&gorm.Session{}) + if status { + session = session.Where("enabled", true) + } + if t != "" { + session = session.Where("type", t) + } + var items []model.ApiKey var keys = make([]vo.ApiKey, 0) - res := h.DB.Find(&items) + res := session.Find(&items) if res.Error == nil { for _, item := range items { var key vo.ApiKey diff --git a/api/handler/admin/chat_model_handler.go b/api/handler/admin/chat_model_handler.go index 4f6ee23e..9e546ac5 100644 --- a/api/handler/admin/chat_model_handler.go +++ b/api/handler/admin/chat_model_handler.go @@ -8,8 +8,6 @@ import ( "chatplus/store/vo" "chatplus/utils" "chatplus/utils/resp" - "time" - "github.com/gin-gonic/gin" "gorm.io/gorm" ) @@ -35,6 +33,7 @@ func (h *ChatModelHandler) Save(c *gin.Context) { MaxTokens int `json:"max_tokens"` // 最大响应长度 MaxContext int `json:"max_context"` // 最大上下文长度 Temperature float32 `json:"temperature"` // 模型温度 + KeyId int `json:"key_id"` CreatedAt int64 `json:"created_at"` } if err := c.ShouldBindJSON(&data); err != nil { @@ -52,12 +51,15 @@ func (h *ChatModelHandler) Save(c *gin.Context) { MaxTokens: data.MaxTokens, MaxContext: data.MaxContext, Temperature: data.Temperature, + KeyId: data.KeyId, Power: data.Power} - item.Id = data.Id - if item.Id > 0 { - item.CreatedAt = time.Unix(data.CreatedAt, 0) + var res *gorm.DB + if data.Id > 0 { + item.Id = data.Id + res = h.DB.Select("*").Omit("created_at").Updates(&item) + } else { + res = h.DB.Create(&item) } - res := h.DB.Save(&item) if res.Error != nil { resp.ERROR(c, "更新数据库失败!") return @@ -84,18 +86,33 @@ func (h *ChatModelHandler) List(c *gin.Context) { var items []model.ChatModel var cms = make([]vo.ChatModel, 0) res := session.Order("sort_num ASC").Find(&items) - if res.Error == nil { - for _, item := range items { - var cm vo.ChatModel - err := utils.CopyObject(item, &cm) - if err == nil { - cm.Id = item.Id - cm.CreatedAt = item.CreatedAt.Unix() - cm.UpdatedAt = item.UpdatedAt.Unix() - cms = append(cms, cm) - } else { - logger.Error(err) - } + if res.Error != nil { + resp.SUCCESS(c, cms) + return + } + + // initialize key name + keyIds := make([]int, 0) + for _, v := range items { + keyIds = append(keyIds, v.KeyId) + } + var keys []model.ApiKey + keyMap := make(map[uint]string) + h.DB.Where("id IN ?", keyIds).Find(&keys) + for _, v := range keys { + keyMap[v.Id] = v.Name + } + for _, item := range items { + var cm vo.ChatModel + err := utils.CopyObject(item, &cm) + if err == nil { + cm.Id = item.Id + cm.CreatedAt = item.CreatedAt.Unix() + cm.UpdatedAt = item.UpdatedAt.Unix() + cm.KeyName = keyMap[uint(item.KeyId)] + cms = append(cms, cm) + } else { + logger.Error(err) } } resp.SUCCESS(c, cms) diff --git a/api/handler/chatimpl/azure_handler.go b/api/handler/chatimpl/azure_handler.go index a040aae6..11b3b69a 100644 --- a/api/handler/chatimpl/azure_handler.go +++ b/api/handler/chatimpl/azure_handler.go @@ -30,7 +30,7 @@ func (h *ChatHandler) sendAzureMessage( promptCreatedAt := time.Now() // 记录提问时间 start := time.Now() var apiKey = model.ApiKey{} - response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) + response, err := h.doRequest(ctx, req, session, &apiKey) logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) if err != nil { if strings.Contains(err.Error(), "context canceled") { diff --git a/api/handler/chatimpl/baidu_handler.go b/api/handler/chatimpl/baidu_handler.go index e39ae455..08809dfe 100644 --- a/api/handler/chatimpl/baidu_handler.go +++ b/api/handler/chatimpl/baidu_handler.go @@ -47,7 +47,7 @@ func (h *ChatHandler) sendBaiduMessage( promptCreatedAt := time.Now() // 记录提问时间 start := time.Now() var apiKey = model.ApiKey{} - response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) + response, err := h.doRequest(ctx, req, session, &apiKey) logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) if err != nil { if strings.Contains(err.Error(), "context canceled") { diff --git a/api/handler/chatimpl/chat_handler.go b/api/handler/chatimpl/chat_handler.go index 5da9af69..08785752 100644 --- a/api/handler/chatimpl/chat_handler.go +++ b/api/handler/chatimpl/chat_handler.go @@ -122,6 +122,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { MaxTokens: chatModel.MaxTokens, MaxContext: chatModel.MaxContext, Temperature: chatModel.Temperature, + KeyId: chatModel.KeyId, Platform: types.Platform(chatModel.Platform)} logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username) @@ -463,13 +464,21 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) { // 发送请求到 OpenAI 服务器 // useOwnApiKey: 是否使用了用户自己的 API KEY -func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *model.ApiKey) (*http.Response, error) { - res := h.DB.Where("platform = ?", platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey) +func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, session *types.ChatSession, apiKey *model.ApiKey) (*http.Response, error) { + // if the chat model bind a KEY, use it directly + var res *gorm.DB + if session.Model.KeyId > 0 { + res = h.DB.Where("id", session.Model.KeyId).Find(apiKey) + } + // use the last unused key + if res.Error != nil { + res = h.DB.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey) + } if res.Error != nil { return nil, errors.New("no available key, please import key") } var apiURL string - switch platform { + switch session.Model.Platform { case types.Azure: md := strings.Replace(req.Model, ".", "", 1) apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1) @@ -492,7 +501,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf // 更新 API KEY 的最后使用时间 h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix()) // 百度文心,需要串接 access_token - if platform == types.Baidu { + if session.Model.Platform == types.Baidu { token, err := h.getBaiduToken(apiKey.Value) if err != nil { return nil, err @@ -527,8 +536,8 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf } else { client = http.DefaultClient } - logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", platform, apiURL, apiKey.Value, proxyURL, req.Model) - switch platform { + logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, proxyURL, req.Model) + switch session.Model.Platform { case types.Azure: request.Header.Set("api-key", apiKey.Value) break diff --git a/api/handler/chatimpl/chatglm_handler.go b/api/handler/chatimpl/chatglm_handler.go index 678f481d..5f391b3f 100644 --- a/api/handler/chatimpl/chatglm_handler.go +++ b/api/handler/chatimpl/chatglm_handler.go @@ -31,7 +31,7 @@ func (h *ChatHandler) sendChatGLMMessage( promptCreatedAt := time.Now() // 记录提问时间 start := time.Now() var apiKey = model.ApiKey{} - response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) + response, err := h.doRequest(ctx, req, session, &apiKey) logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) if err != nil { if strings.Contains(err.Error(), "context canceled") { diff --git a/api/handler/chatimpl/openai_handler.go b/api/handler/chatimpl/openai_handler.go index c4a29338..c991f670 100644 --- a/api/handler/chatimpl/openai_handler.go +++ b/api/handler/chatimpl/openai_handler.go @@ -31,7 +31,7 @@ func (h *ChatHandler) sendOpenAiMessage( promptCreatedAt := time.Now() // 记录提问时间 start := time.Now() var apiKey = model.ApiKey{} - response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) + response, err := h.doRequest(ctx, req, session, &apiKey) logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) if err != nil { if strings.Contains(err.Error(), "context canceled") { diff --git a/api/handler/chatimpl/qwen_handler.go b/api/handler/chatimpl/qwen_handler.go index 13b0156d..4484e57b 100644 --- a/api/handler/chatimpl/qwen_handler.go +++ b/api/handler/chatimpl/qwen_handler.go @@ -45,7 +45,7 @@ func (h *ChatHandler) sendQWenMessage( promptCreatedAt := time.Now() // 记录提问时间 start := time.Now() var apiKey = model.ApiKey{} - response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey) + response, err := h.doRequest(ctx, req, session, &apiKey) logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start)) if err != nil { if strings.Contains(err.Error(), "context canceled") { diff --git a/api/handler/chatimpl/xunfei_handler.go b/api/handler/chatimpl/xunfei_handler.go index adb646dc..36a5b785 100644 --- a/api/handler/chatimpl/xunfei_handler.go +++ b/api/handler/chatimpl/xunfei_handler.go @@ -12,6 +12,7 @@ import ( "encoding/json" "fmt" "github.com/gorilla/websocket" + "gorm.io/gorm" "html/template" "io" "net/http" @@ -69,7 +70,15 @@ func (h *ChatHandler) sendXunFeiMessage( ws *types.WsClient) error { promptCreatedAt := time.Now() // 记录提问时间 var apiKey model.ApiKey - res := h.DB.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey) + var res *gorm.DB + // use the bind key + if session.Model.KeyId > 0 { + res = h.DB.Where("id", session.Model.KeyId).Find(&apiKey) + } + // use the last unused key + if res.Error != nil { + res = h.DB.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey) + } if res.Error != nil { utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!") return nil diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index fa8762c9..e0e0f020 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -125,7 +125,7 @@ func (h *MidJourneyHandler) Image(c *gin.Context) { params += fmt.Sprintf(" --c %d", data.Chaos) } if len(data.ImgArr) > 0 && data.Iw > 0 { - params += fmt.Sprintf(" --iw %f", data.Iw) + params += fmt.Sprintf(" --iw %.2f", data.Iw) } if data.Raw { params += " --style raw" diff --git a/api/store/model/chat_model.go b/api/store/model/chat_model.go index 8ddff961..134655f3 100644 --- a/api/store/model/chat_model.go +++ b/api/store/model/chat_model.go @@ -12,4 +12,5 @@ type ChatModel struct { MaxTokens int // 最大响应长度 MaxContext int // 最大上下文长度 Temperature float32 // 模型温度 + KeyId int // 绑定 API KEY ID } diff --git a/api/store/vo/chat_model.go b/api/store/vo/chat_model.go index 81fc18ca..4fb21051 100644 --- a/api/store/vo/chat_model.go +++ b/api/store/vo/chat_model.go @@ -12,4 +12,6 @@ type ChatModel struct { MaxTokens int `json:"max_tokens"` // 最大响应长度 MaxContext int `json:"max_context"` // 最大上下文长度 Temperature float32 `json:"temperature"` // 模型温度 + KeyId int `json:"key_id"` + KeyName string `json:"key_name"` } diff --git a/database/update-v4.0.3.sql b/database/update-v4.0.3.sql index fb22e6dd..219c4187 100644 --- a/database/update-v4.0.3.sql +++ b/database/update-v4.0.3.sql @@ -1 +1,2 @@ -ALTER TABLE `chatgpt_chat_roles` ADD `model_id` INT NOT NULL DEFAULT '0' COMMENT '绑定模型ID' AFTER `sort_num`; \ No newline at end of file +ALTER TABLE `chatgpt_chat_roles` ADD `model_id` INT NOT NULL DEFAULT '0' COMMENT '绑定模型ID' AFTER `sort_num`; +ALTER TABLE `chatgpt_chat_models` ADD `key_id` INT(11) NOT NULL COMMENT '绑定API KEY ID' AFTER `open`; \ No newline at end of file diff --git a/web/.env.development b/web/.env.development index 330da87b..8474f044 100644 --- a/web/.env.development +++ b/web/.env.development @@ -6,4 +6,4 @@ VUE_APP_ADMIN_USER=admin VUE_APP_ADMIN_PASS=admin123 VUE_APP_KEY_PREFIX=ChatPLUS_DEV_ VUE_APP_TITLE="Geek-AI 创作系统" -VUE_APP_VERSION=v4.0.2 +VUE_APP_VERSION=v4.0.3 diff --git a/web/.env.production b/web/.env.production index c6581695..e1a98fa3 100644 --- a/web/.env.production +++ b/web/.env.production @@ -2,4 +2,4 @@ VUE_APP_API_HOST= VUE_APP_WS_HOST= VUE_APP_KEY_PREFIX=ChatPLUS_ VUE_APP_TITLE="Geek-AI 创作系统" -VUE_APP_VERSION=v4.0.2 +VUE_APP_VERSION=v4.0.3 diff --git a/web/src/components/admin/AdminSidebar.vue b/web/src/components/admin/AdminSidebar.vue index 0393c984..6b52f544 100644 --- a/web/src/components/admin/AdminSidebar.vue +++ b/web/src/components/admin/AdminSidebar.vue @@ -63,7 +63,8 @@ const logo = ref('/images/logo.png') // 加载系统配置 httpGet('/api/admin/config/get?key=system').then(res => { - title.value = res.data['admin_title']; + title.value = res.data['admin_title'] + logo.value = res.data['logo'] }).catch(e => { ElMessage.error("加载系统配置失败: " + e.message) }) @@ -191,9 +192,9 @@ setMenuItems(items) padding 6px 15px; .el-image { - width 30px; - height 30px; - padding-top 8px; + width 36px; + height 36px; + padding-top 5px; border-radius 100% .el-image__inner { diff --git a/web/src/views/ChatPlus.vue b/web/src/views/ChatPlus.vue index 1da5b107..f4d85f56 100644 --- a/web/src/views/ChatPlus.vue +++ b/web/src/views/ChatPlus.vue @@ -377,16 +377,7 @@ const initData = () => { httpGet(`/api/role/list`).then((res) => { roles.value = res.data; roleId.value = roles.value[0]['id']; - - const chatId = localStorage.getItem("chat_id") - const chat = getChatById(chatId) - if (chat === null) { - // 创建新的对话 - newChat(); - } else { - // 加载对话 - loadChat(chat) - } + newChat(); }).catch((e) => { ElMessage.error('获取聊天角色失败: ' + e.messages) }) diff --git a/web/src/views/Home.vue b/web/src/views/Home.vue index e842ae90..e320098a 100644 --- a/web/src/views/Home.vue +++ b/web/src/views/Home.vue @@ -2,7 +2,7 @@