add ws handler

This commit is contained in:
RockYang 2024-09-25 18:43:12 +08:00
parent dfd2be1265
commit 478bc32ddd
9 changed files with 136 additions and 33 deletions

View File

@ -19,23 +19,35 @@ type BizVo struct {
// ReplyMessage 对话回复消息结构 // ReplyMessage 对话回复消息结构
type ReplyMessage struct { type ReplyMessage struct {
Type WsMsgType `json:"type"` // 消息类别start, end, img Channel WsChannel `json:"channel"` // 消息频道,目前只有 chat
Type WsMsgType `json:"type"` // 消息类别
Content interface{} `json:"content"` Content interface{} `json:"content"`
} }
type WsMsgType string type WsMsgType string
type WsChannel string
const ( const (
WsContent = WsMsgType("content") // 输出内容 WsMsgTypeContent = WsMsgType("content") // 输出内容
WsEnd = WsMsgType("end") WsMsgTypeEnd = WsMsgType("end")
WsErr = WsMsgType("error") WsMsgTypeErr = WsMsgType("error")
WsMsgTypePing = WsMsgType("ping") // 心跳消息
WsChat = WsChannel("chat")
WsMj = WsChannel("mj")
WsSd = WsChannel("sd")
WsDall = WsChannel("dall")
WsSuno = WsChannel("suno")
WsLuma = WsChannel("luma")
) )
// InputMessage 对话输入消息结构 // InputMessage 对话输入消息结构
type InputMessage struct { type InputMessage struct {
Content string `json:"content"` Channel WsChannel `json:"channel"` // 消息频道
Tools []int `json:"tools"` // 允许调用工具列表 Type WsMsgType `json:"type"` // 消息类别
Stream bool `json:"stream"` // 是否采用流式输出 Content string `json:"content"`
Tools []int `json:"tools"` // 允许调用工具列表
Stream bool `json:"stream"` // 是否采用流式输出
} }
type BizCode int type BizCode int

View File

@ -151,9 +151,9 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
err = h.sendMessage(ctx, session, chatRole, utils.InterfaceToString(message.Content), client) err = h.sendMessage(ctx, session, chatRole, utils.InterfaceToString(message.Content), client)
if err != nil { if err != nil {
logger.Error(err) logger.Error(err)
utils.ReplyMessage(client, err.Error()) utils.SendMessage(client, err.Error())
} else { } else {
utils.ReplyChunkMessage(client, types.ReplyMessage{Type: types.WsEnd}) utils.SendChunkMessage(client, types.ReplyMessage{Type: types.WsMsgTypeEnd})
logger.Infof("回答完毕: %v", message.Content) logger.Infof("回答完毕: %v", message.Content)
} }

View File

@ -108,7 +108,7 @@ func (h *ChatHandler) sendOpenAiMessage(
} }
if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 { if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
utils.ReplyMessage(ws, "抱歉😔😔😔AI助手由于未知原因已经停止输出内容。") utils.SendMessage(ws, "抱歉😔😔😔AI助手由于未知原因已经停止输出内容。")
break break
} }
@ -136,7 +136,7 @@ func (h *ChatHandler) sendOpenAiMessage(
if res.Error == nil { if res.Error == nil {
toolCall = true toolCall = true
callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label) callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsContent, Content: callMsg}) utils.SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeContent, Content: callMsg})
contents = append(contents, callMsg) contents = append(contents, callMsg)
} }
continue continue
@ -153,8 +153,8 @@ func (h *ChatHandler) sendOpenAiMessage(
} else { } else {
content := responseBody.Choices[0].Delta.Content content := responseBody.Choices[0].Delta.Content
contents = append(contents, utils.InterfaceToString(content)) contents = append(contents, utils.InterfaceToString(content))
utils.ReplyChunkMessage(ws, types.ReplyMessage{ utils.SendChunkMessage(ws, types.ReplyMessage{
Type: types.WsContent, Type: types.WsMsgTypeContent,
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content), Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
}) })
} }
@ -186,14 +186,14 @@ func (h *ChatHandler) sendOpenAiMessage(
} }
if errMsg != "" || apiRes.Code != types.Success { if errMsg != "" || apiRes.Code != types.Success {
msg := "调用函数工具出错:" + apiRes.Message + errMsg msg := "调用函数工具出错:" + apiRes.Message + errMsg
utils.ReplyChunkMessage(ws, types.ReplyMessage{ utils.SendChunkMessage(ws, types.ReplyMessage{
Type: types.WsContent, Type: types.WsMsgTypeContent,
Content: msg, Content: msg,
}) })
contents = append(contents, msg) contents = append(contents, msg)
} else { } else {
utils.ReplyChunkMessage(ws, types.ReplyMessage{ utils.SendChunkMessage(ws, types.ReplyMessage{
Type: types.WsContent, Type: types.WsMsgTypeContent,
Content: apiRes.Data, Content: apiRes.Data,
}) })
contents = append(contents, utils.InterfaceToString(apiRes.Data)) contents = append(contents, utils.InterfaceToString(apiRes.Data))
@ -226,7 +226,7 @@ func (h *ChatHandler) sendOpenAiMessage(
if strings.HasPrefix(req.Model, "o1-") { if strings.HasPrefix(req.Model, "o1-") {
content = fmt.Sprintf("AI思考结束耗时%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content) content = fmt.Sprintf("AI思考结束耗时%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content)
} }
utils.ReplyMessage(ws, content) utils.SendMessage(ws, content)
respVo.Usage.Prompt = prompt respVo.Usage.Prompt = prompt
respVo.Usage.Content = content respVo.Usage.Content = content
h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, chatCtx, session, role, userVo, promptCreatedAt, time.Now()) h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, chatCtx, session, role, userVo, promptCreatedAt, time.Now())

View File

@ -87,7 +87,7 @@ func (h *MarkMapHandler) Client(c *gin.Context) {
logger.Error(err) logger.Error(err)
utils.ReplyErrorMessage(client, err.Error()) utils.ReplyErrorMessage(client, err.Error())
} else { } else {
utils.ReplyMessage(client, types.ReplyMessage{Type: types.WsEnd}) utils.SendMessage(client, types.ReplyMessage{Type: types.WsMsgTypeEnd})
} }
} }
@ -170,13 +170,13 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
break break
} }
utils.ReplyChunkMessage(client, types.ReplyMessage{ utils.SendChunkMessage(client, types.ReplyMessage{
Type: types.WsContent, Type: types.WsMsgTypeContent,
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content), Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
}) })
} // end for } // end for
utils.ReplyChunkMessage(client, types.ReplyMessage{Type: types.WsEnd}) utils.SendChunkMessage(client, types.ReplyMessage{Type: types.WsMsgTypeEnd})
} else { } else {
body, _ := io.ReadAll(response.Body) body, _ := io.ReadAll(response.Body)

85
api/handler/ws_handler.go Normal file
View File

@ -0,0 +1,85 @@
package handler
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/store/model"
"geekai/utils"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"net/http"
)
// Websocket 连接处理 handler
type WebsocketHandler struct {
BaseHandler
wsService *service.WebsocketService
}
func NewWebsocketHandler(app *core.AppServer, s *service.WebsocketService) *WebsocketHandler {
return &WebsocketHandler{
BaseHandler: BaseHandler{App: app},
wsService: s,
}
}
func (h *WebsocketHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
return
}
userId := h.GetInt(c, "user_id", 0)
clientId := c.Query("client")
client := types.NewWsClient(ws)
if userId == 0 {
_ = client.Send([]byte("Invalid user_id"))
c.Abort()
return
}
var user model.User
if err := h.DB.Where("id", userId).First(&user).Error; err != nil {
_ = client.Send([]byte("Invalid user_id"))
c.Abort()
return
}
h.wsService.Clients.Put(clientId, client)
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
go func() {
for {
_, msg, err := client.Receive()
if err != nil {
logger.Debugf("close connection: %s", client.Conn.RemoteAddr())
client.Close()
}
var message types.InputMessage
err = utils.JsonDecode(string(msg), &message)
if err != nil {
continue
}
logger.Infof("Receive a message:%+v", message)
if message.Type == types.WsMsgTypePing {
_ = client.Send([]byte(`{"type":"pong"}`))
}
switch message.Channel {
}
}
}()
}

View File

@ -0,0 +1,7 @@
package service
import "geekai/core/types"
type WebsocketService struct {
Clients *types.LMap[string, *types.WsClient] // clientId => Client
}

View File

@ -19,8 +19,8 @@ import (
var logger = logger2.GetLogger() var logger = logger2.GetLogger()
// ReplyChunkMessage 回复客户片段端消息 // SendChunkMessage 回复客户片段端消息
func ReplyChunkMessage(client *types.WsClient, message interface{}) { func SendChunkMessage(client *types.WsClient, message interface{}) {
msg, err := json.Marshal(message) msg, err := json.Marshal(message)
if err != nil { if err != nil {
logger.Errorf("Error for decoding json data: %v", err.Error()) logger.Errorf("Error for decoding json data: %v", err.Error())
@ -32,19 +32,19 @@ func ReplyChunkMessage(client *types.WsClient, message interface{}) {
} }
} }
// ReplyMessage 回复客户端一条完整的消息 // SendMessage 回复客户端一条完整的消息
func ReplyMessage(ws *types.WsClient, message interface{}) { func SendMessage(ws *types.WsClient, message interface{}) {
ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsContent, Content: message}) SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeContent, Content: message})
ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsEnd}) SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeEnd})
} }
func ReplyContent(ws *types.WsClient, message interface{}) { func ReplyContent(ws *types.WsClient, message interface{}) {
ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsContent, Content: message}) SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeContent, Content: message})
} }
// ReplyErrorMessage 向客户端发送错误消息 // ReplyErrorMessage 向客户端发送错误消息
func ReplyErrorMessage(ws *types.WsClient, message interface{}) { func ReplyErrorMessage(ws *types.WsClient, message interface{}) {
ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsErr, Content: message}) SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeErr, Content: message})
} }
func DownloadImage(imageURL string, proxy string) ([]byte, error) { func DownloadImage(imageURL string, proxy string) ([]byte, error) {

View File

@ -109,7 +109,7 @@ onMounted(() => {
}) })
const fetchApps = () => { const fetchApps = () => {
httpGet("/api/app/list/user").then((res) => { httpGet("/api/app/list").then((res) => {
const items = res.data const items = res.data
// hello message // hello message
for (let i = 0; i < items.length; i++) { for (let i = 0; i < items.length; i++) {

View File

@ -153,8 +153,7 @@
import {onMounted, ref} from "vue"; import {onMounted, ref} from "vue";
import {showFailToast, showLoadingToast, showNotify, showSuccessToast} from "vant"; import {showFailToast, showLoadingToast, showNotify, showSuccessToast} from "vant";
import {httpGet, httpPost} from "@/utils/http"; import {httpGet, httpPost} from "@/utils/http";
import Compressor from 'compressorjs'; import {dateFormat, showLoginDialog} from "@/utils/libs";
import {dateFormat, isWeChatBrowser, showLoginDialog} from "@/utils/libs";
import {ElMessage} from "element-plus"; import {ElMessage} from "element-plus";
import {checkSession, getSystemInfo} from "@/store/cache"; import {checkSession, getSystemInfo} from "@/store/cache";
import {useRouter} from "vue-router"; import {useRouter} from "vue-router";