diff --git a/CHANGELOG.md b/CHANGELOG.md index 923e36b8..5e3d4630 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,8 +1,10 @@ # 更新日志 ## v4.1.6 +* 功能新增:**支持OpenAI实时语音对话功能** :rocket: :rocket: :rocket:, Beta 版,目前没有做算力计费控制,目前只有 VIP 用户可以使用。 * 功能优化:优化MysQL容器配置文档,解决MysQL容器资源占用过高问题 * 功能新增:管理后台增加AI绘图任务管理,可在管理后台浏览和删除用户的绘图任务 * 功能新增:管理后台增加Suno和Luma任务管理功能 +* Bug修复:修复管理后台删除兑换码报 404 错误 ## v4.1.5 * 功能优化:重构 websocket 组件,减少 websocket 连接数,全站共享一个 websocket 连接 diff --git a/api/core/app_server.go b/api/core/app_server.go index 766290b8..668488db 100644 --- a/api/core/app_server.go +++ b/api/core/app_server.go @@ -127,12 +127,19 @@ func corsMiddleware() gin.HandlerFunc { // 用户授权验证 func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc { return func(c *gin.Context) { + clientProtocols := c.GetHeader("Sec-WebSocket-Protocol") var tokenString string isAdminApi := strings.Contains(c.Request.URL.Path, "/api/admin/") if isAdminApi { // 后台管理 API tokenString = c.GetHeader(types.AdminAuthHeader) - } else if c.Request.URL.Path == "/api/ws" { // Websocket 连接 - tokenString = c.Query("token") + } else if clientProtocols != "" { // Websocket 连接 + // 解析子协议内容 + protocols := strings.Split(clientProtocols, ",") + if protocols[0] == "realtime" { + tokenString = strings.TrimSpace(protocols[1][25:]) + } else if protocols[0] == "token" { + tokenString = strings.TrimSpace(protocols[1]) + } } else { tokenString = c.GetHeader(types.UserAuthHeader) } @@ -221,7 +228,6 @@ func needLogin(c *gin.Context) bool { c.Request.URL.Path == "/api/suno/detail" || c.Request.URL.Path == "/api/suno/play" || c.Request.URL.Path == "/api/download" || - c.Request.URL.Path == "/api/realtime" || strings.HasPrefix(c.Request.URL.Path, "/api/test") || strings.HasPrefix(c.Request.URL.Path, "/api/payment/notify/") || strings.HasPrefix(c.Request.URL.Path, "/api/user/clogin") || diff --git a/api/handler/admin/redeem_handler.go b/api/handler/admin/redeem_handler.go index 97fc3620..3bce461c 100644 --- a/api/handler/admin/redeem_handler.go +++ b/api/handler/admin/redeem_handler.go @@ -146,19 +146,15 @@ func (h *RedeemHandler) Set(c *gin.Context) { } func (h *RedeemHandler) Remove(c *gin.Context) { - var data struct { - Id uint - } - if err := c.ShouldBindJSON(&data); err != nil { + id := h.GetInt(c, "id", 0) + if id <= 0 { resp.ERROR(c, types.InvalidArgs) return } - if data.Id > 0 { - err := h.DB.Where("id", data.Id).Delete(&model.Redeem{}).Error - if err != nil { - resp.ERROR(c, err.Error()) - return - } + err := h.DB.Where("id", id).Delete(&model.Redeem{}).Error + if err != nil { + resp.ERROR(c, err.Error()) + return } resp.SUCCESS(c) } diff --git a/api/handler/realtime_handler.go b/api/handler/realtime_handler.go index 0d9a7e51..9cb49859 100644 --- a/api/handler/realtime_handler.go +++ b/api/handler/realtime_handler.go @@ -1,10 +1,15 @@ package handler import ( + "fmt" + "geekai/core" + "geekai/store/model" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" - "log" + "gorm.io/gorm" "net/http" + "strings" + "time" ) // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ @@ -14,28 +19,34 @@ import ( // * @Author yangjian102621@163.com // * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ -// 实时 API 中继器 +// OpenAI Realtime API Relay Server type RealtimeHandler struct { BaseHandler } -func NewRealtimeHandler() *RealtimeHandler { - return &RealtimeHandler{} +func NewRealtimeHandler(server *core.AppServer, db *gorm.DB) *RealtimeHandler { + return &RealtimeHandler{BaseHandler{App: server, DB: db}} } func (h *RealtimeHandler) Connection(c *gin.Context) { // 获取客户端请求中指定的子协议 clientProtocols := c.GetHeader("Sec-WebSocket-Protocol") - logger.Info(clientProtocols) + md := c.Query("model") - // 升级HTTP连接为WebSocket,并传入客户端请求的子协议 - upgrader := websocket.Upgrader{ - CheckOrigin: func(r *http.Request) bool { return true }, - Subprotocols: []string{clientProtocols}, + userId := h.GetLoginUserId(c) + var user model.User + if err := h.DB.Where("id", userId).First(&user).Error; err != nil { + c.Abort() + return } - ws, err := upgrader.Upgrade(c.Writer, c.Request, nil) + // 将 HTTP 协议升级为 Websocket 协议 + subProtocols := strings.Split(clientProtocols, ",") + ws, err := (&websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + Subprotocols: subProtocols, + }).Upgrade(c.Writer, c.Request, nil) if err != nil { logger.Error(err) c.Abort() @@ -43,29 +54,46 @@ func (h *RealtimeHandler) Connection(c *gin.Context) { } defer ws.Close() + // 目前只针对 VIP 用户可以访问 + if !user.Vip { + sendError(ws, "当前功能只针对 VIP 用户开放") + c.Abort() + return + } + + var apiKey model.ApiKey + h.DB.Where("type", "realtime").Where("enabled", true).Order("last_used_at ASC").First(&apiKey) + if apiKey.Id == 0 { + sendError(ws, "管理员未配置 Realtime API KEY") + c.Abort() + return + } + + apiURL := fmt.Sprintf("%s/v1/realtime?model=%s", apiKey.ApiURL, md) // 连接到真实的后端服务器,传入相同的子协议 headers := http.Header{} + // 修正子协议内容 + subProtocols[1] = "openai-insecure-api-key." + apiKey.Value if clientProtocols != "" { - headers.Set("Sec-WebSocket-Protocol", clientProtocols) + headers.Set("Sec-WebSocket-Protocol", strings.Join(subProtocols, ",")) } - for key, values := range headers { - for _, value := range values { - logger.Infof("%s: %s", key, value) - } - } - backendConn, _, err := websocket.DefaultDialer.Dial("wss://api.geekai.pro/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01", headers) + backendConn, _, err := websocket.DefaultDialer.Dial(apiURL, headers) if err != nil { - log.Printf("Failed to connect to backend: %v", err) + sendError(ws, "桥接后端 API 失败:"+err.Error()) + c.Abort() return } defer backendConn.Close() - //logger.Info(ws.Subprotocol(), ",", backendConn.Subprotocol()) - //// 确保协议一致性,如果失败返回 - //if ws.Subprotocol() != backendConn.Subprotocol() { - // log.Println("Subprotocol mismatch") - // return - //} + // 确保协议一致性,如果失败返回 + if ws.Subprotocol() != backendConn.Subprotocol() { + sendError(ws, "Websocket 子协议不匹配") + c.Abort() + return + } + + // 更新API KEY 最后使用时间 + h.DB.Model(&model.ApiKey{}).Where("id", apiKey.Id).UpdateColumn("last_used_at", time.Now().Unix()) // 开始双向转发 errorChan := make(chan error, 2) @@ -73,8 +101,8 @@ func (h *RealtimeHandler) Connection(c *gin.Context) { go relay(backendConn, ws, errorChan) // 等待其中一个连接关闭 - <-errorChan - log.Println("Relay ended") + err = <-errorChan + logger.Infof("Relay ended: %v", err) } func relay(src, dst *websocket.Conn, errorChan chan error) { @@ -92,82 +120,9 @@ func relay(src, dst *websocket.Conn, errorChan chan error) { } } -//func (h *RealtimeHandler) handleMessage(client *RealtimeClient, message []byte) { -// var event Event -// err := json.Unmarshal(message, &event) -// if err != nil { -// logger.Infof("Error parsing event from client: %s", message) -// return -// } -// logger.Infof("Relaying %q to OpenAI", event.Type) -// client.Send(event) -//} -// -//func relay(src, dst *websocket.Conn, errorChan chan error) { -// for { -// messageType, message, err := src.ReadMessage() -// if err != nil { -// errorChan <- err -// return -// } -// err = dst.WriteMessage(messageType, message) -// if err != nil { -// errorChan <- err -// return -// } -// } -//} -// -//func NewRealtimeClient(apiKey string) *RealtimeClient { -// return &RealtimeClient{ -// APIKey: apiKey, -// send: make(chan Event, 100), -// } -//} -// -//func (rc *RealtimeClient) Connect() error { -// u := url.URL{Scheme: "wss", Host: "api.geekai.pro", Path: "v1/realtime", RawQuery: "model=gpt-4o-realtime-preview-2024-10-01"} -// c, _, err := websocket.DefaultDialer.Dial(u.String(), nil) -// if err != nil { -// return err -// } -// rc.conn = c -// -// go rc.readPump() -// go rc.writePump() -// -// return nil -//} -// -//func (rc *RealtimeClient) readPump() { -// defer rc.conn.Close() -// for { -// _, message, err := rc.conn.ReadMessage() -// if err != nil { -// log.Println("read error:", err) -// return -// } -// var event Event -// err = json.Unmarshal(message, &event) -// if err != nil { -// log.Println("parse error:", err) -// continue -// } -// rc.send <- event -// } -//} -// -//func (rc *RealtimeClient) writePump() { -// defer rc.conn.Close() -// for event := range rc.send { -// err := rc.conn.WriteJSON(event) -// if err != nil { -// log.Println("write error:", err) -// return -// } -// } -//} -// -//func (rc *RealtimeClient) Send(event Event) { -// rc.send <- event -//} +func sendError(ws *websocket.Conn, message string) { + err := ws.WriteJSON(map[string]string{"event_id": "event_01", "type": "error", "error": message}) + if err != nil { + logger.Error(err) + } +} diff --git a/api/handler/ws_handler.go b/api/handler/ws_handler.go index 1835f8ab..bac48ad3 100644 --- a/api/handler/ws_handler.go +++ b/api/handler/ws_handler.go @@ -18,6 +18,7 @@ import ( "github.com/gorilla/websocket" "gorm.io/gorm" "net/http" + "strings" ) // Websocket 连接处理 handler @@ -37,7 +38,11 @@ func NewWebsocketHandler(app *core.AppServer, s *service.WebsocketService, db *g } func (h *WebsocketHandler) Client(c *gin.Context) { - ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil) + clientProtocols := c.GetHeader("Sec-WebSocket-Protocol") + ws, err := (&websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { return true }, + Subprotocols: strings.Split(clientProtocols, ","), + }).Upgrade(c.Writer, c.Request, nil) if err != nil { logger.Error(err) c.Abort() diff --git a/api/main.go b/api/main.go index 34d01a0e..8fd36c82 100644 --- a/api/main.go +++ b/api/main.go @@ -349,7 +349,7 @@ func main() { group.GET("list", h.List) group.POST("create", h.Create) group.POST("set", h.Set) - group.POST("remove", h.Remove) + group.GET("remove", h.Remove) }), fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) { group := s.Engine.Group("/api/admin/dashboard/") diff --git a/web/src/App.vue b/web/src/App.vue index cbde045c..f5c6b429 100644 --- a/web/src/App.vue +++ b/web/src/App.vue @@ -12,8 +12,6 @@ import {isChrome, isMobile} from "@/utils/libs"; import {showMessageInfo} from "@/utils/dialog"; import {useSharedStore} from "@/store/sharedata"; import {getUserToken} from "@/store/session"; -import {router} from "@/router"; -import {onBeforeRouteLeave, onBeforeRouteUpdate} from "vue-router"; const debounce = (fn, delay) => { let timer @@ -71,7 +69,7 @@ const connect = () => { } } const clientId = getClientId() - const _socket = new WebSocket(host + `/api/ws?client_id=${clientId}&token=${getUserToken()}`); + const _socket = new WebSocket(host + `/api/ws?client_id=${clientId}`,["token",getUserToken()]); _socket.addEventListener('open', () => { console.log('WebSocket 已连接') handler.value = setInterval(() => { diff --git a/web/src/assets/css/realtime.styl b/web/src/assets/css/realtime.styl index 78903246..18f5337a 100644 --- a/web/src/assets/css/realtime.styl +++ b/web/src/assets/css/realtime.styl @@ -9,7 +9,7 @@ margin: 0; overflow: hidden; font-family: Arial, sans-serif; - width 100vw + width 100% .phone-container { position: relative; @@ -90,7 +90,7 @@ justify-content: space-between; align-items: center; padding: 0; - width 100vw + width 100% .wave-container { padding 3rem @@ -148,13 +148,6 @@ flex-flow row justify-content: space-between; width 100% - - .left { - margin-left 3rem - } - .right { - margin-right 3rem - } } .call-controls { diff --git a/web/src/components/RealtimeConversation .vue b/web/src/components/RealtimeConversation.vue similarity index 65% rename from web/src/components/RealtimeConversation .vue rename to web/src/components/RealtimeConversation.vue index 8ffd4147..a13da6cf 100644 --- a/web/src/components/RealtimeConversation .vue +++ b/web/src/components/RealtimeConversation.vue @@ -54,6 +54,7 @@ import { WavRecorder, WavStreamPlayer } from '@/lib/wavtools/index.js'; import { instructions } from '@/utils/conversation_config.js'; import { WavRenderer } from '@/utils/wav_renderer'; import {showMessageError} from "@/utils/dialog"; +import {getUserToken} from "@/store/session"; // eslint-disable-next-line no-unused-vars,no-undef const props = defineProps({ @@ -73,7 +74,7 @@ const typeText = () => { if (index < fullText.length) { connectingText.value += fullText[index]; index++; - setTimeout(typeText, 300); // 每300毫秒显示一个字 + setTimeout(typeText, 200); // 每300毫秒显示一个字 } else { setTimeout(() => { connectingText.value = ''; @@ -97,10 +98,18 @@ const animateVoice = () => { const wavRecorder = ref(new WavRecorder({ sampleRate: 24000 })); const wavStreamPlayer = ref(new WavStreamPlayer({ sampleRate: 24000 })); +let host = process.env.VUE_APP_WS_HOST +if (host === '') { + if (location.protocol === 'https:') { + host = 'wss://' + location.host; + } else { + host = 'ws://' + location.host; + } +} const client = ref( new RealtimeClient({ - url: "ws://localhost:5678/api/realtime", - apiKey: "sk-Gc5cEzDzGQLIqxWA9d62089350F3454bB359C4A3Fa21B3E4", + url: `${host}/api/realtime`, + apiKey: getUserToken(), dangerouslyAllowAPIKeyInBrowser: true, }) ); @@ -115,41 +124,10 @@ client.value.updateSession({ // set voice wave canvas const clientCanvasRef = ref(null); const serverCanvasRef = ref(null); -// const eventsScrollRef = ref(null); -// const startTime = ref(new Date().toISOString()); - -// const items = ref([]); -// const realtimeEvents = ref([]); -// const expandedEvents = reactive({}); const isConnected = ref(false); -// const canPushToTalk = ref(true); const isRecording = ref(false); -// const memoryKv = ref({}); -// const coords = ref({ lat: 37.775593, lng: -122.418137 }); -// const marker = ref(null); - -// Methods -// const formatTime = (timestamp) => { -// const t0 = new Date(startTime.value).valueOf(); -// const t1 = new Date(timestamp).valueOf(); -// const delta = t1 - t0; -// const hs = Math.floor(delta / 10) % 100; -// const s = Math.floor(delta / 1000) % 60; -// const m = Math.floor(delta / 60_000) % 60; -// const pad = (n) => { -// let s = n + ''; -// while (s.length < 2) { -// s = '0' + s; -// } -// return s; -// }; -// return `${pad(m)}:${pad(s)}.${pad(hs)}`; -// }; const connect = async () => { - // startTime.value = new Date().toISOString(); - // realtimeEvents.value = []; - // items.value = client.value.conversation.getItems(); if (isConnected.value) { return } @@ -158,54 +136,54 @@ const connect = async () => { await client.value.connect(); await wavRecorder.value.begin(); await wavStreamPlayer.value.connect(); - isConnected.value = true; console.log("对话连接成功!") + if (!client.value.isConnected()) { + return + } + + isConnected.value = true; client.value.sendUserMessageContent([ { type: 'input_text', text: '你好,我是老阳!', }, ]); - if (client.value.getTurnDetectionType() === 'server_vad') { await wavRecorder.value.record((data) => client.value.appendInputAudio(data.mono)); } } catch (e) { - showMessageError(e.message) + console.error(e) } }; -// const disconnectConversation = async () => { -// isConnected.value = false; -// // realtimeEvents.value = []; -// // items.value = []; -// // memoryKv.value = {}; -// // coords.value = { lat: 37.775593, lng: -122.418137 }; -// // marker.value = null; -// -// client.value.disconnect(); -// await wavRecorder.value.end(); -// await wavStreamPlayer.value.interrupt(); -// }; - -// const deleteConversationItem = async (id) => { -// client.value.deleteItem(id); -// }; - +// 开始语音输入 const startRecording = async () => { - isRecording.value = true; - const trackSampleOffset = await wavStreamPlayer.value.interrupt(); - if (trackSampleOffset?.trackId) { - const { trackId, offset } = trackSampleOffset; - client.value.cancelResponse(trackId, offset); + if (isRecording.value) { + return + } + + isRecording.value = true; + try { + const trackSampleOffset = await wavStreamPlayer.value.interrupt(); + if (trackSampleOffset?.trackId) { + const { trackId, offset } = trackSampleOffset; + client.value.cancelResponse(trackId, offset); + } + await wavRecorder.value.record((data) => client.value.appendInputAudio(data.mono)); + } catch (e) { + console.error(e) } - await wavRecorder.value.record((data) => client.value.appendInputAudio(data.mono)); }; +// 结束语音输入 const stopRecording = async () => { - isRecording.value = false; - await wavRecorder.value.pause(); - client.value.createResponse(); + try { + isRecording.value = false; + await wavRecorder.value.pause(); + client.value.createResponse(); + } catch (e) { + console.error(e) + } }; // const changeTurnEndType = async (value) => { @@ -220,16 +198,8 @@ const stopRecording = async () => { // } // canPushToTalk.value = value === 'none'; // }; -// -// const toggleEventDetails = (eventId) => { -// if (expandedEvents[eventId]) { -// delete expandedEvents[eventId]; -// } else { -// expandedEvents[eventId] = true; -// } -// }; -// Lifecycle hooks and watchers +// 初始化 WaveRecorder 组件和 RealtimeClient 事件处理 const initialize = async () => { // Set up render loops for the visualization canvas let isLoaded = true; @@ -270,21 +240,15 @@ const initialize = async () => { }; render(); - - // Set up client event listeners - client.value.on('realtime.event', (realtimeEvent) => { - // realtimeEvents.value = realtimeEvents.value.slice(); - // const lastEvent = realtimeEvents.value[realtimeEvents.value.length - 1]; - // if (lastEvent?.event.type === realtimeEvent.event.type) { - // lastEvent.count = (lastEvent.count || 0) + 1; - // realtimeEvents.value.splice(-1, 1, lastEvent); - // } else { - // realtimeEvents.value.push(realtimeEvent); - // } - // console.log(realtimeEvent) + client.value.on('error', (event) => { + showMessageError(event.error) }); - client.value.on('error', (event) => console.error(event)); + client.value.on('realtime.event', (re) => { + if (re.event.type === 'error') { + showMessageError(re.event.error) + } + }); client.value.on('conversation.interrupted', async () => { const trackSampleOffset = await wavStreamPlayer.value.interrupt(); @@ -295,41 +259,19 @@ const initialize = async () => { }); client.value.on('conversation.updated', async ({ item, delta }) => { - console.log('item updated', item, delta) + // console.log('item updated', item, delta) if (delta?.audio) { wavStreamPlayer.value.add16BitPCM(delta.audio, item.id); } - if (item.status === 'completed' && item.formatted.audio?.length) { - const wavFile = await WavRecorder.decode( - item.formatted.audio, - 24000, - 24000 - ); - item.formatted.file = wavFile; - } }); } -// Watchers -// watch(realtimeEvents, () => { -// if (eventsScrollRef.value) { -// const eventsEl = eventsScrollRef.value; -// eventsEl.scrollTop = eventsEl.scrollHeight; -// } -// }); - -// watch(items, () => { -// const conversationEls = document.querySelectorAll('[data-conversation-content]'); -// conversationEls.forEach((el) => { -// el.scrollTop = el.scrollHeight; -// }); -// }); - const voiceInterval = ref(null); onMounted(() => { initialize() - voiceInterval.value = setInterval(animateVoice, 500); + // 启动聊天进行中的动画 + voiceInterval.value = setInterval(animateVoice, 200); typeText() }); @@ -338,16 +280,21 @@ onUnmounted(() => { client.value.reset(); }); +// 挂断通话 const hangUp = async () => { - emits('close') - isConnected.value = false; - client.value.disconnect(); - await wavRecorder.value.end(); - await wavStreamPlayer.value.interrupt(); + try { + isConnected.value = false; + client.value.disconnect(); + await wavRecorder.value.end(); + await wavStreamPlayer.value.interrupt(); + emits('close') + } catch (e) { + console.error(e) + } }; // eslint-disable-next-line no-undef -defineExpose({ connect }); +defineExpose({ connect,hangUp });