mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 16:23:42 +08:00 
			
		
		
		
	fix: replace session handler with jwt authorization
This commit is contained in:
		@@ -7,16 +7,16 @@ import (
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"context"
 | 
			
		||||
	"github.com/gin-contrib/sessions"
 | 
			
		||||
	"github.com/gin-contrib/sessions/cookie"
 | 
			
		||||
	"github.com/gin-contrib/sessions/memstore"
 | 
			
		||||
	"github.com/gin-contrib/sessions/redis"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"github.com/golang-jwt/jwt/v5"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"runtime/debug"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type AppServer struct {
 | 
			
		||||
@@ -53,14 +53,13 @@ func NewServer(appConfig *types.AppConfig, functions map[string]function.Functio
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *AppServer) Init(debug bool) {
 | 
			
		||||
func (s *AppServer) Init(debug bool, client *redis.Client) {
 | 
			
		||||
	if debug { // 调试模式允许跨域请求 API
 | 
			
		||||
		s.Debug = debug
 | 
			
		||||
		logger.Info("Enabled debug mode")
 | 
			
		||||
	}
 | 
			
		||||
	s.Engine.Use(corsMiddleware())
 | 
			
		||||
	s.Engine.Use(sessionMiddleware(s.Config))
 | 
			
		||||
	s.Engine.Use(authorizeMiddleware(s))
 | 
			
		||||
	s.Engine.Use(authorizeMiddleware(s, client))
 | 
			
		||||
	s.Engine.Use(errorHandler)
 | 
			
		||||
	// 添加静态资源访问
 | 
			
		||||
	s.Engine.Static("/static", s.Config.StaticDir)
 | 
			
		||||
@@ -105,42 +104,6 @@ func errorHandler(c *gin.Context) {
 | 
			
		||||
	c.Next()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 会话处理
 | 
			
		||||
func sessionMiddleware(config *types.AppConfig) gin.HandlerFunc {
 | 
			
		||||
	// encrypt the cookie
 | 
			
		||||
	var store sessions.Store
 | 
			
		||||
	var err error
 | 
			
		||||
	switch config.Session.Driver {
 | 
			
		||||
	case types.SessionDriverMem:
 | 
			
		||||
		store = memstore.NewStore([]byte(config.Session.SecretKey))
 | 
			
		||||
		break
 | 
			
		||||
	case types.SessionDriverRedis:
 | 
			
		||||
		store, err = redis.NewStore(10, "tcp", config.Redis.Url(), config.Redis.Password, []byte(config.Session.SecretKey))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Fatal(err)
 | 
			
		||||
		}
 | 
			
		||||
		break
 | 
			
		||||
	case types.SessionDriverCookie:
 | 
			
		||||
		store = cookie.NewStore([]byte(config.Session.SecretKey))
 | 
			
		||||
		break
 | 
			
		||||
	default:
 | 
			
		||||
		config.Session.Driver = types.SessionDriverCookie
 | 
			
		||||
		store = cookie.NewStore([]byte(config.Session.SecretKey))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logger.Info("Session driver: ", config.Session.Driver)
 | 
			
		||||
 | 
			
		||||
	store.Options(sessions.Options{
 | 
			
		||||
		Path:     config.Session.Path,
 | 
			
		||||
		Domain:   config.Session.Domain,
 | 
			
		||||
		MaxAge:   config.Session.MaxAge,
 | 
			
		||||
		Secure:   config.Session.Secure,
 | 
			
		||||
		HttpOnly: config.Session.HttpOnly,
 | 
			
		||||
		SameSite: config.Session.SameSite,
 | 
			
		||||
	})
 | 
			
		||||
	return sessions.Sessions(config.Session.Name, store)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 跨域中间件设置
 | 
			
		||||
func corsMiddleware() gin.HandlerFunc {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
@@ -151,7 +114,7 @@ func corsMiddleware() gin.HandlerFunc {
 | 
			
		||||
			c.Header("Access-Control-Allow-Origin", origin)
 | 
			
		||||
			c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE")
 | 
			
		||||
			//允许跨域设置可以返回其他子段,可以自定义字段
 | 
			
		||||
			c.Header("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, Chat-Token")
 | 
			
		||||
			c.Header("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, Chat-Token, Admin-Authorization")
 | 
			
		||||
			// 允许浏览器(客户端)可以解析的头部 (重要)
 | 
			
		||||
			c.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers")
 | 
			
		||||
			//设置缓存时间
 | 
			
		||||
@@ -175,7 +138,7 @@ func corsMiddleware() gin.HandlerFunc {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 用户授权验证
 | 
			
		||||
func authorizeMiddleware(s *AppServer) gin.HandlerFunc {
 | 
			
		||||
func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		if c.Request.URL.Path == "/api/user/login" ||
 | 
			
		||||
			c.Request.URL.Path == "/api/admin/login" ||
 | 
			
		||||
@@ -190,29 +153,54 @@ func authorizeMiddleware(s *AppServer) gin.HandlerFunc {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// WebSocket 连接请求验证
 | 
			
		||||
		if c.Request.URL.Path == "/api/chat" {
 | 
			
		||||
			sessionId := c.Query("sessionId")
 | 
			
		||||
			session := s.ChatSession.Get(sessionId)
 | 
			
		||||
			if session.ClientIP == c.ClientIP() {
 | 
			
		||||
				c.Next()
 | 
			
		||||
			} else {
 | 
			
		||||
				c.Abort()
 | 
			
		||||
			}
 | 
			
		||||
		var tokenString string
 | 
			
		||||
		if strings.Contains(c.Request.URL.Path, "/api/admin/") { // 后台管理 API
 | 
			
		||||
			tokenString = c.GetHeader(types.AdminAuthHeader)
 | 
			
		||||
		} else if c.Request.URL.Path == "/api/chat/new" {
 | 
			
		||||
			tokenString = c.Query("token")
 | 
			
		||||
		} else {
 | 
			
		||||
			tokenString = c.GetHeader(types.UserAuthHeader)
 | 
			
		||||
		}
 | 
			
		||||
		if tokenString == "" {
 | 
			
		||||
			resp.ERROR(c, "You should put Authorization in request headers")
 | 
			
		||||
			c.Abort()
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		session := sessions.Default(c)
 | 
			
		||||
		var value interface{}
 | 
			
		||||
		if strings.Contains(c.Request.URL.Path, "/api/admin/") { // 后台管理 API
 | 
			
		||||
			value = session.Get(types.SessionAdmin)
 | 
			
		||||
		} else {
 | 
			
		||||
			value = session.Get(types.SessionUser)
 | 
			
		||||
		}
 | 
			
		||||
		if value != nil {
 | 
			
		||||
			c.Next()
 | 
			
		||||
		} else {
 | 
			
		||||
			resp.NotAuth(c)
 | 
			
		||||
 | 
			
		||||
		token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
 | 
			
		||||
			if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
 | 
			
		||||
				return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			return []byte(s.Config.Session.SecretKey), nil
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			resp.NotAuth(c, fmt.Sprintf("Error with parse auth token: %v", err))
 | 
			
		||||
			c.Abort()
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		claims, ok := token.Claims.(jwt.MapClaims)
 | 
			
		||||
		if !ok || !token.Valid {
 | 
			
		||||
			resp.NotAuth(c, "Token is invalid")
 | 
			
		||||
			c.Abort()
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
 | 
			
		||||
		if expr > 0 && int64(expr) < time.Now().Unix() {
 | 
			
		||||
			resp.NotAuth(c, "Token is expired")
 | 
			
		||||
			c.Abort()
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		key := fmt.Sprintf("users/%v", claims["user_id"])
 | 
			
		||||
		if _, err := client.Get(context.Background(), key).Result(); err != nil {
 | 
			
		||||
			resp.NotAuth(c, "Token is not found in redis")
 | 
			
		||||
			c.Abort()
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		c.Set(types.LoginUserID, claims["user_id"])
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -5,7 +5,6 @@ import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	logger2 "chatplus/logger"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
 | 
			
		||||
	"github.com/BurntSushi/toml"
 | 
			
		||||
@@ -23,15 +22,8 @@ func NewDefaultConfig() *types.AppConfig {
 | 
			
		||||
		Redis:         types.RedisConfig{Host: "localhost", Port: 6379, Password: ""},
 | 
			
		||||
		AesEncryptKey: utils.RandString(24),
 | 
			
		||||
		Session: types.Session{
 | 
			
		||||
			Driver:    types.SessionDriverCookie,
 | 
			
		||||
			SecretKey: utils.RandString(64),
 | 
			
		||||
			Name:      "CHAT_PLUS_SESSION",
 | 
			
		||||
			Domain:    "",
 | 
			
		||||
			Path:      "/",
 | 
			
		||||
			MaxAge:    86400,
 | 
			
		||||
			Secure:    true,
 | 
			
		||||
			HttpOnly:  false,
 | 
			
		||||
			SameSite:  http.SameSiteLaxMode,
 | 
			
		||||
		},
 | 
			
		||||
		ApiConfig: types.ChatPlusApiConfig{},
 | 
			
		||||
		ExtConfig: types.ChatPlusExtConfig{Token: utils.RandString(32)},
 | 
			
		||||
 
 | 
			
		||||
@@ -2,7 +2,6 @@ package types
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"net/http"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type AppConfig struct {
 | 
			
		||||
@@ -85,19 +84,6 @@ const (
 | 
			
		||||
	SessionDriverCookie = SessionDriver("cookie")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Session configs struct
 | 
			
		||||
type Session struct {
 | 
			
		||||
	Driver    SessionDriver // session 存储驱动 mem|cookie|redis
 | 
			
		||||
	SecretKey string        // session encryption key
 | 
			
		||||
	Name      string
 | 
			
		||||
	Path      string
 | 
			
		||||
	Domain    string
 | 
			
		||||
	MaxAge    int
 | 
			
		||||
	Secure    bool
 | 
			
		||||
	HttpOnly  bool
 | 
			
		||||
	SameSite  http.SameSite
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ChatConfig 系统默认的聊天配置
 | 
			
		||||
type ChatConfig struct {
 | 
			
		||||
	OpenAI  ModelAPIConfig `json:"open_ai"`
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,14 @@
 | 
			
		||||
package types
 | 
			
		||||
 | 
			
		||||
const SessionName = "ChatGPT-TOKEN"
 | 
			
		||||
const SessionUser = "SESSION_USER"        // 存储用户信息的 session key
 | 
			
		||||
const SessionAdmin = "SESSION_ADMIN"      //存储管理员信息的 session key
 | 
			
		||||
const LoginUserCache = "LOGIN_USER_CACHE" // 已登录用户缓存
 | 
			
		||||
const LoginUserID = "LOGIN_USER_ID"
 | 
			
		||||
const LoginUserCache = "LOGIN_USER_CACHE"
 | 
			
		||||
 | 
			
		||||
const UserAuthHeader = "Authorization"
 | 
			
		||||
const AdminAuthHeader = "Admin-Authorization"
 | 
			
		||||
const ChatTokenHeader = "Chat-Token"
 | 
			
		||||
 | 
			
		||||
// Session configs struct
 | 
			
		||||
type Session struct {
 | 
			
		||||
	SecretKey string
 | 
			
		||||
	MaxAge    int
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -8,9 +8,11 @@ import (
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"context"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"github.com/golang-jwt/jwt/v5"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-contrib/sessions"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
@@ -20,11 +22,12 @@ var logger = logger2.GetLogger()
 | 
			
		||||
 | 
			
		||||
type ManagerHandler struct {
 | 
			
		||||
	handler.BaseHandler
 | 
			
		||||
	db *gorm.DB
 | 
			
		||||
	db    *gorm.DB
 | 
			
		||||
	redis *redis.Client
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewAdminHandler(app *core.AppServer, db *gorm.DB) *ManagerHandler {
 | 
			
		||||
	h := ManagerHandler{db: db}
 | 
			
		||||
func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client) *ManagerHandler {
 | 
			
		||||
	h := ManagerHandler{db: db, redis: client}
 | 
			
		||||
	h.App = app
 | 
			
		||||
	return &h
 | 
			
		||||
}
 | 
			
		||||
@@ -38,13 +41,22 @@ func (h *ManagerHandler) Login(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
	manager := h.App.Config.Manager
 | 
			
		||||
	if data.Username == manager.Username && data.Password == manager.Password {
 | 
			
		||||
		err := utils.SetLoginAdmin(c, manager)
 | 
			
		||||
		// 创建 token
 | 
			
		||||
		token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
 | 
			
		||||
			"user_id": manager.Username,
 | 
			
		||||
			"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)),
 | 
			
		||||
		})
 | 
			
		||||
		tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			resp.ERROR(c, "Save session failed")
 | 
			
		||||
			resp.ERROR(c, "Failed to generate token, "+err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		manager.Password = "" // 清空密码]
 | 
			
		||||
		resp.SUCCESS(c, manager)
 | 
			
		||||
		// 保存到 redis
 | 
			
		||||
		if _, err := h.redis.Set(context.Background(), "users/"+manager.Username, tokenString, 0).Result(); err != nil {
 | 
			
		||||
			resp.ERROR(c, "error with save token: "+err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		resp.SUCCESS(c, tokenString)
 | 
			
		||||
	} else {
 | 
			
		||||
		resp.ERROR(c, "用户名或者密码错误")
 | 
			
		||||
	}
 | 
			
		||||
@@ -52,11 +64,9 @@ func (h *ManagerHandler) Login(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
// Logout 注销
 | 
			
		||||
func (h *ManagerHandler) Logout(c *gin.Context) {
 | 
			
		||||
	session := sessions.Default(c)
 | 
			
		||||
	session.Delete(types.SessionAdmin)
 | 
			
		||||
	err := session.Save()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, "Save session failed")
 | 
			
		||||
	token := c.GetHeader(types.AdminAuthHeader)
 | 
			
		||||
	if _, err := h.redis.Del(c, token).Result(); err != nil {
 | 
			
		||||
		logger.Error("error with delete session: ", err)
 | 
			
		||||
	} else {
 | 
			
		||||
		resp.SUCCESS(c)
 | 
			
		||||
	}
 | 
			
		||||
@@ -64,9 +74,8 @@ func (h *ManagerHandler) Logout(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
// Session 会话检测
 | 
			
		||||
func (h *ManagerHandler) Session(c *gin.Context) {
 | 
			
		||||
	session := sessions.Default(c)
 | 
			
		||||
	admin := session.Get(types.SessionAdmin)
 | 
			
		||||
	if admin == nil {
 | 
			
		||||
	token := c.GetHeader(types.AdminAuthHeader)
 | 
			
		||||
	if token == "" {
 | 
			
		||||
		resp.NotAuth(c)
 | 
			
		||||
	} else {
 | 
			
		||||
		resp.SUCCESS(c)
 | 
			
		||||
 
 | 
			
		||||
@@ -3,7 +3,7 @@ package handler
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core"
 | 
			
		||||
	logger2 "chatplus/logger"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
@@ -20,47 +20,23 @@ func (h *BaseHandler) GetTrim(c *gin.Context, key string) string {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *BaseHandler) PostInt(c *gin.Context, key string, defaultValue int) int {
 | 
			
		||||
	return intValue(c.PostForm(key), defaultValue)
 | 
			
		||||
	return utils.IntValue(c.PostForm(key), defaultValue)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *BaseHandler) GetInt(c *gin.Context, key string, defaultValue int) int {
 | 
			
		||||
	return intValue(c.Query(key), defaultValue)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func intValue(str string, defaultValue int) int {
 | 
			
		||||
	value, err := strconv.Atoi(str)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return defaultValue
 | 
			
		||||
	}
 | 
			
		||||
	return value
 | 
			
		||||
	return utils.IntValue(c.Query(key), defaultValue)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *BaseHandler) GetFloat(c *gin.Context, key string) float64 {
 | 
			
		||||
	return floatValue(c.Query(key))
 | 
			
		||||
	return utils.FloatValue(c.Query(key))
 | 
			
		||||
}
 | 
			
		||||
func (h *BaseHandler) PostFloat(c *gin.Context, key string) float64 {
 | 
			
		||||
	return floatValue(c.PostForm(key))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func floatValue(str string) float64 {
 | 
			
		||||
	value, err := strconv.ParseFloat(str, 64)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0
 | 
			
		||||
	}
 | 
			
		||||
	return value
 | 
			
		||||
	return utils.FloatValue(c.PostForm(key))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *BaseHandler) GetBool(c *gin.Context, key string) bool {
 | 
			
		||||
	return boolValue(c.Query(key))
 | 
			
		||||
	return utils.BoolValue(c.Query(key))
 | 
			
		||||
}
 | 
			
		||||
func (h *BaseHandler) PostBool(c *gin.Context, key string) bool {
 | 
			
		||||
	return boolValue(c.PostForm(key))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func boolValue(str string) bool {
 | 
			
		||||
	value, err := strconv.ParseBool(str)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	return value
 | 
			
		||||
	return utils.BoolValue(c.PostForm(key))
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -9,10 +9,12 @@ import (
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"github.com/golang-jwt/jwt/v5"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-contrib/sessions"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/lionsoul2014/ip2region/binding/golang/xdb"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
@@ -23,6 +25,7 @@ type UserHandler struct {
 | 
			
		||||
	db            *gorm.DB
 | 
			
		||||
	searcher      *xdb.Searcher
 | 
			
		||||
	leveldb       *store.LevelDB
 | 
			
		||||
	redis         *redis.Client
 | 
			
		||||
	uploadManager *oss.UploaderManager
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -31,8 +34,9 @@ func NewUserHandler(
 | 
			
		||||
	db *gorm.DB,
 | 
			
		||||
	searcher *xdb.Searcher,
 | 
			
		||||
	levelDB *store.LevelDB,
 | 
			
		||||
	client *redis.Client,
 | 
			
		||||
	manager *oss.UploaderManager) *UserHandler {
 | 
			
		||||
	handler := &UserHandler{db: db, searcher: searcher, leveldb: levelDB, uploadManager: manager}
 | 
			
		||||
	handler := &UserHandler{db: db, searcher: searcher, leveldb: levelDB, redis: client, uploadManager: manager}
 | 
			
		||||
	handler.App = app
 | 
			
		||||
	return handler
 | 
			
		||||
}
 | 
			
		||||
@@ -122,7 +126,7 @@ func (h *UserHandler) Register(c *gin.Context) {
 | 
			
		||||
// Login 用户登录
 | 
			
		||||
func (h *UserHandler) Login(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Username string `json:"mobile"`
 | 
			
		||||
		Mobile   string `json:"username"`
 | 
			
		||||
		Password string `json:"password"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
@@ -130,7 +134,7 @@ func (h *UserHandler) Login(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	var user model.User
 | 
			
		||||
	res := h.db.Where("mobile = ?", data.Username).First(&user)
 | 
			
		||||
	res := h.db.Where("mobile = ?", data.Mobile).First(&user)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "用户名不存在")
 | 
			
		||||
		return
 | 
			
		||||
@@ -152,13 +156,6 @@ func (h *UserHandler) Login(c *gin.Context) {
 | 
			
		||||
	user.LastLoginAt = time.Now().Unix()
 | 
			
		||||
	h.db.Model(&user).Updates(user)
 | 
			
		||||
 | 
			
		||||
	err := utils.SetLoginUser(c, user)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, "保存会话失败")
 | 
			
		||||
		logger.Error("Error for save session: ", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	h.db.Create(&model.UserLoginLog{
 | 
			
		||||
		UserId:       user.Id,
 | 
			
		||||
		Username:     user.Mobile,
 | 
			
		||||
@@ -166,17 +163,32 @@ func (h *UserHandler) Login(c *gin.Context) {
 | 
			
		||||
		LoginAddress: utils.Ip2Region(h.searcher, c.ClientIP()),
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
	// 创建 token
 | 
			
		||||
	expired := time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge))
 | 
			
		||||
	token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
 | 
			
		||||
		"user_id": user.Id,
 | 
			
		||||
		"expired": expired,
 | 
			
		||||
	})
 | 
			
		||||
	tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, "Failed to generate token, "+err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// 保存到 redis
 | 
			
		||||
	key := fmt.Sprintf("users/%d", user.Id)
 | 
			
		||||
	if _, err := h.redis.Set(c, key, tokenString, 0).Result(); err != nil {
 | 
			
		||||
		resp.ERROR(c, "error with save token: "+err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c, tokenString)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Logout 注 销
 | 
			
		||||
func (h *UserHandler) Logout(c *gin.Context) {
 | 
			
		||||
	sessionId := c.GetHeader(types.SessionName)
 | 
			
		||||
	session := sessions.Default(c)
 | 
			
		||||
	session.Delete(types.SessionUser)
 | 
			
		||||
	err := session.Save()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error("Error for save session: ", err)
 | 
			
		||||
	sessionId := c.GetHeader(types.ChatTokenHeader)
 | 
			
		||||
	token := c.GetHeader(types.UserAuthHeader)
 | 
			
		||||
	if _, err := h.redis.Del(c, token).Result(); err != nil {
 | 
			
		||||
		logger.Error("error with delete session: ", err)
 | 
			
		||||
	}
 | 
			
		||||
	// 删除 websocket 会话列表
 | 
			
		||||
	h.App.ChatSession.Delete(sessionId)
 | 
			
		||||
 
 | 
			
		||||
@@ -12,6 +12,7 @@ import (
 | 
			
		||||
	"chatplus/store"
 | 
			
		||||
	"context"
 | 
			
		||||
	"embed"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"io"
 | 
			
		||||
	"log"
 | 
			
		||||
	"os"
 | 
			
		||||
@@ -81,8 +82,8 @@ func main() {
 | 
			
		||||
		// 创建应用服务
 | 
			
		||||
		fx.Provide(core.NewServer),
 | 
			
		||||
		// 初始化
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer) {
 | 
			
		||||
			s.Init(debug)
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, client *redis.Client) {
 | 
			
		||||
			s.Init(debug, client)
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		// 初始化数据库
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/lionsoul2014/ip2region/binding/golang/xdb"
 | 
			
		||||
@@ -113,3 +114,27 @@ func IsEmptyValue(obj interface{}) bool {
 | 
			
		||||
		return reflect.DeepEqual(obj, reflect.Zero(reflect.TypeOf(obj)).Interface())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func BoolValue(str string) bool {
 | 
			
		||||
	value, err := strconv.ParseBool(str)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	return value
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func FloatValue(str string) float64 {
 | 
			
		||||
	value, err := strconv.ParseFloat(str, 64)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return 0
 | 
			
		||||
	}
 | 
			
		||||
	return value
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func IntValue(str string, defaultValue int) int {
 | 
			
		||||
	value, err := strconv.Atoi(str)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return defaultValue
 | 
			
		||||
	}
 | 
			
		||||
	return value
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -27,6 +27,10 @@ func HACKER(c *gin.Context) {
 | 
			
		||||
	c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: "Hacker attempt!!!"})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NotAuth(c *gin.Context) {
 | 
			
		||||
	c.JSON(http.StatusOK, types.BizVo{Code: types.NotAuthorized, Message: "Not Authorized"})
 | 
			
		||||
func NotAuth(c *gin.Context, messages ...string) {
 | 
			
		||||
	if messages != nil {
 | 
			
		||||
		c.JSON(http.StatusOK, types.BizVo{Code: types.NotAuthorized, Message: messages[0]})
 | 
			
		||||
	} else {
 | 
			
		||||
		c.JSON(http.StatusOK, types.BizVo{Code: types.NotAuthorized, Message: "Not Authorized"})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -5,33 +5,18 @@ import (
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"errors"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-contrib/sessions"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func SetLoginUser(c *gin.Context, user model.User) error {
 | 
			
		||||
	session := sessions.Default(c)
 | 
			
		||||
	session.Set(types.SessionUser, user.Id)
 | 
			
		||||
	// TODO: 后期用户数量增加,考虑将用户数据存储到 leveldb,避免每次查询数据库
 | 
			
		||||
	return session.Save()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SetLoginAdmin(c *gin.Context, admin types.Manager) error {
 | 
			
		||||
	session := sessions.Default(c)
 | 
			
		||||
	session.Set(types.SessionAdmin, admin.Username)
 | 
			
		||||
	return session.Save()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetLoginUser(c *gin.Context, db *gorm.DB) (model.User, error) {
 | 
			
		||||
	value, exists := c.Get(types.LoginUserCache)
 | 
			
		||||
	if exists {
 | 
			
		||||
		return value.(model.User), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	session := sessions.Default(c)
 | 
			
		||||
	userId := session.Get(types.SessionUser)
 | 
			
		||||
	if userId == nil {
 | 
			
		||||
	userId, ok := c.Get(types.LoginUserID)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return model.User{}, errors.New("user not login")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,8 @@
 | 
			
		||||
VUE_APP_API_HOST=http://localhost:5678
 | 
			
		||||
VUE_APP_WS_HOST=ws://localhost:5678
 | 
			
		||||
VUE_APP_USER=geekmaster
 | 
			
		||||
VUE_APP_USER=18575670125
 | 
			
		||||
VUE_APP_PASS=12345678
 | 
			
		||||
VUE_APP_ADMIN_USER=admin
 | 
			
		||||
VUE_APP_ADMIN_PASS=admin123
 | 
			
		||||
VUE_APP_KEY_PREFIX=ChatPLUS_
 | 
			
		||||
VUE_APP_TITLE="ChatGPT-PLUS V3"
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,4 @@
 | 
			
		||||
VUE_APP_API_HOST=
 | 
			
		||||
VUE_APP_WS_HOST=
 | 
			
		||||
VUE_APP_USER=
 | 
			
		||||
VUE_APP_PASS=
 | 
			
		||||
VUE_APP_ADMIN_USER=
 | 
			
		||||
VUE_APP_ADMIN_PASS=
 | 
			
		||||
VUE_APP_KEY_PREFIX=ChatPLUS_
 | 
			
		||||
VUE_APP_TITLE="ChatGPT-PLUS V3"
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,6 @@
 | 
			
		||||
import Storage from 'good-storage'
 | 
			
		||||
 | 
			
		||||
const CHAT_CONFIG_KEY = "chat_config"
 | 
			
		||||
const CHAT_CONFIG_KEY = process.env.VUE_APP_KEY_PREFIX + "chat_config"
 | 
			
		||||
 | 
			
		||||
export function getChatConfig() {
 | 
			
		||||
    return Storage.get(CHAT_CONFIG_KEY)
 | 
			
		||||
 
 | 
			
		||||
@@ -1,15 +1,16 @@
 | 
			
		||||
/* eslint-disable no-constant-condition */
 | 
			
		||||
 | 
			
		||||
import {randString} from "@/utils/libs";
 | 
			
		||||
import Storage from "good-storage";
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * storage handler
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
const SessionIDKey = 'SESSION_ID';
 | 
			
		||||
const SessionIDKey = process.env.VUE_APP_KEY_PREFIX + 'SESSION_ID';
 | 
			
		||||
const UserTokenKey = process.env.VUE_APP_KEY_PREFIX + "Authorization";
 | 
			
		||||
const AdminTokenKey = process.env.VUE_APP_KEY_PREFIX + "Admin-Authorization"
 | 
			
		||||
 | 
			
		||||
export function getSessionId() {
 | 
			
		||||
    let sessionId = sessionStorage.getItem(SessionIDKey)
 | 
			
		||||
    let sessionId = Storage.get(SessionIDKey)
 | 
			
		||||
    if (!sessionId) {
 | 
			
		||||
        sessionId = randString(42)
 | 
			
		||||
        setSessionId(sessionId)
 | 
			
		||||
@@ -17,10 +18,34 @@ export function getSessionId() {
 | 
			
		||||
    return sessionId
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export function removeLoginUser() {
 | 
			
		||||
    sessionStorage.removeItem(SessionIDKey)
 | 
			
		||||
export function removeSessionId() {
 | 
			
		||||
    Storage.remove(SessionIDKey)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export function setSessionId(sessionId) {
 | 
			
		||||
    sessionStorage.setItem(SessionIDKey, sessionId)
 | 
			
		||||
    Storage.set(SessionIDKey, sessionId)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export function getUserToken() {
 | 
			
		||||
    return Storage.get(UserTokenKey)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export function setUserToken(token) {
 | 
			
		||||
    Storage.set(UserTokenKey, token)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export function removeUserToken() {
 | 
			
		||||
    Storage.remove(UserTokenKey)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export function getAdminToken() {
 | 
			
		||||
    return Storage.get(AdminTokenKey)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export function setAdminToken(token) {
 | 
			
		||||
    Storage.set(AdminTokenKey, token)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export function removeAdminToken() {
 | 
			
		||||
    Storage.remove(AdminTokenKey)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,6 @@
 | 
			
		||||
import Storage from "good-storage";
 | 
			
		||||
 | 
			
		||||
const MOBILE_THEME = "MOBILE_THEME"
 | 
			
		||||
const MOBILE_THEME = process.env.VUE_APP_KEY_PREFIX + "MOBILE_THEME"
 | 
			
		||||
 | 
			
		||||
export function getMobileTheme() {
 | 
			
		||||
    return Storage.get(MOBILE_THEME) ? Storage.get(MOBILE_THEME) : 'light'
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,5 @@
 | 
			
		||||
import axios from 'axios'
 | 
			
		||||
import {getSessionId} from "@/store/session";
 | 
			
		||||
import {getAdminToken, getSessionId, getUserToken} from "@/store/session";
 | 
			
		||||
 | 
			
		||||
axios.defaults.timeout = 10000
 | 
			
		||||
axios.defaults.baseURL = process.env.VUE_APP_API_HOST
 | 
			
		||||
@@ -11,6 +11,8 @@ axios.interceptors.request.use(
 | 
			
		||||
    config => {
 | 
			
		||||
        // set token
 | 
			
		||||
        config.headers['Chat-Token'] = getSessionId();
 | 
			
		||||
        config.headers['Authorization'] = getUserToken();
 | 
			
		||||
        config.headers['Admin-Authorization'] = getAdminToken();
 | 
			
		||||
        return config
 | 
			
		||||
    }, error => {
 | 
			
		||||
        return Promise.reject(error)
 | 
			
		||||
 
 | 
			
		||||
@@ -271,7 +271,7 @@ import 'highlight.js/styles/a11y-dark.css'
 | 
			
		||||
import {dateFormat, isMobile, randString, removeArrayItem, renderInputText, UUID} from "@/utils/libs";
 | 
			
		||||
import {ElMessage, ElMessageBox} from "element-plus";
 | 
			
		||||
import hl from "highlight.js";
 | 
			
		||||
import {getSessionId, removeLoginUser} from "@/store/session";
 | 
			
		||||
import {getSessionId, getUserToken, removeUserToken} from "@/store/session";
 | 
			
		||||
import {httpGet, httpPost} from "@/utils/http";
 | 
			
		||||
import {useRouter} from "vue-router";
 | 
			
		||||
import Clipboard from "clipboard";
 | 
			
		||||
@@ -319,9 +319,6 @@ onMounted(() => {
 | 
			
		||||
  checkSession().then((user) => {
 | 
			
		||||
    loginUser.value = user
 | 
			
		||||
    isLogin.value = true
 | 
			
		||||
    if (user.chat_config?.model !== '') {
 | 
			
		||||
      modelID.value = user.chat_config.model
 | 
			
		||||
    }
 | 
			
		||||
    // 加载角色列表
 | 
			
		||||
    httpGet(`/api/role/list?user_id=${user.id}`).then((res) => {
 | 
			
		||||
      roles.value = res.data;
 | 
			
		||||
@@ -400,7 +397,7 @@ const newChat = function () {
 | 
			
		||||
    chat_id: "",
 | 
			
		||||
    icon: icon,
 | 
			
		||||
    role_id: roleId.value,
 | 
			
		||||
    model: modelID.value,
 | 
			
		||||
    model_id: modelID.value,
 | 
			
		||||
    title: '',
 | 
			
		||||
    edit: false,
 | 
			
		||||
    removing: false,
 | 
			
		||||
@@ -419,7 +416,7 @@ const changeChat = function (chat) {
 | 
			
		||||
  activeChat.value = chat
 | 
			
		||||
  newChatItem.value = null;
 | 
			
		||||
  roleId.value = chat.role_id;
 | 
			
		||||
  modelID.value = chat.model;
 | 
			
		||||
  modelID.value = chat.model_id;
 | 
			
		||||
  showStopGenerate.value = false;
 | 
			
		||||
  showReGenerate.value = false;
 | 
			
		||||
  connect(chat.chat_id, chat.role_id)
 | 
			
		||||
@@ -510,7 +507,7 @@ const connect = function (chat_id, role_id) {
 | 
			
		||||
      host = 'ws://' + location.host;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  const _socket = new WebSocket(host + `/api/chat/new?session_id=${_sessionId}&role_id=${role_id}&chat_id=${chat_id}&model_id=${modelID.value}`);
 | 
			
		||||
  const _socket = new WebSocket(host + `/api/chat/new?session_id=${_sessionId}&role_id=${role_id}&chat_id=${chat_id}&model_id=${modelID.value}&token=${getUserToken()}`);
 | 
			
		||||
  _socket.addEventListener('open', () => {
 | 
			
		||||
    chatData.value = []; // 初始化聊天数据
 | 
			
		||||
    previousText.value = '';
 | 
			
		||||
@@ -740,7 +737,7 @@ const clearAllChats = function () {
 | 
			
		||||
const logout = function () {
 | 
			
		||||
  activelyClose.value = true;
 | 
			
		||||
  httpGet('/api/user/logout').then(() => {
 | 
			
		||||
    removeLoginUser();
 | 
			
		||||
    removeUserToken();
 | 
			
		||||
    router.push('login');
 | 
			
		||||
  }).catch(() => {
 | 
			
		||||
    ElMessage.error('注销失败!');
 | 
			
		||||
 
 | 
			
		||||
@@ -56,6 +56,7 @@ import {useRouter} from "vue-router";
 | 
			
		||||
import FooterBar from "@/components/FooterBar.vue";
 | 
			
		||||
import {isMobile} from "@/utils/libs";
 | 
			
		||||
import {checkSession} from "@/action/session";
 | 
			
		||||
import {setUserToken} from "@/store/session";
 | 
			
		||||
 | 
			
		||||
const router = useRouter();
 | 
			
		||||
const title = ref('ChatGPT-PLUS 用户登录');
 | 
			
		||||
@@ -87,7 +88,8 @@ const login = function () {
 | 
			
		||||
    return ElMessage.error('请输入密码');
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  httpPost('/api/user/login', {username: username.value.trim(), password: password.value.trim()}).then(() => {
 | 
			
		||||
  httpPost('/api/user/login', {username: username.value.trim(), password: password.value.trim()}).then((res) => {
 | 
			
		||||
    setUserToken(res.data)
 | 
			
		||||
    if (isMobile()) {
 | 
			
		||||
      router.push('/mobile')
 | 
			
		||||
    } else {
 | 
			
		||||
 
 | 
			
		||||
@@ -46,6 +46,7 @@ import {httpPost} from "@/utils/http";
 | 
			
		||||
import {ElMessage} from "element-plus";
 | 
			
		||||
import {useRouter} from "vue-router";
 | 
			
		||||
import FooterBar from "@/components/FooterBar.vue";
 | 
			
		||||
import {setAdminToken} from "@/store/session";
 | 
			
		||||
 | 
			
		||||
const router = useRouter();
 | 
			
		||||
const title = ref('ChatGPT Plus Admin');
 | 
			
		||||
@@ -68,7 +69,8 @@ const login = function () {
 | 
			
		||||
    return ElMessage.error('请输入密码');
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  httpPost('/api/admin/login', {username: username.value.trim(), password: password.value.trim()}).then((res) => {
 | 
			
		||||
  httpPost('/api/admin/login', {username: username.value.trim(), password: password.value.trim()}).then(res => {
 | 
			
		||||
    setAdminToken(res.data)
 | 
			
		||||
    router.push("/admin")
 | 
			
		||||
  }).catch((e) => {
 | 
			
		||||
    ElMessage.error('登录失败,' + e.message)
 | 
			
		||||
 
 | 
			
		||||
@@ -200,7 +200,7 @@ const connect = function (chat_id, role_id) {
 | 
			
		||||
      host = 'ws://' + location.host;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  const _socket = new WebSocket(host + `/api/chat/new?session_id=${_sessionId}&role_id=${role_id}&chat_id=${chat_id}&model_id=${model}`);
 | 
			
		||||
  const _socket = new WebSocket(host + `/api/chat/new?session_id=${_sessionId}&role_id=${role_id}&chat_id=${chat_id}&model_id=${model}&token=${getUserToken()}`);
 | 
			
		||||
  _socket.addEventListener('open', () => {
 | 
			
		||||
    loading.value = false
 | 
			
		||||
    previousText.value = '';
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user