refactor: refactor controller handler module and admin module

This commit is contained in:
RockYang 2023-06-19 07:06:59 +08:00
parent cd809d17d3
commit 120e54fb29
22 changed files with 436 additions and 300 deletions

View File

@ -2,7 +2,6 @@ package core
import ( import (
"chatplus/core/types" "chatplus/core/types"
logger2 "chatplus/logger"
"chatplus/store/model" "chatplus/store/model"
"chatplus/utils" "chatplus/utils"
"chatplus/utils/resp" "chatplus/utils/resp"
@ -18,8 +17,6 @@ import (
"strings" "strings"
) )
var logger = logger2.GetLogger()
type AppServer struct { type AppServer struct {
AppConfig *types.AppConfig AppConfig *types.AppConfig
Engine *gin.Engine Engine *gin.Engine
@ -111,7 +108,7 @@ func corsMiddleware() gin.HandlerFunc {
c.Header("Access-Control-Allow-Origin", origin) c.Header("Access-Control-Allow-Origin", origin)
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE") c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE")
//允许跨域设置可以返回其他子段,可以自定义字段 //允许跨域设置可以返回其他子段,可以自定义字段
c.Header("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, ChatGPT-TOKEN, ACCESS-KEY") c.Header("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, ChatGPT-TOKEN, ADMIN-SESSION-TOKEN")
// 允许浏览器(客户端)可以解析的头部 (重要) // 允许浏览器(客户端)可以解析的头部 (重要)
c.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers") c.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers")
//设置缓存时间 //设置缓存时间
@ -138,10 +135,10 @@ func corsMiddleware() gin.HandlerFunc {
func authorizeMiddleware(s *AppServer) gin.HandlerFunc { func authorizeMiddleware(s *AppServer) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
if c.Request.URL.Path == "/api/user/login" || if c.Request.URL.Path == "/api/user/login" ||
c.Request.URL.Path == "/api/admin/login" ||
c.Request.URL.Path == "/api/user/register" || c.Request.URL.Path == "/api/user/register" ||
c.Request.URL.Path == "/api/apikey/add" || c.Request.URL.Path == "/api/apikey/add" ||
//c.Request.URL.Path == "/api/apikey/list" { c.Request.URL.Path == "/api/apikey/list" {
strings.Contains(c.Request.URL.Path, "/api/config/") { // TODO 后台 API 暂时放行,用于调试
c.Next() c.Next()
return return
} }
@ -158,7 +155,12 @@ func authorizeMiddleware(s *AppServer) gin.HandlerFunc {
return return
} }
session := sessions.Default(c) session := sessions.Default(c)
value := session.Get(types.SessionUserId) var value interface{}
if strings.Contains(c.Request.URL.Path, "/api/admin/") {
value = session.Get(types.SessionAdmin)
} else {
value = session.Get(types.SessionUser)
}
if value != nil { if value != nil {
c.Next() c.Next()
} else { } else {

View File

@ -3,12 +3,16 @@ package core
import ( import (
"bytes" "bytes"
"chatplus/core/types" "chatplus/core/types"
logger2 "chatplus/logger"
"chatplus/utils" "chatplus/utils"
"github.com/BurntSushi/toml"
"net/http" "net/http"
"os" "os"
"github.com/BurntSushi/toml"
) )
var logger = logger2.GetLogger()
func NewDefaultConfig() *types.AppConfig { func NewDefaultConfig() *types.AppConfig {
return &types.AppConfig{ return &types.AppConfig{
Listen: "0.0.0.0:5678", Listen: "0.0.0.0:5678",
@ -17,7 +21,7 @@ func NewDefaultConfig() *types.AppConfig {
Session: types.Session{ Session: types.Session{
SecretKey: utils.RandString(64), SecretKey: utils.RandString(64),
Name: "CHAT_SESSION_ID", Name: "CHAT_PLUS_SESSION",
Domain: "", Domain: "",
Path: "/", Path: "/",
MaxAge: 86400, MaxAge: 86400,

View File

@ -1,5 +1,7 @@
package types package types
const TokenSessionName = "ChatGPT-TOKEN" const SessionName = "ChatGPT-TOKEN"
const SessionUserId = "SESSION_USER_ID" const SessionUser = "SESSION_USER" // 存储用户信息的 session key
const LoginUserCache = "LOGIN_USER_CACHE" // const SessionAdmin = "SESSION_ADMIN" //存储管理员信息的 session key
const LoginUserCache = "LOGIN_USER_CACHE" // 已登录用户缓存
const AdminUserCache = "ADMIN_USER_CACHE" // 管理员用户信息缓存

View File

@ -0,0 +1,66 @@
package admin
import (
"chatplus/core"
"chatplus/core/types"
"chatplus/handler"
logger2 "chatplus/logger"
"chatplus/utils/resp"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
var logger = logger2.GetLogger()
type ManagerHandler struct {
handler.BaseHandler
db *gorm.DB
}
func NewAdminHandler(app *core.AppServer, db *gorm.DB) *ManagerHandler {
h := ManagerHandler{db: db}
h.App = app
return &h
}
// Login 登录
func (h *ManagerHandler) Login(c *gin.Context) {
var data types.Manager
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
return
}
manager := h.App.AppConfig.Manager
if data.Username == manager.Username && data.Password == manager.Password {
manager.Password = "" // 清空密码
resp.SUCCESS(c, manager)
} else {
resp.ERROR(c, "用户名或者密码错误")
}
}
// Logout 注销
func (h *ManagerHandler) Logout(c *gin.Context) {
session := sessions.Default(c)
session.Delete(types.SessionAdmin)
err := session.Save()
if err != nil {
resp.ERROR(c, "Save session failed")
} else {
resp.SUCCESS(c)
}
}
// Session 会话检测
func (h *ManagerHandler) Session(c *gin.Context) {
session := sessions.Default(c)
admin := session.Get(types.SessionAdmin)
if admin == nil {
resp.NotAuth(c)
} else {
resp.SUCCESS(c)
}
}

View File

@ -1,28 +1,28 @@
package handler package admin
import ( import (
"chatplus/core" "chatplus/core"
"chatplus/core/types" "chatplus/core/types"
"chatplus/handler"
"chatplus/store/model" "chatplus/store/model"
"chatplus/store/vo" "chatplus/store/vo"
"chatplus/utils" "chatplus/utils"
"chatplus/utils/param"
"chatplus/utils/resp" "chatplus/utils/resp"
"net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
"net/http"
) )
type ApiKeyHandler struct { type ApiKeyHandler struct {
BaseHandler handler.BaseHandler
db *gorm.DB db *gorm.DB
} }
func NewApiKeyHandler(config *types.AppConfig, app *core.AppServer, db *gorm.DB) *ApiKeyHandler { func NewApiKeyHandler(app *core.AppServer, db *gorm.DB) *ApiKeyHandler {
handler := ApiKeyHandler{db: db} h := ApiKeyHandler{db: db}
handler.app = app h.App = app
handler.config = config return &h
return &handler
} }
func (h *ApiKeyHandler) Add(c *gin.Context) { func (h *ApiKeyHandler) Add(c *gin.Context) {
@ -49,8 +49,8 @@ func (h *ApiKeyHandler) Add(c *gin.Context) {
} }
func (h *ApiKeyHandler) List(c *gin.Context) { func (h *ApiKeyHandler) List(c *gin.Context) {
page := param.GetInt(c, "page", 1) page := h.GetInt(c, "page", 1)
pageSize := param.GetInt(c, "page_size", 20) pageSize := h.GetInt(c, "page_size", 20)
offset := (page - 1) * pageSize offset := (page - 1) * pageSize
var items []model.ApiKey var items []model.ApiKey
var keys = make([]vo.ApiKey, 0) var keys = make([]vo.ApiKey, 0)

View File

@ -1,29 +0,0 @@
package handler
import (
"chatplus/core"
"chatplus/core/types"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
type AdminHandler struct {
BaseHandler
db *gorm.DB
}
func NewAdminHandler(config *types.AppConfig, app *core.AppServer, db *gorm.DB) *AdminHandler {
handler := AdminHandler{db: db}
handler.app = app
handler.config = config
return &handler
}
// Login 登录
func (h *AdminHandler) Login(c *gin.Context) {
}
// Logout 注销
func (h *AdminHandler) Logout(c *gin.Context) {
}

View File

@ -2,13 +2,65 @@ package handler
import ( import (
"chatplus/core" "chatplus/core"
"chatplus/core/types"
logger2 "chatplus/logger" logger2 "chatplus/logger"
"strconv"
"strings"
"github.com/gin-gonic/gin"
) )
var logger = logger2.GetLogger() var logger = logger2.GetLogger()
type BaseHandler struct { type BaseHandler struct {
app *core.AppServer App *core.AppServer
config *types.AppConfig }
func (h *BaseHandler) GetTrim(c *gin.Context, key string) string {
return strings.TrimSpace(c.Query(key))
}
func (h *BaseHandler) PostInt(c *gin.Context, key string, defaultValue int) int {
return intValue(c.PostForm(key), defaultValue)
}
func (h *BaseHandler) GetInt(c *gin.Context, key string, defaultValue int) int {
return intValue(c.Query(key), defaultValue)
}
func intValue(str string, defaultValue int) int {
value, err := strconv.Atoi(str)
if err != nil {
return defaultValue
}
return value
}
func (h *BaseHandler) GetFloat(c *gin.Context, key string) float64 {
return floatValue(c.Query(key))
}
func (h *BaseHandler) PostFloat(c *gin.Context, key string) float64 {
return floatValue(c.PostForm(key))
}
func floatValue(str string) float64 {
value, err := strconv.ParseFloat(str, 64)
if err != nil {
return 0
}
return value
}
func (h *BaseHandler) GetBool(c *gin.Context, key string) bool {
return boolValue(c.Query(key))
}
func (h *BaseHandler) PostBool(c *gin.Context, key string) bool {
return boolValue(c.PostForm(key))
}
func boolValue(str string) bool {
value, err := strconv.ParseBool(str)
if err != nil {
return false
}
return value
} }

View File

@ -8,21 +8,21 @@ import (
"chatplus/store/model" "chatplus/store/model"
"chatplus/store/vo" "chatplus/store/vo"
"chatplus/utils" "chatplus/utils"
"chatplus/utils/param"
"chatplus/utils/resp" "chatplus/utils/resp"
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"io" "io"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"time" "time"
"unicode/utf8" "unicode/utf8"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
) )
const ErrorMsg = "抱歉AI 助手开小差了,请马上联系管理员去盘它。" const ErrorMsg = "抱歉AI 助手开小差了,请马上联系管理员去盘它。"
@ -32,12 +32,9 @@ type ChatHandler struct {
db *gorm.DB db *gorm.DB
} }
func NewChatHandler(config *types.AppConfig, func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
app *core.AppServer,
db *gorm.DB) *ChatHandler {
handler := ChatHandler{db: db} handler := ChatHandler{db: db}
handler.app = app handler.App = app
handler.config = config
return &handler return &handler
} }
@ -49,11 +46,11 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
return return
} }
sessionId := c.Query("session_id") sessionId := c.Query("session_id")
roleId := param.GetInt(c, "role_id", 0) roleId := h.GetInt(c, "role_id", 0)
chatId := c.Query("chat_id") chatId := c.Query("chat_id")
chatModel := c.Query("model") chatModel := c.Query("model")
session := h.app.ChatSession.Get(sessionId) session := h.App.ChatSession.Get(sessionId)
if session.SessionId == "" { if session.SessionId == "" {
logger.Info("用户未登录") logger.Info("用户未登录")
c.Abort() c.Abort()
@ -81,21 +78,21 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
} }
// 保存会话连接 // 保存会话连接
h.app.ChatClients.Put(sessionId, client) h.App.ChatClients.Put(sessionId, client)
go func() { go func() {
for { for {
_, message, err := client.Receive() _, message, err := client.Receive()
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
client.Close() client.Close()
h.app.ChatClients.Delete(sessionId) h.App.ChatClients.Delete(sessionId)
h.app.ReqCancelFunc.Delete(sessionId) h.App.ReqCancelFunc.Delete(sessionId)
return return
} }
logger.Info("Receive a message: ", string(message)) logger.Info("Receive a message: ", string(message))
//replyMessage(client, "这是一条测试消息!") //replyMessage(client, "这是一条测试消息!")
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
h.app.ReqCancelFunc.Put(sessionId, cancel) h.App.ReqCancelFunc.Put(sessionId, cancel)
// 回复消息 // 回复消息
err = h.sendMessage(ctx, session, chatRole, string(message), client) err = h.sendMessage(ctx, session, chatRole, string(message), client)
if err != nil { if err != nil {
@ -153,8 +150,8 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
// 加载聊天上下文 // 加载聊天上下文
var chatCtx []types.Message var chatCtx []types.Message
if userVo.ChatConfig.EnableContext { if userVo.ChatConfig.EnableContext {
if h.app.ChatContexts.Has(session.ChatId) { if h.App.ChatContexts.Has(session.ChatId) {
chatCtx = h.app.ChatContexts.Get(session.ChatId) chatCtx = h.App.ChatContexts.Get(session.ChatId)
} else { } else {
// 加载角色信息 // 加载角色信息
var messages []types.Message var messages []types.Message
@ -263,7 +260,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
if userVo.ChatConfig.EnableContext { if userVo.ChatConfig.EnableContext {
chatCtx = append(chatCtx, useMsg) // 提问消息 chatCtx = append(chatCtx, useMsg) // 提问消息
chatCtx = append(chatCtx, message) // 回复消息 chatCtx = append(chatCtx, message) // 回复消息
h.app.ChatContexts.Put(session.ChatId, chatCtx) h.App.ChatContexts.Put(session.ChatId, chatCtx)
} }
// 追加聊天记录 // 追加聊天记录
@ -349,11 +346,11 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session types.ChatSession
} else if strings.Contains(res.Error.Message, "This model's maximum context length") { } else if strings.Contains(res.Error.Message, "This model's maximum context length") {
replyMessage(ws, "当前会话上下文长度超出限制,已为您删减会话上下文!") replyMessage(ws, "当前会话上下文长度超出限制,已为您删减会话上下文!")
// 只保留最近的三条记录 // 只保留最近的三条记录
chatContext := h.app.ChatContexts.Get(session.ChatId) chatContext := h.App.ChatContexts.Get(session.ChatId)
if len(chatContext) > 3 { if len(chatContext) > 3 {
chatContext = chatContext[len(chatContext)-3:] chatContext = chatContext[len(chatContext)-3:]
} }
h.app.ChatContexts.Put(session.ChatId, chatContext) h.App.ChatContexts.Put(session.ChatId, chatContext)
return h.sendMessage(ctx, session, role, prompt, ws) return h.sendMessage(ctx, session, role, prompt, ws)
} else { } else {
replyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message) replyMessage(ws, "请求 OpenAI API 失败:"+res.Error.Message)
@ -372,7 +369,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, user vo.User, apiKey *strin
return nil, err return nil, err
} }
// 创建 HttpClient 请求对象 // 创建 HttpClient 请求对象
request, err := http.NewRequest(http.MethodPost, h.app.ChatConfig.ApiURL, bytes.NewBuffer(requestBody)) request, err := http.NewRequest(http.MethodPost, h.App.ChatConfig.ApiURL, bytes.NewBuffer(requestBody))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -380,7 +377,7 @@ func (h *ChatHandler) doRequest(ctx context.Context, user vo.User, apiKey *strin
request = request.WithContext(ctx) request = request.WithContext(ctx)
request.Header.Add("Content-Type", "application/json") request.Header.Add("Content-Type", "application/json")
proxyURL := h.config.ProxyURL proxyURL := h.App.AppConfig.ProxyURL
if proxyURL == "" { if proxyURL == "" {
client = &http.Client{} client = &http.Client{}
} else { // 使用代理 } else { // 使用代理
@ -448,9 +445,9 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
// StopGenerate 停止生成 // StopGenerate 停止生成
func (h *ChatHandler) StopGenerate(c *gin.Context) { func (h *ChatHandler) StopGenerate(c *gin.Context) {
sessionId := c.Query("session_id") sessionId := c.Query("session_id")
if h.app.ReqCancelFunc.Has(sessionId) { if h.App.ReqCancelFunc.Has(sessionId) {
h.app.ReqCancelFunc.Get(sessionId)() h.App.ReqCancelFunc.Get(sessionId)()
h.app.ReqCancelFunc.Delete(sessionId) h.App.ReqCancelFunc.Delete(sessionId)
} }
resp.SUCCESS(c, types.OkMsg) resp.SUCCESS(c, types.OkMsg)
} }

View File

@ -5,14 +5,14 @@ import (
"chatplus/store/model" "chatplus/store/model"
"chatplus/store/vo" "chatplus/store/vo"
"chatplus/utils" "chatplus/utils"
"chatplus/utils/param"
"chatplus/utils/resp" "chatplus/utils/resp"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
// List 获取会话列表 // List 获取会话列表
func (h *ChatHandler) List(c *gin.Context) { func (h *ChatHandler) List(c *gin.Context) {
userId := param.GetInt(c, "user_id", 0) userId := h.GetInt(c, "user_id", 0)
if userId == 0 { if userId == 0 {
resp.ERROR(c, "The parameter 'user_id' is needed.") resp.ERROR(c, "The parameter 'user_id' is needed.")
return return
@ -71,7 +71,7 @@ func (h *ChatHandler) Update(c *gin.Context) {
// Remove 删除会话 // Remove 删除会话
func (h *ChatHandler) Remove(c *gin.Context) { func (h *ChatHandler) Remove(c *gin.Context) {
chatId := param.GetTrim(c, "chat_id") chatId := h.GetTrim(c, "chat_id")
if chatId == "" { if chatId == "" {
resp.ERROR(c, types.InvalidArgs) resp.ERROR(c, types.InvalidArgs)
return return
@ -89,7 +89,7 @@ func (h *ChatHandler) Remove(c *gin.Context) {
} }
// 清空会话上下文 // 清空会话上下文
h.app.ChatContexts.Delete(chatId) h.App.ChatContexts.Delete(chatId)
resp.SUCCESS(c, types.OkMsg) resp.SUCCESS(c, types.OkMsg)
} }
@ -144,7 +144,7 @@ func (h *ChatHandler) Clear(c *gin.Context) {
logger.Warnf("Failed to delele chat history for ChatID: %s", chat.ChatId) logger.Warnf("Failed to delele chat history for ChatID: %s", chat.ChatId)
} }
// 清空会话上下文 // 清空会话上下文
h.app.ChatContexts.Delete(chat.ChatId) h.App.ChatContexts.Delete(chat.ChatId)
} }
// 删除所有的会话记录 // 删除所有的会话记录
res = h.db.Where("user_id = ?", user.Id).Delete(&model.ChatItem{}) res = h.db.Where("user_id = ?", user.Id).Delete(&model.ChatItem{})

View File

@ -8,8 +8,9 @@ import (
"chatplus/store/vo" "chatplus/store/vo"
"chatplus/utils" "chatplus/utils"
"chatplus/utils/resp" "chatplus/utils/resp"
"github.com/gin-gonic/gin"
"strconv" "strconv"
"github.com/gin-gonic/gin"
) )
type ChatRoleHandler struct { type ChatRoleHandler struct {
@ -17,10 +18,9 @@ type ChatRoleHandler struct {
service *service.ChatRoleService service *service.ChatRoleService
} }
func NewChatRoleHandler(config *types.AppConfig, app *core.AppServer, service *service.ChatRoleService) *ChatRoleHandler { func NewChatRoleHandler(app *core.AppServer, service *service.ChatRoleService) *ChatRoleHandler {
handler := &ChatRoleHandler{service: service} handler := &ChatRoleHandler{service: service}
handler.app = app handler.App = app
handler.config = config
return handler return handler
} }

View File

@ -6,6 +6,7 @@ import (
"chatplus/store/model" "chatplus/store/model"
"chatplus/utils" "chatplus/utils"
"chatplus/utils/resp" "chatplus/utils/resp"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
) )
@ -15,10 +16,9 @@ type ConfigHandler struct {
db *gorm.DB db *gorm.DB
} }
func NewConfigHandler(config *types.AppConfig, app *core.AppServer, db *gorm.DB) *ConfigHandler { func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler {
handler := ConfigHandler{db: db} handler := ConfigHandler{db: db}
handler.app = app handler.App = app
handler.config = config
return &handler return &handler
} }

View File

@ -8,12 +8,13 @@ import (
"chatplus/utils" "chatplus/utils"
"chatplus/utils/resp" "chatplus/utils/resp"
"fmt" "fmt"
"strings"
"time"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/lionsoul2014/ip2region/binding/golang/xdb" "github.com/lionsoul2014/ip2region/binding/golang/xdb"
"gorm.io/gorm" "gorm.io/gorm"
"strings"
"time"
) )
type UserHandler struct { type UserHandler struct {
@ -22,10 +23,9 @@ type UserHandler struct {
searcher *xdb.Searcher searcher *xdb.Searcher
} }
func NewUserHandler(config *types.AppConfig, app *core.AppServer, db *gorm.DB, searcher *xdb.Searcher) *UserHandler { func NewUserHandler(app *core.AppServer, db *gorm.DB, searcher *xdb.Searcher) *UserHandler {
handler := &UserHandler{db: db, searcher: searcher} handler := &UserHandler{db: db, searcher: searcher}
handler.app = app handler.App = app
handler.config = config
return handler return handler
} }
@ -77,11 +77,11 @@ func (h *UserHandler) Register(c *gin.Context) {
Status: true, Status: true,
ChatRoles: utils.JsonEncode(roleMap), ChatRoles: utils.JsonEncode(roleMap),
ChatConfig: utils.JsonEncode(types.ChatConfig{ ChatConfig: utils.JsonEncode(types.ChatConfig{
Temperature: h.app.ChatConfig.Temperature, Temperature: h.App.ChatConfig.Temperature,
MaxTokens: h.app.ChatConfig.MaxTokens, MaxTokens: h.App.ChatConfig.MaxTokens,
EnableContext: h.app.ChatConfig.EnableContext, EnableContext: h.App.ChatConfig.EnableContext,
EnableHistory: true, EnableHistory: true,
Model: h.app.ChatConfig.Model, Model: h.App.ChatConfig.Model,
ApiKey: "", ApiKey: "",
}), }),
} }
@ -159,16 +159,15 @@ func (h *UserHandler) Login(c *gin.Context) {
h.db.Model(&user).Updates(user) h.db.Model(&user).Updates(user)
sessionId := utils.RandString(42) sessionId := utils.RandString(42)
c.Header(types.TokenSessionName, sessionId) err := utils.SetLoginUser(c, user)
err := utils.SetLoginUser(c, user.Id)
if err != nil { if err != nil {
resp.ERROR(c, "保存会话失败") resp.ERROR(c, "保存会话失败")
logger.Error("Error for save session: ", err) logger.Error("Error for save session: ", err)
return return
} }
// 记录登录信息在服务 // 记录登录信息在服务
h.app.ChatSession.Put(sessionId, types.ChatSession{ClientIP: c.ClientIP(), UserId: user.Id, Username: data.Username, SessionId: sessionId}) h.App.ChatSession.Put(sessionId, types.ChatSession{ClientIP: c.ClientIP(), UserId: user.Id, Username: data.Username, SessionId: sessionId})
// 加载用户订阅的聊天角色 // 加载用户订阅的聊天角色
var roleMap map[string]int var roleMap map[string]int
@ -229,17 +228,17 @@ func (h *UserHandler) Login(c *gin.Context) {
// Logout 注 销 // Logout 注 销
func (h *UserHandler) Logout(c *gin.Context) { func (h *UserHandler) Logout(c *gin.Context) {
sessionId := c.GetHeader(types.TokenSessionName) sessionId := c.GetHeader(types.SessionName)
session := sessions.Default(c) session := sessions.Default(c)
session.Delete(sessionId) session.Delete(types.SessionUser)
err := session.Save() err := session.Save()
if err != nil { if err != nil {
logger.Error("Error for save session: ", err) logger.Error("Error for save session: ", err)
} }
// 删除 websocket 会话列表 // 删除 websocket 会话列表
h.app.ChatSession.Delete(sessionId) h.App.ChatSession.Delete(sessionId)
// 关闭 socket 连接 // 关闭 socket 连接
client := h.app.ChatClients.Get(sessionId) client := h.App.ChatClients.Get(sessionId)
if client != nil { if client != nil {
client.Close() client.Close()
} }
@ -248,8 +247,8 @@ func (h *UserHandler) Logout(c *gin.Context) {
// Session 获取/验证会话 // Session 获取/验证会话
func (h *UserHandler) Session(c *gin.Context) { func (h *UserHandler) Session(c *gin.Context) {
sessionId := c.GetHeader(types.TokenSessionName) sessionId := c.GetHeader(types.SessionName)
session := h.app.ChatSession.Get(sessionId) session := h.App.ChatSession.Get(sessionId)
if session.ClientIP == c.ClientIP() { if session.ClientIP == c.ClientIP() {
resp.SUCCESS(c, session) resp.SUCCESS(c, session)
} else { } else {

View File

@ -4,20 +4,22 @@ import (
"chatplus/core" "chatplus/core"
"chatplus/core/types" "chatplus/core/types"
"chatplus/handler" "chatplus/handler"
"chatplus/handler/admin"
logger2 "chatplus/logger" logger2 "chatplus/logger"
"chatplus/service" "chatplus/service"
"chatplus/store" "chatplus/store"
"context" "context"
"flag" "flag"
"fmt" "fmt"
"github.com/lionsoul2014/ip2region/binding/golang/xdb"
"go.uber.org/fx"
"gorm.io/gorm"
"log" "log"
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
"time" "time"
"github.com/lionsoul2014/ip2region/binding/golang/xdb"
"go.uber.org/fx"
"gorm.io/gorm"
) )
var logger = logger2.GetLogger() var logger = logger2.GetLogger()
@ -80,19 +82,15 @@ func main() {
fx.Invoke(core.InitChatRoles), fx.Invoke(core.InitChatRoles),
// 创建控制器 // 创建控制器
fx.Provide(handler.NewAdminHandler),
fx.Provide(handler.NewChatRoleHandler), fx.Provide(handler.NewChatRoleHandler),
fx.Provide(handler.NewUserHandler), fx.Provide(handler.NewUserHandler),
fx.Provide(handler.NewChatHandler), fx.Provide(handler.NewChatHandler),
fx.Provide(handler.NewApiKeyHandler),
fx.Provide(handler.NewConfigHandler), fx.Provide(handler.NewConfigHandler),
fx.Provide(admin.NewAdminHandler),
fx.Provide(admin.NewApiKeyHandler),
// 注册路由 // 注册路由
fx.Invoke(func(s *core.AppServer, h *handler.AdminHandler) {
group := s.Engine.Group("/api/admin/")
group.POST("login", h.Login)
group.GET("logout", h.Logout)
}),
fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) { fx.Invoke(func(s *core.AppServer, h *handler.ChatRoleHandler) {
group := s.Engine.Group("/api/chat/role/") group := s.Engine.Group("/api/chat/role/")
group.GET("list", h.List) group.GET("list", h.List)
@ -119,11 +117,6 @@ func main() {
group.GET("tokens", h.Tokens) group.GET("tokens", h.Tokens)
group.GET("stop", h.StopGenerate) group.GET("stop", h.StopGenerate)
}), }),
fx.Invoke(func(s *core.AppServer, h *handler.ApiKeyHandler) {
group := s.Engine.Group("/api/apikey/")
group.POST("add", h.Add)
group.GET("list", h.List)
}),
fx.Invoke(func(s *core.AppServer, h *handler.ConfigHandler) { fx.Invoke(func(s *core.AppServer, h *handler.ConfigHandler) {
group := s.Engine.Group("/api/config/") group := s.Engine.Group("/api/config/")
group.POST("update", h.Update) group.POST("update", h.Update)
@ -131,6 +124,17 @@ func main() {
group.GET("models", h.AllGptModels) group.GET("models", h.AllGptModels)
}), }),
fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) {
group := s.Engine.Group("/api/admin/")
group.POST("login", h.Login)
group.GET("logout", h.Logout)
}),
fx.Invoke(func(s *core.AppServer, h *admin.ApiKeyHandler) {
group := s.Engine.Group("/api/admin/apikey/")
group.POST("add", h.Add)
group.GET("list", h.List)
}),
fx.Invoke(func(s *core.AppServer, db *gorm.DB) { fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
err := s.Run(db) err := s.Run(db)
if err != nil { if err != nil {

View File

@ -1,57 +0,0 @@
package param
import (
"github.com/gin-gonic/gin"
"strconv"
"strings"
)
func GetTrim(c *gin.Context, key string) string {
return strings.TrimSpace(c.Query(key))
}
func GetInt(c *gin.Context, key string, defaultValue int) int {
return intValue(c.Query(key), defaultValue)
}
func PostInt(c *gin.Context, key string, defaultValue int) int {
return intValue(c.PostForm(key), defaultValue)
}
func intValue(str string, defaultValue int) int {
value, err := strconv.Atoi(str)
if err != nil {
return defaultValue
}
return value
}
func GetFloat(c *gin.Context, key string) float64 {
return floatValue(c.Query(key))
}
func PostFloat(c *gin.Context, key string) float64 {
return floatValue(c.PostForm(key))
}
func floatValue(str string) float64 {
value, err := strconv.ParseFloat(str, 64)
if err != nil {
return 0
}
return value
}
func GetBool(c *gin.Context, key string) bool {
return boolValue(c.Query(key))
}
func PostBool(c *gin.Context, key string) bool {
return boolValue(c.PostForm(key))
}
func boolValue(str string) bool {
value, err := strconv.ParseBool(str)
if err != nil {
return false
}
return value
}

View File

@ -4,14 +4,22 @@ import (
"chatplus/core/types" "chatplus/core/types"
"chatplus/store/model" "chatplus/store/model"
"errors" "errors"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"gorm.io/gorm" "gorm.io/gorm"
) )
func SetLoginUser(c *gin.Context, userId uint) error { func SetLoginUser(c *gin.Context, user model.User) error {
session := sessions.Default(c) session := sessions.Default(c)
session.Set(types.SessionUserId, userId) session.Set(types.SessionUser, user.Id)
// TODO: 后期用户数量增加,考虑将用户数据存储到 leveldb避免每次查询数据库
return session.Save()
}
func SetLoginAdmin(c *gin.Context, admin types.Manager) error {
session := sessions.Default(c)
session.Set(types.SessionAdmin, admin)
return session.Save() return session.Save()
} }
@ -22,7 +30,7 @@ func GetLoginUser(c *gin.Context, db *gorm.DB) (model.User, error) {
} }
session := sessions.Default(c) session := sessions.Default(c)
userId := session.Get(types.SessionUserId) userId := session.Get(types.SessionUser)
if userId == nil { if userId == nil {
return model.User{}, errors.New("user not login") return model.User{}, errors.New("user not login")
} }

View File

@ -5,7 +5,7 @@
<meta charset="utf-8"> <meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge"> <meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width,initial-scale=1.0"> <meta name="viewport" content="width=device-width,initial-scale=1.0">
<link rel="icon" href="favicon.ico" type="image/x-icon"> <link rel="icon" href="/favicon.ico" type="image/x-icon">
<title>ChatGPT-Plus</title> <title>ChatGPT-Plus</title>
</head> </head>

Binary file not shown.

After

Width:  |  Height:  |  Size: 101 KiB

View File

@ -7,47 +7,20 @@ import ChatPlus from "@/views/ChatPlus.vue";
import NotFound from './views/404.vue' import NotFound from './views/404.vue'
import TestPage from './views/Test.vue' import TestPage from './views/Test.vue'
import Home from "@/views/Home.vue"; import Home from "@/views/Home.vue";
import Admin from "@/views/Admin.vue"; import Admin from "@/views/admin/Admin.vue";
import Login from "@/views/Login.vue" import Login from "@/views/Login.vue"
import Register from "@/views/Register.vue"; import Register from "@/views/Register.vue";
import AdminLogin from "@/views/admin/Login.vue"
const routes = [ const routes = [
{ {name: 'home', path: '/', component: Home, meta: {title: 'ChatGPT-Plus'}},
name: 'home', path: '/', component: Home, meta: { {name: 'login', path: '/login', component: Login, meta: {title: '用户登录'}},
title: 'ChatGPT-Plus' {name: 'register', path: '/register', component: Register, meta: {title: '用户注册'}},
} {name: 'plus', path: '/chat', component: ChatPlus, meta: {title: 'ChatGPT-智能助手V3'}},
}, {name: 'admin', path: '/admin', component: Admin, meta: {title: 'Chat-Plus 控制台'}},
{ {name: 'admin/login', path: '/admin/login', component: AdminLogin, meta: {title: 'Chat-Plus 控制台登录'}},
name: 'login', path: '/login', component: Login, meta: { {name: 'test', path: '/test', component: TestPage, meta: {title: '测试页面'}},
title: '用户登录' {name: 'NotFound', path: '/:all(.*)', component: NotFound, meta: {title: '页面没有找到'}},
}
},
{
name: 'register', path: '/register', component: Register, meta: {
title: '用户注册'
}
},
{
name: 'plus', path: '/chat', component: ChatPlus, meta: {
title: 'ChatGPT-智能助手V3'
}
},
{
name: 'admin', path: '/admin', component: Admin, meta: {
title: 'Chat-Plus 控制台'
}
},
{
name: 'test', path: '/test', component: TestPage, meta: {
title: '测试页面'
}
},
{
name: 'NotFound', path: '/:all(.*)', component: NotFound, meta: {
title: '页面没有找到'
}
},
] ]
const router = createRouter({ const router = createRouter({

View File

@ -1,20 +1,26 @@
/* eslint-disable no-constant-condition */ /* eslint-disable no-constant-condition */
import {dateFormat, removeArrayItem} from "@/utils/libs";
import Storage from 'good-storage'
/** /**
* storage handler * storage handler
*/ */
const SessionUserKey = 'LOGIN_USER'; const SessionUserKey = 'LOGIN_USER';
const ChatHistoryKey = 'CHAT_HISTORY'; const SessionAdminKey = 'LOGIN_ADMIN';
const ChatListKey = 'CHAT_LIST';
export function getSessionId() { export function getSessionId() {
const user = getLoginUser(); const user = getLoginUser();
return user ? user['session_id'] : ''; return user ? user['session_id'] : '';
} }
export function getLoginAdmin() {
const value = sessionStorage.getItem(SessionAdminKey);
if (value) {
return JSON.parse(value);
} else {
return null;
}
}
export function getLoginUser() { export function getLoginUser() {
const value = sessionStorage.getItem(SessionUserKey); const value = sessionStorage.getItem(SessionUserKey);
if (value) { if (value) {
@ -28,73 +34,6 @@ export function setLoginUser(user) {
sessionStorage.setItem(SessionUserKey, JSON.stringify(user)) sessionStorage.setItem(SessionUserKey, JSON.stringify(user))
} }
export function getUserInfo() { export function setLoginAdmin(admin) {
const data = getLoginUser(); sessionStorage.setItem(SessionAdminKey, JSON.stringify(admin))
if (data !== null) {
const user = data["user"];
user['active_time'] = dateFormat(user['active_time']);
user['expired_time'] = dateFormat(user['expired_time']);
return user;
}
return {}
}
// 追加历史记录
export function appendChatHistory(chatId, message) {
let history = Storage.get(ChatHistoryKey);
if (!history) {
history = {};
}
if (!history[chatId]) {
history[chatId] = [message];
} else {
history[chatId].push(message);
}
Storage.set(ChatHistoryKey, history);
}
export function clearChatHistory() {
Storage.remove(ChatHistoryKey);
Storage.remove(ChatListKey);
}
// 获取指定会话的历史记录
export function getChatHistory(chatId) {
const history = Storage.get(ChatHistoryKey);
if (!history) {
return null;
}
return history[chatId] ? history[chatId] : null;
}
export function getChatList() {
const list = Storage.get(ChatListKey);
if (list) {
if (typeof list.reverse !== 'function') {
Storage.remove(ChatListKey)
return null;
}
return list.reverse();
}
}
export function setChat(chat) {
let chatList = Storage.get(ChatListKey);
if (!chatList) {
chatList = [];
}
chatList.push(chat);
Storage.set(ChatListKey, chatList);
}
export function removeChat(chat) {
let chatList = Storage.get(ChatListKey);
if (chatList) {
chatList = removeArrayItem(chatList, chat, function (v1, v2) {
return v1.id === v2.id
})
Storage.set(ChatListKey, chatList);
}
} }

View File

@ -140,13 +140,14 @@ const validateMobile = function (mobile) {
width 90% width 90%
max-width 400px max-width 400px
transform translate(-50%, -50%) transform translate(-50%, -50%)
padding 20px; padding 20px 40px;
color #ffffff color #ffffff
border-radius 10px; border-radius 10px;
background rgba(255, 255, 255, 0.3) background rgba(255, 255, 255, 0.3)
.logo { .logo {
text-align center text-align center
.el-image { .el-image {
width 120px; width 120px;
} }

View File

@ -0,0 +1,175 @@
<template>
<div>
<div class="bg"></div>
<div class="main">
<div class="contain">
<div class="logo">
<el-image src="../images/logo.png" fit="cover"/>
</div>
<div class="header">{{ title }}</div>
<div class="content">
<div class="block">
<el-input placeholder="请输入用户名" size="large" v-model="username" autocomplete="off">
<template #prefix>
<el-icon>
<UserFilled/>
</el-icon>
</template>
</el-input>
</div>
<div class="block">
<el-input placeholder="请输入密码" size="large" v-model="password" show-password autocomplete="off">
<template #prefix>
<el-icon>
<Lock/>
</el-icon>
</template>
</el-input>
</div>
<el-row class="btn-row">
<el-button class="login-btn" size="large" type="primary" @click="login">登录</el-button>
</el-row>
</div>
</div>
<footer class="footer">
<footer-bar/>
</footer>
</div>
</div>
</template>
<script setup>
import {onMounted, ref} from "vue";
import {Lock, UserFilled} from "@element-plus/icons-vue";
import {httpPost} from "@/utils/http";
import {ElMessage} from "element-plus";
import {setLoginUser} from "@/utils/storage";
import {useRouter} from "vue-router";
import FooterBar from "@/components/FooterBar.vue";
const router = useRouter();
const title = ref('ChatGPT-PLUS 控制台登录');
const username = ref(process.env.VUE_APP_ADMIN_USER);
const password = ref(process.env.VUE_APP_ADMIN_PASS);
onMounted(() => {
document.addEventListener('keyup', (e) => {
if (e.key === 'Enter') {
login();
}
});
})
const login = function () {
if (username.value === '') {
return ElMessage.error('请输入用户名');
}
if (password.value.trim() === '') {
return ElMessage.error('请输入密码');
}
httpPost('/api/admin/login', {username: username.value.trim(), password: password.value.trim()}).then((res) => {
setLoginUser(res.data)
router.push("admin")
}).catch((e) => {
ElMessage.error('登录失败,' + e.message)
})
}
</script>
<style lang="stylus" scoped>
.bg {
position fixed
left 0
right 0
top 0
bottom 0
background-color #313237
background-image url("~@/assets/img/admin-login-bg.jpg")
background-size cover
background-position center
background-repeat no-repeat
filter: blur(10px); /* 调整模糊程度,可以根据需要修改值 */
}
.main {
.contain {
position fixed
left 50%
top 40%
width 90%
max-width 400px;
transform translate(-50%, -50%)
padding 20px 40px;
color #ffffff
border-radius 10px;
background rgba(255, 255, 255, 0.3)
.logo {
text-align center
.el-image {
width 120px;
}
}
.header {
width 100%
margin-bottom 24px
font-size 24px
color $white_v1
letter-space 2px
text-align center
}
.content {
width 100%
height: auto
border-radius 3px
.block {
margin-bottom 16px
.el-input__inner {
border 1px solid $gray-v6 !important
.el-icon-user, .el-icon-lock {
font-size 20px
}
}
}
.btn-row {
padding-top 10px;
.login-btn {
width 100%
font-size 16px
letter-spacing 2px
}
}
.text-line {
justify-content center
padding-top 10px;
font-size 14px;
}
}
}
.footer {
color #ffffff;
.container {
padding 20px;
}
}
}
</style>