mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-10 11:13:42 +08:00
refactor: refactor controller handler module and admin module
This commit is contained in:
@@ -8,21 +8,21 @@ import (
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/param"
|
||||
"chatplus/utils/resp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"gorm.io/gorm"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const ErrorMsg = "抱歉,AI 助手开小差了,请马上联系管理员去盘它。"
|
||||
@@ -32,12 +32,9 @@ type ChatHandler struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewChatHandler(config *types.AppConfig,
|
||||
app *core.AppServer,
|
||||
db *gorm.DB) *ChatHandler {
|
||||
func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
|
||||
handler := ChatHandler{db: db}
|
||||
handler.app = app
|
||||
handler.config = config
|
||||
handler.App = app
|
||||
return &handler
|
||||
}
|
||||
|
||||
@@ -49,11 +46,11 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
sessionId := c.Query("session_id")
|
||||
roleId := param.GetInt(c, "role_id", 0)
|
||||
roleId := h.GetInt(c, "role_id", 0)
|
||||
chatId := c.Query("chat_id")
|
||||
chatModel := c.Query("model")
|
||||
|
||||
session := h.app.ChatSession.Get(sessionId)
|
||||
session := h.App.ChatSession.Get(sessionId)
|
||||
if session.SessionId == "" {
|
||||
logger.Info("用户未登录")
|
||||
c.Abort()
|
||||
@@ -81,21 +78,21 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 保存会话连接
|
||||
h.app.ChatClients.Put(sessionId, client)
|
||||
h.App.ChatClients.Put(sessionId, client)
|
||||
go func() {
|
||||
for {
|
||||
_, message, err := client.Receive()
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
client.Close()
|
||||
h.app.ChatClients.Delete(sessionId)
|
||||
h.app.ReqCancelFunc.Delete(sessionId)
|
||||
h.App.ChatClients.Delete(sessionId)
|
||||
h.App.ReqCancelFunc.Delete(sessionId)
|
||||
return
|
||||
}
|
||||
logger.Info("Receive a message: ", string(message))
|
||||
//replyMessage(client, "这是一条测试消息!")
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
h.app.ReqCancelFunc.Put(sessionId, cancel)
|
||||
h.App.ReqCancelFunc.Put(sessionId, cancel)
|
||||
// 回复消息
|
||||
err = h.sendMessage(ctx, session, chatRole, string(message), client)
|
||||
if err != nil {
|
||||
@@ -153,8 +150,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
// 加载聊天上下文
|
||||
var chatCtx []types.Message
|
||||
if userVo.ChatConfig.EnableContext {
|
||||
if h.app.ChatContexts.Has(session.ChatId) {
|
||||
chatCtx = h.app.ChatContexts.Get(session.ChatId)
|
||||
if h.App.ChatContexts.Has(session.ChatId) {
|
||||
chatCtx = h.App.ChatContexts.Get(session.ChatId)
|
||||
} else {
|
||||
// 加载角色信息
|
||||
var messages []types.Message
|
||||
@@ -263,7 +260,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
if userVo.ChatConfig.EnableContext {
|
||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||
chatCtx = append(chatCtx, message) // 回复消息
|
||||
h.app.ChatContexts.Put(session.ChatId, chatCtx)
|
||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
||||
}
|
||||
|
||||
// 追加聊天记录
|
||||
@@ -349,11 +346,11 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
} else if strings.Contains(res.Error.Message, "This model's maximum context length") {
|
||||
replyMessage(ws, "当前会话上下文长度超出限制,已为您删减会话上下文!")
|
||||
// 只保留最近的三条记录
|
||||
chatContext := h.app.ChatContexts.Get(session.ChatId)
|
||||
chatContext := h.App.ChatContexts.Get(session.ChatId)
|
||||
if len(chatContext) > 3 {
|
||||
chatContext = chatContext[len(chatContext)-3:]
|
||||
}
|
||||
h.app.ChatContexts.Put(session.ChatId, chatContext)
|
||||
h.App.ChatContexts.Put(session.ChatId, chatContext)
|
||||
return h.sendMessage(ctx, session, role, prompt, ws)
|
||||
} else {
|
||||
replyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message)
|
||||
@@ -372,7 +369,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, user vo.User, apiKey *strin
|
||||
return nil, err
|
||||
}
|
||||
// 创建 HttpClient 请求对象
|
||||
request, err := http.NewRequest(http.MethodPost, h.app.ChatConfig.ApiURL, bytes.NewBuffer(requestBody))
|
||||
request, err := http.NewRequest(http.MethodPost, h.App.ChatConfig.ApiURL, bytes.NewBuffer(requestBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -380,7 +377,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, user vo.User, apiKey *strin
|
||||
request = request.WithContext(ctx)
|
||||
request.Header.Add("Content-Type", "application/json")
|
||||
|
||||
proxyURL := h.config.ProxyURL
|
||||
proxyURL := h.App.AppConfig.ProxyURL
|
||||
if proxyURL == "" {
|
||||
client = &http.Client{}
|
||||
} else { // 使用代理
|
||||
@@ -448,9 +445,9 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
|
||||
// StopGenerate 停止生成
|
||||
func (h *ChatHandler) StopGenerate(c *gin.Context) {
|
||||
sessionId := c.Query("session_id")
|
||||
if h.app.ReqCancelFunc.Has(sessionId) {
|
||||
h.app.ReqCancelFunc.Get(sessionId)()
|
||||
h.app.ReqCancelFunc.Delete(sessionId)
|
||||
if h.App.ReqCancelFunc.Has(sessionId) {
|
||||
h.App.ReqCancelFunc.Get(sessionId)()
|
||||
h.App.ReqCancelFunc.Delete(sessionId)
|
||||
}
|
||||
resp.SUCCESS(c, types.OkMsg)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user