mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-18 01:06:39 +08:00
270 lines
7.4 KiB
Go
270 lines
7.4 KiB
Go
package admin
|
||
|
||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||
// * Use of this source code is governed by a Apache-2.0 license
|
||
// * that can be found in the LICENSE file.
|
||
// * @Author yangjian102621@163.com
|
||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||
|
||
import (
|
||
"geekai/core"
|
||
"geekai/core/types"
|
||
"geekai/handler"
|
||
"geekai/store/model"
|
||
"geekai/store/vo"
|
||
"geekai/utils"
|
||
"geekai/utils/resp"
|
||
"github.com/gin-gonic/gin"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
type ChatHandler struct {
|
||
handler.BaseHandler
|
||
}
|
||
|
||
func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
|
||
return &ChatHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
|
||
}
|
||
|
||
type chatItemVo struct {
|
||
Username string `json:"username"`
|
||
UserId uint `json:"user_id"`
|
||
ChatId string `json:"chat_id"`
|
||
Title string `json:"title"`
|
||
Role vo.ChatRole `json:"role"`
|
||
Model string `json:"model"`
|
||
Token int `json:"token"`
|
||
CreatedAt int64 `json:"created_at"`
|
||
MsgNum int `json:"msg_num"` // 消息数量
|
||
}
|
||
|
||
func (h *ChatHandler) List(c *gin.Context) {
|
||
var data struct {
|
||
Title string `json:"title"`
|
||
UserId uint `json:"user_id"`
|
||
Model string `json:"model"`
|
||
CreateAt []string `json:"created_time"`
|
||
Page int `json:"page"`
|
||
PageSize int `json:"page_size"`
|
||
}
|
||
if err := c.ShouldBindJSON(&data); err != nil {
|
||
resp.ERROR(c, types.InvalidArgs)
|
||
return
|
||
}
|
||
|
||
session := h.DB.Session(&gorm.Session{})
|
||
if data.Title != "" {
|
||
session = session.Where("title LIKE ?", "%"+data.Title+"%")
|
||
}
|
||
if data.UserId > 0 {
|
||
session = session.Where("user_id = ?", data.UserId)
|
||
}
|
||
if data.Model != "" {
|
||
session = session.Where("model = ?", data.Model)
|
||
}
|
||
if len(data.CreateAt) == 2 {
|
||
start := utils.Str2stamp(data.CreateAt[0] + " 00:00:00")
|
||
end := utils.Str2stamp(data.CreateAt[1] + " 00:00:00")
|
||
session = session.Where("created_at >= ? AND created_at <= ?", start, end)
|
||
}
|
||
|
||
var total int64
|
||
session.Model(&model.ChatItem{}).Count(&total)
|
||
var items []model.ChatItem
|
||
var list = make([]chatItemVo, 0)
|
||
offset := (data.Page - 1) * data.PageSize
|
||
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
|
||
if res.Error == nil {
|
||
userIds := make([]uint, 0)
|
||
chatIds := make([]string, 0)
|
||
roleIds := make([]uint, 0)
|
||
for _, item := range items {
|
||
userIds = append(userIds, item.UserId)
|
||
chatIds = append(chatIds, item.ChatId)
|
||
roleIds = append(roleIds, item.RoleId)
|
||
}
|
||
var messages []model.ChatMessage
|
||
var users []model.User
|
||
var roles []model.ChatRole
|
||
h.DB.Where("chat_id IN ?", chatIds).Find(&messages)
|
||
h.DB.Where("id IN ?", userIds).Find(&users)
|
||
h.DB.Where("id IN ?", roleIds).Find(&roles)
|
||
|
||
tokenMap := make(map[string]int)
|
||
userMap := make(map[uint]string)
|
||
msgMap := make(map[string]int)
|
||
roleMap := make(map[uint]vo.ChatRole)
|
||
for _, msg := range messages {
|
||
tokenMap[msg.ChatId] += msg.Tokens
|
||
msgMap[msg.ChatId] += 1
|
||
}
|
||
for _, user := range users {
|
||
userMap[user.Id] = user.Username
|
||
}
|
||
for _, r := range roles {
|
||
var roleVo vo.ChatRole
|
||
err := utils.CopyObject(r, &roleVo)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
roleMap[r.Id] = roleVo
|
||
}
|
||
for _, item := range items {
|
||
list = append(list, chatItemVo{
|
||
UserId: item.UserId,
|
||
Username: userMap[item.UserId],
|
||
ChatId: item.ChatId,
|
||
Title: item.Title,
|
||
Model: item.Model,
|
||
Token: tokenMap[item.ChatId],
|
||
MsgNum: msgMap[item.ChatId],
|
||
Role: roleMap[item.RoleId],
|
||
CreatedAt: item.CreatedAt.Unix(),
|
||
})
|
||
}
|
||
}
|
||
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
|
||
}
|
||
|
||
type chatMessageVo struct {
|
||
Id uint `json:"id"`
|
||
UserId uint `json:"user_id"`
|
||
Username string `json:"username"`
|
||
Content string `json:"content"`
|
||
Type string `json:"type"`
|
||
Model string `json:"model"`
|
||
Token int `json:"token"`
|
||
Icon string `json:"icon"`
|
||
CreatedAt int64 `json:"created_at"`
|
||
}
|
||
|
||
// Messages 读取聊天记录列表
|
||
func (h *ChatHandler) Messages(c *gin.Context) {
|
||
var data struct {
|
||
UserId uint `json:"user_id"`
|
||
Content string `json:"content"`
|
||
Model string `json:"model"`
|
||
CreateAt []string `json:"created_time"`
|
||
Page int `json:"page"`
|
||
PageSize int `json:"page_size"`
|
||
}
|
||
if err := c.ShouldBindJSON(&data); err != nil {
|
||
resp.ERROR(c, types.InvalidArgs)
|
||
return
|
||
}
|
||
|
||
session := h.DB.Session(&gorm.Session{})
|
||
if data.Content != "" {
|
||
session = session.Where("content LIKE ?", "%"+data.Content+"%")
|
||
}
|
||
if data.UserId > 0 {
|
||
session = session.Where("user_id = ?", data.UserId)
|
||
}
|
||
if data.Model != "" {
|
||
session = session.Where("model = ?", data.Model)
|
||
}
|
||
if len(data.CreateAt) == 2 {
|
||
start := utils.Str2stamp(data.CreateAt[0] + " 00:00:00")
|
||
end := utils.Str2stamp(data.CreateAt[1] + " 00:00:00")
|
||
session = session.Where("created_at >= ? AND created_at <= ?", start, end)
|
||
}
|
||
|
||
var total int64
|
||
session.Model(&model.ChatMessage{}).Count(&total)
|
||
var items []model.ChatMessage
|
||
var list = make([]chatMessageVo, 0)
|
||
offset := (data.Page - 1) * data.PageSize
|
||
res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
|
||
if res.Error == nil {
|
||
userIds := make([]uint, 0)
|
||
for _, item := range items {
|
||
userIds = append(userIds, item.UserId)
|
||
}
|
||
var users []model.User
|
||
h.DB.Where("id IN ?", userIds).Find(&users)
|
||
userMap := make(map[uint]string)
|
||
for _, user := range users {
|
||
userMap[user.Id] = user.Username
|
||
}
|
||
for _, item := range items {
|
||
list = append(list, chatMessageVo{
|
||
Id: item.Id,
|
||
UserId: item.UserId,
|
||
Username: userMap[item.UserId],
|
||
Content: item.Content,
|
||
Model: item.Model,
|
||
Token: item.Tokens,
|
||
Icon: item.Icon,
|
||
Type: item.Type,
|
||
CreatedAt: item.CreatedAt.Unix(),
|
||
})
|
||
}
|
||
}
|
||
resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
|
||
}
|
||
|
||
// History 获取聊天历史记录
|
||
func (h *ChatHandler) History(c *gin.Context) {
|
||
chatId := c.Query("chat_id") // 会话 ID
|
||
var items []model.ChatMessage
|
||
var messages = make([]vo.HistoryMessage, 0)
|
||
res := h.DB.Where("chat_id = ?", chatId).Find(&items)
|
||
if res.Error != nil {
|
||
resp.ERROR(c, "No history message")
|
||
return
|
||
} else {
|
||
for _, item := range items {
|
||
var v vo.HistoryMessage
|
||
err := utils.CopyObject(item, &v)
|
||
v.CreatedAt = item.CreatedAt.Unix()
|
||
v.UpdatedAt = item.UpdatedAt.Unix()
|
||
if err == nil {
|
||
messages = append(messages, v)
|
||
}
|
||
}
|
||
}
|
||
|
||
resp.SUCCESS(c, messages)
|
||
}
|
||
|
||
// RemoveChat 删除对话
|
||
func (h *ChatHandler) RemoveChat(c *gin.Context) {
|
||
chatId := h.GetTrim(c, "chat_id")
|
||
if chatId == "" {
|
||
resp.ERROR(c, "请传入 ChatId")
|
||
return
|
||
}
|
||
|
||
tx := h.DB.Begin()
|
||
// 删除聊天记录
|
||
res := tx.Unscoped().Debug().Where("chat_id = ?", chatId).Delete(&model.ChatMessage{})
|
||
if res.Error != nil {
|
||
resp.ERROR(c, "failed to remove chat message")
|
||
return
|
||
}
|
||
|
||
// 删除对话
|
||
res = tx.Unscoped().Where("chat_id = ?", chatId).Delete(model.ChatItem{})
|
||
if res.Error != nil {
|
||
tx.Rollback() // 回滚
|
||
resp.ERROR(c, "failed to remove chat")
|
||
return
|
||
}
|
||
|
||
tx.Commit()
|
||
resp.SUCCESS(c)
|
||
}
|
||
|
||
// RemoveMessage 删除聊天记录
|
||
func (h *ChatHandler) RemoveMessage(c *gin.Context) {
|
||
id := h.GetInt(c, "id", 0)
|
||
tx := h.DB.Unscoped().Where("id = ?", id).Delete(&model.ChatMessage{})
|
||
if tx.Error != nil {
|
||
logger.Error("error with update database:", tx.Error)
|
||
resp.ERROR(c, "更新数据库失败!")
|
||
return
|
||
}
|
||
resp.SUCCESS(c)
|
||
}
|