mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-08 18:23:45 +08:00
the relay server for openai websocket is ready
This commit is contained in:
@@ -127,12 +127,19 @@ func corsMiddleware() gin.HandlerFunc {
|
||||
// 用户授权验证
|
||||
func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
|
||||
var tokenString string
|
||||
isAdminApi := strings.Contains(c.Request.URL.Path, "/api/admin/")
|
||||
if isAdminApi { // 后台管理 API
|
||||
tokenString = c.GetHeader(types.AdminAuthHeader)
|
||||
} else if c.Request.URL.Path == "/api/ws" { // Websocket 连接
|
||||
tokenString = c.Query("token")
|
||||
} else if clientProtocols != "" { // Websocket 连接
|
||||
// 解析子协议内容
|
||||
protocols := strings.Split(clientProtocols, ",")
|
||||
if protocols[0] == "realtime" {
|
||||
tokenString = strings.TrimSpace(protocols[1][25:])
|
||||
} else if protocols[0] == "token" {
|
||||
tokenString = strings.TrimSpace(protocols[1])
|
||||
}
|
||||
} else {
|
||||
tokenString = c.GetHeader(types.UserAuthHeader)
|
||||
}
|
||||
@@ -221,7 +228,6 @@ func needLogin(c *gin.Context) bool {
|
||||
c.Request.URL.Path == "/api/suno/detail" ||
|
||||
c.Request.URL.Path == "/api/suno/play" ||
|
||||
c.Request.URL.Path == "/api/download" ||
|
||||
c.Request.URL.Path == "/api/realtime" ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/payment/notify/") ||
|
||||
strings.HasPrefix(c.Request.URL.Path, "/api/user/clogin") ||
|
||||
|
||||
@@ -146,19 +146,15 @@ func (h *RedeemHandler) Set(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *RedeemHandler) Remove(c *gin.Context) {
|
||||
var data struct {
|
||||
Id uint
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
id := h.GetInt(c, "id", 0)
|
||||
if id <= 0 {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
return
|
||||
}
|
||||
if data.Id > 0 {
|
||||
err := h.DB.Where("id", data.Id).Delete(&model.Redeem{}).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
err := h.DB.Where("id", id).Delete(&model.Redeem{}).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
resp.SUCCESS(c)
|
||||
}
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/store/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"log"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
@@ -14,28 +19,34 @@ import (
|
||||
// * @Author yangjian102621@163.com
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
// 实时 API 中继器
|
||||
// OpenAI Realtime API Relay Server
|
||||
|
||||
type RealtimeHandler struct {
|
||||
BaseHandler
|
||||
}
|
||||
|
||||
func NewRealtimeHandler() *RealtimeHandler {
|
||||
return &RealtimeHandler{}
|
||||
func NewRealtimeHandler(server *core.AppServer, db *gorm.DB) *RealtimeHandler {
|
||||
return &RealtimeHandler{BaseHandler{App: server, DB: db}}
|
||||
}
|
||||
|
||||
func (h *RealtimeHandler) Connection(c *gin.Context) {
|
||||
// 获取客户端请求中指定的子协议
|
||||
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
|
||||
logger.Info(clientProtocols)
|
||||
md := c.Query("model")
|
||||
|
||||
// 升级HTTP连接为WebSocket,并传入客户端请求的子协议
|
||||
upgrader := websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
Subprotocols: []string{clientProtocols},
|
||||
userId := h.GetLoginUserId(c)
|
||||
var user model.User
|
||||
if err := h.DB.Where("id", userId).First(&user).Error; err != nil {
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
// 将 HTTP 协议升级为 Websocket 协议
|
||||
subProtocols := strings.Split(clientProtocols, ",")
|
||||
ws, err := (&websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
Subprotocols: subProtocols,
|
||||
}).Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.Abort()
|
||||
@@ -43,29 +54,46 @@ func (h *RealtimeHandler) Connection(c *gin.Context) {
|
||||
}
|
||||
defer ws.Close()
|
||||
|
||||
// 目前只针对 VIP 用户可以访问
|
||||
if !user.Vip {
|
||||
sendError(ws, "当前功能只针对 VIP 用户开放")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
var apiKey model.ApiKey
|
||||
h.DB.Where("type", "realtime").Where("enabled", true).Order("last_used_at ASC").First(&apiKey)
|
||||
if apiKey.Id == 0 {
|
||||
sendError(ws, "管理员未配置 Realtime API KEY")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
apiURL := fmt.Sprintf("%s/v1/realtime?model=%s", apiKey.ApiURL, md)
|
||||
// 连接到真实的后端服务器,传入相同的子协议
|
||||
headers := http.Header{}
|
||||
// 修正子协议内容
|
||||
subProtocols[1] = "openai-insecure-api-key." + apiKey.Value
|
||||
if clientProtocols != "" {
|
||||
headers.Set("Sec-WebSocket-Protocol", clientProtocols)
|
||||
headers.Set("Sec-WebSocket-Protocol", strings.Join(subProtocols, ","))
|
||||
}
|
||||
for key, values := range headers {
|
||||
for _, value := range values {
|
||||
logger.Infof("%s: %s", key, value)
|
||||
}
|
||||
}
|
||||
backendConn, _, err := websocket.DefaultDialer.Dial("wss://api.geekai.pro/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01", headers)
|
||||
backendConn, _, err := websocket.DefaultDialer.Dial(apiURL, headers)
|
||||
if err != nil {
|
||||
log.Printf("Failed to connect to backend: %v", err)
|
||||
sendError(ws, "桥接后端 API 失败:"+err.Error())
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
defer backendConn.Close()
|
||||
|
||||
//logger.Info(ws.Subprotocol(), ",", backendConn.Subprotocol())
|
||||
//// 确保协议一致性,如果失败返回
|
||||
//if ws.Subprotocol() != backendConn.Subprotocol() {
|
||||
// log.Println("Subprotocol mismatch")
|
||||
// return
|
||||
//}
|
||||
// 确保协议一致性,如果失败返回
|
||||
if ws.Subprotocol() != backendConn.Subprotocol() {
|
||||
sendError(ws, "Websocket 子协议不匹配")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 更新API KEY 最后使用时间
|
||||
h.DB.Model(&model.ApiKey{}).Where("id", apiKey.Id).UpdateColumn("last_used_at", time.Now().Unix())
|
||||
|
||||
// 开始双向转发
|
||||
errorChan := make(chan error, 2)
|
||||
@@ -73,8 +101,8 @@ func (h *RealtimeHandler) Connection(c *gin.Context) {
|
||||
go relay(backendConn, ws, errorChan)
|
||||
|
||||
// 等待其中一个连接关闭
|
||||
<-errorChan
|
||||
log.Println("Relay ended")
|
||||
err = <-errorChan
|
||||
logger.Infof("Relay ended: %v", err)
|
||||
}
|
||||
|
||||
func relay(src, dst *websocket.Conn, errorChan chan error) {
|
||||
@@ -92,82 +120,9 @@ func relay(src, dst *websocket.Conn, errorChan chan error) {
|
||||
}
|
||||
}
|
||||
|
||||
//func (h *RealtimeHandler) handleMessage(client *RealtimeClient, message []byte) {
|
||||
// var event Event
|
||||
// err := json.Unmarshal(message, &event)
|
||||
// if err != nil {
|
||||
// logger.Infof("Error parsing event from client: %s", message)
|
||||
// return
|
||||
// }
|
||||
// logger.Infof("Relaying %q to OpenAI", event.Type)
|
||||
// client.Send(event)
|
||||
//}
|
||||
//
|
||||
//func relay(src, dst *websocket.Conn, errorChan chan error) {
|
||||
// for {
|
||||
// messageType, message, err := src.ReadMessage()
|
||||
// if err != nil {
|
||||
// errorChan <- err
|
||||
// return
|
||||
// }
|
||||
// err = dst.WriteMessage(messageType, message)
|
||||
// if err != nil {
|
||||
// errorChan <- err
|
||||
// return
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//func NewRealtimeClient(apiKey string) *RealtimeClient {
|
||||
// return &RealtimeClient{
|
||||
// APIKey: apiKey,
|
||||
// send: make(chan Event, 100),
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//func (rc *RealtimeClient) Connect() error {
|
||||
// u := url.URL{Scheme: "wss", Host: "api.geekai.pro", Path: "v1/realtime", RawQuery: "model=gpt-4o-realtime-preview-2024-10-01"}
|
||||
// c, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// rc.conn = c
|
||||
//
|
||||
// go rc.readPump()
|
||||
// go rc.writePump()
|
||||
//
|
||||
// return nil
|
||||
//}
|
||||
//
|
||||
//func (rc *RealtimeClient) readPump() {
|
||||
// defer rc.conn.Close()
|
||||
// for {
|
||||
// _, message, err := rc.conn.ReadMessage()
|
||||
// if err != nil {
|
||||
// log.Println("read error:", err)
|
||||
// return
|
||||
// }
|
||||
// var event Event
|
||||
// err = json.Unmarshal(message, &event)
|
||||
// if err != nil {
|
||||
// log.Println("parse error:", err)
|
||||
// continue
|
||||
// }
|
||||
// rc.send <- event
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//func (rc *RealtimeClient) writePump() {
|
||||
// defer rc.conn.Close()
|
||||
// for event := range rc.send {
|
||||
// err := rc.conn.WriteJSON(event)
|
||||
// if err != nil {
|
||||
// log.Println("write error:", err)
|
||||
// return
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//func (rc *RealtimeClient) Send(event Event) {
|
||||
// rc.send <- event
|
||||
//}
|
||||
func sendError(ws *websocket.Conn, message string) {
|
||||
err := ws.WriteJSON(map[string]string{"event_id": "event_01", "type": "error", "error": message})
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/gorilla/websocket"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Websocket 连接处理 handler
|
||||
@@ -37,7 +38,11 @@ func NewWebsocketHandler(app *core.AppServer, s *service.WebsocketService, db *g
|
||||
}
|
||||
|
||||
func (h *WebsocketHandler) Client(c *gin.Context) {
|
||||
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
|
||||
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
|
||||
ws, err := (&websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
Subprotocols: strings.Split(clientProtocols, ","),
|
||||
}).Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
logger.Error(err)
|
||||
c.Abort()
|
||||
|
||||
@@ -349,7 +349,7 @@ func main() {
|
||||
group.GET("list", h.List)
|
||||
group.POST("create", h.Create)
|
||||
group.POST("set", h.Set)
|
||||
group.POST("remove", h.Remove)
|
||||
group.GET("remove", h.Remove)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) {
|
||||
group := s.Engine.Group("/api/admin/dashboard/")
|
||||
|
||||
Reference in New Issue
Block a user