mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-19 01:36:38 +08:00
feat: allow bind a chat model for chat role
This commit is contained in:
parent
6a3e26b566
commit
8be9a21efd
@ -8,9 +8,10 @@ import (
|
|||||||
"chatplus/store/vo"
|
"chatplus/store/vo"
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
"chatplus/utils/resp"
|
"chatplus/utils/resp"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type ChatModelHandler struct {
|
type ChatModelHandler struct {
|
||||||
|
@ -8,9 +8,10 @@ import (
|
|||||||
"chatplus/store/vo"
|
"chatplus/store/vo"
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
"chatplus/utils/resp"
|
"chatplus/utils/resp"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type ChatRoleHandler struct {
|
type ChatRoleHandler struct {
|
||||||
@ -63,6 +64,25 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// initialize model mane for role
|
||||||
|
modelIds := make([]int, 0)
|
||||||
|
for _, v := range items {
|
||||||
|
if v.ModelId > 0 {
|
||||||
|
modelIds = append(modelIds, v.ModelId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
modelNameMap := make(map[int]string)
|
||||||
|
if len(modelIds) > 0 {
|
||||||
|
var models []model.ChatModel
|
||||||
|
tx := h.DB.Where("id IN ?", modelIds).Find(&models)
|
||||||
|
if tx.Error == nil {
|
||||||
|
for _, m := range models {
|
||||||
|
modelNameMap[int(m.Id)] = m.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for _, v := range items {
|
for _, v := range items {
|
||||||
var role vo.ChatRole
|
var role vo.ChatRole
|
||||||
err := utils.CopyObject(v, &role)
|
err := utils.CopyObject(v, &role)
|
||||||
@ -70,6 +90,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
|
|||||||
role.Id = v.Id
|
role.Id = v.Id
|
||||||
role.CreatedAt = v.CreatedAt.Unix()
|
role.CreatedAt = v.CreatedAt.Unix()
|
||||||
role.UpdatedAt = v.UpdatedAt.Unix()
|
role.UpdatedAt = v.UpdatedAt.Unix()
|
||||||
|
role.ModelName = modelNameMap[role.ModelId]
|
||||||
roles = append(roles, role)
|
roles = append(roles, role)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -68,9 +68,20 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
|||||||
modelId := h.GetInt(c, "model_id", 0)
|
modelId := h.GetInt(c, "model_id", 0)
|
||||||
|
|
||||||
client := types.NewWsClient(ws)
|
client := types.NewWsClient(ws)
|
||||||
|
var chatRole model.ChatRole
|
||||||
|
res := h.DB.First(&chatRole, roleId)
|
||||||
|
if res.Error != nil || !chatRole.Enable {
|
||||||
|
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// if the role bind a model_id, use role's bind model_id
|
||||||
|
if chatRole.ModelId > 0 {
|
||||||
|
modelId = chatRole.ModelId
|
||||||
|
}
|
||||||
// get model info
|
// get model info
|
||||||
var chatModel model.ChatModel
|
var chatModel model.ChatModel
|
||||||
res := h.DB.First(&chatModel, modelId)
|
res = h.DB.First(&chatModel, modelId)
|
||||||
if res.Error != nil || chatModel.Enabled == false {
|
if res.Error != nil || chatModel.Enabled == false {
|
||||||
utils.ReplyMessage(client, "当前AI模型暂未启用,连接已关闭!!!")
|
utils.ReplyMessage(client, "当前AI模型暂未启用,连接已关闭!!!")
|
||||||
c.Abort()
|
c.Abort()
|
||||||
@ -113,13 +124,6 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
|||||||
Temperature: chatModel.Temperature,
|
Temperature: chatModel.Temperature,
|
||||||
Platform: types.Platform(chatModel.Platform)}
|
Platform: types.Platform(chatModel.Platform)}
|
||||||
logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username)
|
logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username)
|
||||||
var chatRole model.ChatRole
|
|
||||||
res = h.DB.First(&chatRole, roleId)
|
|
||||||
if res.Error != nil || !chatRole.Enable {
|
|
||||||
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
h.Init()
|
h.Init()
|
||||||
|
|
||||||
|
@ -8,10 +8,11 @@ import (
|
|||||||
"chatplus/store/model"
|
"chatplus/store/model"
|
||||||
"chatplus/utils"
|
"chatplus/utils"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/imroc/req/v3"
|
|
||||||
"gorm.io/gorm"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/imroc/req/v3"
|
||||||
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SD 绘画服务
|
// SD 绘画服务
|
||||||
@ -146,7 +147,11 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
|||||||
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL)
|
apiURL := fmt.Sprintf("%s/sdapi/v1/txt2img", s.config.ApiURL)
|
||||||
logger.Debugf("send image request to %s", apiURL)
|
logger.Debugf("send image request to %s", apiURL)
|
||||||
go func() {
|
go func() {
|
||||||
response, err := s.httpClient.R().SetBody(body).SetSuccessResult(&res).Post(apiURL)
|
response, err := s.httpClient.R().
|
||||||
|
SetHeader("Authorization", s.config.ApiKey).
|
||||||
|
SetBody(body).
|
||||||
|
SetSuccessResult(&res).
|
||||||
|
Post(apiURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- err
|
errChan <- err
|
||||||
return
|
return
|
||||||
@ -207,7 +212,10 @@ func (s *Service) Txt2Img(task types.SdTask) error {
|
|||||||
func (s *Service) checkTaskProgress() (error, *TaskProgressResp) {
|
func (s *Service) checkTaskProgress() (error, *TaskProgressResp) {
|
||||||
apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL)
|
apiURL := fmt.Sprintf("%s/sdapi/v1/progress?skip_current_image=false", s.config.ApiURL)
|
||||||
var res TaskProgressResp
|
var res TaskProgressResp
|
||||||
response, err := s.httpClient.R().SetSuccessResult(&res).Get(apiURL)
|
response, err := s.httpClient.R().
|
||||||
|
SetHeader("Authorization", s.config.ApiKey).
|
||||||
|
SetSuccessResult(&res).
|
||||||
|
Get(apiURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err, nil
|
return err, nil
|
||||||
}
|
}
|
||||||
|
@ -9,4 +9,5 @@ type ChatRole struct {
|
|||||||
Icon string // 角色聊天图标
|
Icon string // 角色聊天图标
|
||||||
Enable bool // 是否启用被启用
|
Enable bool // 是否启用被启用
|
||||||
SortNum int //排序数字
|
SortNum int //排序数字
|
||||||
|
ModelId int // 绑定模型ID,绑定模型ID的角色只能用指定的模型来问答
|
||||||
}
|
}
|
||||||
|
@ -4,11 +4,13 @@ import "chatplus/core/types"
|
|||||||
|
|
||||||
type ChatRole struct {
|
type ChatRole struct {
|
||||||
BaseVo
|
BaseVo
|
||||||
Key string `json:"key"` // 角色唯一标识
|
Key string `json:"key"` // 角色唯一标识
|
||||||
Name string `json:"name"` // 角色名称
|
Name string `json:"name"` // 角色名称
|
||||||
Context []types.Message `json:"context"` // 角色语料信息
|
Context []types.Message `json:"context"` // 角色语料信息
|
||||||
HelloMsg string `json:"hello_msg"` // 打招呼的消息
|
HelloMsg string `json:"hello_msg"` // 打招呼的消息
|
||||||
Icon string `json:"icon"` // 角色聊天图标
|
Icon string `json:"icon"` // 角色聊天图标
|
||||||
Enable bool `json:"enable"` // 是否启用被启用
|
Enable bool `json:"enable"` // 是否启用被启用
|
||||||
SortNum int `json:"sort"` // 排序
|
SortNum int `json:"sort"` // 排序
|
||||||
|
ModelId int `json:"model_id"` // 绑定模型 ID
|
||||||
|
ModelName string `json:"model_name"` // 模型名称
|
||||||
}
|
}
|
||||||
|
1
database/update-v4.0.3.sql
Normal file
1
database/update-v4.0.3.sql
Normal file
@ -0,0 +1 @@
|
|||||||
|
ALTER TABLE `chatgpt_chat_roles` ADD `model_id` INT NOT NULL DEFAULT '0' COMMENT '绑定模型ID' AFTER `sort_num`;
|
@ -82,7 +82,6 @@
|
|||||||
<el-main v-loading="loading" element-loading-background="rgba(122, 122, 122, 0.3)">
|
<el-main v-loading="loading" element-loading-background="rgba(122, 122, 122, 0.3)">
|
||||||
<div class="chat-head">
|
<div class="chat-head">
|
||||||
<div class="chat-config">
|
<div class="chat-config">
|
||||||
<!-- <span class="role-select-label">聊天角色:</span>-->
|
|
||||||
<el-select v-model="roleId" filterable placeholder="角色" class="role-select" @change="_newChat">
|
<el-select v-model="roleId" filterable placeholder="角色" class="role-select" @change="_newChat">
|
||||||
<el-option
|
<el-option
|
||||||
v-for="item in roles"
|
v-for="item in roles"
|
||||||
@ -97,7 +96,7 @@
|
|||||||
</el-option>
|
</el-option>
|
||||||
</el-select>
|
</el-select>
|
||||||
|
|
||||||
<el-select v-model="modelID" placeholder="模型" @change="_newChat">
|
<el-select v-model="modelID" placeholder="模型" @change="_newChat" :disabled="disableModel">
|
||||||
<el-option
|
<el-option
|
||||||
v-for="item in models"
|
v-for="item in models"
|
||||||
:key="item.id"
|
:key="item.id"
|
||||||
@ -445,6 +444,8 @@ const _newChat = () => {
|
|||||||
newChat()
|
newChat()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const disableModel = ref(false)
|
||||||
// 新建会话
|
// 新建会话
|
||||||
const newChat = () => {
|
const newChat = () => {
|
||||||
if (!isLogin.value) {
|
if (!isLogin.value) {
|
||||||
@ -452,10 +453,11 @@ const newChat = () => {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const role = getRoleById(roleId.value)
|
const role = getRoleById(roleId.value)
|
||||||
if (role.key === 'gpt') {
|
showHello.value = role.key === 'gpt';
|
||||||
showHello.value = true
|
// if the role bind a model, disable model change
|
||||||
} else {
|
if (role.model_id > 0) {
|
||||||
showHello.value = false
|
modelID.value = role.model_id
|
||||||
|
disableModel.value = true
|
||||||
}
|
}
|
||||||
// 已有新开的会话
|
// 已有新开的会话
|
||||||
if (newChatItem.value !== null && newChatItem.value['role_id'] === roles.value[0]['role_id']) {
|
if (newChatItem.value !== null && newChatItem.value['role_id'] === roles.value[0]['role_id']) {
|
||||||
|
@ -11,7 +11,7 @@
|
|||||||
<el-form-item label="采样方法">
|
<el-form-item label="采样方法">
|
||||||
<template #default>
|
<template #default>
|
||||||
<div class="form-item-inner">
|
<div class="form-item-inner">
|
||||||
<el-select v-model="params.sampler" size="small">
|
<el-select v-model="params.sampler" size="small" style="width:150px">
|
||||||
<el-option v-for="item in samplers" :label="item" :value="item" :key="item"/>
|
<el-option v-for="item in samplers" :label="item" :value="item" :key="item"/>
|
||||||
</el-select>
|
</el-select>
|
||||||
<el-tooltip
|
<el-tooltip
|
||||||
@ -163,7 +163,7 @@
|
|||||||
<el-form-item label="放大算法">
|
<el-form-item label="放大算法">
|
||||||
<template #default>
|
<template #default>
|
||||||
<div class="form-item-inner">
|
<div class="form-item-inner">
|
||||||
<el-select v-model="params.hd_scale_alg" size="small">
|
<el-select v-model="params.hd_scale_alg" size="small" style="width:150px">
|
||||||
<el-option v-for="item in scaleAlg" :label="item" :value="item" :key="item"/>
|
<el-option v-for="item in scaleAlg" :label="item" :value="item" :key="item"/>
|
||||||
</el-select>
|
</el-select>
|
||||||
<el-tooltip
|
<el-tooltip
|
||||||
|
@ -21,6 +21,7 @@
|
|||||||
</template>
|
</template>
|
||||||
</el-table-column>
|
</el-table-column>
|
||||||
<el-table-column label="角色标识" prop="key"/>
|
<el-table-column label="角色标识" prop="key"/>
|
||||||
|
<el-table-column label="绑定模型" prop="model_name"/>
|
||||||
<el-table-column label="启用状态">
|
<el-table-column label="启用状态">
|
||||||
<template #default="scope">
|
<template #default="scope">
|
||||||
<el-switch v-model="scope.row['enable']" @change="roleSet('enable',scope.row)"/>
|
<el-switch v-model="scope.row['enable']" @change="roleSet('enable',scope.row)"/>
|
||||||
@ -47,7 +48,7 @@
|
|||||||
|
|
||||||
<el-dialog
|
<el-dialog
|
||||||
v-model="showDialog"
|
v-model="showDialog"
|
||||||
title="编辑角色"
|
:title="optTitle"
|
||||||
:close-on-click-modal="false"
|
:close-on-click-modal="false"
|
||||||
width="50%"
|
width="50%"
|
||||||
>
|
>
|
||||||
@ -73,6 +74,21 @@
|
|||||||
/>
|
/>
|
||||||
</el-form-item>
|
</el-form-item>
|
||||||
|
|
||||||
|
<el-form-item label="绑定模型:" prop="model_id">
|
||||||
|
<el-select
|
||||||
|
v-model="role.model_id"
|
||||||
|
filterable
|
||||||
|
placeholder="请选择模型"
|
||||||
|
>
|
||||||
|
<el-option
|
||||||
|
v-for="item in models"
|
||||||
|
:key="item.id"
|
||||||
|
:label="item.name"
|
||||||
|
:value="item.id"
|
||||||
|
/>
|
||||||
|
</el-select>
|
||||||
|
</el-form-item>
|
||||||
|
|
||||||
<el-form-item label="打招呼信息:" prop="hello_msg">
|
<el-form-item label="打招呼信息:" prop="hello_msg">
|
||||||
<el-input
|
<el-input
|
||||||
v-model="role.hello_msg"
|
v-model="role.hello_msg"
|
||||||
@ -151,7 +167,7 @@ const tableData = ref([])
|
|||||||
const sortedTableData = ref([])
|
const sortedTableData = ref([])
|
||||||
const role = ref({context: []})
|
const role = ref({context: []})
|
||||||
const formRef = ref(null)
|
const formRef = ref(null)
|
||||||
const editRow = ref({})
|
const optTitle = ref({})
|
||||||
const loading = ref(true)
|
const loading = ref(true)
|
||||||
|
|
||||||
const rules = reactive({
|
const rules = reactive({
|
||||||
@ -165,18 +181,30 @@ const rules = reactive({
|
|||||||
hello_msg: [{required: true, message: '请输入打招呼信息', trigger: 'change',}]
|
hello_msg: [{required: true, message: '请输入打招呼信息', trigger: 'change',}]
|
||||||
})
|
})
|
||||||
|
|
||||||
// 获取角色列表
|
const models = ref([])
|
||||||
httpGet('/api/admin/role/list').then((res) => {
|
onMounted(() => {
|
||||||
tableData.value = res.data
|
fetchData()
|
||||||
sortedTableData.value = copyObj(tableData.value)
|
|
||||||
loading.value = false
|
// get chat models
|
||||||
}).catch(() => {
|
httpGet('/api/admin/model/list?enable=1').then((res) => {
|
||||||
ElMessage.error("获取聊天角色失败");
|
models.value = res.data
|
||||||
|
}).catch(() => {
|
||||||
|
ElMessage.error("获取AI模型数据失败");
|
||||||
|
})
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
onMounted(() => {
|
const fetchData = () => {
|
||||||
const drawBodyWrapper = document.querySelector('.el-table__body tbody')
|
// 获取角色列表
|
||||||
|
httpGet('/api/admin/role/list').then((res) => {
|
||||||
|
tableData.value = res.data
|
||||||
|
sortedTableData.value = copyObj(tableData.value)
|
||||||
|
loading.value = false
|
||||||
|
}).catch(() => {
|
||||||
|
ElMessage.error("获取聊天角色失败");
|
||||||
|
})
|
||||||
|
|
||||||
|
const drawBodyWrapper = document.querySelector('.el-table__body tbody')
|
||||||
// 初始化拖动排序插件
|
// 初始化拖动排序插件
|
||||||
Sortable.create(drawBodyWrapper, {
|
Sortable.create(drawBodyWrapper, {
|
||||||
sort: true,
|
sort: true,
|
||||||
@ -199,7 +227,7 @@ onMounted(() => {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
}
|
||||||
|
|
||||||
const roleSet = (filed, row) => {
|
const roleSet = (filed, row) => {
|
||||||
httpPost('/api/admin/role/set', {id: row.id, filed: filed, value: row[filed]}).then(() => {
|
httpPost('/api/admin/role/set', {id: row.id, filed: filed, value: row[filed]}).then(() => {
|
||||||
@ -212,12 +240,14 @@ const roleSet = (filed, row) => {
|
|||||||
// 编辑
|
// 编辑
|
||||||
const curIndex = ref(0)
|
const curIndex = ref(0)
|
||||||
const rowEdit = function (index, row) {
|
const rowEdit = function (index, row) {
|
||||||
|
optTitle.value = "修改角色"
|
||||||
curIndex.value = index
|
curIndex.value = index
|
||||||
role.value = copyObj(row)
|
role.value = copyObj(row)
|
||||||
showDialog.value = true
|
showDialog.value = true
|
||||||
}
|
}
|
||||||
|
|
||||||
const addRole = function () {
|
const addRole = function () {
|
||||||
|
optTitle.value = "添加新角色"
|
||||||
role.value = {context: []}
|
role.value = {context: []}
|
||||||
showDialog.value = true
|
showDialog.value = true
|
||||||
}
|
}
|
||||||
@ -226,14 +256,9 @@ const save = function () {
|
|||||||
formRef.value.validate((valid) => {
|
formRef.value.validate((valid) => {
|
||||||
if (valid) {
|
if (valid) {
|
||||||
showDialog.value = false
|
showDialog.value = false
|
||||||
httpPost('/api/admin/role/save', role.value).then((res) => {
|
httpPost('/api/admin/role/save', role.value).then(() => {
|
||||||
ElMessage.success('操作成功')
|
ElMessage.success('操作成功')
|
||||||
// 更新当前数据行
|
fetchData()
|
||||||
if (role.value.id) {
|
|
||||||
tableData.value[curIndex.value] = role.value
|
|
||||||
} else {
|
|
||||||
tableData.value.push(res.data)
|
|
||||||
}
|
|
||||||
}).catch((e) => {
|
}).catch((e) => {
|
||||||
ElMessage.error('操作失败,' + e.message)
|
ElMessage.error('操作失败,' + e.message)
|
||||||
})
|
})
|
||||||
@ -263,6 +288,7 @@ const removeContext = function (index) {
|
|||||||
role.value.context.splice(index, 1);
|
role.value.context.splice(index, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<style lang="stylus" scoped>
|
<style lang="stylus" scoped>
|
||||||
|
Loading…
Reference in New Issue
Block a user