refactor: refactor controller handler module and admin module

This commit is contained in:
RockYang
2023-06-19 07:06:59 +08:00
parent cd809d17d3
commit 120e54fb29
22 changed files with 436 additions and 300 deletions

View File

@@ -8,21 +8,21 @@ import (
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/param"
"chatplus/utils/resp"
"context"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"io"
"net/http"
"net/url"
"strings"
"time"
"unicode/utf8"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
)
const ErrorMsg = "抱歉AI 助手开小差了,请马上联系管理员去盘它。"
@@ -32,12 +32,9 @@ type ChatHandler struct {
db *gorm.DB
}
func NewChatHandler(config *types.AppConfig,
app *core.AppServer,
db *gorm.DB) *ChatHandler {
func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
handler := ChatHandler{db: db}
handler.app = app
handler.config = config
handler.App = app
return &handler
}
@@ -49,11 +46,11 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
return
}
sessionId := c.Query("session_id")
roleId := param.GetInt(c, "role_id", 0)
roleId := h.GetInt(c, "role_id", 0)
chatId := c.Query("chat_id")
chatModel := c.Query("model")
session := h.app.ChatSession.Get(sessionId)
session := h.App.ChatSession.Get(sessionId)
if session.SessionId == "" {
logger.Info("用户未登录")
c.Abort()
@@ -81,21 +78,21 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
}
// 保存会话连接
h.app.ChatClients.Put(sessionId, client)
h.App.ChatClients.Put(sessionId, client)
go func() {
for {
_, message, err := client.Receive()
if err != nil {
logger.Error(err)
client.Close()
h.app.ChatClients.Delete(sessionId)
h.app.ReqCancelFunc.Delete(sessionId)
h.App.ChatClients.Delete(sessionId)
h.App.ReqCancelFunc.Delete(sessionId)
return
}
logger.Info("Receive a message: ", string(message))
//replyMessage(client, "这是一条测试消息!")
ctx, cancel := context.WithCancel(context.Background())
h.app.ReqCancelFunc.Put(sessionId, cancel)
h.App.ReqCancelFunc.Put(sessionId, cancel)
// 回复消息
err = h.sendMessage(ctx, session, chatRole, string(message), client)
if err != nil {
@@ -153,8 +150,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
// 加载聊天上下文
var chatCtx []types.Message
if userVo.ChatConfig.EnableContext {
if h.app.ChatContexts.Has(session.ChatId) {
chatCtx = h.app.ChatContexts.Get(session.ChatId)
if h.App.ChatContexts.Has(session.ChatId) {
chatCtx = h.App.ChatContexts.Get(session.ChatId)
} else {
// 加载角色信息
var messages []types.Message
@@ -263,7 +260,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
if userVo.ChatConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.app.ChatContexts.Put(session.ChatId, chatCtx)
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
@@ -349,11 +346,11 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
} else if strings.Contains(res.Error.Message, "This model's maximum context length") {
replyMessage(ws, "当前会话上下文长度超出限制,已为您删减会话上下文!")
// 只保留最近的三条记录
chatContext := h.app.ChatContexts.Get(session.ChatId)
chatContext := h.App.ChatContexts.Get(session.ChatId)
if len(chatContext) > 3 {
chatContext = chatContext[len(chatContext)-3:]
}
h.app.ChatContexts.Put(session.ChatId, chatContext)
h.App.ChatContexts.Put(session.ChatId, chatContext)
return h.sendMessage(ctx, session, role, prompt, ws)
} else {
replyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message)
@@ -372,7 +369,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, user vo.User, apiKey *strin
return nil, err
}
// 创建 HttpClient 请求对象
request, err := http.NewRequest(http.MethodPost, h.app.ChatConfig.ApiURL, bytes.NewBuffer(requestBody))
request, err := http.NewRequest(http.MethodPost, h.App.ChatConfig.ApiURL, bytes.NewBuffer(requestBody))
if err != nil {
return nil, err
}
@@ -380,7 +377,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, user vo.User, apiKey *strin
request = request.WithContext(ctx)
request.Header.Add("Content-Type", "application/json")
proxyURL := h.config.ProxyURL
proxyURL := h.App.AppConfig.ProxyURL
if proxyURL == "" {
client = &http.Client{}
} else { // 使用代理
@@ -448,9 +445,9 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
// StopGenerate 停止生成
func (h *ChatHandler) StopGenerate(c *gin.Context) {
sessionId := c.Query("session_id")
if h.app.ReqCancelFunc.Has(sessionId) {
h.app.ReqCancelFunc.Get(sessionId)()
h.app.ReqCancelFunc.Delete(sessionId)
if h.App.ReqCancelFunc.Has(sessionId) {
h.App.ReqCancelFunc.Get(sessionId)()
h.App.ReqCancelFunc.Delete(sessionId)
}
resp.SUCCESS(c, types.OkMsg)
}