mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-13 04:33:42 +08:00
feat: allow bind a chat model for chat role
This commit is contained in:
@@ -8,9 +8,10 @@ import (
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ChatModelHandler struct {
|
||||
|
||||
@@ -8,9 +8,10 @@ import (
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ChatRoleHandler struct {
|
||||
@@ -63,6 +64,25 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// initialize model mane for role
|
||||
modelIds := make([]int, 0)
|
||||
for _, v := range items {
|
||||
if v.ModelId > 0 {
|
||||
modelIds = append(modelIds, v.ModelId)
|
||||
}
|
||||
}
|
||||
|
||||
modelNameMap := make(map[int]string)
|
||||
if len(modelIds) > 0 {
|
||||
var models []model.ChatModel
|
||||
tx := h.DB.Where("id IN ?", modelIds).Find(&models)
|
||||
if tx.Error == nil {
|
||||
for _, m := range models {
|
||||
modelNameMap[int(m.Id)] = m.Name
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, v := range items {
|
||||
var role vo.ChatRole
|
||||
err := utils.CopyObject(v, &role)
|
||||
@@ -70,6 +90,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
|
||||
role.Id = v.Id
|
||||
role.CreatedAt = v.CreatedAt.Unix()
|
||||
role.UpdatedAt = v.UpdatedAt.Unix()
|
||||
role.ModelName = modelNameMap[role.ModelId]
|
||||
roles = append(roles, role)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,9 +68,20 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
||||
modelId := h.GetInt(c, "model_id", 0)
|
||||
|
||||
client := types.NewWsClient(ws)
|
||||
var chatRole model.ChatRole
|
||||
res := h.DB.First(&chatRole, roleId)
|
||||
if res.Error != nil || !chatRole.Enable {
|
||||
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
// if the role bind a model_id, use role's bind model_id
|
||||
if chatRole.ModelId > 0 {
|
||||
modelId = chatRole.ModelId
|
||||
}
|
||||
// get model info
|
||||
var chatModel model.ChatModel
|
||||
res := h.DB.First(&chatModel, modelId)
|
||||
res = h.DB.First(&chatModel, modelId)
|
||||
if res.Error != nil || chatModel.Enabled == false {
|
||||
utils.ReplyMessage(client, "当前AI模型暂未启用,连接已关闭!!!")
|
||||
c.Abort()
|
||||
@@ -113,13 +124,6 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
||||
Temperature: chatModel.Temperature,
|
||||
Platform: types.Platform(chatModel.Platform)}
|
||||
logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username)
|
||||
var chatRole model.ChatRole
|
||||
res = h.DB.First(&chatRole, roleId)
|
||||
if res.Error != nil || !chatRole.Enable {
|
||||
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
h.Init()
|
||||
|
||||
|
||||
@@ -8,10 +8,11 @@ import (
|
||||
"chatplus/store/model"
|
||||
"chatplus/utils"
|
||||
"fmt"
|
||||
"github.com/imroc/req/v3"
|
||||
"gorm.io/gorm"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// SD 绘画服务
|
||||
@@ -146,7 +147,11 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL)
|
||||
logger.Debugf("send image request to %s", apiURL)
|
||||
go func() {
|
||||
response, err := s.httpClient.R().SetBody(body).SetSuccessResult(&res).Post(apiURL)
|
||||
response, err := s.httpClient.R().
|
||||
SetHeader("Authorization", s.config.ApiKey).
|
||||
SetBody(body).
|
||||
SetSuccessResult(&res).
|
||||
Post(apiURL)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
@@ -207,7 +212,10 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
||||
func (s *Service) checkTaskProgress() (error, *TaskProgressResp) {
|
||||
apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL)
|
||||
var res TaskProgressResp
|
||||
response, err := s.httpClient.R().SetSuccessResult(&res).Get(apiURL)
|
||||
response, err := s.httpClient.R().
|
||||
SetHeader("Authorization", s.config.ApiKey).
|
||||
SetSuccessResult(&res).
|
||||
Get(apiURL)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
|
||||
@@ -9,4 +9,5 @@ type ChatRole struct {
|
||||
Icon string // 角色聊天图标
|
||||
Enable bool // 是否启用被启用
|
||||
SortNum int //排序数字
|
||||
ModelId int // 绑定模型ID,绑定模型ID的角色只能用指定的模型来问答
|
||||
}
|
||||
|
||||
@@ -4,11 +4,13 @@ import "chatplus/core/types"
|
||||
|
||||
type ChatRole struct {
|
||||
BaseVo
|
||||
Key string `json:"key"` // 角色唯一标识
|
||||
Name string `json:"name"` // 角色名称
|
||||
Context []types.Message `json:"context"` // 角色语料信息
|
||||
HelloMsg string `json:"hello_msg"` // 打招呼的消息
|
||||
Icon string `json:"icon"` // 角色聊天图标
|
||||
Enable bool `json:"enable"` // 是否启用被启用
|
||||
SortNum int `json:"sort"` // 排序
|
||||
Key string `json:"key"` // 角色唯一标识
|
||||
Name string `json:"name"` // 角色名称
|
||||
Context []types.Message `json:"context"` // 角色语料信息
|
||||
HelloMsg string `json:"hello_msg"` // 打招呼的消息
|
||||
Icon string `json:"icon"` // 角色聊天图标
|
||||
Enable bool `json:"enable"` // 是否启用被启用
|
||||
SortNum int `json:"sort"` // 排序
|
||||
ModelId int `json:"model_id"` // 绑定模型 ID
|
||||
ModelName string `json:"model_name"` // 模型名称
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user