mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 16:23:42 +08:00 
			
		
		
		
	优化聊天会话管理,支持 websocket 断开重连之后能继续连接会话上下文
This commit is contained in:
		
							
								
								
									
										6
									
								
								main.go
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								main.go
									
									
									
									
									
								
							@@ -16,6 +16,7 @@ var logger = logger2.GetLogger()
 | 
			
		||||
//go:embed dist
 | 
			
		||||
var webRoot embed.FS
 | 
			
		||||
var configFile string
 | 
			
		||||
var debugMode bool
 | 
			
		||||
 | 
			
		||||
func main() {
 | 
			
		||||
	defer func() {
 | 
			
		||||
@@ -49,12 +50,13 @@ func main() {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(err)
 | 
			
		||||
	}
 | 
			
		||||
	s.Run(webRoot, "dist")
 | 
			
		||||
	s.Run(webRoot, "dist", debugMode)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
 | 
			
		||||
	flag.StringVar(&configFile, "config", "", "Config file path (default: ~/.config/chat-gpt/config.toml)")
 | 
			
		||||
	flag.BoolVar(&debugMode, "debug", true, "Enable debug mode (default: true, recommend to set false in production env)")
 | 
			
		||||
	flag.Usage = usage
 | 
			
		||||
	flag.Parse()
 | 
			
		||||
}
 | 
			
		||||
@@ -67,7 +69,7 @@ OPTIONS:
 | 
			
		||||
`, os.Args[0])
 | 
			
		||||
 | 
			
		||||
	flagSet := flag.CommandLine
 | 
			
		||||
	order := []string{"config"}
 | 
			
		||||
	order := []string{"config", "debug"}
 | 
			
		||||
	for _, name := range order {
 | 
			
		||||
		f := flagSet.Lookup(name)
 | 
			
		||||
		fmt.Printf("  --%s => %s\n", f.Name, f.Usage)
 | 
			
		||||
 
 | 
			
		||||
@@ -10,6 +10,7 @@ import (
 | 
			
		||||
	"io"
 | 
			
		||||
	"math/rand"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"openai/types"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
@@ -68,6 +69,19 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 创建 HttpClient 请求对象
 | 
			
		||||
	var client *http.Client
 | 
			
		||||
	if s.Config.ProxyURL == "" {
 | 
			
		||||
		client = &http.Client{}
 | 
			
		||||
	} else { // 使用代理
 | 
			
		||||
		uri := url.URL{}
 | 
			
		||||
		proxy, _ := uri.Parse(s.Config.ProxyURL)
 | 
			
		||||
		client = &http.Client{
 | 
			
		||||
			Transport: &http.Transport{
 | 
			
		||||
				Proxy: http.ProxyURL(proxy),
 | 
			
		||||
			},
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	request, err := http.NewRequest(http.MethodPost, s.Config.Chat.ApiURL, bytes.NewBuffer(requestBody))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
@@ -86,7 +100,7 @@ func (s *Server) sendMessage(userId string, text string, ws Client) error {
 | 
			
		||||
		}
 | 
			
		||||
		logger.Infof("Use API KEY: %s", apiKey)
 | 
			
		||||
		request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey))
 | 
			
		||||
		response, err = s.Client.Do(request)
 | 
			
		||||
		response, err = client.Do(request)
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			break
 | 
			
		||||
		} else {
 | 
			
		||||
 
 | 
			
		||||
@@ -11,12 +11,6 @@ import (
 | 
			
		||||
 | 
			
		||||
// ConfigSetHandle set configs
 | 
			
		||||
func (s *Server) ConfigSetHandle(c *gin.Context) {
 | 
			
		||||
	token := c.Query("token")
 | 
			
		||||
	if token != "RockYang" {
 | 
			
		||||
		c.JSON(http.StatusOK, types.BizVo{Code: types.Failed, Message: types.ErrorMsg})
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var data map[string]string
 | 
			
		||||
	err := json.NewDecoder(c.Request.Body).Decode(&data)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -24,10 +18,6 @@ func (s *Server) ConfigSetHandle(c *gin.Context) {
 | 
			
		||||
		c.JSON(http.StatusBadRequest, nil)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// API key
 | 
			
		||||
	if key, ok := data["api_key"]; ok && len(key) > 20 {
 | 
			
		||||
		s.Config.Chat.ApiKeys = append(s.Config.Chat.ApiKeys, key)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// proxy URL
 | 
			
		||||
	if proxy, ok := data["proxy"]; ok {
 | 
			
		||||
@@ -91,12 +81,6 @@ func (s *Server) ConfigSetHandle(c *gin.Context) {
 | 
			
		||||
		s.Config.EnableAuth = v
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if token, ok := data["token"]; ok {
 | 
			
		||||
		if !utils.ContainsItem(s.Config.Tokens, token) {
 | 
			
		||||
			s.Config.Tokens = append(s.Config.Tokens, token)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 保存配置文件
 | 
			
		||||
	err = types.SaveConfig(s.Config, s.ConfigPath)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -106,3 +90,62 @@ func (s *Server) ConfigSetHandle(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
	c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) AddToken(c *gin.Context) {
 | 
			
		||||
	var data map[string]string
 | 
			
		||||
	err := json.NewDecoder(c.Request.Body).Decode(&data)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Errorf("Error decode json data: %s", err.Error())
 | 
			
		||||
		c.JSON(http.StatusBadRequest, nil)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if token, ok := data["token"]; ok {
 | 
			
		||||
		if !utils.ContainsItem(s.Config.Tokens, token) {
 | 
			
		||||
			s.Config.Tokens = append(s.Config.Tokens, token)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Tokens})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) RemoveToken(c *gin.Context) {
 | 
			
		||||
	var data map[string]string
 | 
			
		||||
	err := json.NewDecoder(c.Request.Body).Decode(&data)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Errorf("Error decode json data: %s", err.Error())
 | 
			
		||||
		c.JSON(http.StatusBadRequest, nil)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if token, ok := data["token"]; ok {
 | 
			
		||||
		for i, v := range s.Config.Tokens {
 | 
			
		||||
			if v == token {
 | 
			
		||||
				s.Config.Tokens = append(s.Config.Tokens[:i], s.Config.Tokens[i+1:]...)
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Tokens})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) AddApiKey(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
	var data map[string]string
 | 
			
		||||
	err := json.NewDecoder(c.Request.Body).Decode(&data)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Errorf("Error decode json data: %s", err.Error())
 | 
			
		||||
		c.JSON(http.StatusBadRequest, nil)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if key, ok := data["api_key"]; ok && len(key) > 20 {
 | 
			
		||||
		s.Config.Chat.ApiKeys = append(s.Config.Chat.ApiKeys, key)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Chat.ApiKeys})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) ListApiKeys(c *gin.Context) {
 | 
			
		||||
	c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Message: types.OkMsg, Data: s.Config.Chat.ApiKeys})
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -9,7 +9,6 @@ import (
 | 
			
		||||
	"io/fs"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	logger2 "openai/logger"
 | 
			
		||||
	"openai/types"
 | 
			
		||||
	"openai/utils"
 | 
			
		||||
@@ -34,7 +33,6 @@ func (s StaticFile) Open(name string) (fs.File, error) {
 | 
			
		||||
type Server struct {
 | 
			
		||||
	Config     *types.Config
 | 
			
		||||
	ConfigPath string
 | 
			
		||||
	Client     *http.Client
 | 
			
		||||
	History    map[string][]types.Message
 | 
			
		||||
 | 
			
		||||
	// 保存 Websocket 会话 Token, 每个 Token 只能连接一次
 | 
			
		||||
@@ -50,16 +48,8 @@ func NewServer(configPath string) (*Server, error) {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	uri := url.URL{}
 | 
			
		||||
	proxy, _ := uri.Parse(config.ProxyURL)
 | 
			
		||||
	client := &http.Client{
 | 
			
		||||
		Transport: &http.Transport{
 | 
			
		||||
			Proxy: http.ProxyURL(proxy),
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	return &Server{
 | 
			
		||||
		Config:           config,
 | 
			
		||||
		Client:           client,
 | 
			
		||||
		ConfigPath:       configPath,
 | 
			
		||||
		History:          make(map[string][]types.Message, 16),
 | 
			
		||||
		WsSession:        make(map[string]string),
 | 
			
		||||
@@ -67,11 +57,13 @@ func NewServer(configPath string) (*Server, error) {
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) Run(webRoot embed.FS, path string) {
 | 
			
		||||
func (s *Server) Run(webRoot embed.FS, path string, debug bool) {
 | 
			
		||||
	gin.SetMode(gin.ReleaseMode)
 | 
			
		||||
	engine := gin.Default()
 | 
			
		||||
	engine.Use(sessionMiddleware(s.Config))
 | 
			
		||||
	if debug {
 | 
			
		||||
		engine.Use(corsMiddleware())
 | 
			
		||||
	}
 | 
			
		||||
	engine.Use(sessionMiddleware(s.Config))
 | 
			
		||||
	engine.Use(AuthorizeMiddleware(s))
 | 
			
		||||
 | 
			
		||||
	engine.GET("/hello", Hello)
 | 
			
		||||
@@ -79,6 +71,10 @@ func (s *Server) Run(webRoot embed.FS, path string) {
 | 
			
		||||
	engine.POST("/api/login", s.LoginHandle)
 | 
			
		||||
	engine.Any("/api/chat", s.ChatHandle)
 | 
			
		||||
	engine.POST("/api/config/set", s.ConfigSetHandle)
 | 
			
		||||
	engine.POST("api/config/token/add", s.AddToken)
 | 
			
		||||
	engine.POST("api/config/token/remove", s.RemoveToken)
 | 
			
		||||
	engine.POST("api/config/apikey/add", s.AddApiKey)
 | 
			
		||||
	engine.POST("api/config/apikey/list", s.ListApiKeys)
 | 
			
		||||
 | 
			
		||||
	engine.NoRoute(func(c *gin.Context) {
 | 
			
		||||
		if c.Request.URL.Path == "/favicon.ico" {
 | 
			
		||||
@@ -123,7 +119,7 @@ func corsMiddleware() gin.HandlerFunc {
 | 
			
		||||
		origin := c.Request.Header.Get("Origin")
 | 
			
		||||
		if origin != "" {
 | 
			
		||||
			// 设置允许的请求源
 | 
			
		||||
			c.Writer.Header().Set("Access-Control-Allow-Origin", origin)
 | 
			
		||||
			c.Header("Access-Control-Allow-Origin", origin)
 | 
			
		||||
			c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE")
 | 
			
		||||
			//允许跨域设置可以返回其他子段,可以自定义字段
 | 
			
		||||
			c.Header("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, ChatGPT-Token")
 | 
			
		||||
@@ -154,18 +150,28 @@ func AuthorizeMiddleware(s *Server) gin.HandlerFunc {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		if !s.Config.EnableAuth ||
 | 
			
		||||
			c.Request.URL.Path == "/api/login" ||
 | 
			
		||||
			c.Request.URL.Path == "/api/config/set" ||
 | 
			
		||||
			!strings.HasPrefix(c.Request.URL.Path, "/api") {
 | 
			
		||||
			c.Next()
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if strings.HasPrefix(c.Request.URL.Path, "/api/config") {
 | 
			
		||||
			accessKey := c.Query("access_key")
 | 
			
		||||
			if accessKey != "RockYang" {
 | 
			
		||||
				c.Abort()
 | 
			
		||||
				c.JSON(http.StatusOK, types.BizVo{Code: types.NotAuthorized, Message: "No Permissions"})
 | 
			
		||||
			} else {
 | 
			
		||||
				c.Next()
 | 
			
		||||
			}
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// WebSocket 连接请求验证
 | 
			
		||||
		if c.Request.URL.Path == "/api/chat" {
 | 
			
		||||
			tokenName := c.Query("token")
 | 
			
		||||
			if addr, ok := s.WsSession[tokenName]; ok && addr == c.ClientIP() {
 | 
			
		||||
				// 每个令牌只能连接一次
 | 
			
		||||
				delete(s.WsSession, tokenName)
 | 
			
		||||
				//delete(s.WsSession, tokenName)
 | 
			
		||||
				c.Next()
 | 
			
		||||
			} else {
 | 
			
		||||
				c.Abort()
 | 
			
		||||
@@ -190,7 +196,16 @@ func AuthorizeMiddleware(s *Server) gin.HandlerFunc {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) GetSessionHandle(c *gin.Context) {
 | 
			
		||||
	c.JSON(http.StatusOK, types.BizVo{Code: types.Success})
 | 
			
		||||
	tokenName := c.GetHeader(types.TokenName)
 | 
			
		||||
	if addr, ok := s.WsSession[tokenName]; ok && addr == c.ClientIP() {
 | 
			
		||||
		c.JSON(http.StatusOK, types.BizVo{Code: types.Success, Data: addr})
 | 
			
		||||
	} else {
 | 
			
		||||
		c.JSON(http.StatusOK, types.BizVo{
 | 
			
		||||
			Code:    types.NotAuthorized,
 | 
			
		||||
			Message: "Not Authorized",
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Server) LoginHandle(c *gin.Context) {
 | 
			
		||||
 
 | 
			
		||||
@@ -52,7 +52,7 @@ func NewDefaultConfig() *Config {
 | 
			
		||||
			MaxAge:    86400,
 | 
			
		||||
			Secure:    true,
 | 
			
		||||
			HttpOnly:  false,
 | 
			
		||||
			SameSite:  http.SameSiteNoneMode,
 | 
			
		||||
			SameSite:  http.SameSiteLaxMode,
 | 
			
		||||
		},
 | 
			
		||||
		Chat: Chat{
 | 
			
		||||
			ApiURL:        "https://api.openai.com/v1/chat/completions",
 | 
			
		||||
 
 | 
			
		||||
@@ -1,2 +1,2 @@
 | 
			
		||||
VUE_APP_API_HOST=http://chat.r9it.com:6789
 | 
			
		||||
VUE_APP_WS_HOST=ws://chat.r9it.com:6789
 | 
			
		||||
VUE_APP_API_HOST=https://ai.r9it.com
 | 
			
		||||
VUE_APP_WS_HOST=wss://ai.r9it.com
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										341
									
								
								web/package-lock.json
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										341
									
								
								web/package-lock.json
									
									
									
										generated
									
									
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@@ -2,15 +2,14 @@
 | 
			
		||||
/**
 | 
			
		||||
 * storage handler
 | 
			
		||||
 */
 | 
			
		||||
import Storage from 'good-storage'
 | 
			
		||||
 | 
			
		||||
const SessionIdKey = 'ChatGPT_SESSION_ID';
 | 
			
		||||
export const Global = {}
 | 
			
		||||
 | 
			
		||||
export function getSessionId() {
 | 
			
		||||
    return Storage.get(SessionIdKey)
 | 
			
		||||
    return sessionStorage.getItem(SessionIdKey)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
export function setSessionId(value) {
 | 
			
		||||
    Storage.set(SessionIdKey, value)
 | 
			
		||||
    sessionStorage.setItem(SessionIdKey, value)
 | 
			
		||||
}
 | 
			
		||||
@@ -114,8 +114,6 @@ export default defineComponent({
 | 
			
		||||
      this.chatBoxHeight = window.innerHeight - this.toolBoxHeight;
 | 
			
		||||
    })
 | 
			
		||||
 | 
			
		||||
    this.checkSession();
 | 
			
		||||
 | 
			
		||||
    // for (let i = 0; i < 10; i++) {
 | 
			
		||||
    //   this.chatData.push({
 | 
			
		||||
    //     type: "prompt",
 | 
			
		||||
@@ -175,44 +173,11 @@ export default defineComponent({
 | 
			
		||||
      this.chatBoxHeight = window.innerHeight - this.toolBoxHeight;
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    this.connect();
 | 
			
		||||
 | 
			
		||||
  },
 | 
			
		||||
 | 
			
		||||
  methods: {
 | 
			
		||||
    // 检查会话
 | 
			
		||||
    checkSession: function () {
 | 
			
		||||
      httpPost("/api/session/get").then(() => {
 | 
			
		||||
        if (this.socket == null) {
 | 
			
		||||
          this.connect();
 | 
			
		||||
        }
 | 
			
		||||
        // 发送心跳
 | 
			
		||||
        //setTimeout(() => this.checkSession(), 5000);
 | 
			
		||||
      }).catch((res) => {
 | 
			
		||||
        if (res.code === 400) {
 | 
			
		||||
          this.showLoginDialog = true;
 | 
			
		||||
        } else {
 | 
			
		||||
          this.connectingMessageBox = ElMessageBox.confirm(
 | 
			
		||||
              '^_^ 会话发生异常,您已经从服务器断开连接!',
 | 
			
		||||
              '注意:',
 | 
			
		||||
              {
 | 
			
		||||
                confirmButtonText: '重连会话',
 | 
			
		||||
                cancelButtonText: '不聊了',
 | 
			
		||||
                type: 'warning',
 | 
			
		||||
                showClose: false,
 | 
			
		||||
                closeOnClickModal: false
 | 
			
		||||
              }
 | 
			
		||||
          ).then(() => {
 | 
			
		||||
            this.connect();
 | 
			
		||||
          }).catch(() => {
 | 
			
		||||
            ElMessage({
 | 
			
		||||
              type: 'info',
 | 
			
		||||
              message: '您关闭了会话',
 | 
			
		||||
            })
 | 
			
		||||
          })
 | 
			
		||||
        }
 | 
			
		||||
      })
 | 
			
		||||
    },
 | 
			
		||||
 | 
			
		||||
    connect: function () {
 | 
			
		||||
      // 初始化 WebSocket 对象
 | 
			
		||||
      const token = getSessionId();
 | 
			
		||||
@@ -264,8 +229,32 @@ export default defineComponent({
 | 
			
		||||
 | 
			
		||||
      });
 | 
			
		||||
      socket.addEventListener('close', () => {
 | 
			
		||||
        // 检查会话,自动登录
 | 
			
		||||
        this.checkSession();
 | 
			
		||||
        // 检查会话
 | 
			
		||||
        httpPost("/api/session/get").then(() => {
 | 
			
		||||
          this.connectingMessageBox = ElMessageBox.confirm(
 | 
			
		||||
              '^_^ 会话发生异常,您已经从服务器断开连接!',
 | 
			
		||||
              '注意:',
 | 
			
		||||
              {
 | 
			
		||||
                confirmButtonText: '重连会话',
 | 
			
		||||
                cancelButtonText: '不聊了',
 | 
			
		||||
                type: 'warning',
 | 
			
		||||
                showClose: false,
 | 
			
		||||
                closeOnClickModal: false
 | 
			
		||||
              }
 | 
			
		||||
          ).then(() => {
 | 
			
		||||
            this.connect();
 | 
			
		||||
          }).catch(() => {
 | 
			
		||||
            ElMessage({
 | 
			
		||||
              type: 'info',
 | 
			
		||||
              message: '您关闭了会话',
 | 
			
		||||
            })
 | 
			
		||||
          })
 | 
			
		||||
        }).catch((res) => {
 | 
			
		||||
          if (res.code === 400) {
 | 
			
		||||
            this.showLoginDialog = true;
 | 
			
		||||
          }
 | 
			
		||||
        })
 | 
			
		||||
 | 
			
		||||
      });
 | 
			
		||||
 | 
			
		||||
      this.socket = socket;
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user