feat: load preview page do not require user to login

This commit is contained in:
RockYang
2024-03-19 18:25:01 +08:00
parent 8f47474edd
commit 58e44b7d12
57 changed files with 758 additions and 1550 deletions

View File

@@ -136,7 +136,7 @@ func (h *ChatHandler) sendAzureMessage(
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
@@ -158,7 +158,7 @@ func (h *ChatHandler) sendAzureMessage(
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
@@ -168,7 +168,7 @@ func (h *ChatHandler) sendAzureMessage(
// 保存当前会话
var chatItem model.ChatItem
res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
@@ -180,7 +180,7 @@ func (h *ChatHandler) sendAzureMessage(
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.db.Create(&chatItem)
h.DB.Create(&chatItem)
}
}
} else {

View File

@@ -160,7 +160,7 @@ func (h *ChatHandler) sendBaiduMessage(
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
@@ -182,7 +182,7 @@ func (h *ChatHandler) sendBaiduMessage(
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
@@ -191,7 +191,7 @@ func (h *ChatHandler) sendBaiduMessage(
// 保存当前会话
var chatItem model.ChatItem
res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
@@ -203,7 +203,7 @@ func (h *ChatHandler) sendBaiduMessage(
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.db.Create(&chatItem)
h.DB.Create(&chatItem)
}
}
} else {

View File

@@ -35,19 +35,16 @@ var logger = logger2.GetLogger()
type ChatHandler struct {
handler.BaseHandler
db *gorm.DB
redis *redis.Client
uploadManager *oss.UploaderManager
}
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager) *ChatHandler {
h := ChatHandler{
db: db,
return &ChatHandler{
BaseHandler: handler.BaseHandler{App: app, DB: db},
redis: redis,
uploadManager: manager,
}
h.App = app
return &h
}
func (h *ChatHandler) Init() {
@@ -73,7 +70,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
client := types.NewWsClient(ws)
// get model info
var chatModel model.ChatModel
res := h.db.First(&chatModel, modelId)
res := h.DB.First(&chatModel, modelId)
if res.Error != nil || chatModel.Enabled == false {
utils.ReplyMessage(client, "当前AI模型暂未启用连接已关闭")
c.Abort()
@@ -82,7 +79,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
session := h.App.ChatSession.Get(sessionId)
if session == nil {
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
logger.Info("用户未登录")
c.Abort()
@@ -99,7 +96,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
// use old chat data override the chat model and role ID
var chat model.ChatItem
res = h.db.Where("chat_id = ?", chatId).First(&chat)
res = h.DB.Where("chat_id = ?", chatId).First(&chat)
if res.Error == nil {
chatModel.Id = chat.ModelId
roleId = int(chat.RoleId)
@@ -116,7 +113,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
Platform: types.Platform(chatModel.Platform)}
logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username)
var chatRole model.ChatRole
res = h.db.First(&chatRole, roleId)
res = h.DB.First(&chatRole, roleId)
if res.Error != nil || !chatRole.Enable {
utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
c.Abort()
@@ -181,7 +178,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
}
var user model.User
res := h.db.Model(&model.User{}).First(&user, session.UserId)
res := h.DB.Model(&model.User{}).First(&user, session.UserId)
if res.Error != nil {
utils.ReplyMessage(ws, "非法用户,请联系管理员!")
return res.Error
@@ -238,7 +235,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
req.MaxTokens = session.Model.MaxTokens
// OpenAI 支持函数功能
var items []model.Function
res := h.db.Where("enabled", true).Find(&items)
res := h.DB.Where("enabled", true).Find(&items)
if res.Error != nil {
break
}
@@ -290,7 +287,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
_ = utils.JsonDecode(role.Context, &messages)
if h.App.SysConfig.ContextDeep > 0 {
var historyMessages []model.ChatMessage
res := h.db.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages)
res := h.DB.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages)
if res.Error == nil {
for i := len(historyMessages) - 1; i >= 0; i-- {
msg := historyMessages[i]
@@ -382,7 +379,7 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
if data.Text == "" && data.ChatId != "" {
var item model.ChatMessage
userId, _ := c.Get(types.LoginUserID)
res := h.db.Where("user_id = ?", userId).Where("chat_id = ?", data.ChatId).Last(&item)
res := h.DB.Where("user_id = ?", userId).Where("chat_id = ?", data.ChatId).Last(&item)
if res.Error != nil {
resp.ERROR(c, res.Error.Error())
return
@@ -433,7 +430,7 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
// 发送请求到 OpenAI 服务器
// useOwnApiKey: 是否使用了用户自己的 API KEY
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *model.ApiKey) (*http.Response, error) {
res := h.db.Where("platform = ?", platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey)
res := h.DB.Where("platform = ?", platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(apiKey)
if res.Error != nil {
return nil, errors.New("no available key, please import key")
}
@@ -459,7 +456,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
apiURL = apiKey.ApiURL
}
// 更新 API KEY 的最后使用时间
h.db.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// 百度文心,需要串接 access_token
if platform == types.Baidu {
token, err := h.getBaiduToken(apiKey.Value)
@@ -527,10 +524,10 @@ func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, p
if session.Model.Power > 0 {
power = session.Model.Power
}
res := h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("power", gorm.Expr("power - ?", power))
res := h.DB.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("power", gorm.Expr("power - ?", power))
if res.Error == nil {
// 记录算力消费日志
h.db.Create(&model.PowerLog{
h.DB.Create(&model.PowerLog{
UserId: userVo.Id,
Username: userVo.Username,
Type: types.PowerConsume,

View File

@@ -13,28 +13,22 @@ import (
// List 获取会话列表
func (h *ChatHandler) List(c *gin.Context) {
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
resp.ERROR(c, "The parameter 'user_id' is needed.")
return
}
// fix: 只能读取本人的消息列表
if uint(userId) != h.GetLoginUserId(c) {
resp.ERROR(c, "Hacker attempt, you can ONLY get yourself chats.")
if !h.IsLogin(c) {
resp.SUCCESS(c)
return
}
userId := h.GetLoginUserId(c)
var items = make([]vo.ChatItem, 0)
var chats []model.ChatItem
res := h.db.Where("user_id = ?", userId).Order("id DESC").Find(&chats)
res := h.DB.Where("user_id = ?", userId).Order("id DESC").Find(&chats)
if res.Error == nil {
var roleIds = make([]uint, 0)
for _, chat := range chats {
roleIds = append(roleIds, chat.RoleId)
}
var roles []model.ChatRole
res = h.db.Find(&roles, roleIds)
res = h.DB.Find(&roles, roleIds)
if res.Error == nil {
roleMap := make(map[uint]model.ChatRole)
for _, role := range roles {
@@ -66,7 +60,7 @@ func (h *ChatHandler) Update(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
res := h.db.Model(&model.ChatItem{}).Where("chat_id = ?", data.ChatId).UpdateColumn("title", data.Title)
res := h.DB.Model(&model.ChatItem{}).Where("chat_id = ?", data.ChatId).UpdateColumn("title", data.Title)
if res.Error != nil {
resp.ERROR(c, "Failed to update database")
return
@@ -78,14 +72,14 @@ func (h *ChatHandler) Update(c *gin.Context) {
// Clear 清空所有聊天记录
func (h *ChatHandler) Clear(c *gin.Context) {
// 获取当前登录用户所有的聊天会话
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
var chats []model.ChatItem
res := h.db.Where("user_id = ?", user.Id).Find(&chats)
res := h.DB.Where("user_id = ?", user.Id).Find(&chats)
if res.Error != nil {
resp.ERROR(c, "No chats found")
return
@@ -97,13 +91,13 @@ func (h *ChatHandler) Clear(c *gin.Context) {
// 清空会话上下文
h.App.ChatContexts.Delete(chat.ChatId)
}
err = h.db.Transaction(func(tx *gorm.DB) error {
res := h.db.Where("user_id =?", user.Id).Delete(&model.ChatItem{})
err = h.DB.Transaction(func(tx *gorm.DB) error {
res := h.DB.Where("user_id =?", user.Id).Delete(&model.ChatItem{})
if res.Error != nil {
return res.Error
}
res = h.db.Where("user_id = ? AND chat_id IN ?", user.Id, chatIds).Delete(&model.ChatMessage{})
res = h.DB.Where("user_id = ? AND chat_id IN ?", user.Id, chatIds).Delete(&model.ChatMessage{})
if res.Error != nil {
return res.Error
}
@@ -126,7 +120,7 @@ func (h *ChatHandler) History(c *gin.Context) {
chatId := c.Query("chat_id") // 会话 ID
var items []model.ChatMessage
var messages = make([]vo.HistoryMessage, 0)
res := h.db.Where("chat_id = ?", chatId).Find(&items)
res := h.DB.Where("chat_id = ?", chatId).Find(&items)
if res.Error != nil {
resp.ERROR(c, "No history message")
return
@@ -152,20 +146,20 @@ func (h *ChatHandler) Remove(c *gin.Context) {
resp.ERROR(c, types.InvalidArgs)
return
}
user, err := utils.GetLoginUser(c, h.db)
user, err := h.GetLoginUser(c)
if err != nil {
resp.NotAuth(c)
return
}
res := h.db.Where("user_id = ? AND chat_id = ?", user.Id, chatId).Delete(&model.ChatItem{})
res := h.DB.Where("user_id = ? AND chat_id = ?", user.Id, chatId).Delete(&model.ChatItem{})
if res.Error != nil {
resp.ERROR(c, "Failed to update database")
return
}
// 删除当前会话的聊天记录
res = h.db.Where("user_id = ? AND chat_id =?", user.Id, chatId).Delete(&model.ChatItem{})
res = h.DB.Where("user_id = ? AND chat_id =?", user.Id, chatId).Delete(&model.ChatItem{})
if res.Error != nil {
resp.ERROR(c, "Failed to remove chat from database.")
return
@@ -187,7 +181,7 @@ func (h *ChatHandler) Detail(c *gin.Context) {
}
var chatItem model.ChatItem
res := h.db.Where("chat_id = ?", chatId).First(&chatItem)
res := h.DB.Where("chat_id = ?", chatId).First(&chatItem)
if res.Error != nil {
resp.ERROR(c, "No chat found")
return

View File

@@ -139,7 +139,7 @@ func (h *ChatHandler) sendChatGLMMessage(
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
@@ -161,7 +161,7 @@ func (h *ChatHandler) sendChatGLMMessage(
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
@@ -171,7 +171,7 @@ func (h *ChatHandler) sendChatGLMMessage(
// 保存当前会话
var chatItem model.ChatItem
res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
@@ -183,7 +183,7 @@ func (h *ChatHandler) sendChatGLMMessage(
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.db.Create(&chatItem)
h.DB.Create(&chatItem)
}
}
} else {

View File

@@ -100,7 +100,7 @@ func (h *ChatHandler) sendOpenAiMessage(
}
if !utils.IsEmptyValue(tool) {
res := h.db.Where("name = ?", tool.Function.Name).First(&function)
res := h.DB.Where("name = ?", tool.Function.Name).First(&function)
if res.Error == nil {
toolCall = true
utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
@@ -210,7 +210,7 @@ func (h *ChatHandler) sendOpenAiMessage(
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
@@ -240,7 +240,7 @@ func (h *ChatHandler) sendOpenAiMessage(
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
@@ -250,7 +250,7 @@ func (h *ChatHandler) sendOpenAiMessage(
// 保存当前会话
var chatItem model.ChatItem
res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
@@ -262,18 +262,19 @@ func (h *ChatHandler) sendOpenAiMessage(
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.db.Create(&chatItem)
h.DB.Create(&chatItem)
}
}
} else {
body, err := io.ReadAll(response.Body)
if err != nil {
utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+err.Error())
return fmt.Errorf("error with reading response: %v", err)
}
var res types.ApiError
err = json.Unmarshal(body, &res)
if err != nil {
logger.Debug(string(body))
utils.ReplyMessage(ws, "请求 OpenAI API 失败:\n"+"```\n"+string(body)+"```")
return fmt.Errorf("error with decode response: %v", err)
}
@@ -281,7 +282,7 @@ func (h *ChatHandler) sendOpenAiMessage(
if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") {
utils.ReplyMessage(ws, "请求 OpenAI API 失败API KEY 所关联的账户被禁用。")
// 移除当前 API key
h.db.Where("value = ?", apiKey).Delete(&model.ApiKey{})
h.DB.Where("value = ?", apiKey).Delete(&model.ApiKey{})
} else if strings.Contains(res.Error.Message, "You exceeded your current quota") {
utils.ReplyMessage(ws, "请求 OpenAI API 失败API KEY 触发并发限制,请稍后再试。")
} else if strings.Contains(res.Error.Message, "This model's maximum context length") {

View File

@@ -172,7 +172,7 @@ func (h *ChatHandler) sendQWenMessage(
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
@@ -194,7 +194,7 @@ func (h *ChatHandler) sendQWenMessage(
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
@@ -204,7 +204,7 @@ func (h *ChatHandler) sendQWenMessage(
// 保存当前会话
var chatItem model.ChatItem
res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
@@ -216,7 +216,7 @@ func (h *ChatHandler) sendQWenMessage(
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.db.Create(&chatItem)
h.DB.Create(&chatItem)
}
}
} else {

View File

@@ -69,13 +69,13 @@ func (h *ChatHandler) sendXunFeiMessage(
ws *types.WsClient) error {
promptCreatedAt := time.Now() // 记录提问时间
var apiKey model.ApiKey
res := h.db.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
res := h.DB.Where("platform = ?", session.Model.Platform).Where("type = ?", "chat").Where("enabled = ?", true).Order("last_used_at ASC").First(&apiKey)
if res.Error != nil {
utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY请联系管理员")
return nil
}
// 更新 API KEY 的最后使用时间
h.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
d := websocket.Dialer{
HandshakeTimeout: 5 * time.Second,
@@ -200,7 +200,7 @@ func (h *ChatHandler) sendXunFeiMessage(
}
historyUserMsg.CreatedAt = promptCreatedAt
historyUserMsg.UpdatedAt = promptCreatedAt
res := h.db.Save(&historyUserMsg)
res := h.DB.Save(&historyUserMsg)
if res.Error != nil {
logger.Error("failed to save prompt history message: ", res.Error)
}
@@ -222,7 +222,7 @@ func (h *ChatHandler) sendXunFeiMessage(
}
historyReplyMsg.CreatedAt = replyCreatedAt
historyReplyMsg.UpdatedAt = replyCreatedAt
res = h.db.Create(&historyReplyMsg)
res = h.DB.Create(&historyReplyMsg)
if res.Error != nil {
logger.Error("failed to save reply history message: ", res.Error)
}
@@ -232,7 +232,7 @@ func (h *ChatHandler) sendXunFeiMessage(
// 保存当前会话
var chatItem model.ChatItem
res = h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
if res.Error != nil {
chatItem.ChatId = session.ChatId
chatItem.UserId = session.UserId
@@ -244,7 +244,7 @@ func (h *ChatHandler) sendXunFeiMessage(
chatItem.Title = prompt
}
chatItem.Model = req.Model
h.db.Create(&chatItem)
h.DB.Create(&chatItem)
}
}