mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-06 09:13:47 +08:00
refactor: 调整项目目录结构,移除其他语言 API 目录
This commit is contained in:
198
api/core/app_server.go
Normal file
198
api/core/app_server.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"chatplus/core/types"
|
||||
"chatplus/store/model"
|
||||
"chatplus/utils"
|
||||
"chatplus/utils/resp"
|
||||
"context"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-contrib/sessions/cookie"
|
||||
"github.com/gin-contrib/sessions/memstore"
|
||||
"github.com/gin-contrib/sessions/redis"
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
"io"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type AppServer struct {
|
||||
Debug bool
|
||||
AppConfig *types.AppConfig
|
||||
Engine *gin.Engine
|
||||
ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message
|
||||
ChatConfig *types.ChatConfig // 聊天配置
|
||||
|
||||
// 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次
|
||||
// 防止第三方直接连接 socket 调用 OpenAI API
|
||||
ChatSession *types.LMap[string, types.ChatSession] //map[sessionId]UserId
|
||||
ChatClients *types.LMap[string, *types.WsClient] // Websocket 连接集合
|
||||
ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
|
||||
}
|
||||
|
||||
func NewServer(appConfig *types.AppConfig) *AppServer {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
gin.DefaultWriter = io.Discard
|
||||
return &AppServer{
|
||||
Debug: false,
|
||||
AppConfig: appConfig,
|
||||
Engine: gin.Default(),
|
||||
ChatContexts: types.NewLMap[string, []types.Message](),
|
||||
ChatSession: types.NewLMap[string, types.ChatSession](),
|
||||
ChatClients: types.NewLMap[string, *types.WsClient](),
|
||||
ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AppServer) Init(debug bool) {
|
||||
if debug { // 调试模式允许跨域请求 API
|
||||
s.Debug = debug
|
||||
logger.Info("Enabled debug mode")
|
||||
s.Engine.Use(corsMiddleware())
|
||||
}
|
||||
|
||||
s.Engine.Use(sessionMiddleware(s.AppConfig))
|
||||
s.Engine.Use(authorizeMiddleware(s))
|
||||
s.Engine.Use(errorHandler)
|
||||
// 添加静态资源访问
|
||||
s.Engine.Static("/static", s.AppConfig.StaticDir)
|
||||
}
|
||||
|
||||
func (s *AppServer) Run(db *gorm.DB) error {
|
||||
// load chat config from database
|
||||
var config model.Config
|
||||
res := db.Where("marker", "chat").First(&config)
|
||||
if res.Error != nil {
|
||||
return res.Error
|
||||
}
|
||||
err := utils.JsonDecode(config.Config, &s.ChatConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
logger.Infof("http://%s", s.AppConfig.Listen)
|
||||
return s.Engine.Run(s.AppConfig.Listen)
|
||||
}
|
||||
|
||||
// 全局异常处理
|
||||
func errorHandler(c *gin.Context) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logger.Error("Handler Panic: %v\n", r)
|
||||
debug.PrintStack()
|
||||
c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: types.ErrorMsg})
|
||||
c.Abort()
|
||||
}
|
||||
}()
|
||||
//加载完 defer recover,继续后续接口调用
|
||||
c.Next()
|
||||
}
|
||||
|
||||
// 会话处理
|
||||
func sessionMiddleware(config *types.AppConfig) gin.HandlerFunc {
|
||||
// encrypt the cookie
|
||||
var store sessions.Store
|
||||
var err error
|
||||
switch config.Session.Driver {
|
||||
case types.SessionDriverMem:
|
||||
store = memstore.NewStore([]byte(config.Session.SecretKey))
|
||||
break
|
||||
case types.SessionDriverRedis:
|
||||
store, err = redis.NewStore(10, "tcp", config.Redis.Url(), config.Redis.Password, []byte(config.Session.SecretKey))
|
||||
if err != nil {
|
||||
logger.Fatal(err)
|
||||
}
|
||||
break
|
||||
case types.SessionDriverCookie:
|
||||
store = cookie.NewStore([]byte(config.Session.SecretKey))
|
||||
break
|
||||
default:
|
||||
config.Session.Driver = types.SessionDriverCookie
|
||||
store = cookie.NewStore([]byte(config.Session.SecretKey))
|
||||
}
|
||||
|
||||
logger.Info("Session driver: ", config.Session.Driver)
|
||||
|
||||
store.Options(sessions.Options{
|
||||
Path: config.Session.Path,
|
||||
Domain: config.Session.Domain,
|
||||
MaxAge: config.Session.MaxAge,
|
||||
Secure: config.Session.Secure,
|
||||
HttpOnly: config.Session.HttpOnly,
|
||||
SameSite: config.Session.SameSite,
|
||||
})
|
||||
return sessions.Sessions(config.Session.Name, store)
|
||||
}
|
||||
|
||||
// 跨域中间件设置
|
||||
func corsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
method := c.Request.Method
|
||||
origin := c.Request.Header.Get("Origin")
|
||||
if origin != "" {
|
||||
// 设置允许的请求源
|
||||
c.Header("Access-Control-Allow-Origin", origin)
|
||||
c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE")
|
||||
//允许跨域设置可以返回其他子段,可以自定义字段
|
||||
c.Header("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, ChatGPT-TOKEN, ADMIN-SESSION-TOKEN")
|
||||
// 允许浏览器(客户端)可以解析的头部 (重要)
|
||||
c.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers")
|
||||
//设置缓存时间
|
||||
c.Header("Access-Control-Max-Age", "172800")
|
||||
//允许客户端传递校验信息比如 cookie (重要)
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
|
||||
if method == http.MethodOptions {
|
||||
c.JSON(http.StatusOK, "ok!")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
logger.Info("Panic info is: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// 用户授权验证
|
||||
func authorizeMiddleware(s *AppServer) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.Request.URL.Path == "/api/user/login" ||
|
||||
c.Request.URL.Path == "/api/admin/login" ||
|
||||
c.Request.URL.Path == "/api/user/register" ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/static/") ||
|
||||
c.Request.URL.Path == "/api/admin/config/get" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// WebSocket 连接请求验证
|
||||
if c.Request.URL.Path == "/api/chat" {
|
||||
sessionId := c.Query("sessionId")
|
||||
session := s.ChatSession.Get(sessionId)
|
||||
if session.ClientIP == c.ClientIP() {
|
||||
c.Next()
|
||||
} else {
|
||||
c.Abort()
|
||||
}
|
||||
return
|
||||
}
|
||||
session := sessions.Default(c)
|
||||
var value interface{}
|
||||
if strings.Contains(c.Request.URL.Path, "/api/admin/") {
|
||||
value = session.Get(types.SessionAdmin)
|
||||
} else {
|
||||
value = session.Get(types.SessionUser)
|
||||
}
|
||||
if value != nil {
|
||||
c.Next()
|
||||
} else {
|
||||
resp.NotAuth(c)
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
}
|
||||
70
api/core/config.go
Normal file
70
api/core/config.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"chatplus/core/types"
|
||||
logger2 "chatplus/logger"
|
||||
"chatplus/utils"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/BurntSushi/toml"
|
||||
)
|
||||
|
||||
var logger = logger2.GetLogger()
|
||||
|
||||
func NewDefaultConfig() *types.AppConfig {
|
||||
return &types.AppConfig{
|
||||
Listen: "0.0.0.0:5678",
|
||||
ProxyURL: "",
|
||||
Manager: types.Manager{Username: "admin", Password: "admin123"},
|
||||
StaticDir: "./static",
|
||||
StaticUrl: "http://localhost/5678/static",
|
||||
Redis: types.RedisConfig{Host: "localhost", Port: 6379, Password: ""},
|
||||
|
||||
Session: types.Session{
|
||||
Driver: types.SessionDriverCookie,
|
||||
SecretKey: utils.RandString(64),
|
||||
Name: "CHAT_PLUS_SESSION",
|
||||
Domain: "",
|
||||
Path: "/",
|
||||
MaxAge: 86400,
|
||||
Secure: true,
|
||||
HttpOnly: false,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func LoadConfig(configFile string) (*types.AppConfig, error) {
|
||||
var config *types.AppConfig
|
||||
_, err := os.Stat(configFile)
|
||||
if err != nil {
|
||||
logger.Info("creating new config file: ", configFile)
|
||||
config = NewDefaultConfig()
|
||||
config.Path = configFile
|
||||
// save config
|
||||
err := SaveConfig(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
_, err = toml.DecodeFile(configFile, &config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return config, err
|
||||
}
|
||||
|
||||
func SaveConfig(config *types.AppConfig) error {
|
||||
buf := new(bytes.Buffer)
|
||||
encoder := toml.NewEncoder(buf)
|
||||
if err := encoder.Encode(&config); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(config.Path, buf.Bytes(), 0644)
|
||||
}
|
||||
47
api/core/types/chat.go
Normal file
47
api/core/types/chat.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package types
|
||||
|
||||
// ApiRequest API 请求实体
|
||||
type ApiRequest struct {
|
||||
Model string `json:"model"`
|
||||
Temperature float32 `json:"temperature"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
Stream bool `json:"stream"`
|
||||
Messages []Message `json:"messages"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ApiResponse struct {
|
||||
Choices []ChoiceItem `json:"choices"`
|
||||
}
|
||||
|
||||
// ChoiceItem API 响应实体
|
||||
type ChoiceItem struct {
|
||||
Delta Message `json:"delta"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
}
|
||||
|
||||
// ChatSession 聊天会话对象
|
||||
type ChatSession struct {
|
||||
SessionId string `json:"session_id"`
|
||||
ClientIP string `json:"client_ip"` // 客户端 IP
|
||||
Username string `json:"username"` // 当前登录的 username
|
||||
UserId uint `json:"user_id"` // 当前登录的 user ID
|
||||
ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段
|
||||
Model string `json:"model"` // GPT 模型
|
||||
}
|
||||
|
||||
type ApiError struct {
|
||||
Error struct {
|
||||
Message string
|
||||
Type string
|
||||
Param interface{}
|
||||
Code string
|
||||
}
|
||||
}
|
||||
|
||||
const PromptMsg = "prompt" // prompt message
|
||||
const ReplyMsg = "reply" // reply message
|
||||
61
api/core/types/client.go
Normal file
61
api/core/types/client.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/gorilla/websocket"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var ErrConClosed = errors.New("connection closed")
|
||||
|
||||
type Client interface {
|
||||
Close()
|
||||
}
|
||||
|
||||
// WsClient websocket client
|
||||
type WsClient struct {
|
||||
Conn *websocket.Conn
|
||||
lock sync.Mutex
|
||||
mt int
|
||||
closed bool
|
||||
}
|
||||
|
||||
func NewWsClient(conn *websocket.Conn) *WsClient {
|
||||
return &WsClient{
|
||||
Conn: conn,
|
||||
lock: sync.Mutex{},
|
||||
mt: 2, // fixed bug for 'Invalid UTF-8 in text frame'
|
||||
closed: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (wc *WsClient) Send(message []byte) error {
|
||||
wc.lock.Lock()
|
||||
defer wc.lock.Unlock()
|
||||
|
||||
if wc.closed {
|
||||
return ErrConClosed
|
||||
}
|
||||
|
||||
return wc.Conn.WriteMessage(wc.mt, message)
|
||||
}
|
||||
|
||||
func (wc *WsClient) Receive() (int, []byte, error) {
|
||||
if wc.closed {
|
||||
return 0, nil, ErrConClosed
|
||||
}
|
||||
|
||||
return wc.Conn.ReadMessage()
|
||||
}
|
||||
|
||||
func (wc *WsClient) Close() {
|
||||
wc.lock.Lock()
|
||||
defer wc.lock.Unlock()
|
||||
|
||||
if wc.closed {
|
||||
return
|
||||
}
|
||||
|
||||
_ = wc.Conn.Close()
|
||||
wc.closed = true
|
||||
}
|
||||
75
api/core/types/config.go
Normal file
75
api/core/types/config.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type AppConfig struct {
|
||||
Path string `toml:"-"`
|
||||
Listen string
|
||||
Session Session
|
||||
ProxyURL string
|
||||
MysqlDns string // mysql 连接地址
|
||||
Manager Manager // 后台管理员账户信息
|
||||
StaticDir string // 静态资源目录
|
||||
StaticUrl string // 静态资源 URL
|
||||
Redis RedisConfig // redis 连接信息
|
||||
}
|
||||
|
||||
type RedisConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Password string
|
||||
}
|
||||
|
||||
func (c RedisConfig) Url() string {
|
||||
return fmt.Sprintf("%s:%d", c.Host, c.Port)
|
||||
}
|
||||
|
||||
// Manager 管理员
|
||||
type Manager struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type SessionDriver string
|
||||
|
||||
const (
|
||||
SessionDriverMem = SessionDriver("mem")
|
||||
SessionDriverRedis = SessionDriver("redis")
|
||||
SessionDriverCookie = SessionDriver("cookie")
|
||||
)
|
||||
|
||||
// Session configs struct
|
||||
type Session struct {
|
||||
Driver SessionDriver // session 存储驱动 mem|cookie|redis
|
||||
SecretKey string // session encryption key
|
||||
Name string
|
||||
Path string
|
||||
Domain string
|
||||
MaxAge int
|
||||
Secure bool
|
||||
HttpOnly bool
|
||||
SameSite http.SameSite
|
||||
}
|
||||
|
||||
// ChatConfig 系统默认的聊天配置
|
||||
type ChatConfig struct {
|
||||
ApiURL string `json:"api_url,omitempty"`
|
||||
Model string `json:"model"` // 默认模型
|
||||
Temperature float32 `json:"temperature"`
|
||||
MaxTokens int `json:"max_tokens"`
|
||||
EnableContext bool `json:"enable_context"` // 是否开启聊天上下文
|
||||
EnableHistory bool `json:"enable_history"` // 是否允许保存聊天记录
|
||||
ApiKey string `json:"api_key"` // OpenAI API key
|
||||
}
|
||||
|
||||
type SystemConfig struct {
|
||||
Title string `json:"title"`
|
||||
AdminTitle string `json:"admin_title"`
|
||||
Models []string `json:"models"`
|
||||
UserInitCalls int `json:"user_init_calls"` // 新用户注册默认总送多少次调用
|
||||
}
|
||||
|
||||
const UserInitCalls = 1000
|
||||
63
api/core/types/locked_map.go
Normal file
63
api/core/types/locked_map.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type MKey interface {
|
||||
string | int
|
||||
}
|
||||
type MValue interface {
|
||||
*WsClient | ChatSession | []Message | context.CancelFunc
|
||||
}
|
||||
type LMap[K MKey, T MValue] struct {
|
||||
lock sync.RWMutex
|
||||
data map[K]T
|
||||
}
|
||||
|
||||
func NewLMap[K MKey, T MValue]() *LMap[K, T] {
|
||||
return &LMap[K, T]{
|
||||
lock: sync.RWMutex{},
|
||||
data: make(map[K]T),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *LMap[K, T]) Put(key K, value T) {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
m.data[key] = value
|
||||
}
|
||||
|
||||
func (m *LMap[K, T]) Get(key K) T {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
|
||||
return m.data[key]
|
||||
}
|
||||
|
||||
func (m *LMap[K, T]) Has(key K) bool {
|
||||
m.lock.RLock()
|
||||
defer m.lock.RUnlock()
|
||||
_, ok := m.data[key]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (m *LMap[K, T]) Delete(key K) {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
delete(m.data, key)
|
||||
}
|
||||
|
||||
func (m *LMap[K, T]) ToList() []T {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
var s = make([]T, 0)
|
||||
for _, v := range m.data {
|
||||
s = append(s, v)
|
||||
}
|
||||
return s
|
||||
}
|
||||
6
api/core/types/session.go
Normal file
6
api/core/types/session.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package types
|
||||
|
||||
const SessionName = "ChatGPT-TOKEN"
|
||||
const SessionUser = "SESSION_USER" // 存储用户信息的 session key
|
||||
const SessionAdmin = "SESSION_ADMIN" //存储管理员信息的 session key
|
||||
const LoginUserCache = "LOGIN_USER_CACHE" // 已登录用户缓存
|
||||
36
api/core/types/web.go
Normal file
36
api/core/types/web.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package types
|
||||
|
||||
// BizVo 业务返回 VO
|
||||
type BizVo struct {
|
||||
Code BizCode `json:"code"`
|
||||
Page int `json:"page,omitempty"`
|
||||
PageSize int `json:"page_size,omitempty"`
|
||||
Total int `json:"total,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// WsMessage Websocket message
|
||||
type WsMessage struct {
|
||||
Type WsMsgType `json:"type"` // 消息类别,start, end
|
||||
Content string `json:"content"`
|
||||
}
|
||||
type WsMsgType string
|
||||
|
||||
const (
|
||||
WsStart = WsMsgType("start")
|
||||
WsMiddle = WsMsgType("middle")
|
||||
WsEnd = WsMsgType("end")
|
||||
)
|
||||
|
||||
type BizCode int
|
||||
|
||||
const (
|
||||
Success = BizCode(0)
|
||||
Failed = BizCode(1)
|
||||
NotAuthorized = BizCode(400) // 未授权
|
||||
|
||||
OkMsg = "Success"
|
||||
ErrorMsg = "系统开小差了"
|
||||
InvalidArgs = "非法参数或参数解析失败"
|
||||
)
|
||||
Reference in New Issue
Block a user