mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 16:23:42 +08:00 
			
		
		
		
	refactor websocket message protocol, keep the only connection for all clients
This commit is contained in:
		@@ -51,9 +51,9 @@ func NewServer(appConfig *types.AppConfig) *AppServer {
 | 
				
			|||||||
func (s *AppServer) Init(debug bool, client *redis.Client) {
 | 
					func (s *AppServer) Init(debug bool, client *redis.Client) {
 | 
				
			||||||
	if debug { // 调试模式允许跨域请求 API
 | 
						if debug { // 调试模式允许跨域请求 API
 | 
				
			||||||
		s.Debug = debug
 | 
							s.Debug = debug
 | 
				
			||||||
 | 
							s.Engine.Use(corsMiddleware())
 | 
				
			||||||
		logger.Info("Enabled debug mode")
 | 
							logger.Info("Enabled debug mode")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	s.Engine.Use(corsMiddleware())
 | 
					 | 
				
			||||||
	s.Engine.Use(staticResourceMiddleware())
 | 
						s.Engine.Use(staticResourceMiddleware())
 | 
				
			||||||
	s.Engine.Use(authorizeMiddleware(s, client))
 | 
						s.Engine.Use(authorizeMiddleware(s, client))
 | 
				
			||||||
	s.Engine.Use(parameterHandlerMiddleware())
 | 
						s.Engine.Use(parameterHandlerMiddleware())
 | 
				
			||||||
@@ -101,9 +101,9 @@ func corsMiddleware() gin.HandlerFunc {
 | 
				
			|||||||
			c.Header("Access-Control-Allow-Origin", origin)
 | 
								c.Header("Access-Control-Allow-Origin", origin)
 | 
				
			||||||
			c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE")
 | 
								c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE")
 | 
				
			||||||
			//允许跨域设置可以返回其他子段,可以自定义字段
 | 
								//允许跨域设置可以返回其他子段,可以自定义字段
 | 
				
			||||||
			c.Header("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, Chat-Token, Admin-Authorization")
 | 
								c.Header("Access-Control-Allow-Headers", "Authorization, Body-Length, Body-Type, Admin-Authorization,content-type")
 | 
				
			||||||
			// 允许浏览器(客户端)可以解析的头部 (重要)
 | 
								// 允许浏览器(客户端)可以解析的头部 (重要)
 | 
				
			||||||
			c.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers")
 | 
								c.Header("Access-Control-Expose-Headers", "Body-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers")
 | 
				
			||||||
			//设置缓存时间
 | 
								//设置缓存时间
 | 
				
			||||||
			c.Header("Access-Control-Max-Age", "172800")
 | 
								c.Header("Access-Control-Max-Age", "172800")
 | 
				
			||||||
			//允许客户端传递校验信息比如 cookie (重要)
 | 
								//允许客户端传递校验信息比如 cookie (重要)
 | 
				
			||||||
@@ -131,7 +131,7 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
 | 
				
			|||||||
		isAdminApi := strings.Contains(c.Request.URL.Path, "/api/admin/")
 | 
							isAdminApi := strings.Contains(c.Request.URL.Path, "/api/admin/")
 | 
				
			||||||
		if isAdminApi { // 后台管理 API
 | 
							if isAdminApi { // 后台管理 API
 | 
				
			||||||
			tokenString = c.GetHeader(types.AdminAuthHeader)
 | 
								tokenString = c.GetHeader(types.AdminAuthHeader)
 | 
				
			||||||
		} else if c.Request.URL.Path == "/api/chat/new" {
 | 
							} else if c.Request.URL.Path == "/api/ws" { // Websocket 连接
 | 
				
			||||||
			tokenString = c.Query("token")
 | 
								tokenString = c.Query("token")
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			tokenString = c.GetHeader(types.UserAuthHeader)
 | 
								tokenString = c.GetHeader(types.UserAuthHeader)
 | 
				
			||||||
@@ -209,23 +209,18 @@ func needLogin(c *gin.Context) bool {
 | 
				
			|||||||
		c.Request.URL.Path == "/api/app/list/user" ||
 | 
							c.Request.URL.Path == "/api/app/list/user" ||
 | 
				
			||||||
		c.Request.URL.Path == "/api/model/list" ||
 | 
							c.Request.URL.Path == "/api/model/list" ||
 | 
				
			||||||
		c.Request.URL.Path == "/api/mj/imgWall" ||
 | 
							c.Request.URL.Path == "/api/mj/imgWall" ||
 | 
				
			||||||
		c.Request.URL.Path == "/api/mj/client" ||
 | 
					 | 
				
			||||||
		c.Request.URL.Path == "/api/mj/notify" ||
 | 
							c.Request.URL.Path == "/api/mj/notify" ||
 | 
				
			||||||
		c.Request.URL.Path == "/api/invite/hits" ||
 | 
							c.Request.URL.Path == "/api/invite/hits" ||
 | 
				
			||||||
		c.Request.URL.Path == "/api/sd/imgWall" ||
 | 
							c.Request.URL.Path == "/api/sd/imgWall" ||
 | 
				
			||||||
		c.Request.URL.Path == "/api/sd/client" ||
 | 
					 | 
				
			||||||
		c.Request.URL.Path == "/api/dall/imgWall" ||
 | 
							c.Request.URL.Path == "/api/dall/imgWall" ||
 | 
				
			||||||
		c.Request.URL.Path == "/api/dall/client" ||
 | 
					 | 
				
			||||||
		c.Request.URL.Path == "/api/product/list" ||
 | 
							c.Request.URL.Path == "/api/product/list" ||
 | 
				
			||||||
		c.Request.URL.Path == "/api/menu/list" ||
 | 
							c.Request.URL.Path == "/api/menu/list" ||
 | 
				
			||||||
		c.Request.URL.Path == "/api/markMap/client" ||
 | 
							c.Request.URL.Path == "/api/markMap/client" ||
 | 
				
			||||||
		c.Request.URL.Path == "/api/payment/doPay" ||
 | 
							c.Request.URL.Path == "/api/payment/doPay" ||
 | 
				
			||||||
		c.Request.URL.Path == "/api/payment/payWays" ||
 | 
							c.Request.URL.Path == "/api/payment/payWays" ||
 | 
				
			||||||
		c.Request.URL.Path == "/api/suno/client" ||
 | 
					 | 
				
			||||||
		c.Request.URL.Path == "/api/suno/detail" ||
 | 
							c.Request.URL.Path == "/api/suno/detail" ||
 | 
				
			||||||
		c.Request.URL.Path == "/api/suno/play" ||
 | 
							c.Request.URL.Path == "/api/suno/play" ||
 | 
				
			||||||
		c.Request.URL.Path == "/api/download" ||
 | 
							c.Request.URL.Path == "/api/download" ||
 | 
				
			||||||
		c.Request.URL.Path == "/api/video/client" ||
 | 
					 | 
				
			||||||
		strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
 | 
							strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
 | 
				
			||||||
		strings.HasPrefix(c.Request.URL.Path, "/api/payment/notify/") ||
 | 
							strings.HasPrefix(c.Request.URL.Path, "/api/payment/notify/") ||
 | 
				
			||||||
		strings.HasPrefix(c.Request.URL.Path, "/api/user/clogin") ||
 | 
							strings.HasPrefix(c.Request.URL.Path, "/api/user/clogin") ||
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -52,14 +52,13 @@ type Delta struct {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// ChatSession 聊天会话对象
 | 
					// ChatSession 聊天会话对象
 | 
				
			||||||
type ChatSession struct {
 | 
					type ChatSession struct {
 | 
				
			||||||
	SessionId string    `json:"session_id"`
 | 
						UserId   uint      `json:"user_id"`
 | 
				
			||||||
	UserId    uint      `json:"user_id"`
 | 
						ClientIP string    `json:"client_ip"` // 客户端 IP
 | 
				
			||||||
	ClientIP  string    `json:"client_ip"` // 客户端 IP
 | 
						ChatId   string    `json:"chat_id"`   // 客户端聊天会话 ID, 多会话模式专用字段
 | 
				
			||||||
	ChatId    string    `json:"chat_id"`   // 客户端聊天会话 ID, 多会话模式专用字段
 | 
						Model    ChatModel `json:"model"`     // GPT 模型
 | 
				
			||||||
	Model     ChatModel `json:"model"`     // GPT 模型
 | 
						Start    int64     `json:"start"`     // 开始请求时间戳
 | 
				
			||||||
	Start     int64     `json:"start"`     // 开始请求时间戳
 | 
						Tools    []int     `json:"tools"`     // 工具函数列表
 | 
				
			||||||
	Tools     []int     `json:"tools"`     // 工具函数列表
 | 
						Stream   bool      `json:"stream"`    // 是否采用流式输出
 | 
				
			||||||
	Stream    bool      `json:"stream"`    // 是否采用流式输出
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type ChatModel struct {
 | 
					type ChatModel struct {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -17,15 +17,17 @@ var ErrConClosed = errors.New("connection Closed")
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// WsClient websocket client
 | 
					// WsClient websocket client
 | 
				
			||||||
type WsClient struct {
 | 
					type WsClient struct {
 | 
				
			||||||
 | 
						Id     string
 | 
				
			||||||
	Conn   *websocket.Conn
 | 
						Conn   *websocket.Conn
 | 
				
			||||||
	lock   sync.Mutex
 | 
						lock   sync.Mutex
 | 
				
			||||||
	mt     int
 | 
						mt     int
 | 
				
			||||||
	Closed bool
 | 
						Closed bool
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewWsClient(conn *websocket.Conn) *WsClient {
 | 
					func NewWsClient(conn *websocket.Conn, id string) *WsClient {
 | 
				
			||||||
	return &WsClient{
 | 
						return &WsClient{
 | 
				
			||||||
		Conn:   conn,
 | 
							Conn:   conn,
 | 
				
			||||||
 | 
							Id:     id,
 | 
				
			||||||
		lock:   sync.Mutex{},
 | 
							lock:   sync.Mutex{},
 | 
				
			||||||
		mt:     2, // fixed bug for 'Invalid UTF-8 in text frame'
 | 
							mt:     2, // fixed bug for 'Invalid UTF-8 in text frame'
 | 
				
			||||||
		Closed: false,
 | 
							Closed: false,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -19,35 +19,44 @@ type BizVo struct {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// ReplyMessage 对话回复消息结构
 | 
					// ReplyMessage 对话回复消息结构
 | 
				
			||||||
type ReplyMessage struct {
 | 
					type ReplyMessage struct {
 | 
				
			||||||
	Channel WsChannel   `json:"channel"` // 消息频道,目前只有 chat
 | 
						Channel  WsChannel   `json:"channel"`  // 消息频道,目前只有 chat
 | 
				
			||||||
	Type    WsMsgType   `json:"type"`    // 消息类别
 | 
						ClientId string      `json:"clientId"` // 客户端ID
 | 
				
			||||||
	Content interface{} `json:"content"`
 | 
						Type     WsMsgType   `json:"type"`     // 消息类别
 | 
				
			||||||
 | 
						Body     interface{} `json:"body"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type WsMsgType string
 | 
					type WsMsgType string
 | 
				
			||||||
type WsChannel string
 | 
					type WsChannel string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const (
 | 
					const (
 | 
				
			||||||
	WsMsgTypeContent = WsMsgType("content") // 输出内容
 | 
						MsgTypeText = WsMsgType("text") // 输出内容
 | 
				
			||||||
	WsMsgTypeEnd     = WsMsgType("end")
 | 
						MsgTypeEnd  = WsMsgType("end")
 | 
				
			||||||
	WsMsgTypeErr     = WsMsgType("error")
 | 
						MsgTypeErr  = WsMsgType("error")
 | 
				
			||||||
	WsMsgTypePing    = WsMsgType("ping") // 心跳消息
 | 
						MsgTypePing = WsMsgType("ping") // 心跳消息
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	WsChat = WsChannel("chat")
 | 
						ChPing = WsChannel("ping")
 | 
				
			||||||
	WsMj   = WsChannel("mj")
 | 
						ChChat = WsChannel("chat")
 | 
				
			||||||
	WsSd   = WsChannel("sd")
 | 
						ChMj   = WsChannel("mj")
 | 
				
			||||||
	WsDall = WsChannel("dall")
 | 
						ChSd   = WsChannel("sd")
 | 
				
			||||||
	WsSuno = WsChannel("suno")
 | 
						ChDall = WsChannel("dall")
 | 
				
			||||||
	WsLuma = WsChannel("luma")
 | 
						ChSuno = WsChannel("suno")
 | 
				
			||||||
 | 
						ChLuma = WsChannel("luma")
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// InputMessage 对话输入消息结构
 | 
					// InputMessage 对话输入消息结构
 | 
				
			||||||
type InputMessage struct {
 | 
					type InputMessage struct {
 | 
				
			||||||
	Channel WsChannel `json:"channel"` // 消息频道
 | 
						Channel WsChannel   `json:"channel"` // 消息频道
 | 
				
			||||||
	Type    WsMsgType `json:"type"`    // 消息类别
 | 
						Type    WsMsgType   `json:"type"`    // 消息类别
 | 
				
			||||||
	Content string    `json:"content"`
 | 
						Body    interface{} `json:"body"`
 | 
				
			||||||
	Tools   []int     `json:"tools"`  // 允许调用工具列表
 | 
					}
 | 
				
			||||||
	Stream  bool      `json:"stream"` // 是否采用流式输出
 | 
					
 | 
				
			||||||
 | 
					type ChatMessage struct {
 | 
				
			||||||
 | 
						Tools   []int  `json:"tools,omitempty"`  // 允许调用工具列表
 | 
				
			||||||
 | 
						Stream  bool   `json:"stream,omitempty"` // 是否采用流式输出
 | 
				
			||||||
 | 
						RoleId  int    `json:"role_id"`
 | 
				
			||||||
 | 
						ModelId int    `json:"model_id"`
 | 
				
			||||||
 | 
						ChatId  string `json:"chat_id"`
 | 
				
			||||||
 | 
						Content string `json:"content"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type BizCode int
 | 
					type BizCode int
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,4 +1,4 @@
 | 
				
			|||||||
package chatimpl
 | 
					package handler
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
					// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
				
			||||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
					// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
				
			||||||
@@ -15,8 +15,6 @@ import (
 | 
				
			|||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"geekai/core"
 | 
						"geekai/core"
 | 
				
			||||||
	"geekai/core/types"
 | 
						"geekai/core/types"
 | 
				
			||||||
	"geekai/handler"
 | 
					 | 
				
			||||||
	logger2 "geekai/logger"
 | 
					 | 
				
			||||||
	"geekai/service"
 | 
						"geekai/service"
 | 
				
			||||||
	"geekai/service/oss"
 | 
						"geekai/service/oss"
 | 
				
			||||||
	"geekai/store/model"
 | 
						"geekai/store/model"
 | 
				
			||||||
@@ -33,14 +31,11 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
	"github.com/go-redis/redis/v8"
 | 
						"github.com/go-redis/redis/v8"
 | 
				
			||||||
	"github.com/gorilla/websocket"
 | 
					 | 
				
			||||||
	"gorm.io/gorm"
 | 
						"gorm.io/gorm"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var logger = logger2.GetLogger()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
type ChatHandler struct {
 | 
					type ChatHandler struct {
 | 
				
			||||||
	handler.BaseHandler
 | 
						BaseHandler
 | 
				
			||||||
	redis          *redis.Client
 | 
						redis          *redis.Client
 | 
				
			||||||
	uploadManager  *oss.UploaderManager
 | 
						uploadManager  *oss.UploaderManager
 | 
				
			||||||
	licenseService *service.LicenseService
 | 
						licenseService *service.LicenseService
 | 
				
			||||||
@@ -51,7 +46,7 @@ type ChatHandler struct {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService, userService *service.UserService) *ChatHandler {
 | 
					func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService, userService *service.UserService) *ChatHandler {
 | 
				
			||||||
	return &ChatHandler{
 | 
						return &ChatHandler{
 | 
				
			||||||
		BaseHandler:    handler.BaseHandler{App: app, DB: db},
 | 
							BaseHandler:    BaseHandler{App: app, DB: db},
 | 
				
			||||||
		redis:          redis,
 | 
							redis:          redis,
 | 
				
			||||||
		uploadManager:  manager,
 | 
							uploadManager:  manager,
 | 
				
			||||||
		licenseService: licenseService,
 | 
							licenseService: licenseService,
 | 
				
			||||||
@@ -61,106 +56,6 @@ func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manag
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ChatHandle 处理聊天 WebSocket 请求
 | 
					 | 
				
			||||||
func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
					 | 
				
			||||||
	ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		logger.Error(err)
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	sessionId := c.Query("session_id")
 | 
					 | 
				
			||||||
	roleId := h.GetInt(c, "role_id", 0)
 | 
					 | 
				
			||||||
	chatId := c.Query("chat_id")
 | 
					 | 
				
			||||||
	modelId := h.GetInt(c, "model_id", 0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	client := types.NewWsClient(ws)
 | 
					 | 
				
			||||||
	var chatRole model.ChatRole
 | 
					 | 
				
			||||||
	res := h.DB.First(&chatRole, roleId)
 | 
					 | 
				
			||||||
	if res.Error != nil || !chatRole.Enable {
 | 
					 | 
				
			||||||
		utils.ReplyErrorMessage(client, "当前聊天角色不存在或者未启用,对话已关闭!!!")
 | 
					 | 
				
			||||||
		c.Abort()
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	// if the role bind a model_id, use role's bind model_id
 | 
					 | 
				
			||||||
	if chatRole.ModelId > 0 {
 | 
					 | 
				
			||||||
		modelId = chatRole.ModelId
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
	// get model info
 | 
					 | 
				
			||||||
	var chatModel model.ChatModel
 | 
					 | 
				
			||||||
	res = h.DB.First(&chatModel, modelId)
 | 
					 | 
				
			||||||
	if res.Error != nil || chatModel.Enabled == false {
 | 
					 | 
				
			||||||
		utils.ReplyErrorMessage(client, "当前AI模型暂未启用,对话已关闭!!!")
 | 
					 | 
				
			||||||
		c.Abort()
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	session := &types.ChatSession{
 | 
					 | 
				
			||||||
		SessionId: sessionId,
 | 
					 | 
				
			||||||
		ClientIP:  c.ClientIP(),
 | 
					 | 
				
			||||||
		UserId:    h.GetLoginUserId(c),
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	// use old chat data override the chat model and role ID
 | 
					 | 
				
			||||||
	var chat model.ChatItem
 | 
					 | 
				
			||||||
	res = h.DB.Where("chat_id = ?", chatId).First(&chat)
 | 
					 | 
				
			||||||
	if res.Error == nil {
 | 
					 | 
				
			||||||
		chatModel.Id = chat.ModelId
 | 
					 | 
				
			||||||
		roleId = int(chat.RoleId)
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	session.ChatId = chatId
 | 
					 | 
				
			||||||
	session.Model = types.ChatModel{
 | 
					 | 
				
			||||||
		Id:          chatModel.Id,
 | 
					 | 
				
			||||||
		Name:        chatModel.Name,
 | 
					 | 
				
			||||||
		Value:       chatModel.Value,
 | 
					 | 
				
			||||||
		Power:       chatModel.Power,
 | 
					 | 
				
			||||||
		MaxTokens:   chatModel.MaxTokens,
 | 
					 | 
				
			||||||
		MaxContext:  chatModel.MaxContext,
 | 
					 | 
				
			||||||
		Temperature: chatModel.Temperature,
 | 
					 | 
				
			||||||
		KeyId:       chatModel.KeyId}
 | 
					 | 
				
			||||||
	logger.Infof("New websocket connected, IP: %s", c.ClientIP())
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	go func() {
 | 
					 | 
				
			||||||
		for {
 | 
					 | 
				
			||||||
			_, msg, err := client.Receive()
 | 
					 | 
				
			||||||
			if err != nil {
 | 
					 | 
				
			||||||
				logger.Debugf("close connection: %s", client.Conn.RemoteAddr())
 | 
					 | 
				
			||||||
				client.Close()
 | 
					 | 
				
			||||||
				cancelFunc := h.ReqCancelFunc.Get(sessionId)
 | 
					 | 
				
			||||||
				if cancelFunc != nil {
 | 
					 | 
				
			||||||
					cancelFunc()
 | 
					 | 
				
			||||||
					h.ReqCancelFunc.Delete(sessionId)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
				return
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			var message types.InputMessage
 | 
					 | 
				
			||||||
			err = utils.JsonDecode(string(msg), &message)
 | 
					 | 
				
			||||||
			if err != nil {
 | 
					 | 
				
			||||||
				continue
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			logger.Infof("Receive a message:%+v", message)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			session.Tools = message.Tools
 | 
					 | 
				
			||||||
			session.Stream = message.Stream
 | 
					 | 
				
			||||||
			ctx, cancel := context.WithCancel(context.Background())
 | 
					 | 
				
			||||||
			h.ReqCancelFunc.Put(sessionId, cancel)
 | 
					 | 
				
			||||||
			// 回复消息
 | 
					 | 
				
			||||||
			err = h.sendMessage(ctx, session, chatRole, utils.InterfaceToString(message.Content), client)
 | 
					 | 
				
			||||||
			if err != nil {
 | 
					 | 
				
			||||||
				logger.Error(err)
 | 
					 | 
				
			||||||
				utils.SendMessage(client, err.Error())
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				utils.SendChunkMessage(client, types.ReplyMessage{Type: types.WsMsgTypeEnd})
 | 
					 | 
				
			||||||
				logger.Infof("回答完毕: %v", message.Content)
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}()
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSession, role model.ChatRole, prompt string, ws *types.WsClient) error {
 | 
					func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSession, role model.ChatRole, prompt string, ws *types.WsClient) error {
 | 
				
			||||||
	if !h.App.Debug {
 | 
						if !h.App.Debug {
 | 
				
			||||||
		defer func() {
 | 
							defer func() {
 | 
				
			||||||
@@ -206,7 +101,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	// 兼容 GPT-O1 模型
 | 
						// 兼容 GPT-O1 模型
 | 
				
			||||||
	if strings.HasPrefix(session.Model.Value, "o1-") {
 | 
						if strings.HasPrefix(session.Model.Value, "o1-") {
 | 
				
			||||||
		utils.ReplyContent(ws, "AI 正在思考...\n")
 | 
							utils.SendChunkMsg(ws, "AI 正在思考...\n")
 | 
				
			||||||
		req.Stream = false
 | 
							req.Stream = false
 | 
				
			||||||
		session.Start = time.Now().Unix()
 | 
							session.Start = time.Now().Unix()
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
@@ -1,4 +1,4 @@
 | 
				
			|||||||
package chatimpl
 | 
					package handler
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
					// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
				
			||||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
					// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
				
			||||||
@@ -28,31 +28,40 @@ func (h *ChatHandler) List(c *gin.Context) {
 | 
				
			|||||||
	userId := h.GetLoginUserId(c)
 | 
						userId := h.GetLoginUserId(c)
 | 
				
			||||||
	var items = make([]vo.ChatItem, 0)
 | 
						var items = make([]vo.ChatItem, 0)
 | 
				
			||||||
	var chats []model.ChatItem
 | 
						var chats []model.ChatItem
 | 
				
			||||||
	res := h.DB.Where("user_id = ?", userId).Order("id DESC").Find(&chats)
 | 
						h.DB.Where("user_id", userId).Order("id DESC").Find(&chats)
 | 
				
			||||||
	if res.Error == nil {
 | 
						if len(chats) == 0 {
 | 
				
			||||||
		var roleIds = make([]uint, 0)
 | 
							resp.SUCCESS(c, items)
 | 
				
			||||||
		for _, chat := range chats {
 | 
							return
 | 
				
			||||||
			roleIds = append(roleIds, chat.RoleId)
 | 
						}
 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
		var roles []model.ChatRole
 | 
					 | 
				
			||||||
		res = h.DB.Find(&roles, roleIds)
 | 
					 | 
				
			||||||
		if res.Error == nil {
 | 
					 | 
				
			||||||
			roleMap := make(map[uint]model.ChatRole)
 | 
					 | 
				
			||||||
			for _, role := range roles {
 | 
					 | 
				
			||||||
				roleMap[role.Id] = role
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
			for _, chat := range chats {
 | 
						var roleIds = make([]uint, 0)
 | 
				
			||||||
				var item vo.ChatItem
 | 
						var modelValues = make([]string, 0)
 | 
				
			||||||
				err := utils.CopyObject(chat, &item)
 | 
						for _, chat := range chats {
 | 
				
			||||||
				if err == nil {
 | 
							roleIds = append(roleIds, chat.RoleId)
 | 
				
			||||||
					item.Id = chat.Id
 | 
							modelValues = append(modelValues, chat.Model)
 | 
				
			||||||
					item.Icon = roleMap[chat.RoleId].Icon
 | 
						}
 | 
				
			||||||
					items = append(items, item)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var roles []model.ChatRole
 | 
				
			||||||
 | 
						var models []model.ChatModel
 | 
				
			||||||
 | 
						roleMap := make(map[uint]model.ChatRole)
 | 
				
			||||||
 | 
						modelMap := make(map[string]model.ChatModel)
 | 
				
			||||||
 | 
						h.DB.Where("id IN ?", roleIds).Find(&roles)
 | 
				
			||||||
 | 
						h.DB.Where("value IN ?", modelValues).Find(&models)
 | 
				
			||||||
 | 
						for _, role := range roles {
 | 
				
			||||||
 | 
							roleMap[role.Id] = role
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for _, m := range models {
 | 
				
			||||||
 | 
							modelMap[m.Value] = m
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for _, chat := range chats {
 | 
				
			||||||
 | 
							var item vo.ChatItem
 | 
				
			||||||
 | 
							err := utils.CopyObject(chat, &item)
 | 
				
			||||||
 | 
							if err == nil {
 | 
				
			||||||
 | 
								item.Id = chat.Id
 | 
				
			||||||
 | 
								item.Icon = roleMap[chat.RoleId].Icon
 | 
				
			||||||
 | 
								item.ModelId = modelMap[chat.Model].Id
 | 
				
			||||||
 | 
								items = append(items, item)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	resp.SUCCESS(c, items)
 | 
						resp.SUCCESS(c, items)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -20,9 +20,7 @@ import (
 | 
				
			|||||||
	"geekai/utils/resp"
 | 
						"geekai/utils/resp"
 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
	"github.com/go-redis/redis/v8"
 | 
						"github.com/go-redis/redis/v8"
 | 
				
			||||||
	"github.com/gorilla/websocket"
 | 
					 | 
				
			||||||
	"gorm.io/gorm"
 | 
						"gorm.io/gorm"
 | 
				
			||||||
	"net/http"
 | 
					 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type DallJobHandler struct {
 | 
					type DallJobHandler struct {
 | 
				
			||||||
@@ -45,49 +43,6 @@ func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service,
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Client WebSocket 客户端,用于通知任务状态变更
 | 
					 | 
				
			||||||
func (h *DallJobHandler) Client(c *gin.Context) {
 | 
					 | 
				
			||||||
	ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		logger.Error(err)
 | 
					 | 
				
			||||||
		c.Abort()
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	userId := h.GetInt(c, "user_id", 0)
 | 
					 | 
				
			||||||
	if userId == 0 {
 | 
					 | 
				
			||||||
		logger.Info("Invalid user ID")
 | 
					 | 
				
			||||||
		c.Abort()
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	client := types.NewWsClient(ws)
 | 
					 | 
				
			||||||
	h.dallService.Clients.Put(uint(userId), client)
 | 
					 | 
				
			||||||
	logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
 | 
					 | 
				
			||||||
	go func() {
 | 
					 | 
				
			||||||
		for {
 | 
					 | 
				
			||||||
			_, msg, err := client.Receive()
 | 
					 | 
				
			||||||
			if err != nil {
 | 
					 | 
				
			||||||
				client.Close()
 | 
					 | 
				
			||||||
				h.dallService.Clients.Delete(uint(userId))
 | 
					 | 
				
			||||||
				return
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			var message types.ReplyMessage
 | 
					 | 
				
			||||||
			err = utils.JsonDecode(string(msg), &message)
 | 
					 | 
				
			||||||
			if err != nil {
 | 
					 | 
				
			||||||
				continue
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			// 心跳消息
 | 
					 | 
				
			||||||
			if message.Type == "heartbeat" {
 | 
					 | 
				
			||||||
				logger.Debug("收到 DallE 心跳消息:", message.Content)
 | 
					 | 
				
			||||||
				continue
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}()
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (h *DallJobHandler) preCheck(c *gin.Context) bool {
 | 
					func (h *DallJobHandler) preCheck(c *gin.Context) bool {
 | 
				
			||||||
	user, err := h.GetLoginUser(c)
 | 
						user, err := h.GetLoginUser(c)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -19,7 +19,6 @@ import (
 | 
				
			|||||||
	"geekai/store/model"
 | 
						"geekai/store/model"
 | 
				
			||||||
	"geekai/utils"
 | 
						"geekai/utils"
 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
	"github.com/gorilla/websocket"
 | 
					 | 
				
			||||||
	"gorm.io/gorm"
 | 
						"gorm.io/gorm"
 | 
				
			||||||
	"io"
 | 
						"io"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
@@ -43,55 +42,9 @@ func NewMarkMapHandler(app *core.AppServer, db *gorm.DB, userService *service.Us
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (h *MarkMapHandler) Client(c *gin.Context) {
 | 
					// Generate 生成思维导图
 | 
				
			||||||
	ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
 | 
					func (h *MarkMapHandler) Generate(c *gin.Context) {
 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		logger.Error(err)
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	modelId := h.GetInt(c, "model_id", 0)
 | 
					 | 
				
			||||||
	userId := h.GetInt(c, "user_id", 0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	client := types.NewWsClient(ws)
 | 
					 | 
				
			||||||
	h.clients.Put(userId, client)
 | 
					 | 
				
			||||||
	go func() {
 | 
					 | 
				
			||||||
		for {
 | 
					 | 
				
			||||||
			_, msg, err := client.Receive()
 | 
					 | 
				
			||||||
			if err != nil {
 | 
					 | 
				
			||||||
				client.Close()
 | 
					 | 
				
			||||||
				h.clients.Delete(userId)
 | 
					 | 
				
			||||||
				return
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			var message types.ReplyMessage
 | 
					 | 
				
			||||||
			err = utils.JsonDecode(string(msg), &message)
 | 
					 | 
				
			||||||
			if err != nil {
 | 
					 | 
				
			||||||
				continue
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			// 心跳消息
 | 
					 | 
				
			||||||
			if message.Type == "heartbeat" {
 | 
					 | 
				
			||||||
				logger.Debug("收到 MarkMap 心跳消息:", message.Content)
 | 
					 | 
				
			||||||
				continue
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
			// change model
 | 
					 | 
				
			||||||
			if message.Type == "model_id" {
 | 
					 | 
				
			||||||
				modelId = utils.IntValue(utils.InterfaceToString(message.Content), 0)
 | 
					 | 
				
			||||||
				continue
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			logger.Info("Receive a message: ", message.Content)
 | 
					 | 
				
			||||||
			err = h.sendMessage(client, utils.InterfaceToString(message.Content), modelId, userId)
 | 
					 | 
				
			||||||
			if err != nil {
 | 
					 | 
				
			||||||
				logger.Error(err)
 | 
					 | 
				
			||||||
				utils.ReplyErrorMessage(client, err.Error())
 | 
					 | 
				
			||||||
			} else {
 | 
					 | 
				
			||||||
				utils.SendMessage(client, types.ReplyMessage{Type: types.WsMsgTypeEnd})
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}()
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, modelId int, userId int) error {
 | 
					func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, modelId int, userId int) error {
 | 
				
			||||||
@@ -170,13 +123,13 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
 | 
				
			|||||||
				break
 | 
									break
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			utils.SendChunkMessage(client, types.ReplyMessage{
 | 
								utils.SendMsg(client, types.ReplyMessage{
 | 
				
			||||||
				Type:    types.WsMsgTypeContent,
 | 
									Type: types.MsgTypeText,
 | 
				
			||||||
				Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
 | 
									Body: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
 | 
				
			||||||
			})
 | 
								})
 | 
				
			||||||
		} // end for
 | 
							} // end for
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		utils.SendChunkMessage(client, types.ReplyMessage{Type: types.WsMsgTypeEnd})
 | 
							utils.SendMsg(client, types.ReplyMessage{Type: types.MsgTypeEnd})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		body, _ := io.ReadAll(response.Body)
 | 
							body, _ := io.ReadAll(response.Body)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -19,12 +19,10 @@ import (
 | 
				
			|||||||
	"geekai/store/vo"
 | 
						"geekai/store/vo"
 | 
				
			||||||
	"geekai/utils"
 | 
						"geekai/utils"
 | 
				
			||||||
	"geekai/utils/resp"
 | 
						"geekai/utils/resp"
 | 
				
			||||||
	"net/http"
 | 
					 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
	"github.com/gorilla/websocket"
 | 
					 | 
				
			||||||
	"gorm.io/gorm"
 | 
						"gorm.io/gorm"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -65,27 +63,6 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Client WebSocket 客户端,用于通知任务状态变更
 | 
					 | 
				
			||||||
func (h *MidJourneyHandler) Client(c *gin.Context) {
 | 
					 | 
				
			||||||
	ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		logger.Error(err)
 | 
					 | 
				
			||||||
		c.Abort()
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	userId := h.GetInt(c, "user_id", 0)
 | 
					 | 
				
			||||||
	if userId == 0 {
 | 
					 | 
				
			||||||
		logger.Info("Invalid user ID")
 | 
					 | 
				
			||||||
		c.Abort()
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	client := types.NewWsClient(ws)
 | 
					 | 
				
			||||||
	h.mjService.Clients.Put(uint(userId), client)
 | 
					 | 
				
			||||||
	logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Image 创建一个绘画任务
 | 
					// Image 创建一个绘画任务
 | 
				
			||||||
func (h *MidJourneyHandler) Image(c *gin.Context) {
 | 
					func (h *MidJourneyHandler) Image(c *gin.Context) {
 | 
				
			||||||
	var data struct {
 | 
						var data struct {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,4 +1,4 @@
 | 
				
			|||||||
package chatimpl
 | 
					package handler
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
					// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
				
			||||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
					// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
				
			||||||
@@ -108,7 +108,7 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
 | 
								if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
 | 
				
			||||||
				utils.SendMessage(ws, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。")
 | 
									utils.SendChunkMsg(ws, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。")
 | 
				
			||||||
				break
 | 
									break
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -136,7 +136,7 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
				
			|||||||
				if res.Error == nil {
 | 
									if res.Error == nil {
 | 
				
			||||||
					toolCall = true
 | 
										toolCall = true
 | 
				
			||||||
					callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
 | 
										callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
 | 
				
			||||||
					utils.SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeContent, Content: callMsg})
 | 
										utils.SendChunkMsg(ws, callMsg)
 | 
				
			||||||
					contents = append(contents, callMsg)
 | 
										contents = append(contents, callMsg)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				continue
 | 
									continue
 | 
				
			||||||
@@ -153,10 +153,7 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
				
			|||||||
			} else {
 | 
								} else {
 | 
				
			||||||
				content := responseBody.Choices[0].Delta.Content
 | 
									content := responseBody.Choices[0].Delta.Content
 | 
				
			||||||
				contents = append(contents, utils.InterfaceToString(content))
 | 
									contents = append(contents, utils.InterfaceToString(content))
 | 
				
			||||||
				utils.SendChunkMessage(ws, types.ReplyMessage{
 | 
									utils.SendChunkMsg(ws, responseBody.Choices[0].Delta.Content)
 | 
				
			||||||
					Type:    types.WsMsgTypeContent,
 | 
					 | 
				
			||||||
					Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
 | 
					 | 
				
			||||||
				})
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		} // end for
 | 
							} // end for
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -174,7 +171,7 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
				
			|||||||
			logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
 | 
								logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
 | 
				
			||||||
			params["user_id"] = userVo.Id
 | 
								params["user_id"] = userVo.Id
 | 
				
			||||||
			var apiRes types.BizVo
 | 
								var apiRes types.BizVo
 | 
				
			||||||
			r, err := req2.C().R().SetHeader("Content-Type", "application/json").
 | 
								r, err := req2.C().R().SetHeader("Body-Type", "application/json").
 | 
				
			||||||
				SetHeader("Authorization", function.Token).
 | 
									SetHeader("Authorization", function.Token).
 | 
				
			||||||
				SetBody(params).
 | 
									SetBody(params).
 | 
				
			||||||
				SetSuccessResult(&apiRes).Post(function.Action)
 | 
									SetSuccessResult(&apiRes).Post(function.Action)
 | 
				
			||||||
@@ -185,19 +182,13 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
				
			|||||||
				errMsg = r.Status
 | 
									errMsg = r.Status
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			if errMsg != "" || apiRes.Code != types.Success {
 | 
								if errMsg != "" || apiRes.Code != types.Success {
 | 
				
			||||||
				msg := "调用函数工具出错:" + apiRes.Message + errMsg
 | 
									errMsg = "调用函数工具出错:" + apiRes.Message + errMsg
 | 
				
			||||||
				utils.SendChunkMessage(ws, types.ReplyMessage{
 | 
									contents = append(contents, errMsg)
 | 
				
			||||||
					Type:    types.WsMsgTypeContent,
 | 
					 | 
				
			||||||
					Content: msg,
 | 
					 | 
				
			||||||
				})
 | 
					 | 
				
			||||||
				contents = append(contents, msg)
 | 
					 | 
				
			||||||
			} else {
 | 
								} else {
 | 
				
			||||||
				utils.SendChunkMessage(ws, types.ReplyMessage{
 | 
									errMsg = utils.InterfaceToString(apiRes.Data)
 | 
				
			||||||
					Type:    types.WsMsgTypeContent,
 | 
									contents = append(contents, errMsg)
 | 
				
			||||||
					Content: apiRes.Data,
 | 
					 | 
				
			||||||
				})
 | 
					 | 
				
			||||||
				contents = append(contents, utils.InterfaceToString(apiRes.Data))
 | 
					 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
								utils.SendChunkMsg(ws, errMsg)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// 消息发送成功
 | 
							// 消息发送成功
 | 
				
			||||||
@@ -226,7 +217,7 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
				
			|||||||
		if strings.HasPrefix(req.Model, "o1-") {
 | 
							if strings.HasPrefix(req.Model, "o1-") {
 | 
				
			||||||
			content = fmt.Sprintf("AI思考结束,耗时:%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content)
 | 
								content = fmt.Sprintf("AI思考结束,耗时:%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		utils.SendMessage(ws, content)
 | 
							utils.SendChunkMsg(ws, content)
 | 
				
			||||||
		respVo.Usage.Prompt = prompt
 | 
							respVo.Usage.Prompt = prompt
 | 
				
			||||||
		respVo.Usage.Content = content
 | 
							respVo.Usage.Content = content
 | 
				
			||||||
		h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, chatCtx, session, role, userVo, promptCreatedAt, time.Now())
 | 
							h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, chatCtx, session, role, userVo, promptCreatedAt, time.Now())
 | 
				
			||||||
@@ -19,11 +19,8 @@ import (
 | 
				
			|||||||
	"geekai/store/vo"
 | 
						"geekai/store/vo"
 | 
				
			||||||
	"geekai/utils"
 | 
						"geekai/utils"
 | 
				
			||||||
	"geekai/utils/resp"
 | 
						"geekai/utils/resp"
 | 
				
			||||||
	"net/http"
 | 
					 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"github.com/gorilla/websocket"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
	"github.com/go-redis/redis/v8"
 | 
						"github.com/go-redis/redis/v8"
 | 
				
			||||||
	"gorm.io/gorm"
 | 
						"gorm.io/gorm"
 | 
				
			||||||
@@ -59,27 +56,6 @@ func NewSdJobHandler(app *core.AppServer,
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Client WebSocket 客户端,用于通知任务状态变更
 | 
					 | 
				
			||||||
func (h *SdJobHandler) Client(c *gin.Context) {
 | 
					 | 
				
			||||||
	ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		logger.Error(err)
 | 
					 | 
				
			||||||
		c.Abort()
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	userId := h.GetInt(c, "user_id", 0)
 | 
					 | 
				
			||||||
	if userId == 0 {
 | 
					 | 
				
			||||||
		logger.Info("Invalid user ID")
 | 
					 | 
				
			||||||
		c.Abort()
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	client := types.NewWsClient(ws)
 | 
					 | 
				
			||||||
	h.sdService.Clients.Put(uint(userId), client)
 | 
					 | 
				
			||||||
	logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (h *SdJobHandler) preCheck(c *gin.Context) bool {
 | 
					func (h *SdJobHandler) preCheck(c *gin.Context) bool {
 | 
				
			||||||
	user, err := h.GetLoginUser(c)
 | 
						user, err := h.GetLoginUser(c)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -19,9 +19,7 @@ import (
 | 
				
			|||||||
	"geekai/utils"
 | 
						"geekai/utils"
 | 
				
			||||||
	"geekai/utils/resp"
 | 
						"geekai/utils/resp"
 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
	"github.com/gorilla/websocket"
 | 
					 | 
				
			||||||
	"gorm.io/gorm"
 | 
						"gorm.io/gorm"
 | 
				
			||||||
	"net/http"
 | 
					 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -44,27 +42,6 @@ func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, upl
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Client WebSocket 客户端,用于通知任务状态变更
 | 
					 | 
				
			||||||
func (h *SunoHandler) Client(c *gin.Context) {
 | 
					 | 
				
			||||||
	ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		logger.Error(err)
 | 
					 | 
				
			||||||
		c.Abort()
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	userId := h.GetInt(c, "user_id", 0)
 | 
					 | 
				
			||||||
	if userId == 0 {
 | 
					 | 
				
			||||||
		logger.Info("Invalid user ID")
 | 
					 | 
				
			||||||
		c.Abort()
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	client := types.NewWsClient(ws)
 | 
					 | 
				
			||||||
	h.sunoService.Clients.Put(uint(userId), client)
 | 
					 | 
				
			||||||
	logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (h *SunoHandler) Create(c *gin.Context) {
 | 
					func (h *SunoHandler) Create(c *gin.Context) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var data struct {
 | 
						var data struct {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -19,7 +19,7 @@ func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.GeekP
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (h *TestHandler) SseTest(c *gin.Context) {
 | 
					func (h *TestHandler) SseTest(c *gin.Context) {
 | 
				
			||||||
	//c.Header("Content-Type", "text/event-stream")
 | 
						//c.Header("Body-Type", "text/event-stream")
 | 
				
			||||||
	//c.Header("Cache-Control", "no-cache")
 | 
						//c.Header("Cache-Control", "no-cache")
 | 
				
			||||||
	//c.Header("Connection", "keep-alive")
 | 
						//c.Header("Connection", "keep-alive")
 | 
				
			||||||
	//
 | 
						//
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -19,9 +19,7 @@ import (
 | 
				
			|||||||
	"geekai/utils"
 | 
						"geekai/utils"
 | 
				
			||||||
	"geekai/utils/resp"
 | 
						"geekai/utils/resp"
 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
	"github.com/gorilla/websocket"
 | 
					 | 
				
			||||||
	"gorm.io/gorm"
 | 
						"gorm.io/gorm"
 | 
				
			||||||
	"net/http"
 | 
					 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -44,27 +42,6 @@ func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, u
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Client WebSocket 客户端,用于通知任务状态变更
 | 
					 | 
				
			||||||
func (h *VideoHandler) Client(c *gin.Context) {
 | 
					 | 
				
			||||||
	ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
 | 
					 | 
				
			||||||
	if err != nil {
 | 
					 | 
				
			||||||
		logger.Error(err)
 | 
					 | 
				
			||||||
		c.Abort()
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	userId := h.GetInt(c, "user_id", 0)
 | 
					 | 
				
			||||||
	if userId == 0 {
 | 
					 | 
				
			||||||
		logger.Info("Invalid user ID")
 | 
					 | 
				
			||||||
		c.Abort()
 | 
					 | 
				
			||||||
		return
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	client := types.NewWsClient(ws)
 | 
					 | 
				
			||||||
	h.videoService.Clients.Put(uint(userId), client)
 | 
					 | 
				
			||||||
	logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
func (h *VideoHandler) LumaCreate(c *gin.Context) {
 | 
					func (h *VideoHandler) LumaCreate(c *gin.Context) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var data struct {
 | 
						var data struct {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -8,6 +8,7 @@ package handler
 | 
				
			|||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
					// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"context"
 | 
				
			||||||
	"geekai/core"
 | 
						"geekai/core"
 | 
				
			||||||
	"geekai/core/types"
 | 
						"geekai/core/types"
 | 
				
			||||||
	"geekai/service"
 | 
						"geekai/service"
 | 
				
			||||||
@@ -15,6 +16,7 @@ import (
 | 
				
			|||||||
	"geekai/utils"
 | 
						"geekai/utils"
 | 
				
			||||||
	"github.com/gin-gonic/gin"
 | 
						"github.com/gin-gonic/gin"
 | 
				
			||||||
	"github.com/gorilla/websocket"
 | 
						"github.com/gorilla/websocket"
 | 
				
			||||||
 | 
						"gorm.io/gorm"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -22,12 +24,14 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
type WebsocketHandler struct {
 | 
					type WebsocketHandler struct {
 | 
				
			||||||
	BaseHandler
 | 
						BaseHandler
 | 
				
			||||||
	wsService *service.WebsocketService
 | 
						wsService   *service.WebsocketService
 | 
				
			||||||
 | 
						chatHandler *ChatHandler
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func NewWebsocketHandler(app *core.AppServer, s *service.WebsocketService) *WebsocketHandler {
 | 
					func NewWebsocketHandler(app *core.AppServer, s *service.WebsocketService, db *gorm.DB, chatHandler *ChatHandler) *WebsocketHandler {
 | 
				
			||||||
	return &WebsocketHandler{
 | 
						return &WebsocketHandler{
 | 
				
			||||||
		BaseHandler: BaseHandler{App: app},
 | 
							BaseHandler: BaseHandler{App: app, DB: db},
 | 
				
			||||||
 | 
							chatHandler: chatHandler,
 | 
				
			||||||
		wsService:   s,
 | 
							wsService:   s,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -40,9 +44,9 @@ func (h *WebsocketHandler) Client(c *gin.Context) {
 | 
				
			|||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	userId := h.GetInt(c, "user_id", 0)
 | 
						clientId := c.Query("client_id")
 | 
				
			||||||
	clientId := c.Query("client")
 | 
						client := types.NewWsClient(ws, clientId)
 | 
				
			||||||
	client := types.NewWsClient(ws)
 | 
						userId := h.GetLoginUserId(c)
 | 
				
			||||||
	if userId == 0 {
 | 
						if userId == 0 {
 | 
				
			||||||
		_ = client.Send([]byte("Invalid user_id"))
 | 
							_ = client.Send([]byte("Invalid user_id"))
 | 
				
			||||||
		c.Abort()
 | 
							c.Abort()
 | 
				
			||||||
@@ -63,6 +67,8 @@ func (h *WebsocketHandler) Client(c *gin.Context) {
 | 
				
			|||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
				logger.Debugf("close connection: %s", client.Conn.RemoteAddr())
 | 
									logger.Debugf("close connection: %s", client.Conn.RemoteAddr())
 | 
				
			||||||
				client.Close()
 | 
									client.Close()
 | 
				
			||||||
 | 
									h.wsService.Clients.Delete(clientId)
 | 
				
			||||||
 | 
									break
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			var message types.InputMessage
 | 
								var message types.InputMessage
 | 
				
			||||||
@@ -72,12 +78,66 @@ func (h *WebsocketHandler) Client(c *gin.Context) {
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			logger.Infof("Receive a message:%+v", message)
 | 
								logger.Infof("Receive a message:%+v", message)
 | 
				
			||||||
			if message.Type == types.WsMsgTypePing {
 | 
								if message.Type == types.MsgTypePing {
 | 
				
			||||||
				_ = client.Send([]byte(`{"type":"pong"}`))
 | 
									utils.SendChannelMsg(client, types.ChPing, "pong")
 | 
				
			||||||
 | 
									continue
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			switch message.Channel {
 | 
								// 当前只处理聊天消息,其他消息全部丢弃
 | 
				
			||||||
 | 
								var chatMessage types.ChatMessage
 | 
				
			||||||
 | 
								err = utils.JsonDecode(utils.JsonEncode(message.Body), &chatMessage)
 | 
				
			||||||
 | 
								if err != nil || message.Channel != types.ChChat {
 | 
				
			||||||
 | 
									logger.Warnf("invalid message body:%+v", message.Body)
 | 
				
			||||||
 | 
									continue
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								var chatRole model.ChatRole
 | 
				
			||||||
 | 
								err = h.DB.First(&chatRole, chatMessage.RoleId).Error
 | 
				
			||||||
 | 
								if err != nil || !chatRole.Enable {
 | 
				
			||||||
 | 
									utils.SendAndFlush(client, "当前聊天角色不存在或者未启用,请更换角色之后再发起对话!!!")
 | 
				
			||||||
 | 
									continue
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								// if the role bind a model_id, use role's bind model_id
 | 
				
			||||||
 | 
								if chatRole.ModelId > 0 {
 | 
				
			||||||
 | 
									chatMessage.RoleId = chatRole.ModelId
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								// get model info
 | 
				
			||||||
 | 
								var chatModel model.ChatModel
 | 
				
			||||||
 | 
								err = h.DB.Where("id", chatMessage.ModelId).First(&chatModel).Error
 | 
				
			||||||
 | 
								if err != nil || chatModel.Enabled == false {
 | 
				
			||||||
 | 
									utils.SendAndFlush(client, "当前AI模型暂未启用,请更换模型后再发起对话!!!")
 | 
				
			||||||
 | 
									continue
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								session := &types.ChatSession{
 | 
				
			||||||
 | 
									ClientIP: c.ClientIP(),
 | 
				
			||||||
 | 
									UserId:   userId,
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// use old chat data override the chat model and role ID
 | 
				
			||||||
 | 
								var chat model.ChatItem
 | 
				
			||||||
 | 
								h.DB.Where("chat_id", chatMessage.ChatId).First(&chat)
 | 
				
			||||||
 | 
								if chat.Id > 0 {
 | 
				
			||||||
 | 
									chatModel.Id = chat.ModelId
 | 
				
			||||||
 | 
									chatMessage.RoleId = int(chat.RoleId)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								session.ChatId = chatMessage.ChatId
 | 
				
			||||||
 | 
								session.Tools = chatMessage.Tools
 | 
				
			||||||
 | 
								session.Stream = chatMessage.Stream
 | 
				
			||||||
 | 
								// 复制模型数据
 | 
				
			||||||
 | 
								err = utils.CopyObject(chatModel, &session.Model)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									logger.Error(err, chatModel)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								ctx, cancel := context.WithCancel(context.Background())
 | 
				
			||||||
 | 
								h.chatHandler.ReqCancelFunc.Put(clientId, cancel)
 | 
				
			||||||
 | 
								err = h.chatHandler.sendMessage(ctx, session, chatRole, chatMessage.Content, client)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									logger.Error(err)
 | 
				
			||||||
 | 
									utils.SendAndFlush(client, err.Error())
 | 
				
			||||||
 | 
								} else {
 | 
				
			||||||
 | 
									utils.SendMsg(client, types.ReplyMessage{Channel: types.ChChat, Type: types.MsgTypeEnd})
 | 
				
			||||||
 | 
									logger.Infof("回答完毕: %v", message.Body)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										19
									
								
								api/main.go
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								api/main.go
									
									
									
									
									
								
							@@ -14,7 +14,6 @@ import (
 | 
				
			|||||||
	"geekai/core/types"
 | 
						"geekai/core/types"
 | 
				
			||||||
	"geekai/handler"
 | 
						"geekai/handler"
 | 
				
			||||||
	"geekai/handler/admin"
 | 
						"geekai/handler/admin"
 | 
				
			||||||
	"geekai/handler/chatimpl"
 | 
					 | 
				
			||||||
	logger2 "geekai/logger"
 | 
						logger2 "geekai/logger"
 | 
				
			||||||
	"geekai/service"
 | 
						"geekai/service"
 | 
				
			||||||
	"geekai/service/dalle"
 | 
						"geekai/service/dalle"
 | 
				
			||||||
@@ -128,7 +127,7 @@ func main() {
 | 
				
			|||||||
		// 创建控制器
 | 
							// 创建控制器
 | 
				
			||||||
		fx.Provide(handler.NewChatRoleHandler),
 | 
							fx.Provide(handler.NewChatRoleHandler),
 | 
				
			||||||
		fx.Provide(handler.NewUserHandler),
 | 
							fx.Provide(handler.NewUserHandler),
 | 
				
			||||||
		fx.Provide(chatimpl.NewChatHandler),
 | 
							fx.Provide(handler.NewChatHandler),
 | 
				
			||||||
		fx.Provide(handler.NewNetHandler),
 | 
							fx.Provide(handler.NewNetHandler),
 | 
				
			||||||
		fx.Provide(handler.NewSmsHandler),
 | 
							fx.Provide(handler.NewSmsHandler),
 | 
				
			||||||
		fx.Provide(handler.NewRedeemHandler),
 | 
							fx.Provide(handler.NewRedeemHandler),
 | 
				
			||||||
@@ -246,9 +245,8 @@ func main() {
 | 
				
			|||||||
			group.GET("clogin", h.CLogin)
 | 
								group.GET("clogin", h.CLogin)
 | 
				
			||||||
			group.GET("clogin/callback", h.CLoginCallback)
 | 
								group.GET("clogin/callback", h.CLoginCallback)
 | 
				
			||||||
		}),
 | 
							}),
 | 
				
			||||||
		fx.Invoke(func(s *core.AppServer, h *chatimpl.ChatHandler) {
 | 
							fx.Invoke(func(s *core.AppServer, h *handler.ChatHandler) {
 | 
				
			||||||
			group := s.Engine.Group("/api/chat/")
 | 
								group := s.Engine.Group("/api/chat/")
 | 
				
			||||||
			group.Any("new", h.ChatHandle)
 | 
					 | 
				
			||||||
			group.GET("list", h.List)
 | 
								group.GET("list", h.List)
 | 
				
			||||||
			group.GET("detail", h.Detail)
 | 
								group.GET("detail", h.Detail)
 | 
				
			||||||
			group.POST("update", h.Update)
 | 
								group.POST("update", h.Update)
 | 
				
			||||||
@@ -281,7 +279,6 @@ func main() {
 | 
				
			|||||||
		}),
 | 
							}),
 | 
				
			||||||
		fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
 | 
							fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
 | 
				
			||||||
			group := s.Engine.Group("/api/mj/")
 | 
								group := s.Engine.Group("/api/mj/")
 | 
				
			||||||
			group.Any("client", h.Client)
 | 
					 | 
				
			||||||
			group.POST("image", h.Image)
 | 
								group.POST("image", h.Image)
 | 
				
			||||||
			group.POST("upscale", h.Upscale)
 | 
								group.POST("upscale", h.Upscale)
 | 
				
			||||||
			group.POST("variation", h.Variation)
 | 
								group.POST("variation", h.Variation)
 | 
				
			||||||
@@ -292,7 +289,6 @@ func main() {
 | 
				
			|||||||
		}),
 | 
							}),
 | 
				
			||||||
		fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
 | 
							fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
 | 
				
			||||||
			group := s.Engine.Group("/api/sd")
 | 
								group := s.Engine.Group("/api/sd")
 | 
				
			||||||
			group.Any("client", h.Client)
 | 
					 | 
				
			||||||
			group.POST("image", h.Image)
 | 
								group.POST("image", h.Image)
 | 
				
			||||||
			group.GET("jobs", h.JobList)
 | 
								group.GET("jobs", h.JobList)
 | 
				
			||||||
			group.GET("imgWall", h.ImgWall)
 | 
								group.GET("imgWall", h.ImgWall)
 | 
				
			||||||
@@ -467,13 +463,11 @@ func main() {
 | 
				
			|||||||
		}),
 | 
							}),
 | 
				
			||||||
		fx.Provide(handler.NewMarkMapHandler),
 | 
							fx.Provide(handler.NewMarkMapHandler),
 | 
				
			||||||
		fx.Invoke(func(s *core.AppServer, h *handler.MarkMapHandler) {
 | 
							fx.Invoke(func(s *core.AppServer, h *handler.MarkMapHandler) {
 | 
				
			||||||
			group := s.Engine.Group("/api/markMap/")
 | 
								s.Engine.POST("/api/markMap/gen", h.Generate)
 | 
				
			||||||
			group.Any("client", h.Client)
 | 
					 | 
				
			||||||
		}),
 | 
							}),
 | 
				
			||||||
		fx.Provide(handler.NewDallJobHandler),
 | 
							fx.Provide(handler.NewDallJobHandler),
 | 
				
			||||||
		fx.Invoke(func(s *core.AppServer, h *handler.DallJobHandler) {
 | 
							fx.Invoke(func(s *core.AppServer, h *handler.DallJobHandler) {
 | 
				
			||||||
			group := s.Engine.Group("/api/dall")
 | 
								group := s.Engine.Group("/api/dall")
 | 
				
			||||||
			group.Any("client", h.Client)
 | 
					 | 
				
			||||||
			group.POST("image", h.Image)
 | 
								group.POST("image", h.Image)
 | 
				
			||||||
			group.GET("jobs", h.JobList)
 | 
								group.GET("jobs", h.JobList)
 | 
				
			||||||
			group.GET("imgWall", h.ImgWall)
 | 
								group.GET("imgWall", h.ImgWall)
 | 
				
			||||||
@@ -483,7 +477,6 @@ func main() {
 | 
				
			|||||||
		fx.Provide(handler.NewSunoHandler),
 | 
							fx.Provide(handler.NewSunoHandler),
 | 
				
			||||||
		fx.Invoke(func(s *core.AppServer, h *handler.SunoHandler) {
 | 
							fx.Invoke(func(s *core.AppServer, h *handler.SunoHandler) {
 | 
				
			||||||
			group := s.Engine.Group("/api/suno")
 | 
								group := s.Engine.Group("/api/suno")
 | 
				
			||||||
			group.Any("client", h.Client)
 | 
					 | 
				
			||||||
			group.POST("create", h.Create)
 | 
								group.POST("create", h.Create)
 | 
				
			||||||
			group.GET("list", h.List)
 | 
								group.GET("list", h.List)
 | 
				
			||||||
			group.GET("remove", h.Remove)
 | 
								group.GET("remove", h.Remove)
 | 
				
			||||||
@@ -496,7 +489,6 @@ func main() {
 | 
				
			|||||||
		fx.Provide(handler.NewVideoHandler),
 | 
							fx.Provide(handler.NewVideoHandler),
 | 
				
			||||||
		fx.Invoke(func(s *core.AppServer, h *handler.VideoHandler) {
 | 
							fx.Invoke(func(s *core.AppServer, h *handler.VideoHandler) {
 | 
				
			||||||
			group := s.Engine.Group("/api/video")
 | 
								group := s.Engine.Group("/api/video")
 | 
				
			||||||
			group.Any("client", h.Client)
 | 
					 | 
				
			||||||
			group.POST("luma/create", h.LumaCreate)
 | 
								group.POST("luma/create", h.LumaCreate)
 | 
				
			||||||
			group.GET("list", h.List)
 | 
								group.GET("list", h.List)
 | 
				
			||||||
			group.GET("remove", h.Remove)
 | 
								group.GET("remove", h.Remove)
 | 
				
			||||||
@@ -521,6 +513,11 @@ func main() {
 | 
				
			|||||||
			group := s.Engine.Group("/api/test")
 | 
								group := s.Engine.Group("/api/test")
 | 
				
			||||||
			group.Any("sse", h.PostTest, h.SseTest)
 | 
								group.Any("sse", h.PostTest, h.SseTest)
 | 
				
			||||||
		}),
 | 
							}),
 | 
				
			||||||
 | 
							fx.Provide(service.NewWebsocketService),
 | 
				
			||||||
 | 
							fx.Provide(handler.NewWebsocketHandler),
 | 
				
			||||||
 | 
							fx.Invoke(func(s *core.AppServer, h *handler.WebsocketHandler) {
 | 
				
			||||||
 | 
								s.Engine.Any("/api/ws", h.Client)
 | 
				
			||||||
 | 
							}),
 | 
				
			||||||
		fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
 | 
							fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
 | 
				
			||||||
			go func() {
 | 
								go func() {
 | 
				
			||||||
				err := s.Run(db)
 | 
									err := s.Run(db)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -158,7 +158,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
 | 
				
			|||||||
		Quality: task.Quality,
 | 
							Quality: task.Quality,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	logger.Infof("Channel:%s, API KEY:%s, BODY: %+v", apiURL, apiKey.Value, reqBody)
 | 
						logger.Infof("Channel:%s, API KEY:%s, BODY: %+v", apiURL, apiKey.Value, reqBody)
 | 
				
			||||||
	r, err := s.httpClient.R().SetHeader("Content-Type", "application/json").
 | 
						r, err := s.httpClient.R().SetHeader("Body-Type", "application/json").
 | 
				
			||||||
		SetHeader("Authorization", "Bearer "+apiKey.Value).
 | 
							SetHeader("Authorization", "Bearer "+apiKey.Value).
 | 
				
			||||||
		SetBody(reqBody).
 | 
							SetBody(reqBody).
 | 
				
			||||||
		SetErrorResult(&errRes).
 | 
							SetErrorResult(&errRes).
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -89,7 +89,7 @@ func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) {
 | 
				
			|||||||
	fileExt := utils.GetImgExt(file.Filename)
 | 
						fileExt := utils.GetImgExt(file.Filename)
 | 
				
			||||||
	filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
 | 
						filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
 | 
				
			||||||
	info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{
 | 
						info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{
 | 
				
			||||||
		ContentType: file.Header.Get("Content-Type"),
 | 
							ContentType: file.Header.Get("Body-Type"),
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return File{}, fmt.Errorf("error uploading to MinIO: %v", err)
 | 
							return File{}, fmt.Errorf("error uploading to MinIO: %v", err)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,3 +5,9 @@ import "geekai/core/types"
 | 
				
			|||||||
type WebsocketService struct {
 | 
					type WebsocketService struct {
 | 
				
			||||||
	Clients *types.LMap[string, *types.WsClient] // clientId => Client
 | 
						Clients *types.LMap[string, *types.WsClient] // clientId => Client
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func NewWebsocketService() *WebsocketService {
 | 
				
			||||||
 | 
						return &WebsocketService{
 | 
				
			||||||
 | 
							Clients: types.NewLMap[string, *types.WsClient](),
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -19,8 +19,9 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
var logger = logger2.GetLogger()
 | 
					var logger = logger2.GetLogger()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// SendChunkMessage 回复客户片段端消息
 | 
					// SendMsg 回复客户片段端消息
 | 
				
			||||||
func SendChunkMessage(client *types.WsClient, message interface{}) {
 | 
					func SendMsg(client *types.WsClient, message types.ReplyMessage) {
 | 
				
			||||||
 | 
						message.ClientId = client.Id
 | 
				
			||||||
	msg, err := json.Marshal(message)
 | 
						msg, err := json.Marshal(message)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		logger.Errorf("Error for decoding json data: %v", err.Error())
 | 
							logger.Errorf("Error for decoding json data: %v", err.Error())
 | 
				
			||||||
@@ -32,19 +33,23 @@ func SendChunkMessage(client *types.WsClient, message interface{}) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// SendMessage 回复客户端一条完整的消息
 | 
					// SendAndFlush 回复客户端一条完整的消息
 | 
				
			||||||
func SendMessage(ws *types.WsClient, message interface{}) {
 | 
					func SendAndFlush(ws *types.WsClient, message interface{}) {
 | 
				
			||||||
	SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeContent, Content: message})
 | 
						SendMsg(ws, types.ReplyMessage{Channel: types.ChChat, Type: types.MsgTypeText, Body: message})
 | 
				
			||||||
	SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeEnd})
 | 
						SendMsg(ws, types.ReplyMessage{Channel: types.ChChat, Type: types.MsgTypeEnd})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func ReplyContent(ws *types.WsClient, message interface{}) {
 | 
					func SendChunkMsg(ws *types.WsClient, message interface{}) {
 | 
				
			||||||
	SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeContent, Content: message})
 | 
						SendMsg(ws, types.ReplyMessage{Channel: types.ChChat, Type: types.MsgTypeText, Body: message})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ReplyErrorMessage 向客户端发送错误消息
 | 
					// SendErrMsg 向客户端发送错误消息
 | 
				
			||||||
func ReplyErrorMessage(ws *types.WsClient, message interface{}) {
 | 
					func SendErrMsg(ws *types.WsClient, message interface{}) {
 | 
				
			||||||
	SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeErr, Content: message})
 | 
						SendMsg(ws, types.ReplyMessage{Channel: types.ChChat, Type: types.MsgTypeErr, Body: message})
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func SendChannelMsg(ws *types.WsClient, channel types.WsChannel, message interface{}) {
 | 
				
			||||||
 | 
						SendMsg(ws, types.ReplyMessage{Channel: channel, Type: types.MsgTypeText, Body: message})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func DownloadImage(imageURL string, proxy string) ([]byte, error) {
 | 
					func DownloadImage(imageURL string, proxy string) ([]byte, error) {
 | 
				
			||||||
@@ -68,7 +73,9 @@ func DownloadImage(imageURL string, proxy string) ([]byte, error) {
 | 
				
			|||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	defer resp.Body.Close()
 | 
						defer func(Body io.ReadCloser) {
 | 
				
			||||||
 | 
							_ = Body.Close()
 | 
				
			||||||
 | 
						}(resp.Body)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	imageBytes, err := io.ReadAll(resp.Body)
 | 
						imageBytes, err := io.ReadAll(resp.Body)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -65,7 +65,7 @@ func OpenAIRequest(db *gorm.DB, prompt string, modelName string) (string, error)
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
 | 
						apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
 | 
				
			||||||
	logger.Debugf("Sending %s request, API KEY:%s, PROXY: %s, Model: %s", apiKey.ApiURL, apiURL, apiKey.ProxyURL, modelName)
 | 
						logger.Debugf("Sending %s request, API KEY:%s, PROXY: %s, Model: %s", apiKey.ApiURL, apiURL, apiKey.ProxyURL, modelName)
 | 
				
			||||||
	r, err := client.R().SetHeader("Content-Type", "application/json").
 | 
						r, err := client.R().SetHeader("Body-Type", "application/json").
 | 
				
			||||||
		SetHeader("Authorization", "Bearer "+apiKey.Value).
 | 
							SetHeader("Authorization", "Bearer "+apiKey.Value).
 | 
				
			||||||
		SetBody(types.ApiRequest{
 | 
							SetBody(types.ApiRequest{
 | 
				
			||||||
			Model:       modelName,
 | 
								Model:       modelName,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,10 +6,13 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
<script setup>
 | 
					<script setup>
 | 
				
			||||||
import {ElConfigProvider} from 'element-plus';
 | 
					import {ElConfigProvider} from 'element-plus';
 | 
				
			||||||
import {onMounted} from "vue";
 | 
					import {onMounted, ref} from "vue";
 | 
				
			||||||
import {getSystemInfo} from "@/store/cache";
 | 
					import {checkSession, getClientId, getSystemInfo} from "@/store/cache";
 | 
				
			||||||
import {isChrome, isMobile} from "@/utils/libs";
 | 
					import {isChrome, isMobile} from "@/utils/libs";
 | 
				
			||||||
import {showMessageInfo} from "@/utils/dialog";
 | 
					import {showMessageInfo} from "@/utils/dialog";
 | 
				
			||||||
 | 
					import {useSharedStore} from "@/store/sharedata";
 | 
				
			||||||
 | 
					import {getUserToken} from "@/store/session";
 | 
				
			||||||
 | 
					import {clear} from "core-js/internals/task";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const debounce = (fn, delay) => {
 | 
					const debounce = (fn, delay) => {
 | 
				
			||||||
  let timer
 | 
					  let timer
 | 
				
			||||||
@@ -32,6 +35,7 @@ window.ResizeObserver = class ResizeObserver extends _ResizeObserver {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
onMounted(() => {
 | 
					onMounted(() => {
 | 
				
			||||||
 | 
					  // 获取系统参数
 | 
				
			||||||
  getSystemInfo().then((res) => {
 | 
					  getSystemInfo().then((res) => {
 | 
				
			||||||
    const link = document.createElement('link')
 | 
					    const link = document.createElement('link')
 | 
				
			||||||
    link.rel = 'shortcut icon'
 | 
					    link.rel = 'shortcut icon'
 | 
				
			||||||
@@ -39,9 +43,50 @@ onMounted(() => {
 | 
				
			|||||||
    document.head.appendChild(link)
 | 
					    document.head.appendChild(link)
 | 
				
			||||||
  })
 | 
					  })
 | 
				
			||||||
  if (!isChrome() && !isMobile()) {
 | 
					  if (!isChrome() && !isMobile()) {
 | 
				
			||||||
    showMessageInfo("检测到您使用的浏览器不是 Chrome,可能会导致部分功能无法正常使用,建议使用 Chrome 浏览器。")
 | 
					    showMessageInfo("建议使用 Chrome 浏览器以获得最佳体验。")
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  checkSession().then(() => {
 | 
				
			||||||
 | 
					    connect()
 | 
				
			||||||
 | 
					  }).catch(()=>{})
 | 
				
			||||||
})
 | 
					})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					const store = useSharedStore()
 | 
				
			||||||
 | 
					const handler = ref(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 初始化 websocket 连接
 | 
				
			||||||
 | 
					const connect = () => {
 | 
				
			||||||
 | 
					  let host = process.env.VUE_APP_WS_HOST
 | 
				
			||||||
 | 
					  if (host === '') {
 | 
				
			||||||
 | 
					    if (location.protocol === 'https:') {
 | 
				
			||||||
 | 
					      host = 'wss://' + location.host;
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      host = 'ws://' + location.host;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					  const clientId = getClientId()
 | 
				
			||||||
 | 
					  const _socket = new WebSocket(host + `/api/ws?client_id=${clientId}&token=${getUserToken()}`);
 | 
				
			||||||
 | 
					  _socket.addEventListener('open', () => {
 | 
				
			||||||
 | 
					    console.log('WebSocket 已连接')
 | 
				
			||||||
 | 
					    handler.value = setInterval(() => {
 | 
				
			||||||
 | 
					      _socket.send(JSON.stringify({"type":"ping"}))
 | 
				
			||||||
 | 
					    },5000)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (const key in store.messageHandlers) {
 | 
				
			||||||
 | 
					      console.log(key, store.messageHandlers[key])
 | 
				
			||||||
 | 
					      store.setMessageHandler(store.messageHandlers[key])
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  _socket.addEventListener('close', () => {
 | 
				
			||||||
 | 
					    store.setSocket(null)
 | 
				
			||||||
 | 
					    clearInterval(handler.value)
 | 
				
			||||||
 | 
					    connect()
 | 
				
			||||||
 | 
					  });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  store.setSocket(_socket)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
</script>
 | 
					</script>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,5 +1,6 @@
 | 
				
			|||||||
import {httpGet} from "@/utils/http";
 | 
					import {httpGet} from "@/utils/http";
 | 
				
			||||||
import Storage from "good-storage";
 | 
					import Storage from "good-storage";
 | 
				
			||||||
 | 
					import {randString} from "@/utils/libs";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const userDataKey = "USER_INFO_CACHE_KEY"
 | 
					const userDataKey = "USER_INFO_CACHE_KEY"
 | 
				
			||||||
const adminDataKey = "ADMIN_INFO_CACHE_KEY"
 | 
					const adminDataKey = "ADMIN_INFO_CACHE_KEY"
 | 
				
			||||||
@@ -70,4 +71,14 @@ export function getLicenseInfo() {
 | 
				
			|||||||
            resolve(err)
 | 
					            resolve(err)
 | 
				
			||||||
        })
 | 
					        })
 | 
				
			||||||
    })
 | 
					    })
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					export function getClientId() {
 | 
				
			||||||
 | 
					    let clientId = Storage.get('client_id')
 | 
				
			||||||
 | 
					    if (clientId) {
 | 
				
			||||||
 | 
					        return clientId
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    clientId = randString(42)
 | 
				
			||||||
 | 
					    Storage.set('client_id', clientId)
 | 
				
			||||||
 | 
					    return clientId
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -6,6 +6,8 @@ export const useSharedStore = defineStore('shared', {
 | 
				
			|||||||
        showLoginDialog: false,
 | 
					        showLoginDialog: false,
 | 
				
			||||||
        chatListStyle: Storage.get("chat_list_style","chat"),
 | 
					        chatListStyle: Storage.get("chat_list_style","chat"),
 | 
				
			||||||
        chatStream: Storage.get("chat_stream",true),
 | 
					        chatStream: Storage.get("chat_stream",true),
 | 
				
			||||||
 | 
					        socket: WebSocket,
 | 
				
			||||||
 | 
					        messageHandlers:{},
 | 
				
			||||||
    }),
 | 
					    }),
 | 
				
			||||||
    getters: {},
 | 
					    getters: {},
 | 
				
			||||||
    actions: {
 | 
					    actions: {
 | 
				
			||||||
@@ -19,6 +21,36 @@ export const useSharedStore = defineStore('shared', {
 | 
				
			|||||||
        setChatStream(value) {
 | 
					        setChatStream(value) {
 | 
				
			||||||
            this.chatStream = value;
 | 
					            this.chatStream = value;
 | 
				
			||||||
            Storage.set("chat_stream", value);
 | 
					            Storage.set("chat_stream", value);
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        setSocket(value) {
 | 
				
			||||||
 | 
					            this.socket = value;
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        addMessageHandler(key, callback) {
 | 
				
			||||||
 | 
					            if (!this.messageHandlers[key]) {
 | 
				
			||||||
 | 
					                this.messageHandlers[key] = callback;
 | 
				
			||||||
 | 
					                this.setMessageHandler(callback)
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        },
 | 
				
			||||||
 | 
					        setMessageHandler(callback) {
 | 
				
			||||||
 | 
					            if (this.socket instanceof WebSocket && this.socket.readyState === WebSocket.OPEN) {
 | 
				
			||||||
 | 
					                this.socket.addEventListener('message', (event) => {
 | 
				
			||||||
 | 
					                    try {
 | 
				
			||||||
 | 
					                        if (event.data instanceof Blob) {
 | 
				
			||||||
 | 
					                            const reader = new FileReader();
 | 
				
			||||||
 | 
					                            reader.readAsText(event.data, "UTF-8");
 | 
				
			||||||
 | 
					                            reader.onload = () => {
 | 
				
			||||||
 | 
					                                callback(JSON.parse(String(reader.result)))
 | 
				
			||||||
 | 
					                            }
 | 
				
			||||||
 | 
					                        }
 | 
				
			||||||
 | 
					                    } catch (e) {
 | 
				
			||||||
 | 
					                        console.warn(e)
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                })
 | 
				
			||||||
 | 
					            } else {
 | 
				
			||||||
 | 
					                setTimeout(() => {
 | 
				
			||||||
 | 
					                    this.setMessageHandler(callback)
 | 
				
			||||||
 | 
					                }, 1000)
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
});
 | 
					});
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,7 +6,6 @@
 | 
				
			|||||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
					// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import Storage from "good-storage";
 | 
					import Storage from "good-storage";
 | 
				
			||||||
import {useRouter} from "vue-router";
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
const MOBILE_THEME = process.env.VUE_APP_KEY_PREFIX + "MOBILE_THEME"
 | 
					const MOBILE_THEME = process.env.VUE_APP_KEY_PREFIX + "MOBILE_THEME"
 | 
				
			||||||
const ADMIN_THEME = process.env.VUE_APP_KEY_PREFIX + "ADMIN_THEME"
 | 
					const ADMIN_THEME = process.env.VUE_APP_KEY_PREFIX + "ADMIN_THEME"
 | 
				
			||||||
@@ -71,4 +70,4 @@ export function setRoute(path) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
export function getRoute() {
 | 
					export function getRoute() {
 | 
				
			||||||
    return Storage.get(process.env.VUE_APP_KEY_PREFIX + 'ROUTE_')
 | 
					    return Storage.get(process.env.VUE_APP_KEY_PREFIX + 'ROUTE_')
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -17,7 +17,6 @@ axios.defaults.headers.post['Content-Type'] = 'application/json'
 | 
				
			|||||||
axios.interceptors.request.use(
 | 
					axios.interceptors.request.use(
 | 
				
			||||||
    config => {
 | 
					    config => {
 | 
				
			||||||
        // set token
 | 
					        // set token
 | 
				
			||||||
        config.headers['Chat-Token'] = getSessionId();
 | 
					 | 
				
			||||||
        config.headers['Authorization'] = getUserToken();
 | 
					        config.headers['Authorization'] = getUserToken();
 | 
				
			||||||
        config.headers['Admin-Authorization'] = getAdminToken();
 | 
					        config.headers['Admin-Authorization'] = getAdminToken();
 | 
				
			||||||
        return config
 | 
					        return config
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -213,7 +213,7 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
</template>
 | 
					</template>
 | 
				
			||||||
<script setup>
 | 
					<script setup>
 | 
				
			||||||
import {nextTick, onMounted, onUnmounted, ref, watch} from 'vue'
 | 
					import {nextTick, onMounted, ref, watch} from 'vue'
 | 
				
			||||||
import ChatPrompt from "@/components/ChatPrompt.vue";
 | 
					import ChatPrompt from "@/components/ChatPrompt.vue";
 | 
				
			||||||
import ChatReply from "@/components/ChatReply.vue";
 | 
					import ChatReply from "@/components/ChatReply.vue";
 | 
				
			||||||
import {Delete, Edit, InfoFilled, More, Plus, Promotion, Search, Share, VideoPause} from '@element-plus/icons-vue'
 | 
					import {Delete, Edit, InfoFilled, More, Plus, Promotion, Search, Share, VideoPause} from '@element-plus/icons-vue'
 | 
				
			||||||
@@ -225,11 +225,10 @@ import {
 | 
				
			|||||||
  UUID
 | 
					  UUID
 | 
				
			||||||
} from "@/utils/libs";
 | 
					} from "@/utils/libs";
 | 
				
			||||||
import {ElMessage, ElMessageBox} from "element-plus";
 | 
					import {ElMessage, ElMessageBox} from "element-plus";
 | 
				
			||||||
import {getSessionId, getUserToken} from "@/store/session";
 | 
					 | 
				
			||||||
import {httpGet, httpPost} from "@/utils/http";
 | 
					import {httpGet, httpPost} from "@/utils/http";
 | 
				
			||||||
import {useRouter} from "vue-router";
 | 
					import {useRouter} from "vue-router";
 | 
				
			||||||
import Clipboard from "clipboard";
 | 
					import Clipboard from "clipboard";
 | 
				
			||||||
import {checkSession, getSystemInfo} from "@/store/cache";
 | 
					import {checkSession, getClientId, getSystemInfo} from "@/store/cache";
 | 
				
			||||||
import Welcome from "@/components/Welcome.vue";
 | 
					import Welcome from "@/components/Welcome.vue";
 | 
				
			||||||
import {useSharedStore} from "@/store/sharedata";
 | 
					import {useSharedStore} from "@/store/sharedata";
 | 
				
			||||||
import FileSelect from "@/components/FileSelect.vue";
 | 
					import FileSelect from "@/components/FileSelect.vue";
 | 
				
			||||||
@@ -270,7 +269,6 @@ watch(() => store.chatListStyle, (newValue) => {
 | 
				
			|||||||
});
 | 
					});
 | 
				
			||||||
const tools = ref([])
 | 
					const tools = ref([])
 | 
				
			||||||
const toolSelected = ref([])
 | 
					const toolSelected = ref([])
 | 
				
			||||||
const loadHistory = ref(false)
 | 
					 | 
				
			||||||
const stream = ref(store.chatStream)
 | 
					const stream = ref(store.chatStream)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
watch(() => store.chatStream, (newValue) => {
 | 
					watch(() => store.chatStream, (newValue) => {
 | 
				
			||||||
@@ -337,6 +335,13 @@ httpGet("/api/function/list").then(res => {
 | 
				
			|||||||
  showMessageError("获取工具函数失败:" + e.message)
 | 
					  showMessageError("获取工具函数失败:" + e.message)
 | 
				
			||||||
})
 | 
					})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// 创建 socket 连接
 | 
				
			||||||
 | 
					const prompt = ref('');
 | 
				
			||||||
 | 
					const showStopGenerate = ref(false); // 停止生成
 | 
				
			||||||
 | 
					const lineBuffer = ref(''); // 输出缓冲行
 | 
				
			||||||
 | 
					const canSend = ref(true);
 | 
				
			||||||
 | 
					const isNewMsg = ref(true)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
onMounted(() => {
 | 
					onMounted(() => {
 | 
				
			||||||
  resizeElement();
 | 
					  resizeElement();
 | 
				
			||||||
  initData()
 | 
					  initData()
 | 
				
			||||||
@@ -351,14 +356,73 @@ onMounted(() => {
 | 
				
			|||||||
  })
 | 
					  })
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  window.onresize = () => resizeElement();
 | 
					  window.onresize = () => resizeElement();
 | 
				
			||||||
 | 
					  store.addMessageHandler("chat", (data) => {
 | 
				
			||||||
 | 
					    console.log(data)
 | 
				
			||||||
 | 
					    // 丢去非本频道和本客户端的消息
 | 
				
			||||||
 | 
					    if (data.channel !== 'chat' || data.clientId !== getClientId()) {
 | 
				
			||||||
 | 
					      return
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (data.type === 'error') {
 | 
				
			||||||
 | 
					      ElMessage.error(data.body)
 | 
				
			||||||
 | 
					      return
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const chatRole = getRoleById(roleId.value)
 | 
				
			||||||
 | 
					    if (isNewMsg.value && data.type !== 'end') {
 | 
				
			||||||
 | 
					      const prePrompt = chatData.value[chatData.value.length-1]?.content
 | 
				
			||||||
 | 
					      chatData.value.push({
 | 
				
			||||||
 | 
					        type: "reply",
 | 
				
			||||||
 | 
					        id: randString(32),
 | 
				
			||||||
 | 
					        icon: chatRole['icon'],
 | 
				
			||||||
 | 
					        prompt:prePrompt,
 | 
				
			||||||
 | 
					        content: data.body,
 | 
				
			||||||
 | 
					      });
 | 
				
			||||||
 | 
					      isNewMsg.value = false
 | 
				
			||||||
 | 
					      lineBuffer.value = data.body;
 | 
				
			||||||
 | 
					    } else if (data.type === 'end') { // 消息接收完毕
 | 
				
			||||||
 | 
					      // 追加当前会话到会话列表
 | 
				
			||||||
 | 
					      if (newChatItem.value !== null) {
 | 
				
			||||||
 | 
					        newChatItem.value['title'] = tmpChatTitle.value;
 | 
				
			||||||
 | 
					        newChatItem.value['chat_id'] = chatId.value;
 | 
				
			||||||
 | 
					        chatList.value.unshift(newChatItem.value);
 | 
				
			||||||
 | 
					        newChatItem.value = null; // 只追加一次
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      enableInput()
 | 
				
			||||||
 | 
					      lineBuffer.value = ''; // 清空缓冲
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					      // 获取 token
 | 
				
			||||||
 | 
					      const reply = chatData.value[chatData.value.length - 1]
 | 
				
			||||||
 | 
					      httpPost("/api/chat/tokens", {
 | 
				
			||||||
 | 
					        text: "",
 | 
				
			||||||
 | 
					        model: getModelValue(modelID.value),
 | 
				
			||||||
 | 
					        chat_id: chatId.value,
 | 
				
			||||||
 | 
					      }).then(res => {
 | 
				
			||||||
 | 
					        reply['created_at'] = new Date().getTime();
 | 
				
			||||||
 | 
					        reply['tokens'] = res.data;
 | 
				
			||||||
 | 
					        // 将聊天框的滚动条滑动到最底部
 | 
				
			||||||
 | 
					        nextTick(() => {
 | 
				
			||||||
 | 
					          document.getElementById('chat-box').scrollTo(0, document.getElementById('chat-box').scrollHeight)
 | 
				
			||||||
 | 
					        })
 | 
				
			||||||
 | 
					      }).catch(() => {
 | 
				
			||||||
 | 
					      })
 | 
				
			||||||
 | 
					      isNewMsg.value = true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    } else if (data.type === 'text') {
 | 
				
			||||||
 | 
					      lineBuffer.value += data.body;
 | 
				
			||||||
 | 
					      const reply = chatData.value[chatData.value.length - 1]
 | 
				
			||||||
 | 
					      if (reply) {
 | 
				
			||||||
 | 
					        reply['content'] = lineBuffer.value;
 | 
				
			||||||
 | 
					      }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    // 将聊天框的滚动条滑动到最底部
 | 
				
			||||||
 | 
					    nextTick(() => {
 | 
				
			||||||
 | 
					      document.getElementById('chat-box').scrollTo(0, document.getElementById('chat-box').scrollHeight)
 | 
				
			||||||
 | 
					      localStorage.setItem("chat_id", chatId.value)
 | 
				
			||||||
 | 
					    })
 | 
				
			||||||
 | 
					  })
 | 
				
			||||||
});
 | 
					});
 | 
				
			||||||
 | 
					 | 
				
			||||||
onUnmounted(() => {
 | 
					 | 
				
			||||||
  if (socket.value !== null) {
 | 
					 | 
				
			||||||
    socket.value = null
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
})
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// 初始化数据
 | 
					// 初始化数据
 | 
				
			||||||
const initData = () => {
 | 
					const initData = () => {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -492,9 +556,8 @@ const newChat = () => {
 | 
				
			|||||||
    removing: false,
 | 
					    removing: false,
 | 
				
			||||||
  };
 | 
					  };
 | 
				
			||||||
  showStopGenerate.value = false;
 | 
					  showStopGenerate.value = false;
 | 
				
			||||||
 | 
					  loadChatHistory(chatId.value)
 | 
				
			||||||
  router.push(`/chat/${chatId.value}`)
 | 
					  router.push(`/chat/${chatId.value}`)
 | 
				
			||||||
  loadHistory.value = true
 | 
					 | 
				
			||||||
  connect()
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 切换会话
 | 
					// 切换会话
 | 
				
			||||||
@@ -507,14 +570,12 @@ const loadChat = function (chat) {
 | 
				
			|||||||
  if (chatId.value === chat.chat_id) {
 | 
					  if (chatId.value === chat.chat_id) {
 | 
				
			||||||
    return;
 | 
					    return;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					 | 
				
			||||||
  newChatItem.value = null;
 | 
					  newChatItem.value = null;
 | 
				
			||||||
  roleId.value = chat.role_id;
 | 
					  roleId.value = chat.role_id;
 | 
				
			||||||
  modelID.value = chat.model_id;
 | 
					  modelID.value = chat.model_id;
 | 
				
			||||||
  chatId.value = chat.chat_id;
 | 
					  chatId.value = chat.chat_id;
 | 
				
			||||||
  showStopGenerate.value = false;
 | 
					  showStopGenerate.value = false;
 | 
				
			||||||
  loadHistory.value = true
 | 
					  loadChatHistory(chatId.value)
 | 
				
			||||||
  connect()
 | 
					 | 
				
			||||||
  router.replace(`/chat/${chatId.value}`)
 | 
					  router.replace(`/chat/${chatId.value}`)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -587,118 +648,6 @@ const removeChat = function (chat) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// 创建 socket 连接
 | 
					 | 
				
			||||||
const prompt = ref('');
 | 
					 | 
				
			||||||
const showStopGenerate = ref(false); // 停止生成
 | 
					 | 
				
			||||||
const lineBuffer = ref(''); // 输出缓冲行
 | 
					 | 
				
			||||||
const socket = ref(null);
 | 
					 | 
				
			||||||
const canSend = ref(true);
 | 
					 | 
				
			||||||
const sessionId = ref("")
 | 
					 | 
				
			||||||
const isNewMsg = ref(true)
 | 
					 | 
				
			||||||
const connect = function () {
 | 
					 | 
				
			||||||
  const chatRole = getRoleById(roleId.value);
 | 
					 | 
				
			||||||
  // 初始化 WebSocket 对象
 | 
					 | 
				
			||||||
  sessionId.value = getSessionId();
 | 
					 | 
				
			||||||
  let host = process.env.VUE_APP_WS_HOST
 | 
					 | 
				
			||||||
  if (host === '') {
 | 
					 | 
				
			||||||
    if (location.protocol === 'https:') {
 | 
					 | 
				
			||||||
      host = 'wss://' + location.host;
 | 
					 | 
				
			||||||
    } else {
 | 
					 | 
				
			||||||
      host = 'ws://' + location.host;
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
  }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  loading.value = true
 | 
					 | 
				
			||||||
  const _socket = new WebSocket(host + `/api/chat/new?session_id=${sessionId.value}&role_id=${roleId.value}&chat_id=${chatId.value}&model_id=${modelID.value}&token=${getUserToken()}`);
 | 
					 | 
				
			||||||
  _socket.addEventListener('open', () => {
 | 
					 | 
				
			||||||
    enableInput()
 | 
					 | 
				
			||||||
    if (loadHistory.value) {
 | 
					 | 
				
			||||||
      loadChatHistory(chatId.value)
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    loading.value = false
 | 
					 | 
				
			||||||
  });
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  _socket.addEventListener('message', event => {
 | 
					 | 
				
			||||||
    try {
 | 
					 | 
				
			||||||
      if (event.data instanceof Blob) {
 | 
					 | 
				
			||||||
        const reader = new FileReader();
 | 
					 | 
				
			||||||
        reader.readAsText(event.data, "UTF-8");
 | 
					 | 
				
			||||||
        reader.onload = () => {
 | 
					 | 
				
			||||||
          const data = JSON.parse(String(reader.result));
 | 
					 | 
				
			||||||
          if (data.type === 'error') {
 | 
					 | 
				
			||||||
            ElMessage.error(data.message)
 | 
					 | 
				
			||||||
            return
 | 
					 | 
				
			||||||
          }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
          if (isNewMsg.value && data.type !== 'end') {
 | 
					 | 
				
			||||||
            const prePrompt = chatData.value[chatData.value.length-1]?.content
 | 
					 | 
				
			||||||
            chatData.value.push({
 | 
					 | 
				
			||||||
              type: "reply",
 | 
					 | 
				
			||||||
              id: randString(32),
 | 
					 | 
				
			||||||
              icon: chatRole['icon'],
 | 
					 | 
				
			||||||
              prompt:prePrompt,
 | 
					 | 
				
			||||||
              content: data.content,
 | 
					 | 
				
			||||||
            });
 | 
					 | 
				
			||||||
            isNewMsg.value = false
 | 
					 | 
				
			||||||
            lineBuffer.value = data.content;
 | 
					 | 
				
			||||||
          } else if (data.type === 'end') { // 消息接收完毕
 | 
					 | 
				
			||||||
            // 追加当前会话到会话列表
 | 
					 | 
				
			||||||
            if (newChatItem.value !== null) {
 | 
					 | 
				
			||||||
              newChatItem.value['title'] = tmpChatTitle.value;
 | 
					 | 
				
			||||||
              newChatItem.value['chat_id'] = chatId.value;
 | 
					 | 
				
			||||||
              chatList.value.unshift(newChatItem.value);
 | 
					 | 
				
			||||||
              newChatItem.value = null; // 只追加一次
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            enableInput()
 | 
					 | 
				
			||||||
            lineBuffer.value = ''; // 清空缓冲
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            // 获取 token
 | 
					 | 
				
			||||||
            const reply = chatData.value[chatData.value.length - 1]
 | 
					 | 
				
			||||||
            httpPost("/api/chat/tokens", {
 | 
					 | 
				
			||||||
              text: "",
 | 
					 | 
				
			||||||
              model: getModelValue(modelID.value),
 | 
					 | 
				
			||||||
              chat_id: chatId.value,
 | 
					 | 
				
			||||||
            }).then(res => {
 | 
					 | 
				
			||||||
              reply['created_at'] = new Date().getTime();
 | 
					 | 
				
			||||||
              reply['tokens'] = res.data;
 | 
					 | 
				
			||||||
              // 将聊天框的滚动条滑动到最底部
 | 
					 | 
				
			||||||
              nextTick(() => {
 | 
					 | 
				
			||||||
                document.getElementById('chat-box').scrollTo(0, document.getElementById('chat-box').scrollHeight)
 | 
					 | 
				
			||||||
              })
 | 
					 | 
				
			||||||
              isNewMsg.value = true
 | 
					 | 
				
			||||||
            }).catch(() => {
 | 
					 | 
				
			||||||
            })
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
          } else {
 | 
					 | 
				
			||||||
            lineBuffer.value += data.content;
 | 
					 | 
				
			||||||
            const reply = chatData.value[chatData.value.length - 1]
 | 
					 | 
				
			||||||
            if (reply) {
 | 
					 | 
				
			||||||
              reply['content'] = lineBuffer.value;
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
          }
 | 
					 | 
				
			||||||
          // 将聊天框的滚动条滑动到最底部
 | 
					 | 
				
			||||||
          nextTick(() => {
 | 
					 | 
				
			||||||
            document.getElementById('chat-box').scrollTo(0, document.getElementById('chat-box').scrollHeight)
 | 
					 | 
				
			||||||
            localStorage.setItem("chat_id", chatId.value)
 | 
					 | 
				
			||||||
          })
 | 
					 | 
				
			||||||
        };
 | 
					 | 
				
			||||||
      }
 | 
					 | 
				
			||||||
    } catch (e) {
 | 
					 | 
				
			||||||
      console.warn(e)
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  });
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  _socket.addEventListener('close', () => {
 | 
					 | 
				
			||||||
    disableInput(true)
 | 
					 | 
				
			||||||
    loadHistory.value = false
 | 
					 | 
				
			||||||
    connect()
 | 
					 | 
				
			||||||
  });
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  socket.value = _socket;
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
const disableInput = (force) => {
 | 
					const disableInput = (force) => {
 | 
				
			||||||
  canSend.value = false;
 | 
					  canSend.value = false;
 | 
				
			||||||
  showStopGenerate.value = !force;
 | 
					  showStopGenerate.value = !force;
 | 
				
			||||||
@@ -747,6 +696,11 @@ const sendMessage = function () {
 | 
				
			|||||||
    return;
 | 
					    return;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if (store.socket.readyState !== WebSocket.OPEN) {
 | 
				
			||||||
 | 
					    ElMessage.warning("连接断开,正在重连...");
 | 
				
			||||||
 | 
					    return
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  if (canSend.value === false) {
 | 
					  if (canSend.value === false) {
 | 
				
			||||||
    ElMessage.warning("AI 正在作答中,请稍后...");
 | 
					    ElMessage.warning("AI 正在作答中,请稍后...");
 | 
				
			||||||
    return
 | 
					    return
 | 
				
			||||||
@@ -780,7 +734,18 @@ const sendMessage = function () {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
  showHello.value = false
 | 
					  showHello.value = false
 | 
				
			||||||
  disableInput(false)
 | 
					  disableInput(false)
 | 
				
			||||||
  socket.value.send(JSON.stringify({tools: toolSelected.value, content: content, stream: stream.value}));
 | 
					  store.socket.send(JSON.stringify({
 | 
				
			||||||
 | 
					    channel: 'chat',
 | 
				
			||||||
 | 
					    type:'text',
 | 
				
			||||||
 | 
					    body:{
 | 
				
			||||||
 | 
					      role_id: roleId.value,
 | 
				
			||||||
 | 
					      model_id: modelID.value,
 | 
				
			||||||
 | 
					      chat_id: chatId.value,
 | 
				
			||||||
 | 
					      content: content,
 | 
				
			||||||
 | 
					      tools:toolSelected.value,
 | 
				
			||||||
 | 
					      stream: stream.value
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }));
 | 
				
			||||||
  tmpChatTitle.value = content
 | 
					  tmpChatTitle.value = content
 | 
				
			||||||
  prompt.value = ''
 | 
					  prompt.value = ''
 | 
				
			||||||
  files.value = []
 | 
					  files.value = []
 | 
				
			||||||
@@ -849,7 +814,7 @@ const loadChatHistory = function (chatId) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
const stopGenerate = function () {
 | 
					const stopGenerate = function () {
 | 
				
			||||||
  showStopGenerate.value = false;
 | 
					  showStopGenerate.value = false;
 | 
				
			||||||
  httpGet("/api/chat/stop?session_id=" + sessionId.value).then(() => {
 | 
					  httpGet("/api/chat/stop?session_id=" + getClientId()).then(() => {
 | 
				
			||||||
    enableInput()
 | 
					    enableInput()
 | 
				
			||||||
  })
 | 
					  })
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -865,7 +830,18 @@ const reGenerate = function (prompt) {
 | 
				
			|||||||
    icon: loginUser.value.avatar,
 | 
					    icon: loginUser.value.avatar,
 | 
				
			||||||
    content: text
 | 
					    content: text
 | 
				
			||||||
  });
 | 
					  });
 | 
				
			||||||
  socket.value.send(JSON.stringify({tools: toolSelected.value, content: text, stream: stream.value}));
 | 
					  store.socket.send(JSON.stringify({
 | 
				
			||||||
 | 
					    channel: 'chat',
 | 
				
			||||||
 | 
					    type:'text',
 | 
				
			||||||
 | 
					    body:{
 | 
				
			||||||
 | 
					      role_id: roleId.value,
 | 
				
			||||||
 | 
					      model_id: modelID.value,
 | 
				
			||||||
 | 
					      chat_id: chatId.value,
 | 
				
			||||||
 | 
					      content: text,
 | 
				
			||||||
 | 
					      tools:toolSelected.value,
 | 
				
			||||||
 | 
					      stream: stream.value
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }));
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
const chatName = ref('')
 | 
					const chatName = ref('')
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -31,6 +31,7 @@
 | 
				
			|||||||
            <el-button :color="theme.btnBgColor" :style="{color: theme.btnTextColor}" @click="router.push('/login')" class="shadow" round>登录</el-button>
 | 
					            <el-button :color="theme.btnBgColor" :style="{color: theme.btnTextColor}" @click="router.push('/login')" class="shadow" round>登录</el-button>
 | 
				
			||||||
            <el-button :color="theme.btnBgColor" :style="{color: theme.btnTextColor}" @click="router.push('/register')" class="shadow" round>注册</el-button>
 | 
					            <el-button :color="theme.btnBgColor" :style="{color: theme.btnTextColor}" @click="router.push('/register')" class="shadow" round>注册</el-button>
 | 
				
			||||||
          </span>
 | 
					          </span>
 | 
				
			||||||
 | 
					          <el-button :color="theme.btnBgColor" :style="{color: theme.btnTextColor}" @click="router.push('/test')" class="shadow" round>测试</el-button>
 | 
				
			||||||
        </div>
 | 
					        </div>
 | 
				
			||||||
      </el-menu>
 | 
					      </el-menu>
 | 
				
			||||||
    </div>
 | 
					    </div>
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,20 +5,12 @@
 | 
				
			|||||||
</template>
 | 
					</template>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
<script setup>
 | 
					<script setup>
 | 
				
			||||||
import {ref, onMounted, onUpdated} from 'vue';
 | 
					import {onMounted, ref} from "vue";
 | 
				
			||||||
import {Markmap} from 'markmap-view';
 | 
					 | 
				
			||||||
import {loadJS, loadCSS} from 'markmap-common';
 | 
					 | 
				
			||||||
import {Transformer} from 'markmap-lib';
 | 
					 | 
				
			||||||
import {httpPost} from "@/utils/http";
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
const data=ref("")
 | 
					const data = ref('abc')
 | 
				
			||||||
httpPost("/api/test/sse",{
 | 
					
 | 
				
			||||||
  "message":"你是什么模型",
 | 
					onMounted(() => {
 | 
				
			||||||
  "user_id":123
 | 
					  
 | 
				
			||||||
}).then(res=>{
 | 
					 | 
				
			||||||
  // const source = new EventSource("http://localhost:5678/api/test/sse");
 | 
					 | 
				
			||||||
  // source.onmessage = function(event) {
 | 
					 | 
				
			||||||
  //   console.log(event.data)
 | 
					 | 
				
			||||||
  // };
 | 
					 | 
				
			||||||
})
 | 
					})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
</script>
 | 
					</script>
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user