diff --git a/api/core/app_server.go b/api/core/app_server.go index 20e3c0b1..95cf680d 100644 --- a/api/core/app_server.go +++ b/api/core/app_server.go @@ -51,9 +51,9 @@ func NewServer(appConfig *types.AppConfig) *AppServer { func (s *AppServer) Init(debug bool, client *redis.Client) { if debug { // 调试模式允许跨域请求 API s.Debug = debug + s.Engine.Use(corsMiddleware()) logger.Info("Enabled debug mode") } - s.Engine.Use(corsMiddleware()) s.Engine.Use(staticResourceMiddleware()) s.Engine.Use(authorizeMiddleware(s, client)) s.Engine.Use(parameterHandlerMiddleware()) @@ -101,9 +101,9 @@ func corsMiddleware() gin.HandlerFunc { c.Header("Access-Control-Allow-Origin", origin) c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE") //允许跨域设置可以返回其他子段,可以自定义字段 - c.Header("Access-Control-Allow-Headers", "Authorization, Content-Length, Content-Type, Chat-Token, Admin-Authorization") + c.Header("Access-Control-Allow-Headers", "Authorization, Body-Length, Body-Type, Admin-Authorization,content-type") // 允许浏览器(客户端)可以解析的头部 (重要) - c.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers") + c.Header("Access-Control-Expose-Headers", "Body-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers") //设置缓存时间 c.Header("Access-Control-Max-Age", "172800") //允许客户端传递校验信息比如 cookie (重要) @@ -131,7 +131,7 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc { isAdminApi := strings.Contains(c.Request.URL.Path, "/api/admin/") if isAdminApi { // 后台管理 API tokenString = c.GetHeader(types.AdminAuthHeader) - } else if c.Request.URL.Path == "/api/chat/new" { + } else if c.Request.URL.Path == "/api/ws" { // Websocket 连接 tokenString = c.Query("token") } else { tokenString = c.GetHeader(types.UserAuthHeader) @@ -209,23 +209,18 @@ func needLogin(c *gin.Context) bool { c.Request.URL.Path == "/api/app/list/user" || c.Request.URL.Path == "/api/model/list" || c.Request.URL.Path == "/api/mj/imgWall" || - c.Request.URL.Path == "/api/mj/client" || c.Request.URL.Path == "/api/mj/notify" || c.Request.URL.Path == "/api/invite/hits" || c.Request.URL.Path == "/api/sd/imgWall" || - c.Request.URL.Path == "/api/sd/client" || c.Request.URL.Path == "/api/dall/imgWall" || - c.Request.URL.Path == "/api/dall/client" || c.Request.URL.Path == "/api/product/list" || c.Request.URL.Path == "/api/menu/list" || c.Request.URL.Path == "/api/markMap/client" || c.Request.URL.Path == "/api/payment/doPay" || c.Request.URL.Path == "/api/payment/payWays" || - c.Request.URL.Path == "/api/suno/client" || c.Request.URL.Path == "/api/suno/detail" || c.Request.URL.Path == "/api/suno/play" || c.Request.URL.Path == "/api/download" || - c.Request.URL.Path == "/api/video/client" || strings.HasPrefix(c.Request.URL.Path, "/api/test") || strings.HasPrefix(c.Request.URL.Path, "/api/payment/notify/") || strings.HasPrefix(c.Request.URL.Path, "/api/user/clogin") || diff --git a/api/core/types/chat.go b/api/core/types/chat.go index 63f18622..56e2639f 100644 --- a/api/core/types/chat.go +++ b/api/core/types/chat.go @@ -52,14 +52,13 @@ type Delta struct { // ChatSession 聊天会话对象 type ChatSession struct { - SessionId string `json:"session_id"` - UserId uint `json:"user_id"` - ClientIP string `json:"client_ip"` // 客户端 IP - ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段 - Model ChatModel `json:"model"` // GPT 模型 - Start int64 `json:"start"` // 开始请求时间戳 - Tools []int `json:"tools"` // 工具函数列表 - Stream bool `json:"stream"` // 是否采用流式输出 + UserId uint `json:"user_id"` + ClientIP string `json:"client_ip"` // 客户端 IP + ChatId string `json:"chat_id"` // 客户端聊天会话 ID, 多会话模式专用字段 + Model ChatModel `json:"model"` // GPT 模型 + Start int64 `json:"start"` // 开始请求时间戳 + Tools []int `json:"tools"` // 工具函数列表 + Stream bool `json:"stream"` // 是否采用流式输出 } type ChatModel struct { diff --git a/api/core/types/client.go b/api/core/types/client.go index 5f65ac59..bba71165 100644 --- a/api/core/types/client.go +++ b/api/core/types/client.go @@ -17,15 +17,17 @@ var ErrConClosed = errors.New("connection Closed") // WsClient websocket client type WsClient struct { + Id string Conn *websocket.Conn lock sync.Mutex mt int Closed bool } -func NewWsClient(conn *websocket.Conn) *WsClient { +func NewWsClient(conn *websocket.Conn, id string) *WsClient { return &WsClient{ Conn: conn, + Id: id, lock: sync.Mutex{}, mt: 2, // fixed bug for 'Invalid UTF-8 in text frame' Closed: false, diff --git a/api/core/types/web.go b/api/core/types/web.go index 514b9b95..e2b5e636 100644 --- a/api/core/types/web.go +++ b/api/core/types/web.go @@ -19,35 +19,44 @@ type BizVo struct { // ReplyMessage 对话回复消息结构 type ReplyMessage struct { - Channel WsChannel `json:"channel"` // 消息频道,目前只有 chat - Type WsMsgType `json:"type"` // 消息类别 - Content interface{} `json:"content"` + Channel WsChannel `json:"channel"` // 消息频道,目前只有 chat + ClientId string `json:"clientId"` // 客户端ID + Type WsMsgType `json:"type"` // 消息类别 + Body interface{} `json:"body"` } type WsMsgType string type WsChannel string const ( - WsMsgTypeContent = WsMsgType("content") // 输出内容 - WsMsgTypeEnd = WsMsgType("end") - WsMsgTypeErr = WsMsgType("error") - WsMsgTypePing = WsMsgType("ping") // 心跳消息 + MsgTypeText = WsMsgType("text") // 输出内容 + MsgTypeEnd = WsMsgType("end") + MsgTypeErr = WsMsgType("error") + MsgTypePing = WsMsgType("ping") // 心跳消息 - WsChat = WsChannel("chat") - WsMj = WsChannel("mj") - WsSd = WsChannel("sd") - WsDall = WsChannel("dall") - WsSuno = WsChannel("suno") - WsLuma = WsChannel("luma") + ChPing = WsChannel("ping") + ChChat = WsChannel("chat") + ChMj = WsChannel("mj") + ChSd = WsChannel("sd") + ChDall = WsChannel("dall") + ChSuno = WsChannel("suno") + ChLuma = WsChannel("luma") ) // InputMessage 对话输入消息结构 type InputMessage struct { - Channel WsChannel `json:"channel"` // 消息频道 - Type WsMsgType `json:"type"` // 消息类别 - Content string `json:"content"` - Tools []int `json:"tools"` // 允许调用工具列表 - Stream bool `json:"stream"` // 是否采用流式输出 + Channel WsChannel `json:"channel"` // 消息频道 + Type WsMsgType `json:"type"` // 消息类别 + Body interface{} `json:"body"` +} + +type ChatMessage struct { + Tools []int `json:"tools,omitempty"` // 允许调用工具列表 + Stream bool `json:"stream,omitempty"` // 是否采用流式输出 + RoleId int `json:"role_id"` + ModelId int `json:"model_id"` + ChatId string `json:"chat_id"` + Content string `json:"content"` } type BizCode int diff --git a/api/handler/chatimpl/chat_handler.go b/api/handler/chat_handler.go similarity index 83% rename from api/handler/chatimpl/chat_handler.go rename to api/handler/chat_handler.go index 863ab15f..4edffa78 100644 --- a/api/handler/chatimpl/chat_handler.go +++ b/api/handler/chat_handler.go @@ -1,4 +1,4 @@ -package chatimpl +package handler // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * Copyright 2023 The Geek-AI Authors. All rights reserved. @@ -15,8 +15,6 @@ import ( "fmt" "geekai/core" "geekai/core/types" - "geekai/handler" - logger2 "geekai/logger" "geekai/service" "geekai/service/oss" "geekai/store/model" @@ -33,14 +31,11 @@ import ( "github.com/gin-gonic/gin" "github.com/go-redis/redis/v8" - "github.com/gorilla/websocket" "gorm.io/gorm" ) -var logger = logger2.GetLogger() - type ChatHandler struct { - handler.BaseHandler + BaseHandler redis *redis.Client uploadManager *oss.UploaderManager licenseService *service.LicenseService @@ -51,7 +46,7 @@ type ChatHandler struct { func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager, licenseService *service.LicenseService, userService *service.UserService) *ChatHandler { return &ChatHandler{ - BaseHandler: handler.BaseHandler{App: app, DB: db}, + BaseHandler: BaseHandler{App: app, DB: db}, redis: redis, uploadManager: manager, licenseService: licenseService, @@ -61,106 +56,6 @@ func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manag } } -// ChatHandle 处理聊天 WebSocket 请求 -func (h *ChatHandler) ChatHandle(c *gin.Context) { - ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) - if err != nil { - logger.Error(err) - return - } - - sessionId := c.Query("session_id") - roleId := h.GetInt(c, "role_id", 0) - chatId := c.Query("chat_id") - modelId := h.GetInt(c, "model_id", 0) - - client := types.NewWsClient(ws) - var chatRole model.ChatRole - res := h.DB.First(&chatRole, roleId) - if res.Error != nil || !chatRole.Enable { - utils.ReplyErrorMessage(client, "当前聊天角色不存在或者未启用,对话已关闭!!!") - c.Abort() - return - } - // if the role bind a model_id, use role's bind model_id - if chatRole.ModelId > 0 { - modelId = chatRole.ModelId - } - // get model info - var chatModel model.ChatModel - res = h.DB.First(&chatModel, modelId) - if res.Error != nil || chatModel.Enabled == false { - utils.ReplyErrorMessage(client, "当前AI模型暂未启用,对话已关闭!!!") - c.Abort() - return - } - - session := &types.ChatSession{ - SessionId: sessionId, - ClientIP: c.ClientIP(), - UserId: h.GetLoginUserId(c), - } - - // use old chat data override the chat model and role ID - var chat model.ChatItem - res = h.DB.Where("chat_id = ?", chatId).First(&chat) - if res.Error == nil { - chatModel.Id = chat.ModelId - roleId = int(chat.RoleId) - } - - session.ChatId = chatId - session.Model = types.ChatModel{ - Id: chatModel.Id, - Name: chatModel.Name, - Value: chatModel.Value, - Power: chatModel.Power, - MaxTokens: chatModel.MaxTokens, - MaxContext: chatModel.MaxContext, - Temperature: chatModel.Temperature, - KeyId: chatModel.KeyId} - logger.Infof("New websocket connected, IP: %s", c.ClientIP()) - - go func() { - for { - _, msg, err := client.Receive() - if err != nil { - logger.Debugf("close connection: %s", client.Conn.RemoteAddr()) - client.Close() - cancelFunc := h.ReqCancelFunc.Get(sessionId) - if cancelFunc != nil { - cancelFunc() - h.ReqCancelFunc.Delete(sessionId) - } - return - } - - var message types.InputMessage - err = utils.JsonDecode(string(msg), &message) - if err != nil { - continue - } - - logger.Infof("Receive a message:%+v", message) - - session.Tools = message.Tools - session.Stream = message.Stream - ctx, cancel := context.WithCancel(context.Background()) - h.ReqCancelFunc.Put(sessionId, cancel) - // 回复消息 - err = h.sendMessage(ctx, session, chatRole, utils.InterfaceToString(message.Content), client) - if err != nil { - logger.Error(err) - utils.SendMessage(client, err.Error()) - } else { - utils.SendChunkMessage(client, types.ReplyMessage{Type: types.WsMsgTypeEnd}) - logger.Infof("回答完毕: %v", message.Content) - } - - } - }() -} - func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSession, role model.ChatRole, prompt string, ws *types.WsClient) error { if !h.App.Debug { defer func() { @@ -206,7 +101,7 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio } // 兼容 GPT-O1 模型 if strings.HasPrefix(session.Model.Value, "o1-") { - utils.ReplyContent(ws, "AI 正在思考...\n") + utils.SendChunkMsg(ws, "AI 正在思考...\n") req.Stream = false session.Start = time.Now().Unix() } else { diff --git a/api/handler/chatimpl/chat_item_handler.go b/api/handler/chat_item_handler.go similarity index 82% rename from api/handler/chatimpl/chat_item_handler.go rename to api/handler/chat_item_handler.go index bce39249..f08be3fe 100644 --- a/api/handler/chatimpl/chat_item_handler.go +++ b/api/handler/chat_item_handler.go @@ -1,4 +1,4 @@ -package chatimpl +package handler // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * Copyright 2023 The Geek-AI Authors. All rights reserved. @@ -28,31 +28,40 @@ func (h *ChatHandler) List(c *gin.Context) { userId := h.GetLoginUserId(c) var items = make([]vo.ChatItem, 0) var chats []model.ChatItem - res := h.DB.Where("user_id = ?", userId).Order("id DESC").Find(&chats) - if res.Error == nil { - var roleIds = make([]uint, 0) - for _, chat := range chats { - roleIds = append(roleIds, chat.RoleId) - } - var roles []model.ChatRole - res = h.DB.Find(&roles, roleIds) - if res.Error == nil { - roleMap := make(map[uint]model.ChatRole) - for _, role := range roles { - roleMap[role.Id] = role - } + h.DB.Where("user_id", userId).Order("id DESC").Find(&chats) + if len(chats) == 0 { + resp.SUCCESS(c, items) + return + } - for _, chat := range chats { - var item vo.ChatItem - err := utils.CopyObject(chat, &item) - if err == nil { - item.Id = chat.Id - item.Icon = roleMap[chat.RoleId].Icon - items = append(items, item) - } - } - } + var roleIds = make([]uint, 0) + var modelValues = make([]string, 0) + for _, chat := range chats { + roleIds = append(roleIds, chat.RoleId) + modelValues = append(modelValues, chat.Model) + } + var roles []model.ChatRole + var models []model.ChatModel + roleMap := make(map[uint]model.ChatRole) + modelMap := make(map[string]model.ChatModel) + h.DB.Where("id IN ?", roleIds).Find(&roles) + h.DB.Where("value IN ?", modelValues).Find(&models) + for _, role := range roles { + roleMap[role.Id] = role + } + for _, m := range models { + modelMap[m.Value] = m + } + for _, chat := range chats { + var item vo.ChatItem + err := utils.CopyObject(chat, &item) + if err == nil { + item.Id = chat.Id + item.Icon = roleMap[chat.RoleId].Icon + item.ModelId = modelMap[chat.Model].Id + items = append(items, item) + } } resp.SUCCESS(c, items) } diff --git a/api/handler/dalle_handler.go b/api/handler/dalle_handler.go index 80b993ee..816086f6 100644 --- a/api/handler/dalle_handler.go +++ b/api/handler/dalle_handler.go @@ -20,9 +20,7 @@ import ( "geekai/utils/resp" "github.com/gin-gonic/gin" "github.com/go-redis/redis/v8" - "github.com/gorilla/websocket" "gorm.io/gorm" - "net/http" ) type DallJobHandler struct { @@ -45,49 +43,6 @@ func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, } } -// Client WebSocket 客户端,用于通知任务状态变更 -func (h *DallJobHandler) Client(c *gin.Context) { - ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) - if err != nil { - logger.Error(err) - c.Abort() - return - } - - userId := h.GetInt(c, "user_id", 0) - if userId == 0 { - logger.Info("Invalid user ID") - c.Abort() - return - } - - client := types.NewWsClient(ws) - h.dallService.Clients.Put(uint(userId), client) - logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) - go func() { - for { - _, msg, err := client.Receive() - if err != nil { - client.Close() - h.dallService.Clients.Delete(uint(userId)) - return - } - - var message types.ReplyMessage - err = utils.JsonDecode(string(msg), &message) - if err != nil { - continue - } - - // 心跳消息 - if message.Type == "heartbeat" { - logger.Debug("收到 DallE 心跳消息:", message.Content) - continue - } - } - }() -} - func (h *DallJobHandler) preCheck(c *gin.Context) bool { user, err := h.GetLoginUser(c) if err != nil { diff --git a/api/handler/markmap_handler.go b/api/handler/markmap_handler.go index 0cfa1743..9337d996 100644 --- a/api/handler/markmap_handler.go +++ b/api/handler/markmap_handler.go @@ -19,7 +19,6 @@ import ( "geekai/store/model" "geekai/utils" "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" "gorm.io/gorm" "io" "net/http" @@ -43,55 +42,9 @@ func NewMarkMapHandler(app *core.AppServer, db *gorm.DB, userService *service.Us } } -func (h *MarkMapHandler) Client(c *gin.Context) { - ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) - if err != nil { - logger.Error(err) - return - } +// Generate 生成思维导图 +func (h *MarkMapHandler) Generate(c *gin.Context) { - modelId := h.GetInt(c, "model_id", 0) - userId := h.GetInt(c, "user_id", 0) - - client := types.NewWsClient(ws) - h.clients.Put(userId, client) - go func() { - for { - _, msg, err := client.Receive() - if err != nil { - client.Close() - h.clients.Delete(userId) - return - } - - var message types.ReplyMessage - err = utils.JsonDecode(string(msg), &message) - if err != nil { - continue - } - - // 心跳消息 - if message.Type == "heartbeat" { - logger.Debug("收到 MarkMap 心跳消息:", message.Content) - continue - } - // change model - if message.Type == "model_id" { - modelId = utils.IntValue(utils.InterfaceToString(message.Content), 0) - continue - } - - logger.Info("Receive a message: ", message.Content) - err = h.sendMessage(client, utils.InterfaceToString(message.Content), modelId, userId) - if err != nil { - logger.Error(err) - utils.ReplyErrorMessage(client, err.Error()) - } else { - utils.SendMessage(client, types.ReplyMessage{Type: types.WsMsgTypeEnd}) - } - - } - }() } func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, modelId int, userId int) error { @@ -170,13 +123,13 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode break } - utils.SendChunkMessage(client, types.ReplyMessage{ - Type: types.WsMsgTypeContent, - Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content), + utils.SendMsg(client, types.ReplyMessage{ + Type: types.MsgTypeText, + Body: utils.InterfaceToString(responseBody.Choices[0].Delta.Content), }) } // end for - utils.SendChunkMessage(client, types.ReplyMessage{Type: types.WsMsgTypeEnd}) + utils.SendMsg(client, types.ReplyMessage{Type: types.MsgTypeEnd}) } else { body, _ := io.ReadAll(response.Body) diff --git a/api/handler/mj_handler.go b/api/handler/mj_handler.go index 34996c81..c758032b 100644 --- a/api/handler/mj_handler.go +++ b/api/handler/mj_handler.go @@ -19,12 +19,10 @@ import ( "geekai/store/vo" "geekai/utils" "geekai/utils/resp" - "net/http" "strings" "time" "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" "gorm.io/gorm" ) @@ -65,27 +63,6 @@ func (h *MidJourneyHandler) preCheck(c *gin.Context) bool { } -// Client WebSocket 客户端,用于通知任务状态变更 -func (h *MidJourneyHandler) Client(c *gin.Context) { - ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) - if err != nil { - logger.Error(err) - c.Abort() - return - } - - userId := h.GetInt(c, "user_id", 0) - if userId == 0 { - logger.Info("Invalid user ID") - c.Abort() - return - } - - client := types.NewWsClient(ws) - h.mjService.Clients.Put(uint(userId), client) - logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) -} - // Image 创建一个绘画任务 func (h *MidJourneyHandler) Image(c *gin.Context) { var data struct { diff --git a/api/handler/chatimpl/openai_handler.go b/api/handler/openai_handler.go similarity index 87% rename from api/handler/chatimpl/openai_handler.go rename to api/handler/openai_handler.go index a66dcd30..302b3138 100644 --- a/api/handler/chatimpl/openai_handler.go +++ b/api/handler/openai_handler.go @@ -1,4 +1,4 @@ -package chatimpl +package handler // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ // * Copyright 2023 The Geek-AI Authors. All rights reserved. @@ -108,7 +108,7 @@ func (h *ChatHandler) sendOpenAiMessage( } if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 { - utils.SendMessage(ws, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。") + utils.SendChunkMsg(ws, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。") break } @@ -136,7 +136,7 @@ func (h *ChatHandler) sendOpenAiMessage( if res.Error == nil { toolCall = true callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label) - utils.SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeContent, Content: callMsg}) + utils.SendChunkMsg(ws, callMsg) contents = append(contents, callMsg) } continue @@ -153,10 +153,7 @@ func (h *ChatHandler) sendOpenAiMessage( } else { content := responseBody.Choices[0].Delta.Content contents = append(contents, utils.InterfaceToString(content)) - utils.SendChunkMessage(ws, types.ReplyMessage{ - Type: types.WsMsgTypeContent, - Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content), - }) + utils.SendChunkMsg(ws, responseBody.Choices[0].Delta.Content) } } // end for @@ -174,7 +171,7 @@ func (h *ChatHandler) sendOpenAiMessage( logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params) params["user_id"] = userVo.Id var apiRes types.BizVo - r, err := req2.C().R().SetHeader("Content-Type", "application/json"). + r, err := req2.C().R().SetHeader("Body-Type", "application/json"). SetHeader("Authorization", function.Token). SetBody(params). SetSuccessResult(&apiRes).Post(function.Action) @@ -185,19 +182,13 @@ func (h *ChatHandler) sendOpenAiMessage( errMsg = r.Status } if errMsg != "" || apiRes.Code != types.Success { - msg := "调用函数工具出错:" + apiRes.Message + errMsg - utils.SendChunkMessage(ws, types.ReplyMessage{ - Type: types.WsMsgTypeContent, - Content: msg, - }) - contents = append(contents, msg) + errMsg = "调用函数工具出错:" + apiRes.Message + errMsg + contents = append(contents, errMsg) } else { - utils.SendChunkMessage(ws, types.ReplyMessage{ - Type: types.WsMsgTypeContent, - Content: apiRes.Data, - }) - contents = append(contents, utils.InterfaceToString(apiRes.Data)) + errMsg = utils.InterfaceToString(apiRes.Data) + contents = append(contents, errMsg) } + utils.SendChunkMsg(ws, errMsg) } // 消息发送成功 @@ -226,7 +217,7 @@ func (h *ChatHandler) sendOpenAiMessage( if strings.HasPrefix(req.Model, "o1-") { content = fmt.Sprintf("AI思考结束,耗时:%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content) } - utils.SendMessage(ws, content) + utils.SendChunkMsg(ws, content) respVo.Usage.Prompt = prompt respVo.Usage.Content = content h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, chatCtx, session, role, userVo, promptCreatedAt, time.Now()) diff --git a/api/handler/sd_handler.go b/api/handler/sd_handler.go index 6471e94e..2f658e73 100644 --- a/api/handler/sd_handler.go +++ b/api/handler/sd_handler.go @@ -19,11 +19,8 @@ import ( "geekai/store/vo" "geekai/utils" "geekai/utils/resp" - "net/http" "time" - "github.com/gorilla/websocket" - "github.com/gin-gonic/gin" "github.com/go-redis/redis/v8" "gorm.io/gorm" @@ -59,27 +56,6 @@ func NewSdJobHandler(app *core.AppServer, } } -// Client WebSocket 客户端,用于通知任务状态变更 -func (h *SdJobHandler) Client(c *gin.Context) { - ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) - if err != nil { - logger.Error(err) - c.Abort() - return - } - - userId := h.GetInt(c, "user_id", 0) - if userId == 0 { - logger.Info("Invalid user ID") - c.Abort() - return - } - - client := types.NewWsClient(ws) - h.sdService.Clients.Put(uint(userId), client) - logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) -} - func (h *SdJobHandler) preCheck(c *gin.Context) bool { user, err := h.GetLoginUser(c) if err != nil { diff --git a/api/handler/suno_handler.go b/api/handler/suno_handler.go index 7aaeab72..8df60385 100644 --- a/api/handler/suno_handler.go +++ b/api/handler/suno_handler.go @@ -19,9 +19,7 @@ import ( "geekai/utils" "geekai/utils/resp" "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" "gorm.io/gorm" - "net/http" "time" ) @@ -44,27 +42,6 @@ func NewSunoHandler(app *core.AppServer, db *gorm.DB, service *suno.Service, upl } } -// Client WebSocket 客户端,用于通知任务状态变更 -func (h *SunoHandler) Client(c *gin.Context) { - ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) - if err != nil { - logger.Error(err) - c.Abort() - return - } - - userId := h.GetInt(c, "user_id", 0) - if userId == 0 { - logger.Info("Invalid user ID") - c.Abort() - return - } - - client := types.NewWsClient(ws) - h.sunoService.Clients.Put(uint(userId), client) - logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) -} - func (h *SunoHandler) Create(c *gin.Context) { var data struct { diff --git a/api/handler/test_handler.go b/api/handler/test_handler.go index 88e95a2f..3ee0c622 100644 --- a/api/handler/test_handler.go +++ b/api/handler/test_handler.go @@ -19,7 +19,7 @@ func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.GeekP } func (h *TestHandler) SseTest(c *gin.Context) { - //c.Header("Content-Type", "text/event-stream") + //c.Header("Body-Type", "text/event-stream") //c.Header("Cache-Control", "no-cache") //c.Header("Connection", "keep-alive") // diff --git a/api/handler/video_handler.go b/api/handler/video_handler.go index 31c34e57..f9a911ab 100644 --- a/api/handler/video_handler.go +++ b/api/handler/video_handler.go @@ -19,9 +19,7 @@ import ( "geekai/utils" "geekai/utils/resp" "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" "gorm.io/gorm" - "net/http" "time" ) @@ -44,27 +42,6 @@ func NewVideoHandler(app *core.AppServer, db *gorm.DB, service *video.Service, u } } -// Client WebSocket 客户端,用于通知任务状态变更 -func (h *VideoHandler) Client(c *gin.Context) { - ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) - if err != nil { - logger.Error(err) - c.Abort() - return - } - - userId := h.GetInt(c, "user_id", 0) - if userId == 0 { - logger.Info("Invalid user ID") - c.Abort() - return - } - - client := types.NewWsClient(ws) - h.videoService.Clients.Put(uint(userId), client) - logger.Infof("New websocket connected, IP: %s", c.RemoteIP()) -} - func (h *VideoHandler) LumaCreate(c *gin.Context) { var data struct { diff --git a/api/handler/ws_handler.go b/api/handler/ws_handler.go index 8227c98e..05933116 100644 --- a/api/handler/ws_handler.go +++ b/api/handler/ws_handler.go @@ -8,6 +8,7 @@ package handler // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ import ( + "context" "geekai/core" "geekai/core/types" "geekai/service" @@ -15,6 +16,7 @@ import ( "geekai/utils" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" + "gorm.io/gorm" "net/http" ) @@ -22,12 +24,14 @@ import ( type WebsocketHandler struct { BaseHandler - wsService *service.WebsocketService + wsService *service.WebsocketService + chatHandler *ChatHandler } -func NewWebsocketHandler(app *core.AppServer, s *service.WebsocketService) *WebsocketHandler { +func NewWebsocketHandler(app *core.AppServer, s *service.WebsocketService, db *gorm.DB, chatHandler *ChatHandler) *WebsocketHandler { return &WebsocketHandler{ - BaseHandler: BaseHandler{App: app}, + BaseHandler: BaseHandler{App: app, DB: db}, + chatHandler: chatHandler, wsService: s, } } @@ -40,9 +44,9 @@ func (h *WebsocketHandler) Client(c *gin.Context) { return } - userId := h.GetInt(c, "user_id", 0) - clientId := c.Query("client") - client := types.NewWsClient(ws) + clientId := c.Query("client_id") + client := types.NewWsClient(ws, clientId) + userId := h.GetLoginUserId(c) if userId == 0 { _ = client.Send([]byte("Invalid user_id")) c.Abort() @@ -63,6 +67,8 @@ func (h *WebsocketHandler) Client(c *gin.Context) { if err != nil { logger.Debugf("close connection: %s", client.Conn.RemoteAddr()) client.Close() + h.wsService.Clients.Delete(clientId) + break } var message types.InputMessage @@ -72,12 +78,66 @@ func (h *WebsocketHandler) Client(c *gin.Context) { } logger.Infof("Receive a message:%+v", message) - if message.Type == types.WsMsgTypePing { - _ = client.Send([]byte(`{"type":"pong"}`)) + if message.Type == types.MsgTypePing { + utils.SendChannelMsg(client, types.ChPing, "pong") + continue } - switch message.Channel { + // 当前只处理聊天消息,其他消息全部丢弃 + var chatMessage types.ChatMessage + err = utils.JsonDecode(utils.JsonEncode(message.Body), &chatMessage) + if err != nil || message.Channel != types.ChChat { + logger.Warnf("invalid message body:%+v", message.Body) + continue + } + var chatRole model.ChatRole + err = h.DB.First(&chatRole, chatMessage.RoleId).Error + if err != nil || !chatRole.Enable { + utils.SendAndFlush(client, "当前聊天角色不存在或者未启用,请更换角色之后再发起对话!!!") + continue + } + // if the role bind a model_id, use role's bind model_id + if chatRole.ModelId > 0 { + chatMessage.RoleId = chatRole.ModelId + } + // get model info + var chatModel model.ChatModel + err = h.DB.Where("id", chatMessage.ModelId).First(&chatModel).Error + if err != nil || chatModel.Enabled == false { + utils.SendAndFlush(client, "当前AI模型暂未启用,请更换模型后再发起对话!!!") + continue + } + session := &types.ChatSession{ + ClientIP: c.ClientIP(), + UserId: userId, + } + + // use old chat data override the chat model and role ID + var chat model.ChatItem + h.DB.Where("chat_id", chatMessage.ChatId).First(&chat) + if chat.Id > 0 { + chatModel.Id = chat.ModelId + chatMessage.RoleId = int(chat.RoleId) + } + + session.ChatId = chatMessage.ChatId + session.Tools = chatMessage.Tools + session.Stream = chatMessage.Stream + // 复制模型数据 + err = utils.CopyObject(chatModel, &session.Model) + if err != nil { + logger.Error(err, chatModel) + } + ctx, cancel := context.WithCancel(context.Background()) + h.chatHandler.ReqCancelFunc.Put(clientId, cancel) + err = h.chatHandler.sendMessage(ctx, session, chatRole, chatMessage.Content, client) + if err != nil { + logger.Error(err) + utils.SendAndFlush(client, err.Error()) + } else { + utils.SendMsg(client, types.ReplyMessage{Channel: types.ChChat, Type: types.MsgTypeEnd}) + logger.Infof("回答完毕: %v", message.Body) } } diff --git a/api/main.go b/api/main.go index 7de6f330..6d1f1058 100644 --- a/api/main.go +++ b/api/main.go @@ -14,7 +14,6 @@ import ( "geekai/core/types" "geekai/handler" "geekai/handler/admin" - "geekai/handler/chatimpl" logger2 "geekai/logger" "geekai/service" "geekai/service/dalle" @@ -128,7 +127,7 @@ func main() { // 创建控制器 fx.Provide(handler.NewChatRoleHandler), fx.Provide(handler.NewUserHandler), - fx.Provide(chatimpl.NewChatHandler), + fx.Provide(handler.NewChatHandler), fx.Provide(handler.NewNetHandler), fx.Provide(handler.NewSmsHandler), fx.Provide(handler.NewRedeemHandler), @@ -246,9 +245,8 @@ func main() { group.GET("clogin", h.CLogin) group.GET("clogin/callback", h.CLoginCallback) }), - fx.Invoke(func(s *core.AppServer, h *chatimpl.ChatHandler) { + fx.Invoke(func(s *core.AppServer, h *handler.ChatHandler) { group := s.Engine.Group("/api/chat/") - group.Any("new", h.ChatHandle) group.GET("list", h.List) group.GET("detail", h.Detail) group.POST("update", h.Update) @@ -281,7 +279,6 @@ func main() { }), fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) { group := s.Engine.Group("/api/mj/") - group.Any("client", h.Client) group.POST("image", h.Image) group.POST("upscale", h.Upscale) group.POST("variation", h.Variation) @@ -292,7 +289,6 @@ func main() { }), fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) { group := s.Engine.Group("/api/sd") - group.Any("client", h.Client) group.POST("image", h.Image) group.GET("jobs", h.JobList) group.GET("imgWall", h.ImgWall) @@ -467,13 +463,11 @@ func main() { }), fx.Provide(handler.NewMarkMapHandler), fx.Invoke(func(s *core.AppServer, h *handler.MarkMapHandler) { - group := s.Engine.Group("/api/markMap/") - group.Any("client", h.Client) + s.Engine.POST("/api/markMap/gen", h.Generate) }), fx.Provide(handler.NewDallJobHandler), fx.Invoke(func(s *core.AppServer, h *handler.DallJobHandler) { group := s.Engine.Group("/api/dall") - group.Any("client", h.Client) group.POST("image", h.Image) group.GET("jobs", h.JobList) group.GET("imgWall", h.ImgWall) @@ -483,7 +477,6 @@ func main() { fx.Provide(handler.NewSunoHandler), fx.Invoke(func(s *core.AppServer, h *handler.SunoHandler) { group := s.Engine.Group("/api/suno") - group.Any("client", h.Client) group.POST("create", h.Create) group.GET("list", h.List) group.GET("remove", h.Remove) @@ -496,7 +489,6 @@ func main() { fx.Provide(handler.NewVideoHandler), fx.Invoke(func(s *core.AppServer, h *handler.VideoHandler) { group := s.Engine.Group("/api/video") - group.Any("client", h.Client) group.POST("luma/create", h.LumaCreate) group.GET("list", h.List) group.GET("remove", h.Remove) @@ -521,6 +513,11 @@ func main() { group := s.Engine.Group("/api/test") group.Any("sse", h.PostTest, h.SseTest) }), + fx.Provide(service.NewWebsocketService), + fx.Provide(handler.NewWebsocketHandler), + fx.Invoke(func(s *core.AppServer, h *handler.WebsocketHandler) { + s.Engine.Any("/api/ws", h.Client) + }), fx.Invoke(func(s *core.AppServer, db *gorm.DB) { go func() { err := s.Run(db) diff --git a/api/service/dalle/service.go b/api/service/dalle/service.go index 4ea1082e..d0732413 100644 --- a/api/service/dalle/service.go +++ b/api/service/dalle/service.go @@ -158,7 +158,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) { Quality: task.Quality, } logger.Infof("Channel:%s, API KEY:%s, BODY: %+v", apiURL, apiKey.Value, reqBody) - r, err := s.httpClient.R().SetHeader("Content-Type", "application/json"). + r, err := s.httpClient.R().SetHeader("Body-Type", "application/json"). SetHeader("Authorization", "Bearer "+apiKey.Value). SetBody(reqBody). SetErrorResult(&errRes). diff --git a/api/service/oss/minio_oss.go b/api/service/oss/minio_oss.go index d095127b..0e346097 100644 --- a/api/service/oss/minio_oss.go +++ b/api/service/oss/minio_oss.go @@ -89,7 +89,7 @@ func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) { fileExt := utils.GetImgExt(file.Filename) filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt) info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{ - ContentType: file.Header.Get("Content-Type"), + ContentType: file.Header.Get("Body-Type"), }) if err != nil { return File{}, fmt.Errorf("error uploading to MinIO: %v", err) diff --git a/api/service/ws_service.go b/api/service/ws_service.go index 65def76f..d049f6bd 100644 --- a/api/service/ws_service.go +++ b/api/service/ws_service.go @@ -5,3 +5,9 @@ import "geekai/core/types" type WebsocketService struct { Clients *types.LMap[string, *types.WsClient] // clientId => Client } + +func NewWebsocketService() *WebsocketService { + return &WebsocketService{ + Clients: types.NewLMap[string, *types.WsClient](), + } +} diff --git a/api/utils/net.go b/api/utils/net.go index f103a870..5e8a0985 100644 --- a/api/utils/net.go +++ b/api/utils/net.go @@ -19,8 +19,9 @@ import ( var logger = logger2.GetLogger() -// SendChunkMessage 回复客户片段端消息 -func SendChunkMessage(client *types.WsClient, message interface{}) { +// SendMsg 回复客户片段端消息 +func SendMsg(client *types.WsClient, message types.ReplyMessage) { + message.ClientId = client.Id msg, err := json.Marshal(message) if err != nil { logger.Errorf("Error for decoding json data: %v", err.Error()) @@ -32,19 +33,23 @@ func SendChunkMessage(client *types.WsClient, message interface{}) { } } -// SendMessage 回复客户端一条完整的消息 -func SendMessage(ws *types.WsClient, message interface{}) { - SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeContent, Content: message}) - SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeEnd}) +// SendAndFlush 回复客户端一条完整的消息 +func SendAndFlush(ws *types.WsClient, message interface{}) { + SendMsg(ws, types.ReplyMessage{Channel: types.ChChat, Type: types.MsgTypeText, Body: message}) + SendMsg(ws, types.ReplyMessage{Channel: types.ChChat, Type: types.MsgTypeEnd}) } -func ReplyContent(ws *types.WsClient, message interface{}) { - SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeContent, Content: message}) +func SendChunkMsg(ws *types.WsClient, message interface{}) { + SendMsg(ws, types.ReplyMessage{Channel: types.ChChat, Type: types.MsgTypeText, Body: message}) } -// ReplyErrorMessage 向客户端发送错误消息 -func ReplyErrorMessage(ws *types.WsClient, message interface{}) { - SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeErr, Content: message}) +// SendErrMsg 向客户端发送错误消息 +func SendErrMsg(ws *types.WsClient, message interface{}) { + SendMsg(ws, types.ReplyMessage{Channel: types.ChChat, Type: types.MsgTypeErr, Body: message}) +} + +func SendChannelMsg(ws *types.WsClient, channel types.WsChannel, message interface{}) { + SendMsg(ws, types.ReplyMessage{Channel: channel, Type: types.MsgTypeText, Body: message}) } func DownloadImage(imageURL string, proxy string) ([]byte, error) { @@ -68,7 +73,9 @@ func DownloadImage(imageURL string, proxy string) ([]byte, error) { if err != nil { return nil, err } - defer resp.Body.Close() + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) imageBytes, err := io.ReadAll(resp.Body) if err != nil { diff --git a/api/utils/openai.go b/api/utils/openai.go index e1bccd67..c9d7363a 100644 --- a/api/utils/openai.go +++ b/api/utils/openai.go @@ -65,7 +65,7 @@ func OpenAIRequest(db *gorm.DB, prompt string, modelName string) (string, error) } apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL) logger.Debugf("Sending %s request, API KEY:%s, PROXY: %s, Model: %s", apiKey.ApiURL, apiURL, apiKey.ProxyURL, modelName) - r, err := client.R().SetHeader("Content-Type", "application/json"). + r, err := client.R().SetHeader("Body-Type", "application/json"). SetHeader("Authorization", "Bearer "+apiKey.Value). SetBody(types.ApiRequest{ Model: modelName, diff --git a/web/src/App.vue b/web/src/App.vue index f40f87d3..d999a96e 100644 --- a/web/src/App.vue +++ b/web/src/App.vue @@ -6,10 +6,13 @@ diff --git a/web/src/store/cache.js b/web/src/store/cache.js index 368a770e..d388f2d6 100644 --- a/web/src/store/cache.js +++ b/web/src/store/cache.js @@ -1,5 +1,6 @@ import {httpGet} from "@/utils/http"; import Storage from "good-storage"; +import {randString} from "@/utils/libs"; const userDataKey = "USER_INFO_CACHE_KEY" const adminDataKey = "ADMIN_INFO_CACHE_KEY" @@ -70,4 +71,14 @@ export function getLicenseInfo() { resolve(err) }) }) +} + +export function getClientId() { + let clientId = Storage.get('client_id') + if (clientId) { + return clientId + } + clientId = randString(42) + Storage.set('client_id', clientId) + return clientId } \ No newline at end of file diff --git a/web/src/store/sharedata.js b/web/src/store/sharedata.js index c43f7a13..dcbdd0ea 100644 --- a/web/src/store/sharedata.js +++ b/web/src/store/sharedata.js @@ -6,6 +6,8 @@ export const useSharedStore = defineStore('shared', { showLoginDialog: false, chatListStyle: Storage.get("chat_list_style","chat"), chatStream: Storage.get("chat_stream",true), + socket: WebSocket, + messageHandlers:{}, }), getters: {}, actions: { @@ -19,6 +21,36 @@ export const useSharedStore = defineStore('shared', { setChatStream(value) { this.chatStream = value; Storage.set("chat_stream", value); + }, + setSocket(value) { + this.socket = value; + }, + addMessageHandler(key, callback) { + if (!this.messageHandlers[key]) { + this.messageHandlers[key] = callback; + this.setMessageHandler(callback) + } + }, + setMessageHandler(callback) { + if (this.socket instanceof WebSocket && this.socket.readyState === WebSocket.OPEN) { + this.socket.addEventListener('message', (event) => { + try { + if (event.data instanceof Blob) { + const reader = new FileReader(); + reader.readAsText(event.data, "UTF-8"); + reader.onload = () => { + callback(JSON.parse(String(reader.result))) + } + } + } catch (e) { + console.warn(e) + } + }) + } else { + setTimeout(() => { + this.setMessageHandler(callback) + }, 1000) + } } } }); diff --git a/web/src/store/system.js b/web/src/store/system.js index 2c6c15ae..57685bbb 100644 --- a/web/src/store/system.js +++ b/web/src/store/system.js @@ -6,7 +6,6 @@ // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ import Storage from "good-storage"; -import {useRouter} from "vue-router"; const MOBILE_THEME = process.env.VUE_APP_KEY_PREFIX + "MOBILE_THEME" const ADMIN_THEME = process.env.VUE_APP_KEY_PREFIX + "ADMIN_THEME" @@ -71,4 +70,4 @@ export function setRoute(path) { export function getRoute() { return Storage.get(process.env.VUE_APP_KEY_PREFIX + 'ROUTE_') -} \ No newline at end of file +} diff --git a/web/src/utils/http.js b/web/src/utils/http.js index 3329591f..efac3d5a 100644 --- a/web/src/utils/http.js +++ b/web/src/utils/http.js @@ -17,7 +17,6 @@ axios.defaults.headers.post['Content-Type'] = 'application/json' axios.interceptors.request.use( config => { // set token - config.headers['Chat-Token'] = getSessionId(); config.headers['Authorization'] = getUserToken(); config.headers['Admin-Authorization'] = getAdminToken(); return config diff --git a/web/src/views/ChatPlus.vue b/web/src/views/ChatPlus.vue index ecf61b76..7a1c573e 100644 --- a/web/src/views/ChatPlus.vue +++ b/web/src/views/ChatPlus.vue @@ -213,7 +213,7 @@