mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 08:13:43 +08:00 
			
		
		
		
	add ws handler
This commit is contained in:
		@@ -19,23 +19,35 @@ type BizVo struct {
 | 
			
		||||
 | 
			
		||||
// ReplyMessage 对话回复消息结构
 | 
			
		||||
type ReplyMessage struct {
 | 
			
		||||
	Type    WsMsgType   `json:"type"` // 消息类别,start, end, img
 | 
			
		||||
	Channel WsChannel   `json:"channel"` // 消息频道,目前只有 chat
 | 
			
		||||
	Type    WsMsgType   `json:"type"`    // 消息类别
 | 
			
		||||
	Content interface{} `json:"content"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type WsMsgType string
 | 
			
		||||
type WsChannel string
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	WsContent = WsMsgType("content") // 输出内容
 | 
			
		||||
	WsEnd     = WsMsgType("end")
 | 
			
		||||
	WsErr     = WsMsgType("error")
 | 
			
		||||
	WsMsgTypeContent = WsMsgType("content") // 输出内容
 | 
			
		||||
	WsMsgTypeEnd     = WsMsgType("end")
 | 
			
		||||
	WsMsgTypeErr     = WsMsgType("error")
 | 
			
		||||
	WsMsgTypePing    = WsMsgType("ping") // 心跳消息
 | 
			
		||||
 | 
			
		||||
	WsChat = WsChannel("chat")
 | 
			
		||||
	WsMj   = WsChannel("mj")
 | 
			
		||||
	WsSd   = WsChannel("sd")
 | 
			
		||||
	WsDall = WsChannel("dall")
 | 
			
		||||
	WsSuno = WsChannel("suno")
 | 
			
		||||
	WsLuma = WsChannel("luma")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// InputMessage 对话输入消息结构
 | 
			
		||||
type InputMessage struct {
 | 
			
		||||
	Content string `json:"content"`
 | 
			
		||||
	Tools   []int  `json:"tools"`  // 允许调用工具列表
 | 
			
		||||
	Stream  bool   `json:"stream"` // 是否采用流式输出
 | 
			
		||||
	Channel WsChannel `json:"channel"` // 消息频道
 | 
			
		||||
	Type    WsMsgType `json:"type"`    // 消息类别
 | 
			
		||||
	Content string    `json:"content"`
 | 
			
		||||
	Tools   []int     `json:"tools"`  // 允许调用工具列表
 | 
			
		||||
	Stream  bool      `json:"stream"` // 是否采用流式输出
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type BizCode int
 | 
			
		||||
 
 | 
			
		||||
@@ -151,9 +151,9 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
			
		||||
			err = h.sendMessage(ctx, session, chatRole, utils.InterfaceToString(message.Content), client)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error(err)
 | 
			
		||||
				utils.ReplyMessage(client, err.Error())
 | 
			
		||||
				utils.SendMessage(client, err.Error())
 | 
			
		||||
			} else {
 | 
			
		||||
				utils.ReplyChunkMessage(client, types.ReplyMessage{Type: types.WsEnd})
 | 
			
		||||
				utils.SendChunkMessage(client, types.ReplyMessage{Type: types.WsMsgTypeEnd})
 | 
			
		||||
				logger.Infof("回答完毕: %v", message.Content)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -108,7 +108,7 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
 | 
			
		||||
				utils.ReplyMessage(ws, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。")
 | 
			
		||||
				utils.SendMessage(ws, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。")
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
@@ -136,7 +136,7 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
				if res.Error == nil {
 | 
			
		||||
					toolCall = true
 | 
			
		||||
					callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsContent, Content: callMsg})
 | 
			
		||||
					utils.SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeContent, Content: callMsg})
 | 
			
		||||
					contents = append(contents, callMsg)
 | 
			
		||||
				}
 | 
			
		||||
				continue
 | 
			
		||||
@@ -153,8 +153,8 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
			} else {
 | 
			
		||||
				content := responseBody.Choices[0].Delta.Content
 | 
			
		||||
				contents = append(contents, utils.InterfaceToString(content))
 | 
			
		||||
				utils.ReplyChunkMessage(ws, types.ReplyMessage{
 | 
			
		||||
					Type:    types.WsContent,
 | 
			
		||||
				utils.SendChunkMessage(ws, types.ReplyMessage{
 | 
			
		||||
					Type:    types.WsMsgTypeContent,
 | 
			
		||||
					Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
 | 
			
		||||
				})
 | 
			
		||||
			}
 | 
			
		||||
@@ -186,14 +186,14 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
			}
 | 
			
		||||
			if errMsg != "" || apiRes.Code != types.Success {
 | 
			
		||||
				msg := "调用函数工具出错:" + apiRes.Message + errMsg
 | 
			
		||||
				utils.ReplyChunkMessage(ws, types.ReplyMessage{
 | 
			
		||||
					Type:    types.WsContent,
 | 
			
		||||
				utils.SendChunkMessage(ws, types.ReplyMessage{
 | 
			
		||||
					Type:    types.WsMsgTypeContent,
 | 
			
		||||
					Content: msg,
 | 
			
		||||
				})
 | 
			
		||||
				contents = append(contents, msg)
 | 
			
		||||
			} else {
 | 
			
		||||
				utils.ReplyChunkMessage(ws, types.ReplyMessage{
 | 
			
		||||
					Type:    types.WsContent,
 | 
			
		||||
				utils.SendChunkMessage(ws, types.ReplyMessage{
 | 
			
		||||
					Type:    types.WsMsgTypeContent,
 | 
			
		||||
					Content: apiRes.Data,
 | 
			
		||||
				})
 | 
			
		||||
				contents = append(contents, utils.InterfaceToString(apiRes.Data))
 | 
			
		||||
@@ -226,7 +226,7 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
		if strings.HasPrefix(req.Model, "o1-") {
 | 
			
		||||
			content = fmt.Sprintf("AI思考结束,耗时:%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content)
 | 
			
		||||
		}
 | 
			
		||||
		utils.ReplyMessage(ws, content)
 | 
			
		||||
		utils.SendMessage(ws, content)
 | 
			
		||||
		respVo.Usage.Prompt = prompt
 | 
			
		||||
		respVo.Usage.Content = content
 | 
			
		||||
		h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, chatCtx, session, role, userVo, promptCreatedAt, time.Now())
 | 
			
		||||
 
 | 
			
		||||
@@ -87,7 +87,7 @@ func (h *MarkMapHandler) Client(c *gin.Context) {
 | 
			
		||||
				logger.Error(err)
 | 
			
		||||
				utils.ReplyErrorMessage(client, err.Error())
 | 
			
		||||
			} else {
 | 
			
		||||
				utils.ReplyMessage(client, types.ReplyMessage{Type: types.WsEnd})
 | 
			
		||||
				utils.SendMessage(client, types.ReplyMessage{Type: types.WsMsgTypeEnd})
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		}
 | 
			
		||||
@@ -170,13 +170,13 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			utils.ReplyChunkMessage(client, types.ReplyMessage{
 | 
			
		||||
				Type:    types.WsContent,
 | 
			
		||||
			utils.SendChunkMessage(client, types.ReplyMessage{
 | 
			
		||||
				Type:    types.WsMsgTypeContent,
 | 
			
		||||
				Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
 | 
			
		||||
			})
 | 
			
		||||
		} // end for
 | 
			
		||||
 | 
			
		||||
		utils.ReplyChunkMessage(client, types.ReplyMessage{Type: types.WsEnd})
 | 
			
		||||
		utils.SendChunkMessage(client, types.ReplyMessage{Type: types.WsMsgTypeEnd})
 | 
			
		||||
 | 
			
		||||
	} else {
 | 
			
		||||
		body, _ := io.ReadAll(response.Body)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										85
									
								
								api/handler/ws_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										85
									
								
								api/handler/ws_handler.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,85 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/service"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
	"net/http"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Websocket 连接处理 handler
 | 
			
		||||
 | 
			
		||||
type WebsocketHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	wsService *service.WebsocketService
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewWebsocketHandler(app *core.AppServer, s *service.WebsocketService) *WebsocketHandler {
 | 
			
		||||
	return &WebsocketHandler{
 | 
			
		||||
		BaseHandler: BaseHandler{App: app},
 | 
			
		||||
		wsService:   s,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *WebsocketHandler) Client(c *gin.Context) {
 | 
			
		||||
	ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error(err)
 | 
			
		||||
		c.Abort()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	userId := h.GetInt(c, "user_id", 0)
 | 
			
		||||
	clientId := c.Query("client")
 | 
			
		||||
	client := types.NewWsClient(ws)
 | 
			
		||||
	if userId == 0 {
 | 
			
		||||
		_ = client.Send([]byte("Invalid user_id"))
 | 
			
		||||
		c.Abort()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	var user model.User
 | 
			
		||||
	if err := h.DB.Where("id", userId).First(&user).Error; err != nil {
 | 
			
		||||
		_ = client.Send([]byte("Invalid user_id"))
 | 
			
		||||
		c.Abort()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	h.wsService.Clients.Put(clientId, client)
 | 
			
		||||
	logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			_, msg, err := client.Receive()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Debugf("close connection: %s", client.Conn.RemoteAddr())
 | 
			
		||||
				client.Close()
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			var message types.InputMessage
 | 
			
		||||
			err = utils.JsonDecode(string(msg), &message)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			logger.Infof("Receive a message:%+v", message)
 | 
			
		||||
			if message.Type == types.WsMsgTypePing {
 | 
			
		||||
				_ = client.Send([]byte(`{"type":"pong"}`))
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			switch message.Channel {
 | 
			
		||||
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										7
									
								
								api/service/ws_service.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								api/service/ws_service.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,7 @@
 | 
			
		||||
package service
 | 
			
		||||
 | 
			
		||||
import "geekai/core/types"
 | 
			
		||||
 | 
			
		||||
type WebsocketService struct {
 | 
			
		||||
	Clients *types.LMap[string, *types.WsClient] // clientId => Client
 | 
			
		||||
}
 | 
			
		||||
@@ -19,8 +19,8 @@ import (
 | 
			
		||||
 | 
			
		||||
var logger = logger2.GetLogger()
 | 
			
		||||
 | 
			
		||||
// ReplyChunkMessage 回复客户片段端消息
 | 
			
		||||
func ReplyChunkMessage(client *types.WsClient, message interface{}) {
 | 
			
		||||
// SendChunkMessage 回复客户片段端消息
 | 
			
		||||
func SendChunkMessage(client *types.WsClient, message interface{}) {
 | 
			
		||||
	msg, err := json.Marshal(message)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Errorf("Error for decoding json data: %v", err.Error())
 | 
			
		||||
@@ -32,19 +32,19 @@ func ReplyChunkMessage(client *types.WsClient, message interface{}) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ReplyMessage 回复客户端一条完整的消息
 | 
			
		||||
func ReplyMessage(ws *types.WsClient, message interface{}) {
 | 
			
		||||
	ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsContent, Content: message})
 | 
			
		||||
	ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsEnd})
 | 
			
		||||
// SendMessage 回复客户端一条完整的消息
 | 
			
		||||
func SendMessage(ws *types.WsClient, message interface{}) {
 | 
			
		||||
	SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeContent, Content: message})
 | 
			
		||||
	SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeEnd})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func ReplyContent(ws *types.WsClient, message interface{}) {
 | 
			
		||||
	ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsContent, Content: message})
 | 
			
		||||
	SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeContent, Content: message})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ReplyErrorMessage 向客户端发送错误消息
 | 
			
		||||
func ReplyErrorMessage(ws *types.WsClient, message interface{}) {
 | 
			
		||||
	ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsErr, Content: message})
 | 
			
		||||
	SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeErr, Content: message})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func DownloadImage(imageURL string, proxy string) ([]byte, error) {
 | 
			
		||||
 
 | 
			
		||||
@@ -109,7 +109,7 @@ onMounted(() => {
 | 
			
		||||
})
 | 
			
		||||
 | 
			
		||||
const fetchApps = () => {
 | 
			
		||||
  httpGet("/api/app/list/user").then((res) => {
 | 
			
		||||
  httpGet("/api/app/list").then((res) => {
 | 
			
		||||
    const items = res.data
 | 
			
		||||
    // 处理 hello message
 | 
			
		||||
    for (let i = 0; i < items.length; i++) {
 | 
			
		||||
 
 | 
			
		||||
@@ -153,8 +153,7 @@
 | 
			
		||||
import {onMounted, ref} from "vue";
 | 
			
		||||
import {showFailToast, showLoadingToast, showNotify, showSuccessToast} from "vant";
 | 
			
		||||
import {httpGet, httpPost} from "@/utils/http";
 | 
			
		||||
import Compressor from 'compressorjs';
 | 
			
		||||
import {dateFormat, isWeChatBrowser, showLoginDialog} from "@/utils/libs";
 | 
			
		||||
import {dateFormat, showLoginDialog} from "@/utils/libs";
 | 
			
		||||
import {ElMessage} from "element-plus";
 | 
			
		||||
import {checkSession, getSystemInfo} from "@/store/cache";
 | 
			
		||||
import {useRouter} from "vue-router";
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user