diff --git a/controller/relay.go b/controller/relay.go index 4a49d51..f891000 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" "io" "log" "net/http" @@ -134,6 +135,62 @@ func Relay(c *gin.Context) { } } +var upgrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true // 允许跨域 + }, +} + +func WssRelay(c *gin.Context) { + // 将 HTTP 连接升级为 WebSocket 连接 + ws, err := upgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError) + service.WssError(c, ws, openaiErr.Error) + return + } + relayMode := constant.Path2RelayMode(c.Request.URL.Path) + requestId := c.GetString(common.RequestIdKey) + group := c.GetString("group") + //wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01 + originalModel := c.GetString("original_model") + var openaiErr *dto.OpenAIErrorWithStatusCode + + for i := 0; i <= common.RetryTimes; i++ { + channel, err := getChannel(c, group, originalModel, i) + if err != nil { + common.LogError(c, err.Error()) + openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError) + break + } + + openaiErr = relayRequest(c, relayMode, channel) + + if openaiErr == nil { + return // 成功处理请求,直接返回 + } + + go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr) + + if !shouldRetry(c, openaiErr, common.RetryTimes-i) { + break + } + } + useChannel := c.GetStringSlice("use_channel") + if len(useChannel) > 1 { + retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) + common.LogInfo(c, retryLogStr) + } + + if openaiErr != nil { + if openaiErr.StatusCode == http.StatusTooManyRequests { + openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试" + } + openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId) + service.WssError(c, ws, openaiErr.Error) + } +} + func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode { addUsedChannel(c, channel.Id) requestBody, _ := common.GetRequestBody(c) diff --git a/dto/realtime.go b/dto/realtime.go new file mode 100644 index 0000000..dbc3f40 --- /dev/null +++ b/dto/realtime.go @@ -0,0 +1,59 @@ +package dto + +const ( + RealtimeEventTypeError = "error" + RealtimeEventTypeSessionUpdate = "session.update" + RealtimeEventTypeConversationCreate = "conversation.item.create" + RealtimeEventTypeResponseCreate = "response.create" +) + +type RealtimeEvent struct { + EventId string `json:"event_id"` + Type string `json:"type"` + //PreviousItemId string `json:"previous_item_id"` + Session *RealtimeSession `json:"session,omitempty"` + Item *RealtimeItem `json:"item,omitempty"` + Error *OpenAIError `json:"error,omitempty"` +} + +type RealtimeSession struct { + Modalities []string `json:"modalities"` + Instructions string `json:"instructions"` + Voice string `json:"voice"` + InputAudioFormat string `json:"input_audio_format"` + OutputAudioFormat string `json:"output_audio_format"` + InputAudioTranscription InputAudioTranscription `json:"input_audio_transcription"` + TurnDetection interface{} `json:"turn_detection"` + Tools []RealTimeTool `json:"tools"` + ToolChoice string `json:"tool_choice"` + Temperature float64 `json:"temperature"` + MaxResponseOutputTokens int `json:"max_response_output_tokens"` +} + +type InputAudioTranscription struct { + Model string `json:"model"` +} + +type RealTimeTool struct { + Type string `json:"type"` + Name string `json:"name"` + Description string `json:"description"` + Parameters any `json:"parameters"` +} + +type RealtimeItem struct { + Id string `json:"id"` + Type string `json:"type"` + Status string `json:"status"` + Role string `json:"role"` + Content RealtimeContent `json:"content"` + Name *string `json:"name,omitempty"` + ToolCalls any `json:"tool_calls,omitempty"` + CallId string `json:"call_id,omitempty"` +} +type RealtimeContent struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + Audio string `json:"audio,omitempty"` // Base64-encoded audio bytes. + Transcript string `json:"transcript,omitempty"` +} diff --git a/middleware/distributor.go b/middleware/distributor.go index f1f64ca..9bab3e9 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -170,6 +170,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error()) return nil, false, errors.New("无效的请求, " + err.Error()) } + if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") { + //wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01 + modelRequest.Model = c.Query("model") + } if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { if modelRequest.Model == "" { modelRequest.Model = "text-moderation-stable" diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index 7fecda1..845166c 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -38,6 +38,8 @@ const ( RelayModeSunoSubmit RelayModeRerank + + RelayModeRealtime ) func Path2RelayMode(path string) int { @@ -64,6 +66,8 @@ func Path2RelayMode(path string) int { relayMode = RelayModeAudioTranslation } else if strings.HasPrefix(path, "/v1/rerank") { relayMode = RelayModeRerank + } else if strings.HasPrefix(path, "/v1/realtime") { + relayMode = RelayModeRealtime } return relayMode } diff --git a/router/relay-router.go b/router/relay-router.go index 0d6cbca..a90c60f 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -22,32 +22,41 @@ func SetRelayRouter(router *gin.Engine) { playgroundRouter.POST("/chat/completions", controller.Playground) } relayV1Router := router.Group("/v1") - relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) + relayV1Router.Use(middleware.TokenAuth()) { - relayV1Router.POST("/completions", controller.Relay) - relayV1Router.POST("/chat/completions", controller.Relay) - relayV1Router.POST("/edits", controller.Relay) - relayV1Router.POST("/images/generations", controller.Relay) - relayV1Router.POST("/images/edits", controller.RelayNotImplemented) - relayV1Router.POST("/images/variations", controller.RelayNotImplemented) - relayV1Router.POST("/embeddings", controller.Relay) - relayV1Router.POST("/engines/:model/embeddings", controller.Relay) - relayV1Router.POST("/audio/transcriptions", controller.Relay) - relayV1Router.POST("/audio/translations", controller.Relay) - relayV1Router.POST("/audio/speech", controller.Relay) - relayV1Router.GET("/files", controller.RelayNotImplemented) - relayV1Router.POST("/files", controller.RelayNotImplemented) - relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented) - relayV1Router.GET("/files/:id", controller.RelayNotImplemented) - relayV1Router.GET("/files/:id/content", controller.RelayNotImplemented) - relayV1Router.POST("/fine-tunes", controller.RelayNotImplemented) - relayV1Router.GET("/fine-tunes", controller.RelayNotImplemented) - relayV1Router.GET("/fine-tunes/:id", controller.RelayNotImplemented) - relayV1Router.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented) - relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented) - relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented) - relayV1Router.POST("/moderations", controller.Relay) - relayV1Router.POST("/rerank", controller.Relay) + // WebSocket 路由 + wsRouter := relayV1Router.Group("") + wsRouter.Use(middleware.Distribute()) + wsRouter.GET("/realtime", controller.WssRelay) + } + { + //http router + httpRouter := relayV1Router.Group("") + httpRouter.Use(middleware.Distribute()) + httpRouter.POST("/completions", controller.Relay) + httpRouter.POST("/chat/completions", controller.Relay) + httpRouter.POST("/edits", controller.Relay) + httpRouter.POST("/images/generations", controller.Relay) + httpRouter.POST("/images/edits", controller.RelayNotImplemented) + httpRouter.POST("/images/variations", controller.RelayNotImplemented) + httpRouter.POST("/embeddings", controller.Relay) + httpRouter.POST("/engines/:model/embeddings", controller.Relay) + httpRouter.POST("/audio/transcriptions", controller.Relay) + httpRouter.POST("/audio/translations", controller.Relay) + httpRouter.POST("/audio/speech", controller.Relay) + httpRouter.GET("/files", controller.RelayNotImplemented) + httpRouter.POST("/files", controller.RelayNotImplemented) + httpRouter.DELETE("/files/:id", controller.RelayNotImplemented) + httpRouter.GET("/files/:id", controller.RelayNotImplemented) + httpRouter.GET("/files/:id/content", controller.RelayNotImplemented) + httpRouter.POST("/fine-tunes", controller.RelayNotImplemented) + httpRouter.GET("/fine-tunes", controller.RelayNotImplemented) + httpRouter.GET("/fine-tunes/:id", controller.RelayNotImplemented) + httpRouter.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented) + httpRouter.GET("/fine-tunes/:id/events", controller.RelayNotImplemented) + httpRouter.DELETE("/models/:model", controller.RelayNotImplemented) + httpRouter.POST("/moderations", controller.Relay) + httpRouter.POST("/rerank", controller.Relay) } relayMjRouter := router.Group("/mj") diff --git a/service/relay.go b/service/relay.go index 924e0bb..0aa2f13 100644 --- a/service/relay.go +++ b/service/relay.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" "net/http" "one-api/common" "one-api/dto" @@ -42,11 +43,33 @@ func Done(c *gin.Context) { _ = StringData(c, "[DONE]") } +func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error { + jsonData, err := json.Marshal(object) + if err != nil { + return fmt.Errorf("error marshalling object: %w", err) + } + return ws.WriteMessage(1, jsonData) +} + +func WssError(c *gin.Context, ws *websocket.Conn, openaiError dto.OpenAIError) { + errorObj := &dto.RealtimeEvent{ + Type: "error", + EventId: GetLocalRealtimeID(c), + Error: &openaiError, + } + _ = WssObject(c, ws, errorObj) +} + func GetResponseID(c *gin.Context) string { - logID := c.GetString("X-Oneapi-Request-Id") + logID := c.GetString(common.RequestIdKey) return fmt.Sprintf("chatcmpl-%s", logID) } +func GetLocalRealtimeID(c *gin.Context) string { + logID := c.GetString(common.RequestIdKey) + return fmt.Sprintf("evt_%s", logID) +} + func GenerateStopResponse(id string, createAt int64, model string, finishReason string) *dto.ChatCompletionsStreamResponse { return &dto.ChatCompletionsStreamResponse{ Id: id,