From 8be9a21efd7399f7169d1cf0df480d7f08a25f10 Mon Sep 17 00:00:00 2001 From: RockYang Date: Fri, 5 Apr 2024 12:51:18 +0800 Subject: [PATCH 01/52] feat: allow bind a chat model for chat role --- api/handler/admin/chat_model_handler.go | 3 +- api/handler/admin/chat_role_handler.go | 23 ++++++++- api/handler/chatimpl/chat_handler.go | 20 ++++---- api/service/sd/service.go | 16 +++++-- api/store/model/chat_role.go | 1 + api/store/vo/chat_role.go | 16 ++++--- database/update-v4.0.3.sql | 1 + web/src/views/ChatPlus.vue | 14 +++--- web/src/views/ImageSd.vue | 4 +- web/src/views/admin/Roles.vue | 64 +++++++++++++++++-------- 10 files changed, 114 insertions(+), 48 deletions(-) create mode 100644 database/update-v4.0.3.sql diff --git a/api/handler/admin/chat_model_handler.go b/api/handler/admin/chat_model_handler.go index 97bb559e..47113055 100644 --- a/api/handler/admin/chat_model_handler.go +++ b/api/handler/admin/chat_model_handler.go @@ -8,9 +8,10 @@ import ( "chatplus/store/vo" "chatplus/utils" "chatplus/utils/resp" + "time" + "github.com/gin-gonic/gin" "gorm.io/gorm" - "time" ) type ChatModelHandler struct { diff --git a/api/handler/admin/chat_role_handler.go b/api/handler/admin/chat_role_handler.go index 7b72cb44..4d119faf 100644 --- a/api/handler/admin/chat_role_handler.go +++ b/api/handler/admin/chat_role_handler.go @@ -8,9 +8,10 @@ import ( "chatplus/store/vo" "chatplus/utils" "chatplus/utils/resp" + "time" + "github.com/gin-gonic/gin" "gorm.io/gorm" - "time" ) type ChatRoleHandler struct { @@ -63,6 +64,25 @@ func (h *ChatRoleHandler) List(c *gin.Context) { return } + // initialize model mane for role + modelIds := make([]int, 0) + for _, v := range items { + if v.ModelId > 0 { + modelIds = append(modelIds, v.ModelId) + } + } + + modelNameMap := make(map[int]string) + if len(modelIds) > 0 { + var models []model.ChatModel + tx := h.DB.Where("id IN ?", modelIds).Find(&models) + if tx.Error == nil { + for _, m := range models { + modelNameMap[int(m.Id)] = m.Name + } + } + } + for _, v := range items { var role vo.ChatRole err := utils.CopyObject(v, &role) @@ -70,6 +90,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) { role.Id = v.Id role.CreatedAt = v.CreatedAt.Unix() role.UpdatedAt = v.UpdatedAt.Unix() + role.ModelName = modelNameMap[role.ModelId] roles = append(roles, role) } } diff --git a/api/handler/chatimpl/chat_handler.go b/api/handler/chatimpl/chat_handler.go index c60df32e..cee0987d 100644 --- a/api/handler/chatimpl/chat_handler.go +++ b/api/handler/chatimpl/chat_handler.go @@ -68,9 +68,20 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { modelId := h.GetInt(c, "model_id", 0) client := types.NewWsClient(ws) + var chatRole model.ChatRole + res := h.DB.First(&chatRole, roleId) + if res.Error != nil || !chatRole.Enable { + utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!") + c.Abort() + return + } + // if the role bind a model_id, use role's bind model_id + if chatRole.ModelId > 0 { + modelId = chatRole.ModelId + } // get model info var chatModel model.ChatModel - res := h.DB.First(&chatModel, modelId) + res = h.DB.First(&chatModel, modelId) if res.Error != nil || chatModel.Enabled == false { utils.ReplyMessage(client, "当前AI模型暂未启用,连接已关闭!!!") c.Abort() @@ -113,13 +124,6 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) { Temperature: chatModel.Temperature, Platform: types.Platform(chatModel.Platform)} logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username) - var chatRole model.ChatRole - res = h.DB.First(&chatRole, roleId) - if res.Error != nil || !chatRole.Enable { - utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!") - c.Abort() - return - } h.Init() diff --git a/api/service/sd/service.go b/api/service/sd/service.go index 4f68f3e0..34d47697 100644 --- a/api/service/sd/service.go +++ b/api/service/sd/service.go @@ -8,10 +8,11 @@ import ( "chatplus/store/model" "chatplus/utils" "fmt" - "github.com/imroc/req/v3" - "gorm.io/gorm" "strings" "time" + + "github.com/imroc/req/v3" + "gorm.io/gorm" ) // SD 绘画服务 @@ -146,7 +147,11 @@ func (s *Service) Txt2Img(task types.SdTask) error { apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL) logger.Debugf("send image request to %s", apiURL) go func() { - response, err := s.httpClient.R().SetBody(body).SetSuccessResult(&res).Post(apiURL) + response, err := s.httpClient.R(). + SetHeader("Authorization", s.config.ApiKey). + SetBody(body). + SetSuccessResult(&res). + Post(apiURL) if err != nil { errChan <- err return @@ -207,7 +212,10 @@ func (s *Service) Txt2Img(task types.SdTask) error { func (s *Service) checkTaskProgress() (error, *TaskProgressResp) { apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL) var res TaskProgressResp - response, err := s.httpClient.R().SetSuccessResult(&res).Get(apiURL) + response, err := s.httpClient.R(). + SetHeader("Authorization", s.config.ApiKey). + SetSuccessResult(&res). + Get(apiURL) if err != nil { return err, nil } diff --git a/api/store/model/chat_role.go b/api/store/model/chat_role.go index cc05cf7d..50e438bf 100644 --- a/api/store/model/chat_role.go +++ b/api/store/model/chat_role.go @@ -9,4 +9,5 @@ type ChatRole struct { Icon string // 角色聊天图标 Enable bool // 是否启用被启用 SortNum int //排序数字 + ModelId int // 绑定模型ID,绑定模型ID的角色只能用指定的模型来问答 } diff --git a/api/store/vo/chat_role.go b/api/store/vo/chat_role.go index 52f696e5..e13d5f0c 100644 --- a/api/store/vo/chat_role.go +++ b/api/store/vo/chat_role.go @@ -4,11 +4,13 @@ import "chatplus/core/types" type ChatRole struct { BaseVo - Key string `json:"key"` // 角色唯一标识 - Name string `json:"name"` // 角色名称 - Context []types.Message `json:"context"` // 角色语料信息 - HelloMsg string `json:"hello_msg"` // 打招呼的消息 - Icon string `json:"icon"` // 角色聊天图标 - Enable bool `json:"enable"` // 是否启用被启用 - SortNum int `json:"sort"` // 排序 + Key string `json:"key"` // 角色唯一标识 + Name string `json:"name"` // 角色名称 + Context []types.Message `json:"context"` // 角色语料信息 + HelloMsg string `json:"hello_msg"` // 打招呼的消息 + Icon string `json:"icon"` // 角色聊天图标 + Enable bool `json:"enable"` // 是否启用被启用 + SortNum int `json:"sort"` // 排序 + ModelId int `json:"model_id"` // 绑定模型 ID + ModelName string `json:"model_name"` // 模型名称 } diff --git a/database/update-v4.0.3.sql b/database/update-v4.0.3.sql new file mode 100644 index 00000000..fb22e6dd --- /dev/null +++ b/database/update-v4.0.3.sql @@ -0,0 +1 @@ +ALTER TABLE `chatgpt_chat_roles` ADD `model_id` INT NOT NULL DEFAULT '0' COMMENT '绑定模型ID' AFTER `sort_num`; \ No newline at end of file diff --git a/web/src/views/ChatPlus.vue b/web/src/views/ChatPlus.vue index e14bac7b..8e3593d0 100644 --- a/web/src/views/ChatPlus.vue +++ b/web/src/views/ChatPlus.vue @@ -82,7 +82,6 @@
- - + { newChat() } } + +const disableModel = ref(false) // 新建会话 const newChat = () => { if (!isLogin.value) { @@ -452,10 +453,11 @@ const newChat = () => { return; } const role = getRoleById(roleId.value) - if (role.key === 'gpt') { - showHello.value = true - } else { - showHello.value = false + showHello.value = role.key === 'gpt'; + // if the role bind a model, disable model change + if (role.model_id > 0) { + modelID.value = role.model_id + disableModel.value = true } // 已有新开的会话 if (newChatItem.value !== null && newChatItem.value['role_id'] === roles.value[0]['role_id']) { diff --git a/web/src/views/ImageSd.vue b/web/src/views/ImageSd.vue index 990fa400..36ceae2f 100644 --- a/web/src/views/ImageSd.vue +++ b/web/src/views/ImageSd.vue @@ -11,7 +11,7 @@