mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-08 10:13:44 +08:00
refactor: refactor controller handler module and admin module
This commit is contained in:
@@ -2,7 +2,6 @@ package core
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
logger2 "chatplus/logger"
|
||||
"chatplus/store/model"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
@@ -18,8 +17,6 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
type AppServer struct {
|
||||
AppConfig *types.AppConfig
|
||||
Engine *gin.Engine
|
||||
@@ -111,7 +108,7 @@ func corsMiddleware() gin.HandlerFunc {
|
||||
c.Header("Access-Control-Allow-Origin", origin)
|
||||
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE")
|
||||
//允许跨域设置可以返回其他子段,可以自定义字段
|
||||
c.Header("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, ChatGPT-TOKEN, ACCESS-KEY")
|
||||
c.Header("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, ChatGPT-TOKEN, ADMIN-SESSION-TOKEN")
|
||||
// 允许浏览器(客户端)可以解析的头部 (重要)
|
||||
c.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers")
|
||||
//设置缓存时间
|
||||
@@ -138,10 +135,10 @@ func corsMiddleware() gin.HandlerFunc {
|
||||
func authorizeMiddleware(s *AppServer) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.Request.URL.Path == "/api/user/login" ||
|
||||
c.Request.URL.Path == "/api/admin/login" ||
|
||||
c.Request.URL.Path == "/api/user/register" ||
|
||||
c.Request.URL.Path == "/api/apikey/add" ||
|
||||
//c.Request.URL.Path == "/api/apikey/list" {
|
||||
strings.Contains(c.Request.URL.Path, "/api/config/") { // TODO: 后台 API 暂时放行,用于调试
|
||||
c.Request.URL.Path == "/api/apikey/list" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
@@ -158,7 +155,12 @@ func authorizeMiddleware(s *AppServer) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
value := session.Get(types.SessionUserId)
|
||||
var value interface{}
|
||||
if strings.Contains(c.Request.URL.Path, "/api/admin/") {
|
||||
value = session.Get(types.SessionAdmin)
|
||||
} else {
|
||||
value = session.Get(types.SessionUser)
|
||||
}
|
||||
if value != nil {
|
||||
c.Next()
|
||||
} else {
|
||||
|
||||
@@ -3,12 +3,16 @@ package core
|
||||
import (
|
||||
"bytes"
|
||||
"chatplus/core/types"
|
||||
logger2 "chatplus/logger"
|
||||
"chatplus/utils"
|
||||
"github.com/BurntSushi/toml"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/BurntSushi/toml"
|
||||
)
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
func NewDefaultConfig() *types.AppConfig {
|
||||
return &types.AppConfig{
|
||||
Listen: "0.0.0.0:5678",
|
||||
@@ -17,7 +21,7 @@ func NewDefaultConfig() *types.AppConfig {
|
||||
|
||||
Session: types.Session{
|
||||
SecretKey: utils.RandString(64),
|
||||
Name: "CHAT_SESSION_ID",
|
||||
Name: "CHAT_PLUS_SESSION",
|
||||
Domain: "",
|
||||
Path: "/",
|
||||
MaxAge: 86400,
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package types
|
||||
|
||||
const TokenSessionName = "ChatGPT-TOKEN"
|
||||
const SessionUserId = "SESSION_USER_ID"
|
||||
const LoginUserCache = "LOGIN_USER_CACHE" //
|
||||
const SessionName = "ChatGPT-TOKEN"
|
||||
const SessionUser = "SESSION_USER" // 存储用户信息的 session key
|
||||
const SessionAdmin = "SESSION_ADMIN" //存储管理员信息的 session key
|
||||
const LoginUserCache = "LOGIN_USER_CACHE" // 已登录用户缓存
|
||||
const AdminUserCache = "ADMIN_USER_CACHE" // 管理员用户信息缓存
|
||||
|
||||
66
api/go/handler/admin/admin_handler.go
Normal file
66
api/go/handler/admin/admin_handler.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/handler"
|
||||
logger2 "chatplus/logger"
|
||||
"chatplus/utils/resp"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
type ManagerHandler struct {
|
||||
handler.BaseHandler
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewAdminHandler(app *core.AppServer, db *gorm.DB) *ManagerHandler {
|
||||
h := ManagerHandler{db: db}
|
||||
h.App = app
|
||||
return &h
|
||||
}
|
||||
|
||||
// Login 登录
|
||||
func (h *ManagerHandler) Login(c *gin.Context) {
|
||||
var data types.Manager
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
manager := h.App.AppConfig.Manager
|
||||
if data.Username == manager.Username && data.Password == manager.Password {
|
||||
manager.Password = "" // 清空密码
|
||||
resp.SUCCESS(c, manager)
|
||||
} else {
|
||||
resp.ERROR(c, "用户名或者密码错误")
|
||||
}
|
||||
}
|
||||
|
||||
// Logout 注销
|
||||
func (h *ManagerHandler) Logout(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
session.Delete(types.SessionAdmin)
|
||||
err := session.Save()
|
||||
if err != nil {
|
||||
resp.ERROR(c, "Save session failed")
|
||||
} else {
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
}
|
||||
|
||||
// Session 会话检测
|
||||
func (h *ManagerHandler) Session(c *gin.Context) {
|
||||
session := sessions.Default(c)
|
||||
admin := session.Get(types.SessionAdmin)
|
||||
if admin == nil {
|
||||
resp.NotAuth(c)
|
||||
} else {
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
}
|
||||
@@ -1,28 +1,28 @@
|
||||
package handler
|
||||
package admin
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/handler"
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/param"
|
||||
"chatplus/utils/resp"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type ApiKeyHandler struct {
|
||||
BaseHandler
|
||||
handler.BaseHandler
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewApiKeyHandler(config *types.AppConfig, app *core.AppServer, db *gorm.DB) *ApiKeyHandler {
|
||||
handler := ApiKeyHandler{db: db}
|
||||
handler.app = app
|
||||
handler.config = config
|
||||
return &handler
|
||||
func NewApiKeyHandler(app *core.AppServer, db *gorm.DB) *ApiKeyHandler {
|
||||
h := ApiKeyHandler{db: db}
|
||||
h.App = app
|
||||
return &h
|
||||
}
|
||||
|
||||
func (h *ApiKeyHandler) Add(c *gin.Context) {
|
||||
@@ -49,8 +49,8 @@ func (h *ApiKeyHandler) Add(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *ApiKeyHandler) List(c *gin.Context) {
|
||||
page := param.GetInt(c, "page", 1)
|
||||
pageSize := param.GetInt(c, "page_size", 20)
|
||||
page := h.GetInt(c, "page", 1)
|
||||
pageSize := h.GetInt(c, "page_size", 20)
|
||||
offset := (page - 1) * pageSize
|
||||
var items []model.ApiKey
|
||||
var keys = make([]vo.ApiKey, 0)
|
||||
@@ -1,29 +0,0 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type AdminHandler struct {
|
||||
BaseHandler
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewAdminHandler(config *types.AppConfig, app *core.AppServer, db *gorm.DB) *AdminHandler {
|
||||
handler := AdminHandler{db: db}
|
||||
handler.app = app
|
||||
handler.config = config
|
||||
return &handler
|
||||
}
|
||||
|
||||
// Login 登录
|
||||
func (h *AdminHandler) Login(c *gin.Context) {
|
||||
|
||||
}
|
||||
|
||||
// Logout 注销
|
||||
func (h *AdminHandler) Logout(c *gin.Context) {
|
||||
}
|
||||
@@ -2,13 +2,65 @@ package handler
|
||||
|
||||
import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
logger2 "chatplus/logger"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
type BaseHandler struct {
|
||||
app *core.AppServer
|
||||
config *types.AppConfig
|
||||
App *core.AppServer
|
||||
}
|
||||
|
||||
func (h *BaseHandler) GetTrim(c *gin.Context, key string) string {
|
||||
return strings.TrimSpace(c.Query(key))
|
||||
}
|
||||
|
||||
func (h *BaseHandler) PostInt(c *gin.Context, key string, defaultValue int) int {
|
||||
return intValue(c.PostForm(key), defaultValue)
|
||||
}
|
||||
|
||||
func (h *BaseHandler) GetInt(c *gin.Context, key string, defaultValue int) int {
|
||||
return intValue(c.Query(key), defaultValue)
|
||||
}
|
||||
|
||||
func intValue(str string, defaultValue int) int {
|
||||
value, err := strconv.Atoi(str)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func (h *BaseHandler) GetFloat(c *gin.Context, key string) float64 {
|
||||
return floatValue(c.Query(key))
|
||||
}
|
||||
func (h *BaseHandler) PostFloat(c *gin.Context, key string) float64 {
|
||||
return floatValue(c.PostForm(key))
|
||||
}
|
||||
|
||||
func floatValue(str string) float64 {
|
||||
value, err := strconv.ParseFloat(str, 64)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func (h *BaseHandler) GetBool(c *gin.Context, key string) bool {
|
||||
return boolValue(c.Query(key))
|
||||
}
|
||||
func (h *BaseHandler) PostBool(c *gin.Context, key string) bool {
|
||||
return boolValue(c.PostForm(key))
|
||||
}
|
||||
|
||||
func boolValue(str string) bool {
|
||||
value, err := strconv.ParseBool(str)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
@@ -8,21 +8,21 @@ import (
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/param"
|
||||
"chatplus/utils/resp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"gorm.io/gorm"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const ErrorMsg = "抱歉,AI 助手开小差了,请马上联系管理员去盘它。"
|
||||
@@ -32,12 +32,9 @@ type ChatHandler struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewChatHandler(config *types.AppConfig,
|
||||
app *core.AppServer,
|
||||
db *gorm.DB) *ChatHandler {
|
||||
func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
|
||||
handler := ChatHandler{db: db}
|
||||
handler.app = app
|
||||
handler.config = config
|
||||
handler.App = app
|
||||
return &handler
|
||||
}
|
||||
|
||||
@@ -49,11 +46,11 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
sessionId := c.Query("session_id")
|
||||
roleId := param.GetInt(c, "role_id", 0)
|
||||
roleId := h.GetInt(c, "role_id", 0)
|
||||
chatId := c.Query("chat_id")
|
||||
chatModel := c.Query("model")
|
||||
|
||||
session := h.app.ChatSession.Get(sessionId)
|
||||
session := h.App.ChatSession.Get(sessionId)
|
||||
if session.SessionId == "" {
|
||||
logger.Info("用户未登录")
|
||||
c.Abort()
|
||||
@@ -81,21 +78,21 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 保存会话连接
|
||||
h.app.ChatClients.Put(sessionId, client)
|
||||
h.App.ChatClients.Put(sessionId, client)
|
||||
go func() {
|
||||
for {
|
||||
_, message, err := client.Receive()
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
client.Close()
|
||||
h.app.ChatClients.Delete(sessionId)
|
||||
h.app.ReqCancelFunc.Delete(sessionId)
|
||||
h.App.ChatClients.Delete(sessionId)
|
||||
h.App.ReqCancelFunc.Delete(sessionId)
|
||||
return
|
||||
}
|
||||
logger.Info("Receive a message: ", string(message))
|
||||
//replyMessage(client, "这是一条测试消息!")
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
h.app.ReqCancelFunc.Put(sessionId, cancel)
|
||||
h.App.ReqCancelFunc.Put(sessionId, cancel)
|
||||
// 回复消息
|
||||
err = h.sendMessage(ctx, session, chatRole, string(message), client)
|
||||
if err != nil {
|
||||
@@ -153,8 +150,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
// 加载聊天上下文
|
||||
var chatCtx []types.Message
|
||||
if userVo.ChatConfig.EnableContext {
|
||||
if h.app.ChatContexts.Has(session.ChatId) {
|
||||
chatCtx = h.app.ChatContexts.Get(session.ChatId)
|
||||
if h.App.ChatContexts.Has(session.ChatId) {
|
||||
chatCtx = h.App.ChatContexts.Get(session.ChatId)
|
||||
} else {
|
||||
// 加载角色信息
|
||||
var messages []types.Message
|
||||
@@ -263,7 +260,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
if userVo.ChatConfig.EnableContext {
|
||||
chatCtx = append(chatCtx, useMsg) // 提问消息
|
||||
chatCtx = append(chatCtx, message) // 回复消息
|
||||
h.app.ChatContexts.Put(session.ChatId, chatCtx)
|
||||
h.App.ChatContexts.Put(session.ChatId, chatCtx)
|
||||
}
|
||||
|
||||
// 追加聊天记录
|
||||
@@ -349,11 +346,11 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
|
||||
} else if strings.Contains(res.Error.Message, "This model's maximum context length") {
|
||||
replyMessage(ws, "当前会话上下文长度超出限制,已为您删减会话上下文!")
|
||||
// 只保留最近的三条记录
|
||||
chatContext := h.app.ChatContexts.Get(session.ChatId)
|
||||
chatContext := h.App.ChatContexts.Get(session.ChatId)
|
||||
if len(chatContext) > 3 {
|
||||
chatContext = chatContext[len(chatContext)-3:]
|
||||
}
|
||||
h.app.ChatContexts.Put(session.ChatId, chatContext)
|
||||
h.App.ChatContexts.Put(session.ChatId, chatContext)
|
||||
return h.sendMessage(ctx, session, role, prompt, ws)
|
||||
} else {
|
||||
replyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message)
|
||||
@@ -372,7 +369,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, user vo.User, apiKey *strin
|
||||
return nil, err
|
||||
}
|
||||
// 创建 HttpClient 请求对象
|
||||
request, err := http.NewRequest(http.MethodPost, h.app.ChatConfig.ApiURL, bytes.NewBuffer(requestBody))
|
||||
request, err := http.NewRequest(http.MethodPost, h.App.ChatConfig.ApiURL, bytes.NewBuffer(requestBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -380,7 +377,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, user vo.User, apiKey *strin
|
||||
request = request.WithContext(ctx)
|
||||
request.Header.Add("Content-Type", "application/json")
|
||||
|
||||
proxyURL := h.config.ProxyURL
|
||||
proxyURL := h.App.AppConfig.ProxyURL
|
||||
if proxyURL == "" {
|
||||
client = &http.Client{}
|
||||
} else { // 使用代理
|
||||
@@ -448,9 +445,9 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
|
||||
// StopGenerate 停止生成
|
||||
func (h *ChatHandler) StopGenerate(c *gin.Context) {
|
||||
sessionId := c.Query("session_id")
|
||||
if h.app.ReqCancelFunc.Has(sessionId) {
|
||||
h.app.ReqCancelFunc.Get(sessionId)()
|
||||
h.app.ReqCancelFunc.Delete(sessionId)
|
||||
if h.App.ReqCancelFunc.Has(sessionId) {
|
||||
h.App.ReqCancelFunc.Get(sessionId)()
|
||||
h.App.ReqCancelFunc.Delete(sessionId)
|
||||
}
|
||||
resp.SUCCESS(c, types.OkMsg)
|
||||
}
|
||||
|
||||
@@ -5,14 +5,14 @@ import (
|
||||
"chatplus/store/model"
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/param"
|
||||
"chatplus/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// List 获取会话列表
|
||||
func (h *ChatHandler) List(c *gin.Context) {
|
||||
userId := param.GetInt(c, "user_id", 0)
|
||||
userId := h.GetInt(c, "user_id", 0)
|
||||
if userId == 0 {
|
||||
resp.ERROR(c, "The parameter 'user_id' is needed.")
|
||||
return
|
||||
@@ -71,7 +71,7 @@ func (h *ChatHandler) Update(c *gin.Context) {
|
||||
|
||||
// Remove 删除会话
|
||||
func (h *ChatHandler) Remove(c *gin.Context) {
|
||||
chatId := param.GetTrim(c, "chat_id")
|
||||
chatId := h.GetTrim(c, "chat_id")
|
||||
if chatId == "" {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
@@ -89,7 +89,7 @@ func (h *ChatHandler) Remove(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 清空会话上下文
|
||||
h.app.ChatContexts.Delete(chatId)
|
||||
h.App.ChatContexts.Delete(chatId)
|
||||
resp.SUCCESS(c, types.OkMsg)
|
||||
}
|
||||
|
||||
@@ -144,7 +144,7 @@ func (h *ChatHandler) Clear(c *gin.Context) {
|
||||
logger.Warnf("Failed to delele chat history for ChatID: %s", chat.ChatId)
|
||||
}
|
||||
// 清空会话上下文
|
||||
h.app.ChatContexts.Delete(chat.ChatId)
|
||||
h.App.ChatContexts.Delete(chat.ChatId)
|
||||
}
|
||||
// 删除所有的会话记录
|
||||
res = h.db.Where("user_id = ?", user.Id).Delete(&model.ChatItem{})
|
||||
|
||||
@@ -8,8 +8,9 @@ import (
|
||||
"chatplus/store/vo"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"github.com/gin-gonic/gin"
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type ChatRoleHandler struct {
|
||||
@@ -17,10 +18,9 @@ type ChatRoleHandler struct {
|
||||
service *service.ChatRoleService
|
||||
}
|
||||
|
||||
func NewChatRoleHandler(config *types.AppConfig, app *core.AppServer, service *service.ChatRoleService) *ChatRoleHandler {
|
||||
func NewChatRoleHandler(app *core.AppServer, service *service.ChatRoleService) *ChatRoleHandler {
|
||||
handler := &ChatRoleHandler{service: service}
|
||||
handler.app = app
|
||||
handler.config = config
|
||||
handler.App = app
|
||||
return handler
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"chatplus/store/model"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -15,10 +16,9 @@ type ConfigHandler struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewConfigHandler(config *types.AppConfig, app *core.AppServer, db *gorm.DB) *ConfigHandler {
|
||||
func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler {
|
||||
handler := ConfigHandler{db: db}
|
||||
handler.app = app
|
||||
handler.config = config
|
||||
handler.App = app
|
||||
return &handler
|
||||
}
|
||||
|
||||
|
||||
@@ -8,12 +8,13 @@ import (
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/lionsoul2014/ip2region/binding/golang/xdb"
|
||||
"gorm.io/gorm"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type UserHandler struct {
|
||||
@@ -22,10 +23,9 @@ type UserHandler struct {
|
||||
searcher *xdb.Searcher
|
||||
}
|
||||
|
||||
func NewUserHandler(config *types.AppConfig, app *core.AppServer, db *gorm.DB, searcher *xdb.Searcher) *UserHandler {
|
||||
func NewUserHandler(app *core.AppServer, db *gorm.DB, searcher *xdb.Searcher) *UserHandler {
|
||||
handler := &UserHandler{db: db, searcher: searcher}
|
||||
handler.app = app
|
||||
handler.config = config
|
||||
handler.App = app
|
||||
return handler
|
||||
}
|
||||
|
||||
@@ -77,11 +77,11 @@ func (h *UserHandler) Register(c *gin.Context) {
|
||||
Status: true,
|
||||
ChatRoles: utils.JsonEncode(roleMap),
|
||||
ChatConfig: utils.JsonEncode(types.ChatConfig{
|
||||
Temperature: h.app.ChatConfig.Temperature,
|
||||
MaxTokens: h.app.ChatConfig.MaxTokens,
|
||||
EnableContext: h.app.ChatConfig.EnableContext,
|
||||
Temperature: h.App.ChatConfig.Temperature,
|
||||
MaxTokens: h.App.ChatConfig.MaxTokens,
|
||||
EnableContext: h.App.ChatConfig.EnableContext,
|
||||
EnableHistory: true,
|
||||
Model: h.app.ChatConfig.Model,
|
||||
Model: h.App.ChatConfig.Model,
|
||||
ApiKey: "",
|
||||
}),
|
||||
}
|
||||
@@ -159,16 +159,15 @@ func (h *UserHandler) Login(c *gin.Context) {
|
||||
h.db.Model(&user).Updates(user)
|
||||
|
||||
sessionId := utils.RandString(42)
|
||||
c.Header(types.TokenSessionName, sessionId)
|
||||
err := utils.SetLoginUser(c, user.Id)
|
||||
err := utils.SetLoginUser(c, user)
|
||||
if err != nil {
|
||||
resp.ERROR(c, "保存会话失败")
|
||||
logger.Error("Error for save session: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 记录登录信息在服务器
|
||||
h.app.ChatSession.Put(sessionId, types.ChatSession{ClientIP: c.ClientIP(), UserId: user.Id, Username: data.Username, SessionId: sessionId})
|
||||
// 记录登录信息在服务端
|
||||
h.App.ChatSession.Put(sessionId, types.ChatSession{ClientIP: c.ClientIP(), UserId: user.Id, Username: data.Username, SessionId: sessionId})
|
||||
|
||||
// 加载用户订阅的聊天角色
|
||||
var roleMap map[string]int
|
||||
@@ -229,17 +228,17 @@ func (h *UserHandler) Login(c *gin.Context) {
|
||||
|
||||
// Logout 注 销
|
||||
func (h *UserHandler) Logout(c *gin.Context) {
|
||||
sessionId := c.GetHeader(types.TokenSessionName)
|
||||
sessionId := c.GetHeader(types.SessionName)
|
||||
session := sessions.Default(c)
|
||||
session.Delete(sessionId)
|
||||
session.Delete(types.SessionUser)
|
||||
err := session.Save()
|
||||
if err != nil {
|
||||
logger.Error("Error for save session: ", err)
|
||||
}
|
||||
// 删除 websocket 会话列表
|
||||
h.app.ChatSession.Delete(sessionId)
|
||||
h.App.ChatSession.Delete(sessionId)
|
||||
// 关闭 socket 连接
|
||||
client := h.app.ChatClients.Get(sessionId)
|
||||
client := h.App.ChatClients.Get(sessionId)
|
||||
if client != nil {
|
||||
client.Close()
|
||||
}
|
||||
@@ -248,8 +247,8 @@ func (h *UserHandler) Logout(c *gin.Context) {
|
||||
|
||||
// Session 获取/验证会话
|
||||
func (h *UserHandler) Session(c *gin.Context) {
|
||||
sessionId := c.GetHeader(types.TokenSessionName)
|
||||
session := h.app.ChatSession.Get(sessionId)
|
||||
sessionId := c.GetHeader(types.SessionName)
|
||||
session := h.App.ChatSession.Get(sessionId)
|
||||
if session.ClientIP == c.ClientIP() {
|
||||
resp.SUCCESS(c, session)
|
||||
} else {
|
||||
|
||||
@@ -4,20 +4,22 @@ import (
|
||||
"chatplus/core"
|
||||
"chatplus/core/types"
|
||||
"chatplus/handler"
|
||||
"chatplus/handler/admin"
|
||||
logger2 "chatplus/logger"
|
||||
"chatplus/service"
|
||||
"chatplus/store"
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"github.com/lionsoul2014/ip2region/binding/golang/xdb"
|
||||
"go.uber.org/fx"
|
||||
"gorm.io/gorm"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/lionsoul2014/ip2region/binding/golang/xdb"
|
||||
"go.uber.org/fx"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
@@ -80,19 +82,15 @@ func main() {
|
||||
fx.Invoke(core.InitChatRoles),
|
||||
|
||||
// 创建控制器
|
||||
fx.Provide(handler.NewAdminHandler),
|
||||
fx.Provide(handler.NewChatRoleHandler),
|
||||
fx.Provide(handler.NewUserHandler),
|
||||
fx.Provide(handler.NewChatHandler),
|
||||
fx.Provide(handler.NewApiKeyHandler),
|
||||
fx.Provide(handler.NewConfigHandler),
|
||||
|
||||
fx.Provide(admin.NewAdminHandler),
|
||||
fx.Provide(admin.NewApiKeyHandler),
|
||||
|
||||
// 注册路由
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.AdminHandler) {
|
||||
group := s.Engine.Group("/api/admin/")
|
||||
group.POST("login", h.Login)
|
||||
group.GET("logout", h.Logout)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) {
|
||||
group := s.Engine.Group("/api/chat/role/")
|
||||
group.GET("list", h.List)
|
||||
@@ -119,11 +117,6 @@ func main() {
|
||||
group.GET("tokens", h.Tokens)
|
||||
group.GET("stop", h.StopGenerate)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.ApiKeyHandler) {
|
||||
group := s.Engine.Group("/api/apikey/")
|
||||
group.POST("add", h.Add)
|
||||
group.GET("list", h.List)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.ConfigHandler) {
|
||||
group := s.Engine.Group("/api/config/")
|
||||
group.POST("update", h.Update)
|
||||
@@ -131,6 +124,17 @@ func main() {
|
||||
group.GET("models", h.AllGptModels)
|
||||
}),
|
||||
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) {
|
||||
group := s.Engine.Group("/api/admin/")
|
||||
group.POST("login", h.Login)
|
||||
group.GET("logout", h.Logout)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.ApiKeyHandler) {
|
||||
group := s.Engine.Group("/api/admin/apikey/")
|
||||
group.POST("add", h.Add)
|
||||
group.GET("list", h.List)
|
||||
}),
|
||||
|
||||
fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
|
||||
err := s.Run(db)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
package param
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func GetTrim(c *gin.Context, key string) string {
|
||||
return strings.TrimSpace(c.Query(key))
|
||||
}
|
||||
|
||||
func GetInt(c *gin.Context, key string, defaultValue int) int {
|
||||
return intValue(c.Query(key), defaultValue)
|
||||
}
|
||||
|
||||
func PostInt(c *gin.Context, key string, defaultValue int) int {
|
||||
return intValue(c.PostForm(key), defaultValue)
|
||||
}
|
||||
|
||||
func intValue(str string, defaultValue int) int {
|
||||
value, err := strconv.Atoi(str)
|
||||
if err != nil {
|
||||
return defaultValue
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func GetFloat(c *gin.Context, key string) float64 {
|
||||
return floatValue(c.Query(key))
|
||||
}
|
||||
func PostFloat(c *gin.Context, key string) float64 {
|
||||
return floatValue(c.PostForm(key))
|
||||
}
|
||||
|
||||
func floatValue(str string) float64 {
|
||||
value, err := strconv.ParseFloat(str, 64)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func GetBool(c *gin.Context, key string) bool {
|
||||
return boolValue(c.Query(key))
|
||||
}
|
||||
func PostBool(c *gin.Context, key string) bool {
|
||||
return boolValue(c.PostForm(key))
|
||||
}
|
||||
|
||||
func boolValue(str string) bool {
|
||||
value, err := strconv.ParseBool(str)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return value
|
||||
}
|
||||
@@ -4,14 +4,22 @@ import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/store/model"
|
||||
"errors"
|
||||
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func SetLoginUser(c *gin.Context, userId uint) error {
|
||||
func SetLoginUser(c *gin.Context, user model.User) error {
|
||||
session := sessions.Default(c)
|
||||
session.Set(types.SessionUserId, userId)
|
||||
session.Set(types.SessionUser, user.Id)
|
||||
// TODO: 后期用户数量增加,考虑将用户数据存储到 leveldb,避免每次查询数据库
|
||||
return session.Save()
|
||||
}
|
||||
|
||||
func SetLoginAdmin(c *gin.Context, admin types.Manager) error {
|
||||
session := sessions.Default(c)
|
||||
session.Set(types.SessionAdmin, admin)
|
||||
return session.Save()
|
||||
}
|
||||
|
||||
@@ -22,7 +30,7 @@ func GetLoginUser(c *gin.Context, db *gorm.DB) (model.User, error) {
|
||||
}
|
||||
|
||||
session := sessions.Default(c)
|
||||
userId := session.Get(types.SessionUserId)
|
||||
userId := session.Get(types.SessionUser)
|
||||
if userId == nil {
|
||||
return model.User{}, errors.New("user not login")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user