mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-09-23 03:36:39 +08:00
add ws handler
This commit is contained in:
parent
dfd2be1265
commit
478bc32ddd
@ -19,23 +19,35 @@ type BizVo struct {
|
|||||||
|
|
||||||
// ReplyMessage 对话回复消息结构
|
// ReplyMessage 对话回复消息结构
|
||||||
type ReplyMessage struct {
|
type ReplyMessage struct {
|
||||||
Type WsMsgType `json:"type"` // 消息类别,start, end, img
|
Channel WsChannel `json:"channel"` // 消息频道,目前只有 chat
|
||||||
|
Type WsMsgType `json:"type"` // 消息类别
|
||||||
Content interface{} `json:"content"`
|
Content interface{} `json:"content"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type WsMsgType string
|
type WsMsgType string
|
||||||
|
type WsChannel string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
WsContent = WsMsgType("content") // 输出内容
|
WsMsgTypeContent = WsMsgType("content") // 输出内容
|
||||||
WsEnd = WsMsgType("end")
|
WsMsgTypeEnd = WsMsgType("end")
|
||||||
WsErr = WsMsgType("error")
|
WsMsgTypeErr = WsMsgType("error")
|
||||||
|
WsMsgTypePing = WsMsgType("ping") // 心跳消息
|
||||||
|
|
||||||
|
WsChat = WsChannel("chat")
|
||||||
|
WsMj = WsChannel("mj")
|
||||||
|
WsSd = WsChannel("sd")
|
||||||
|
WsDall = WsChannel("dall")
|
||||||
|
WsSuno = WsChannel("suno")
|
||||||
|
WsLuma = WsChannel("luma")
|
||||||
)
|
)
|
||||||
|
|
||||||
// InputMessage 对话输入消息结构
|
// InputMessage 对话输入消息结构
|
||||||
type InputMessage struct {
|
type InputMessage struct {
|
||||||
Content string `json:"content"`
|
Channel WsChannel `json:"channel"` // 消息频道
|
||||||
Tools []int `json:"tools"` // 允许调用工具列表
|
Type WsMsgType `json:"type"` // 消息类别
|
||||||
Stream bool `json:"stream"` // 是否采用流式输出
|
Content string `json:"content"`
|
||||||
|
Tools []int `json:"tools"` // 允许调用工具列表
|
||||||
|
Stream bool `json:"stream"` // 是否采用流式输出
|
||||||
}
|
}
|
||||||
|
|
||||||
type BizCode int
|
type BizCode int
|
||||||
|
@ -151,9 +151,9 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
|
|||||||
err = h.sendMessage(ctx, session, chatRole, utils.InterfaceToString(message.Content), client)
|
err = h.sendMessage(ctx, session, chatRole, utils.InterfaceToString(message.Content), client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
utils.ReplyMessage(client, err.Error())
|
utils.SendMessage(client, err.Error())
|
||||||
} else {
|
} else {
|
||||||
utils.ReplyChunkMessage(client, types.ReplyMessage{Type: types.WsEnd})
|
utils.SendChunkMessage(client, types.ReplyMessage{Type: types.WsMsgTypeEnd})
|
||||||
logger.Infof("回答完毕: %v", message.Content)
|
logger.Infof("回答完毕: %v", message.Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -108,7 +108,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
|
if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
|
||||||
utils.ReplyMessage(ws, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。")
|
utils.SendMessage(ws, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。")
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -136,7 +136,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
if res.Error == nil {
|
if res.Error == nil {
|
||||||
toolCall = true
|
toolCall = true
|
||||||
callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
|
callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
|
||||||
utils.ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsContent, Content: callMsg})
|
utils.SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeContent, Content: callMsg})
|
||||||
contents = append(contents, callMsg)
|
contents = append(contents, callMsg)
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
@ -153,8 +153,8 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
} else {
|
} else {
|
||||||
content := responseBody.Choices[0].Delta.Content
|
content := responseBody.Choices[0].Delta.Content
|
||||||
contents = append(contents, utils.InterfaceToString(content))
|
contents = append(contents, utils.InterfaceToString(content))
|
||||||
utils.ReplyChunkMessage(ws, types.ReplyMessage{
|
utils.SendChunkMessage(ws, types.ReplyMessage{
|
||||||
Type: types.WsContent,
|
Type: types.WsMsgTypeContent,
|
||||||
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
|
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -186,14 +186,14 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
}
|
}
|
||||||
if errMsg != "" || apiRes.Code != types.Success {
|
if errMsg != "" || apiRes.Code != types.Success {
|
||||||
msg := "调用函数工具出错:" + apiRes.Message + errMsg
|
msg := "调用函数工具出错:" + apiRes.Message + errMsg
|
||||||
utils.ReplyChunkMessage(ws, types.ReplyMessage{
|
utils.SendChunkMessage(ws, types.ReplyMessage{
|
||||||
Type: types.WsContent,
|
Type: types.WsMsgTypeContent,
|
||||||
Content: msg,
|
Content: msg,
|
||||||
})
|
})
|
||||||
contents = append(contents, msg)
|
contents = append(contents, msg)
|
||||||
} else {
|
} else {
|
||||||
utils.ReplyChunkMessage(ws, types.ReplyMessage{
|
utils.SendChunkMessage(ws, types.ReplyMessage{
|
||||||
Type: types.WsContent,
|
Type: types.WsMsgTypeContent,
|
||||||
Content: apiRes.Data,
|
Content: apiRes.Data,
|
||||||
})
|
})
|
||||||
contents = append(contents, utils.InterfaceToString(apiRes.Data))
|
contents = append(contents, utils.InterfaceToString(apiRes.Data))
|
||||||
@ -226,7 +226,7 @@ func (h *ChatHandler) sendOpenAiMessage(
|
|||||||
if strings.HasPrefix(req.Model, "o1-") {
|
if strings.HasPrefix(req.Model, "o1-") {
|
||||||
content = fmt.Sprintf("AI思考结束,耗时:%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content)
|
content = fmt.Sprintf("AI思考结束,耗时:%d 秒。\n%s", time.Now().Unix()-session.Start, respVo.Choices[0].Message.Content)
|
||||||
}
|
}
|
||||||
utils.ReplyMessage(ws, content)
|
utils.SendMessage(ws, content)
|
||||||
respVo.Usage.Prompt = prompt
|
respVo.Usage.Prompt = prompt
|
||||||
respVo.Usage.Content = content
|
respVo.Usage.Content = content
|
||||||
h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, chatCtx, session, role, userVo, promptCreatedAt, time.Now())
|
h.saveChatHistory(req, respVo.Usage, respVo.Choices[0].Message, chatCtx, session, role, userVo, promptCreatedAt, time.Now())
|
||||||
|
@ -87,7 +87,7 @@ func (h *MarkMapHandler) Client(c *gin.Context) {
|
|||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
utils.ReplyErrorMessage(client, err.Error())
|
utils.ReplyErrorMessage(client, err.Error())
|
||||||
} else {
|
} else {
|
||||||
utils.ReplyMessage(client, types.ReplyMessage{Type: types.WsEnd})
|
utils.SendMessage(client, types.ReplyMessage{Type: types.WsMsgTypeEnd})
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -170,13 +170,13 @@ func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, mode
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
utils.ReplyChunkMessage(client, types.ReplyMessage{
|
utils.SendChunkMessage(client, types.ReplyMessage{
|
||||||
Type: types.WsContent,
|
Type: types.WsMsgTypeContent,
|
||||||
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
|
Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
|
||||||
})
|
})
|
||||||
} // end for
|
} // end for
|
||||||
|
|
||||||
utils.ReplyChunkMessage(client, types.ReplyMessage{Type: types.WsEnd})
|
utils.SendChunkMessage(client, types.ReplyMessage{Type: types.WsMsgTypeEnd})
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
body, _ := io.ReadAll(response.Body)
|
body, _ := io.ReadAll(response.Body)
|
||||||
|
85
api/handler/ws_handler.go
Normal file
85
api/handler/ws_handler.go
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
|
||||||
|
// * Use of this source code is governed by a Apache-2.0 license
|
||||||
|
// * that can be found in the LICENSE file.
|
||||||
|
// * @Author yangjian102621@163.com
|
||||||
|
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||||
|
|
||||||
|
import (
|
||||||
|
"geekai/core"
|
||||||
|
"geekai/core/types"
|
||||||
|
"geekai/service"
|
||||||
|
"geekai/store/model"
|
||||||
|
"geekai/utils"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Websocket 连接处理 handler
|
||||||
|
|
||||||
|
type WebsocketHandler struct {
|
||||||
|
BaseHandler
|
||||||
|
wsService *service.WebsocketService
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWebsocketHandler(app *core.AppServer, s *service.WebsocketService) *WebsocketHandler {
|
||||||
|
return &WebsocketHandler{
|
||||||
|
BaseHandler: BaseHandler{App: app},
|
||||||
|
wsService: s,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *WebsocketHandler) Client(c *gin.Context) {
|
||||||
|
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(err)
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userId := h.GetInt(c, "user_id", 0)
|
||||||
|
clientId := c.Query("client")
|
||||||
|
client := types.NewWsClient(ws)
|
||||||
|
if userId == 0 {
|
||||||
|
_ = client.Send([]byte("Invalid user_id"))
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var user model.User
|
||||||
|
if err := h.DB.Where("id", userId).First(&user).Error; err != nil {
|
||||||
|
_ = client.Send([]byte("Invalid user_id"))
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.wsService.Clients.Put(clientId, client)
|
||||||
|
logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
_, msg, err := client.Receive()
|
||||||
|
if err != nil {
|
||||||
|
logger.Debugf("close connection: %s", client.Conn.RemoteAddr())
|
||||||
|
client.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
var message types.InputMessage
|
||||||
|
err = utils.JsonDecode(string(msg), &message)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Infof("Receive a message:%+v", message)
|
||||||
|
if message.Type == types.WsMsgTypePing {
|
||||||
|
_ = client.Send([]byte(`{"type":"pong"}`))
|
||||||
|
}
|
||||||
|
|
||||||
|
switch message.Channel {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
7
api/service/ws_service.go
Normal file
7
api/service/ws_service.go
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import "geekai/core/types"
|
||||||
|
|
||||||
|
type WebsocketService struct {
|
||||||
|
Clients *types.LMap[string, *types.WsClient] // clientId => Client
|
||||||
|
}
|
@ -19,8 +19,8 @@ import (
|
|||||||
|
|
||||||
var logger = logger2.GetLogger()
|
var logger = logger2.GetLogger()
|
||||||
|
|
||||||
// ReplyChunkMessage 回复客户片段端消息
|
// SendChunkMessage 回复客户片段端消息
|
||||||
func ReplyChunkMessage(client *types.WsClient, message interface{}) {
|
func SendChunkMessage(client *types.WsClient, message interface{}) {
|
||||||
msg, err := json.Marshal(message)
|
msg, err := json.Marshal(message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Error for decoding json data: %v", err.Error())
|
logger.Errorf("Error for decoding json data: %v", err.Error())
|
||||||
@ -32,19 +32,19 @@ func ReplyChunkMessage(client *types.WsClient, message interface{}) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReplyMessage 回复客户端一条完整的消息
|
// SendMessage 回复客户端一条完整的消息
|
||||||
func ReplyMessage(ws *types.WsClient, message interface{}) {
|
func SendMessage(ws *types.WsClient, message interface{}) {
|
||||||
ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsContent, Content: message})
|
SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeContent, Content: message})
|
||||||
ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsEnd})
|
SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeEnd})
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReplyContent(ws *types.WsClient, message interface{}) {
|
func ReplyContent(ws *types.WsClient, message interface{}) {
|
||||||
ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsContent, Content: message})
|
SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeContent, Content: message})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReplyErrorMessage 向客户端发送错误消息
|
// ReplyErrorMessage 向客户端发送错误消息
|
||||||
func ReplyErrorMessage(ws *types.WsClient, message interface{}) {
|
func ReplyErrorMessage(ws *types.WsClient, message interface{}) {
|
||||||
ReplyChunkMessage(ws, types.ReplyMessage{Type: types.WsErr, Content: message})
|
SendChunkMessage(ws, types.ReplyMessage{Type: types.WsMsgTypeErr, Content: message})
|
||||||
}
|
}
|
||||||
|
|
||||||
func DownloadImage(imageURL string, proxy string) ([]byte, error) {
|
func DownloadImage(imageURL string, proxy string) ([]byte, error) {
|
||||||
|
@ -109,7 +109,7 @@ onMounted(() => {
|
|||||||
})
|
})
|
||||||
|
|
||||||
const fetchApps = () => {
|
const fetchApps = () => {
|
||||||
httpGet("/api/app/list/user").then((res) => {
|
httpGet("/api/app/list").then((res) => {
|
||||||
const items = res.data
|
const items = res.data
|
||||||
// 处理 hello message
|
// 处理 hello message
|
||||||
for (let i = 0; i < items.length; i++) {
|
for (let i = 0; i < items.length; i++) {
|
||||||
|
@ -153,8 +153,7 @@
|
|||||||
import {onMounted, ref} from "vue";
|
import {onMounted, ref} from "vue";
|
||||||
import {showFailToast, showLoadingToast, showNotify, showSuccessToast} from "vant";
|
import {showFailToast, showLoadingToast, showNotify, showSuccessToast} from "vant";
|
||||||
import {httpGet, httpPost} from "@/utils/http";
|
import {httpGet, httpPost} from "@/utils/http";
|
||||||
import Compressor from 'compressorjs';
|
import {dateFormat, showLoginDialog} from "@/utils/libs";
|
||||||
import {dateFormat, isWeChatBrowser, showLoginDialog} from "@/utils/libs";
|
|
||||||
import {ElMessage} from "element-plus";
|
import {ElMessage} from "element-plus";
|
||||||
import {checkSession, getSystemInfo} from "@/store/cache";
|
import {checkSession, getSystemInfo} from "@/store/cache";
|
||||||
import {useRouter} from "vue-router";
|
import {useRouter} from "vue-router";
|
||||||
|
Loading…
Reference in New Issue
Block a user