Update dto

(cherry picked from commit 030187ff75c64c40017cda2fa98ef2b3c01f0bd5)
This commit is contained in:
1808837298@qq.com 2024-10-03 20:46:00 +08:00 committed by CalciumIon
parent 4b48e490fa
commit e3c85572d4
6 changed files with 182 additions and 26 deletions

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"io" "io"
"log" "log"
"net/http" "net/http"
@ -134,6 +135,62 @@ func Relay(c *gin.Context) {
} }
} }
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true // 允许跨域
},
}
func WssRelay(c *gin.Context) {
// 将 HTTP 连接升级为 WebSocket 连接
ws, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError)
service.WssError(c, ws, openaiErr.Error)
return
}
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
requestId := c.GetString(common.RequestIdKey)
group := c.GetString("group")
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
originalModel := c.GetString("original_model")
var openaiErr *dto.OpenAIErrorWithStatusCode
for i := 0; i <= common.RetryTimes; i++ {
channel, err := getChannel(c, group, originalModel, i)
if err != nil {
common.LogError(c, err.Error())
openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError)
break
}
openaiErr = relayRequest(c, relayMode, channel)
if openaiErr == nil {
return // 成功处理请求,直接返回
}
go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr)
if !shouldRetry(c, openaiErr, common.RetryTimes-i) {
break
}
}
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c, retryLogStr)
}
if openaiErr != nil {
if openaiErr.StatusCode == http.StatusTooManyRequests {
openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
}
openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId)
service.WssError(c, ws, openaiErr.Error)
}
}
func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode { func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode {
addUsedChannel(c, channel.Id) addUsedChannel(c, channel.Id)
requestBody, _ := common.GetRequestBody(c) requestBody, _ := common.GetRequestBody(c)

59
dto/realtime.go Normal file
View File

@ -0,0 +1,59 @@
package dto
const (
RealtimeEventTypeError = "error"
RealtimeEventTypeSessionUpdate = "session.update"
RealtimeEventTypeConversationCreate = "conversation.item.create"
RealtimeEventTypeResponseCreate = "response.create"
)
type RealtimeEvent struct {
EventId string `json:"event_id"`
Type string `json:"type"`
//PreviousItemId string `json:"previous_item_id"`
Session *RealtimeSession `json:"session,omitempty"`
Item *RealtimeItem `json:"item,omitempty"`
Error *OpenAIError `json:"error,omitempty"`
}
type RealtimeSession struct {
Modalities []string `json:"modalities"`
Instructions string `json:"instructions"`
Voice string `json:"voice"`
InputAudioFormat string `json:"input_audio_format"`
OutputAudioFormat string `json:"output_audio_format"`
InputAudioTranscription InputAudioTranscription `json:"input_audio_transcription"`
TurnDetection interface{} `json:"turn_detection"`
Tools []RealTimeTool `json:"tools"`
ToolChoice string `json:"tool_choice"`
Temperature float64 `json:"temperature"`
MaxResponseOutputTokens int `json:"max_response_output_tokens"`
}
type InputAudioTranscription struct {
Model string `json:"model"`
}
type RealTimeTool struct {
Type string `json:"type"`
Name string `json:"name"`
Description string `json:"description"`
Parameters any `json:"parameters"`
}
type RealtimeItem struct {
Id string `json:"id"`
Type string `json:"type"`
Status string `json:"status"`
Role string `json:"role"`
Content RealtimeContent `json:"content"`
Name *string `json:"name,omitempty"`
ToolCalls any `json:"tool_calls,omitempty"`
CallId string `json:"call_id,omitempty"`
}
type RealtimeContent struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
Audio string `json:"audio,omitempty"` // Base64-encoded audio bytes.
Transcript string `json:"transcript,omitempty"`
}

View File

@ -170,6 +170,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error()) abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
return nil, false, errors.New("无效的请求, " + err.Error()) return nil, false, errors.New("无效的请求, " + err.Error())
} }
if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") {
//wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01
modelRequest.Model = c.Query("model")
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" { if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable" modelRequest.Model = "text-moderation-stable"

View File

@ -38,6 +38,8 @@ const (
RelayModeSunoSubmit RelayModeSunoSubmit
RelayModeRerank RelayModeRerank
RelayModeRealtime
) )
func Path2RelayMode(path string) int { func Path2RelayMode(path string) int {
@ -64,6 +66,8 @@ func Path2RelayMode(path string) int {
relayMode = RelayModeAudioTranslation relayMode = RelayModeAudioTranslation
} else if strings.HasPrefix(path, "/v1/rerank") { } else if strings.HasPrefix(path, "/v1/rerank") {
relayMode = RelayModeRerank relayMode = RelayModeRerank
} else if strings.HasPrefix(path, "/v1/realtime") {
relayMode = RelayModeRealtime
} }
return relayMode return relayMode
} }

View File

@ -22,32 +22,41 @@ func SetRelayRouter(router *gin.Engine) {
playgroundRouter.POST("/chat/completions", controller.Playground) playgroundRouter.POST("/chat/completions", controller.Playground)
} }
relayV1Router := router.Group("/v1") relayV1Router := router.Group("/v1")
relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) relayV1Router.Use(middleware.TokenAuth())
{ {
relayV1Router.POST("/completions", controller.Relay) // WebSocket 路由
relayV1Router.POST("/chat/completions", controller.Relay) wsRouter := relayV1Router.Group("")
relayV1Router.POST("/edits", controller.Relay) wsRouter.Use(middleware.Distribute())
relayV1Router.POST("/images/generations", controller.Relay) wsRouter.GET("/realtime", controller.WssRelay)
relayV1Router.POST("/images/edits", controller.RelayNotImplemented) }
relayV1Router.POST("/images/variations", controller.RelayNotImplemented) {
relayV1Router.POST("/embeddings", controller.Relay) //http router
relayV1Router.POST("/engines/:model/embeddings", controller.Relay) httpRouter := relayV1Router.Group("")
relayV1Router.POST("/audio/transcriptions", controller.Relay) httpRouter.Use(middleware.Distribute())
relayV1Router.POST("/audio/translations", controller.Relay) httpRouter.POST("/completions", controller.Relay)
relayV1Router.POST("/audio/speech", controller.Relay) httpRouter.POST("/chat/completions", controller.Relay)
relayV1Router.GET("/files", controller.RelayNotImplemented) httpRouter.POST("/edits", controller.Relay)
relayV1Router.POST("/files", controller.RelayNotImplemented) httpRouter.POST("/images/generations", controller.Relay)
relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented) httpRouter.POST("/images/edits", controller.RelayNotImplemented)
relayV1Router.GET("/files/:id", controller.RelayNotImplemented) httpRouter.POST("/images/variations", controller.RelayNotImplemented)
relayV1Router.GET("/files/:id/content", controller.RelayNotImplemented) httpRouter.POST("/embeddings", controller.Relay)
relayV1Router.POST("/fine-tunes", controller.RelayNotImplemented) httpRouter.POST("/engines/:model/embeddings", controller.Relay)
relayV1Router.GET("/fine-tunes", controller.RelayNotImplemented) httpRouter.POST("/audio/transcriptions", controller.Relay)
relayV1Router.GET("/fine-tunes/:id", controller.RelayNotImplemented) httpRouter.POST("/audio/translations", controller.Relay)
relayV1Router.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented) httpRouter.POST("/audio/speech", controller.Relay)
relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented) httpRouter.GET("/files", controller.RelayNotImplemented)
relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented) httpRouter.POST("/files", controller.RelayNotImplemented)
relayV1Router.POST("/moderations", controller.Relay) httpRouter.DELETE("/files/:id", controller.RelayNotImplemented)
relayV1Router.POST("/rerank", controller.Relay) httpRouter.GET("/files/:id", controller.RelayNotImplemented)
httpRouter.GET("/files/:id/content", controller.RelayNotImplemented)
httpRouter.POST("/fine-tunes", controller.RelayNotImplemented)
httpRouter.GET("/fine-tunes", controller.RelayNotImplemented)
httpRouter.GET("/fine-tunes/:id", controller.RelayNotImplemented)
httpRouter.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented)
httpRouter.GET("/fine-tunes/:id/events", controller.RelayNotImplemented)
httpRouter.DELETE("/models/:model", controller.RelayNotImplemented)
httpRouter.POST("/moderations", controller.Relay)
httpRouter.POST("/rerank", controller.Relay)
} }
relayMjRouter := router.Group("/mj") relayMjRouter := router.Group("/mj")

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/dto" "one-api/dto"
@ -42,11 +43,33 @@ func Done(c *gin.Context) {
_ = StringData(c, "[DONE]") _ = StringData(c, "[DONE]")
} }
func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error {
jsonData, err := json.Marshal(object)
if err != nil {
return fmt.Errorf("error marshalling object: %w", err)
}
return ws.WriteMessage(1, jsonData)
}
func WssError(c *gin.Context, ws *websocket.Conn, openaiError dto.OpenAIError) {
errorObj := &dto.RealtimeEvent{
Type: "error",
EventId: GetLocalRealtimeID(c),
Error: &openaiError,
}
_ = WssObject(c, ws, errorObj)
}
func GetResponseID(c *gin.Context) string { func GetResponseID(c *gin.Context) string {
logID := c.GetString("X-Oneapi-Request-Id") logID := c.GetString(common.RequestIdKey)
return fmt.Sprintf("chatcmpl-%s", logID) return fmt.Sprintf("chatcmpl-%s", logID)
} }
func GetLocalRealtimeID(c *gin.Context) string {
logID := c.GetString(common.RequestIdKey)
return fmt.Sprintf("evt_%s", logID)
}
func GenerateStopResponse(id string, createAt int64, model string, finishReason string) *dto.ChatCompletionsStreamResponse { func GenerateStopResponse(id string, createAt int64, model string, finishReason string) *dto.ChatCompletionsStreamResponse {
return &dto.ChatCompletionsStreamResponse{ return &dto.ChatCompletionsStreamResponse{
Id: id, Id: id,