the relay server for openai websocket is ready

This commit is contained in:
RockYang
2024-10-17 16:46:41 +08:00
parent e356771049
commit 43c507c597
13 changed files with 184 additions and 263 deletions

View File

@@ -146,19 +146,15 @@ func (h *RedeemHandler) Set(c *gin.Context) {
}
func (h *RedeemHandler) Remove(c *gin.Context) {
var data struct {
Id uint
}
if err := c.ShouldBindJSON(&data); err != nil {
id := h.GetInt(c, "id", 0)
if id <= 0 {
resp.ERROR(c, types.InvalidArgs)
return
}
if data.Id > 0 {
err := h.DB.Where("id", data.Id).Delete(&model.Redeem{}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
err := h.DB.Where("id", id).Delete(&model.Redeem{}).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
resp.SUCCESS(c)
}

View File

@@ -1,10 +1,15 @@
package handler
import (
"fmt"
"geekai/core"
"geekai/store/model"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"log"
"gorm.io/gorm"
"net/http"
"strings"
"time"
)
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
@@ -14,28 +19,34 @@ import (
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// 实时 API 中继器
// OpenAI Realtime API Relay Server
type RealtimeHandler struct {
BaseHandler
}
func NewRealtimeHandler() *RealtimeHandler {
return &RealtimeHandler{}
func NewRealtimeHandler(server *core.AppServer, db *gorm.DB) *RealtimeHandler {
return &RealtimeHandler{BaseHandler{App: server, DB: db}}
}
func (h *RealtimeHandler) Connection(c *gin.Context) {
// 获取客户端请求中指定的子协议
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
logger.Info(clientProtocols)
md := c.Query("model")
// 升级HTTP连接为WebSocket并传入客户端请求的子协议
upgrader := websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
Subprotocols: []string{clientProtocols},
userId := h.GetLoginUserId(c)
var user model.User
if err := h.DB.Where("id", userId).First(&user).Error; err != nil {
c.Abort()
return
}
ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
// 将 HTTP 协议升级为 Websocket 协议
subProtocols := strings.Split(clientProtocols, ",")
ws, err := (&websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
Subprotocols: subProtocols,
}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()
@@ -43,29 +54,46 @@ func (h *RealtimeHandler) Connection(c *gin.Context) {
}
defer ws.Close()
// 目前只针对 VIP 用户可以访问
if !user.Vip {
sendError(ws, "当前功能只针对 VIP 用户开放")
c.Abort()
return
}
var apiKey model.ApiKey
h.DB.Where("type", "realtime").Where("enabled", true).Order("last_used_at ASC").First(&apiKey)
if apiKey.Id == 0 {
sendError(ws, "管理员未配置 Realtime API KEY")
c.Abort()
return
}
apiURL := fmt.Sprintf("%s/v1/realtime?model=%s", apiKey.ApiURL, md)
// 连接到真实的后端服务器,传入相同的子协议
headers := http.Header{}
// 修正子协议内容
subProtocols[1] = "openai-insecure-api-key." + apiKey.Value
if clientProtocols != "" {
headers.Set("Sec-WebSocket-Protocol", clientProtocols)
headers.Set("Sec-WebSocket-Protocol", strings.Join(subProtocols, ","))
}
for key, values := range headers {
for _, value := range values {
logger.Infof("%s: %s", key, value)
}
}
backendConn, _, err := websocket.DefaultDialer.Dial("wss://api.geekai.pro/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01", headers)
backendConn, _, err := websocket.DefaultDialer.Dial(apiURL, headers)
if err != nil {
log.Printf("Failed to connect to backend: %v", err)
sendError(ws, "桥接后端 API 失败:"+err.Error())
c.Abort()
return
}
defer backendConn.Close()
//logger.Info(ws.Subprotocol(), ",", backendConn.Subprotocol())
//// 确保协议一致性,如果失败返回
//if ws.Subprotocol() != backendConn.Subprotocol() {
// log.Println("Subprotocol mismatch")
// return
//}
// 确保协议一致性,如果失败返回
if ws.Subprotocol() != backendConn.Subprotocol() {
sendError(ws, "Websocket 子协议不匹配")
c.Abort()
return
}
// 更新API KEY 最后使用时间
h.DB.Model(&model.ApiKey{}).Where("id", apiKey.Id).UpdateColumn("last_used_at", time.Now().Unix())
// 开始双向转发
errorChan := make(chan error, 2)
@@ -73,8 +101,8 @@ func (h *RealtimeHandler) Connection(c *gin.Context) {
go relay(backendConn, ws, errorChan)
// 等待其中一个连接关闭
<-errorChan
log.Println("Relay ended")
err = <-errorChan
logger.Infof("Relay ended: %v", err)
}
func relay(src, dst *websocket.Conn, errorChan chan error) {
@@ -92,82 +120,9 @@ func relay(src, dst *websocket.Conn, errorChan chan error) {
}
}
//func (h *RealtimeHandler) handleMessage(client *RealtimeClient, message []byte) {
// var event Event
// err := json.Unmarshal(message, &event)
// if err != nil {
// logger.Infof("Error parsing event from client: %s", message)
// return
// }
// logger.Infof("Relaying %q to OpenAI", event.Type)
// client.Send(event)
//}
//
//func relay(src, dst *websocket.Conn, errorChan chan error) {
// for {
// messageType, message, err := src.ReadMessage()
// if err != nil {
// errorChan <- err
// return
// }
// err = dst.WriteMessage(messageType, message)
// if err != nil {
// errorChan <- err
// return
// }
// }
//}
//
//func NewRealtimeClient(apiKey string) *RealtimeClient {
// return &RealtimeClient{
// APIKey: apiKey,
// send: make(chan Event, 100),
// }
//}
//
//func (rc *RealtimeClient) Connect() error {
// u := url.URL{Scheme: "wss", Host: "api.geekai.pro", Path: "v1/realtime", RawQuery: "model=gpt-4o-realtime-preview-2024-10-01"}
// c, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
// if err != nil {
// return err
// }
// rc.conn = c
//
// go rc.readPump()
// go rc.writePump()
//
// return nil
//}
//
//func (rc *RealtimeClient) readPump() {
// defer rc.conn.Close()
// for {
// _, message, err := rc.conn.ReadMessage()
// if err != nil {
// log.Println("read error:", err)
// return
// }
// var event Event
// err = json.Unmarshal(message, &event)
// if err != nil {
// log.Println("parse error:", err)
// continue
// }
// rc.send <- event
// }
//}
//
//func (rc *RealtimeClient) writePump() {
// defer rc.conn.Close()
// for event := range rc.send {
// err := rc.conn.WriteJSON(event)
// if err != nil {
// log.Println("write error:", err)
// return
// }
// }
//}
//
//func (rc *RealtimeClient) Send(event Event) {
// rc.send <- event
//}
func sendError(ws *websocket.Conn, message string) {
err := ws.WriteJSON(map[string]string{"event_id": "event_01", "type": "error", "error": message})
if err != nil {
logger.Error(err)
}
}

View File

@@ -18,6 +18,7 @@ import (
"github.com/gorilla/websocket"
"gorm.io/gorm"
"net/http"
"strings"
)
// Websocket 连接处理 handler
@@ -37,7 +38,11 @@ func NewWebsocketHandler(app *core.AppServer, s *service.WebsocketService, db *g
}
func (h *WebsocketHandler) Client(c *gin.Context) {
ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
clientProtocols := c.GetHeader("Sec-WebSocket-Protocol")
ws, err := (&websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool { return true },
Subprotocols: strings.Split(clientProtocols, ","),
}).Upgrade(c.Writer, c.Request, nil)
if err != nil {
logger.Error(err)
c.Abort()