refactor: refactor controller handler module and admin module

This commit is contained in:
RockYang
2023-06-19 07:06:59 +08:00
parent 831dd3e2e0
commit 90bce1d437
22 changed files with 436 additions and 300 deletions

View File

@@ -0,0 +1,66 @@
package admin
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
logger2 "chatplus/logger"
"chatplus/utils/resp"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
var logger = logger2.GetLogger()
type ManagerHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewAdminHandler(app *core.AppServer, db *gorm.DB) *ManagerHandler {
h := ManagerHandler{db: db}
h.App = app
return &h
}
// Login 登录
func (h *ManagerHandler) Login(c *gin.Context) {
var data types.Manager
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
manager := h.App.AppConfig.Manager
if data.Username == manager.Username && data.Password == manager.Password {
manager.Password = "" // 清空密码
resp.SUCCESS(c, manager)
} else {
resp.ERROR(c, "用户名或者密码错误")
}
}
// Logout 注销
func (h *ManagerHandler) Logout(c *gin.Context) {
session := sessions.Default(c)
session.Delete(types.SessionAdmin)
err := session.Save()
if err != nil {
resp.ERROR(c, "Save session failed")
} else {
resp.SUCCESS(c)
}
}
// Session 会话检测
func (h *ManagerHandler) Session(c *gin.Context) {
session := sessions.Default(c)
admin := session.Get(types.SessionAdmin)
if admin == nil {
resp.NotAuth(c)
} else {
resp.SUCCESS(c)
}
}

View File

@@ -1,28 +1,28 @@
package handler
package admin
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/param"
"chatplus/utils/resp"
"net/http"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
"net/http"
)
type ApiKeyHandler struct {
BaseHandler
handler.BaseHandler
db *gorm.DB
}
func NewApiKeyHandler(config *types.AppConfig, app *core.AppServer, db *gorm.DB) *ApiKeyHandler {
handler := ApiKeyHandler{db: db}
handler.app = app
handler.config = config
return &handler
func NewApiKeyHandler(app *core.AppServer, db *gorm.DB) *ApiKeyHandler {
h := ApiKeyHandler{db: db}
h.App = app
return &h
}
func (h *ApiKeyHandler) Add(c *gin.Context) {
@@ -49,8 +49,8 @@ func (h *ApiKeyHandler) Add(c *gin.Context) {
}
func (h *ApiKeyHandler) List(c *gin.Context) {
page := param.GetInt(c, "page", 1)
pageSize := param.GetInt(c, "page_size", 20)
page := h.GetInt(c, "page", 1)
pageSize := h.GetInt(c, "page_size", 20)
offset := (page - 1) * pageSize
var items []model.ApiKey
var keys = make([]vo.ApiKey, 0)

View File

@@ -1,29 +0,0 @@
package handler
import (
"chatplus/core"
"chatplus/core/types"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type AdminHandler struct {
BaseHandler
db *gorm.DB
}
func NewAdminHandler(config *types.AppConfig, app *core.AppServer, db *gorm.DB) *AdminHandler {
handler := AdminHandler{db: db}
handler.app = app
handler.config = config
return &handler
}
// Login 登录
func (h *AdminHandler) Login(c *gin.Context) {
}
// Logout 注销
func (h *AdminHandler) Logout(c *gin.Context) {
}

View File

@@ -2,13 +2,65 @@ package handler
import (
"chatplus/core"
"chatplus/core/types"
logger2 "chatplus/logger"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
var logger = logger2.GetLogger()
type BaseHandler struct {
app *core.AppServer
config *types.AppConfig
App *core.AppServer
}
func (h *BaseHandler) GetTrim(c *gin.Context, key string) string {
return strings.TrimSpace(c.Query(key))
}
func (h *BaseHandler) PostInt(c *gin.Context, key string, defaultValue int) int {
return intValue(c.PostForm(key), defaultValue)
}
func (h *BaseHandler) GetInt(c *gin.Context, key string, defaultValue int) int {
return intValue(c.Query(key), defaultValue)
}
func intValue(str string, defaultValue int) int {
value, err := strconv.Atoi(str)
if err != nil {
return defaultValue
}
return value
}
func (h *BaseHandler) GetFloat(c *gin.Context, key string) float64 {
return floatValue(c.Query(key))
}
func (h *BaseHandler) PostFloat(c *gin.Context, key string) float64 {
return floatValue(c.PostForm(key))
}
func floatValue(str string) float64 {
value, err := strconv.ParseFloat(str, 64)
if err != nil {
return 0
}
return value
}
func (h *BaseHandler) GetBool(c *gin.Context, key string) bool {
return boolValue(c.Query(key))
}
func (h *BaseHandler) PostBool(c *gin.Context, key string) bool {
return boolValue(c.PostForm(key))
}
func boolValue(str string) bool {
value, err := strconv.ParseBool(str)
if err != nil {
return false
}
return value
}

View File

@@ -8,21 +8,21 @@ import (
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/param"
"chatplus/utils/resp"
"context"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"io"
"net/http"
"net/url"
"strings"
"time"
"unicode/utf8"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
)
const ErrorMsg = "抱歉AI 助手开小差了,请马上联系管理员去盘它。"
@@ -32,12 +32,9 @@ type ChatHandler struct {
db *gorm.DB
}
func NewChatHandler(config *types.AppConfig,
app *core.AppServer,
db *gorm.DB) *ChatHandler {
func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
handler := ChatHandler{db: db}
handler.app = app
handler.config = config
handler.App = app
return &handler
}
@@ -49,11 +46,11 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
return
}
sessionId := c.Query("session_id")
roleId := param.GetInt(c, "role_id", 0)
roleId := h.GetInt(c, "role_id", 0)
chatId := c.Query("chat_id")
chatModel := c.Query("model")
session := h.app.ChatSession.Get(sessionId)
session := h.App.ChatSession.Get(sessionId)
if session.SessionId == "" {
logger.Info("用户未登录")
c.Abort()
@@ -81,21 +78,21 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
}
// 保存会话连接
h.app.ChatClients.Put(sessionId, client)
h.App.ChatClients.Put(sessionId, client)
go func() {
for {
_, message, err := client.Receive()
if err != nil {
logger.Error(err)
client.Close()
h.app.ChatClients.Delete(sessionId)
h.app.ReqCancelFunc.Delete(sessionId)
h.App.ChatClients.Delete(sessionId)
h.App.ReqCancelFunc.Delete(sessionId)
return
}
logger.Info("Receive a message: ", string(message))
//replyMessage(client, "这是一条测试消息!")
ctx, cancel := context.WithCancel(context.Background())
h.app.ReqCancelFunc.Put(sessionId, cancel)
h.App.ReqCancelFunc.Put(sessionId, cancel)
// 回复消息
err = h.sendMessage(ctx, session, chatRole, string(message), client)
if err != nil {
@@ -153,8 +150,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
// 加载聊天上下文
var chatCtx []types.Message
if userVo.ChatConfig.EnableContext {
if h.app.ChatContexts.Has(session.ChatId) {
chatCtx = h.app.ChatContexts.Get(session.ChatId)
if h.App.ChatContexts.Has(session.ChatId) {
chatCtx = h.App.ChatContexts.Get(session.ChatId)
} else {
// 加载角色信息
var messages []types.Message
@@ -263,7 +260,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
if userVo.ChatConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息
h.app.ChatContexts.Put(session.ChatId, chatCtx)
h.App.ChatContexts.Put(session.ChatId, chatCtx)
}
// 追加聊天记录
@@ -349,11 +346,11 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
} else if strings.Contains(res.Error.Message, "This model's maximum context length") {
replyMessage(ws, "当前会话上下文长度超出限制,已为您删减会话上下文!")
// 只保留最近的三条记录
chatContext := h.app.ChatContexts.Get(session.ChatId)
chatContext := h.App.ChatContexts.Get(session.ChatId)
if len(chatContext) > 3 {
chatContext = chatContext[len(chatContext)-3:]
}
h.app.ChatContexts.Put(session.ChatId, chatContext)
h.App.ChatContexts.Put(session.ChatId, chatContext)
return h.sendMessage(ctx, session, role, prompt, ws)
} else {
replyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message)
@@ -372,7 +369,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, user vo.User, apiKey *strin
return nil, err
}
// 创建 HttpClient 请求对象
request, err := http.NewRequest(http.MethodPost, h.app.ChatConfig.ApiURL, bytes.NewBuffer(requestBody))
request, err := http.NewRequest(http.MethodPost, h.App.ChatConfig.ApiURL, bytes.NewBuffer(requestBody))
if err != nil {
return nil, err
}
@@ -380,7 +377,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, user vo.User, apiKey *strin
request = request.WithContext(ctx)
request.Header.Add("Content-Type", "application/json")
proxyURL := h.config.ProxyURL
proxyURL := h.App.AppConfig.ProxyURL
if proxyURL == "" {
client = &http.Client{}
} else { // 使用代理
@@ -448,9 +445,9 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
// StopGenerate 停止生成
func (h *ChatHandler) StopGenerate(c *gin.Context) {
sessionId := c.Query("session_id")
if h.app.ReqCancelFunc.Has(sessionId) {
h.app.ReqCancelFunc.Get(sessionId)()
h.app.ReqCancelFunc.Delete(sessionId)
if h.App.ReqCancelFunc.Has(sessionId) {
h.App.ReqCancelFunc.Get(sessionId)()
h.App.ReqCancelFunc.Delete(sessionId)
}
resp.SUCCESS(c, types.OkMsg)
}

View File

@@ -5,14 +5,14 @@ import (
"chatplus/store/model"
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/param"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
)
// List 获取会话列表
func (h *ChatHandler) List(c *gin.Context) {
userId := param.GetInt(c, "user_id", 0)
userId := h.GetInt(c, "user_id", 0)
if userId == 0 {
resp.ERROR(c, "The parameter 'user_id' is needed.")
return
@@ -71,7 +71,7 @@ func (h *ChatHandler) Update(c *gin.Context) {
// Remove 删除会话
func (h *ChatHandler) Remove(c *gin.Context) {
chatId := param.GetTrim(c, "chat_id")
chatId := h.GetTrim(c, "chat_id")
if chatId == "" {
resp.ERROR(c, types.InvalidArgs)
return
@@ -89,7 +89,7 @@ func (h *ChatHandler) Remove(c *gin.Context) {
}
// 清空会话上下文
h.app.ChatContexts.Delete(chatId)
h.App.ChatContexts.Delete(chatId)
resp.SUCCESS(c, types.OkMsg)
}
@@ -144,7 +144,7 @@ func (h *ChatHandler) Clear(c *gin.Context) {
logger.Warnf("Failed to delele chat history for ChatID: %s", chat.ChatId)
}
// 清空会话上下文
h.app.ChatContexts.Delete(chat.ChatId)
h.App.ChatContexts.Delete(chat.ChatId)
}
// 删除所有的会话记录
res = h.db.Where("user_id = ?", user.Id).Delete(&model.ChatItem{})

View File

@@ -8,8 +8,9 @@ import (
"chatplus/store/vo"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"strconv"
"github.com/gin-gonic/gin"
)
type ChatRoleHandler struct {
@@ -17,10 +18,9 @@ type ChatRoleHandler struct {
service *service.ChatRoleService
}
func NewChatRoleHandler(config *types.AppConfig, app *core.AppServer, service *service.ChatRoleService) *ChatRoleHandler {
func NewChatRoleHandler(app *core.AppServer, service *service.ChatRoleService) *ChatRoleHandler {
handler := &ChatRoleHandler{service: service}
handler.app = app
handler.config = config
handler.App = app
return handler
}

View File

@@ -6,6 +6,7 @@ import (
"chatplus/store/model"
"chatplus/utils"
"chatplus/utils/resp"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
@@ -15,10 +16,9 @@ type ConfigHandler struct {
db *gorm.DB
}
func NewConfigHandler(config *types.AppConfig, app *core.AppServer, db *gorm.DB) *ConfigHandler {
func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler {
handler := ConfigHandler{db: db}
handler.app = app
handler.config = config
handler.App = app
return &handler
}

View File

@@ -8,12 +8,13 @@ import (
"chatplus/utils"
"chatplus/utils/resp"
"fmt"
"strings"
"time"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/lionsoul2014/ip2region/binding/golang/xdb"
"gorm.io/gorm"
"strings"
"time"
)
type UserHandler struct {
@@ -22,10 +23,9 @@ type UserHandler struct {
searcher *xdb.Searcher
}
func NewUserHandler(config *types.AppConfig, app *core.AppServer, db *gorm.DB, searcher *xdb.Searcher) *UserHandler {
func NewUserHandler(app *core.AppServer, db *gorm.DB, searcher *xdb.Searcher) *UserHandler {
handler := &UserHandler{db: db, searcher: searcher}
handler.app = app
handler.config = config
handler.App = app
return handler
}
@@ -77,11 +77,11 @@ func (h *UserHandler) Register(c *gin.Context) {
Status: true,
ChatRoles: utils.JsonEncode(roleMap),
ChatConfig: utils.JsonEncode(types.ChatConfig{
Temperature: h.app.ChatConfig.Temperature,
MaxTokens: h.app.ChatConfig.MaxTokens,
EnableContext: h.app.ChatConfig.EnableContext,
Temperature: h.App.ChatConfig.Temperature,
MaxTokens: h.App.ChatConfig.MaxTokens,
EnableContext: h.App.ChatConfig.EnableContext,
EnableHistory: true,
Model: h.app.ChatConfig.Model,
Model: h.App.ChatConfig.Model,
ApiKey: "",
}),
}
@@ -159,16 +159,15 @@ func (h *UserHandler) Login(c *gin.Context) {
h.db.Model(&user).Updates(user)
sessionId := utils.RandString(42)
c.Header(types.TokenSessionName, sessionId)
err := utils.SetLoginUser(c, user.Id)
err := utils.SetLoginUser(c, user)
if err != nil {
resp.ERROR(c, "保存会话失败")
logger.Error("Error for save session: ", err)
return
}
// 记录登录信息在服务
h.app.ChatSession.Put(sessionId, types.ChatSession{ClientIP: c.ClientIP(), UserId: user.Id, Username: data.Username, SessionId: sessionId})
// 记录登录信息在服务
h.App.ChatSession.Put(sessionId, types.ChatSession{ClientIP: c.ClientIP(), UserId: user.Id, Username: data.Username, SessionId: sessionId})
// 加载用户订阅的聊天角色
var roleMap map[string]int
@@ -229,17 +228,17 @@ func (h *UserHandler) Login(c *gin.Context) {
// Logout 注 销
func (h *UserHandler) Logout(c *gin.Context) {
sessionId := c.GetHeader(types.TokenSessionName)
sessionId := c.GetHeader(types.SessionName)
session := sessions.Default(c)
session.Delete(sessionId)
session.Delete(types.SessionUser)
err := session.Save()
if err != nil {
logger.Error("Error for save session: ", err)
}
// 删除 websocket 会话列表
h.app.ChatSession.Delete(sessionId)
h.App.ChatSession.Delete(sessionId)
// 关闭 socket 连接
client := h.app.ChatClients.Get(sessionId)
client := h.App.ChatClients.Get(sessionId)
if client != nil {
client.Close()
}
@@ -248,8 +247,8 @@ func (h *UserHandler) Logout(c *gin.Context) {
// Session 获取/验证会话
func (h *UserHandler) Session(c *gin.Context) {
sessionId := c.GetHeader(types.TokenSessionName)
session := h.app.ChatSession.Get(sessionId)
sessionId := c.GetHeader(types.SessionName)
session := h.App.ChatSession.Get(sessionId)
if session.ClientIP == c.ClientIP() {
resp.SUCCESS(c, session)
} else {