From e3c85572d4275b03392b2c47d073401a6c57938e Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Thu, 3 Oct 2024 20:46:00 +0800 Subject: [PATCH 01/13] Update dto (cherry picked from commit 030187ff75c64c40017cda2fa98ef2b3c01f0bd5) --- controller/relay.go | 57 ++++++++++++++++++++++++++++++++++ dto/realtime.go | 59 ++++++++++++++++++++++++++++++++++++ middleware/distributor.go | 4 +++ relay/constant/relay_mode.go | 4 +++ router/relay-router.go | 59 +++++++++++++++++++++--------------- service/relay.go | 25 ++++++++++++++- 6 files changed, 182 insertions(+), 26 deletions(-) create mode 100644 dto/realtime.go diff --git a/controller/relay.go b/controller/relay.go index 4a49d51..f891000 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" "io" "log" "net/http" @@ -134,6 +135,62 @@ func Relay(c *gin.Context) { } } +var upgrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true // 允许跨域 + }, +} + +func WssRelay(c *gin.Context) { + // 将 HTTP 连接升级为 WebSocket 连接 + ws, err := upgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError) + service.WssError(c, ws, openaiErr.Error) + return + } + relayMode := constant.Path2RelayMode(c.Request.URL.Path) + requestId := c.GetString(common.RequestIdKey) + group := c.GetString("group") + //wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01 + originalModel := c.GetString("original_model") + var openaiErr *dto.OpenAIErrorWithStatusCode + + for i := 0; i <= common.RetryTimes; i++ { + channel, err := getChannel(c, group, originalModel, i) + if err != nil { + common.LogError(c, err.Error()) + openaiErr = service.OpenAIErrorWrapperLocal(err, "get_channel_failed", http.StatusInternalServerError) + break + } + + openaiErr = relayRequest(c, relayMode, channel) + + if openaiErr == nil { + return // 成功处理请求,直接返回 + } + + go processChannelError(c, channel.Id, channel.Type, channel.Name, channel.GetAutoBan(), openaiErr) + + if !shouldRetry(c, openaiErr, common.RetryTimes-i) { + break + } + } + useChannel := c.GetStringSlice("use_channel") + if len(useChannel) > 1 { + retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]")) + common.LogInfo(c, retryLogStr) + } + + if openaiErr != nil { + if openaiErr.StatusCode == http.StatusTooManyRequests { + openaiErr.Error.Message = "当前分组上游负载已饱和,请稍后再试" + } + openaiErr.Error.Message = common.MessageWithRequestId(openaiErr.Error.Message, requestId) + service.WssError(c, ws, openaiErr.Error) + } +} + func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode { addUsedChannel(c, channel.Id) requestBody, _ := common.GetRequestBody(c) diff --git a/dto/realtime.go b/dto/realtime.go new file mode 100644 index 0000000..dbc3f40 --- /dev/null +++ b/dto/realtime.go @@ -0,0 +1,59 @@ +package dto + +const ( + RealtimeEventTypeError = "error" + RealtimeEventTypeSessionUpdate = "session.update" + RealtimeEventTypeConversationCreate = "conversation.item.create" + RealtimeEventTypeResponseCreate = "response.create" +) + +type RealtimeEvent struct { + EventId string `json:"event_id"` + Type string `json:"type"` + //PreviousItemId string `json:"previous_item_id"` + Session *RealtimeSession `json:"session,omitempty"` + Item *RealtimeItem `json:"item,omitempty"` + Error *OpenAIError `json:"error,omitempty"` +} + +type RealtimeSession struct { + Modalities []string `json:"modalities"` + Instructions string `json:"instructions"` + Voice string `json:"voice"` + InputAudioFormat string `json:"input_audio_format"` + OutputAudioFormat string `json:"output_audio_format"` + InputAudioTranscription InputAudioTranscription `json:"input_audio_transcription"` + TurnDetection interface{} `json:"turn_detection"` + Tools []RealTimeTool `json:"tools"` + ToolChoice string `json:"tool_choice"` + Temperature float64 `json:"temperature"` + MaxResponseOutputTokens int `json:"max_response_output_tokens"` +} + +type InputAudioTranscription struct { + Model string `json:"model"` +} + +type RealTimeTool struct { + Type string `json:"type"` + Name string `json:"name"` + Description string `json:"description"` + Parameters any `json:"parameters"` +} + +type RealtimeItem struct { + Id string `json:"id"` + Type string `json:"type"` + Status string `json:"status"` + Role string `json:"role"` + Content RealtimeContent `json:"content"` + Name *string `json:"name,omitempty"` + ToolCalls any `json:"tool_calls,omitempty"` + CallId string `json:"call_id,omitempty"` +} +type RealtimeContent struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + Audio string `json:"audio,omitempty"` // Base64-encoded audio bytes. + Transcript string `json:"transcript,omitempty"` +} diff --git a/middleware/distributor.go b/middleware/distributor.go index f1f64ca..9bab3e9 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -170,6 +170,10 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) { abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error()) return nil, false, errors.New("无效的请求, " + err.Error()) } + if strings.HasPrefix(c.Request.URL.Path, "/v1/realtime") { + //wss://api.openai.com/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01 + modelRequest.Model = c.Query("model") + } if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { if modelRequest.Model == "" { modelRequest.Model = "text-moderation-stable" diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go index 7fecda1..845166c 100644 --- a/relay/constant/relay_mode.go +++ b/relay/constant/relay_mode.go @@ -38,6 +38,8 @@ const ( RelayModeSunoSubmit RelayModeRerank + + RelayModeRealtime ) func Path2RelayMode(path string) int { @@ -64,6 +66,8 @@ func Path2RelayMode(path string) int { relayMode = RelayModeAudioTranslation } else if strings.HasPrefix(path, "/v1/rerank") { relayMode = RelayModeRerank + } else if strings.HasPrefix(path, "/v1/realtime") { + relayMode = RelayModeRealtime } return relayMode } diff --git a/router/relay-router.go b/router/relay-router.go index 0d6cbca..a90c60f 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -22,32 +22,41 @@ func SetRelayRouter(router *gin.Engine) { playgroundRouter.POST("/chat/completions", controller.Playground) } relayV1Router := router.Group("/v1") - relayV1Router.Use(middleware.TokenAuth(), middleware.Distribute()) + relayV1Router.Use(middleware.TokenAuth()) { - relayV1Router.POST("/completions", controller.Relay) - relayV1Router.POST("/chat/completions", controller.Relay) - relayV1Router.POST("/edits", controller.Relay) - relayV1Router.POST("/images/generations", controller.Relay) - relayV1Router.POST("/images/edits", controller.RelayNotImplemented) - relayV1Router.POST("/images/variations", controller.RelayNotImplemented) - relayV1Router.POST("/embeddings", controller.Relay) - relayV1Router.POST("/engines/:model/embeddings", controller.Relay) - relayV1Router.POST("/audio/transcriptions", controller.Relay) - relayV1Router.POST("/audio/translations", controller.Relay) - relayV1Router.POST("/audio/speech", controller.Relay) - relayV1Router.GET("/files", controller.RelayNotImplemented) - relayV1Router.POST("/files", controller.RelayNotImplemented) - relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented) - relayV1Router.GET("/files/:id", controller.RelayNotImplemented) - relayV1Router.GET("/files/:id/content", controller.RelayNotImplemented) - relayV1Router.POST("/fine-tunes", controller.RelayNotImplemented) - relayV1Router.GET("/fine-tunes", controller.RelayNotImplemented) - relayV1Router.GET("/fine-tunes/:id", controller.RelayNotImplemented) - relayV1Router.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented) - relayV1Router.GET("/fine-tunes/:id/events", controller.RelayNotImplemented) - relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented) - relayV1Router.POST("/moderations", controller.Relay) - relayV1Router.POST("/rerank", controller.Relay) + // WebSocket 路由 + wsRouter := relayV1Router.Group("") + wsRouter.Use(middleware.Distribute()) + wsRouter.GET("/realtime", controller.WssRelay) + } + { + //http router + httpRouter := relayV1Router.Group("") + httpRouter.Use(middleware.Distribute()) + httpRouter.POST("/completions", controller.Relay) + httpRouter.POST("/chat/completions", controller.Relay) + httpRouter.POST("/edits", controller.Relay) + httpRouter.POST("/images/generations", controller.Relay) + httpRouter.POST("/images/edits", controller.RelayNotImplemented) + httpRouter.POST("/images/variations", controller.RelayNotImplemented) + httpRouter.POST("/embeddings", controller.Relay) + httpRouter.POST("/engines/:model/embeddings", controller.Relay) + httpRouter.POST("/audio/transcriptions", controller.Relay) + httpRouter.POST("/audio/translations", controller.Relay) + httpRouter.POST("/audio/speech", controller.Relay) + httpRouter.GET("/files", controller.RelayNotImplemented) + httpRouter.POST("/files", controller.RelayNotImplemented) + httpRouter.DELETE("/files/:id", controller.RelayNotImplemented) + httpRouter.GET("/files/:id", controller.RelayNotImplemented) + httpRouter.GET("/files/:id/content", controller.RelayNotImplemented) + httpRouter.POST("/fine-tunes", controller.RelayNotImplemented) + httpRouter.GET("/fine-tunes", controller.RelayNotImplemented) + httpRouter.GET("/fine-tunes/:id", controller.RelayNotImplemented) + httpRouter.POST("/fine-tunes/:id/cancel", controller.RelayNotImplemented) + httpRouter.GET("/fine-tunes/:id/events", controller.RelayNotImplemented) + httpRouter.DELETE("/models/:model", controller.RelayNotImplemented) + httpRouter.POST("/moderations", controller.Relay) + httpRouter.POST("/rerank", controller.Relay) } relayMjRouter := router.Group("/mj") diff --git a/service/relay.go b/service/relay.go index 924e0bb..0aa2f13 100644 --- a/service/relay.go +++ b/service/relay.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" "net/http" "one-api/common" "one-api/dto" @@ -42,11 +43,33 @@ func Done(c *gin.Context) { _ = StringData(c, "[DONE]") } +func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error { + jsonData, err := json.Marshal(object) + if err != nil { + return fmt.Errorf("error marshalling object: %w", err) + } + return ws.WriteMessage(1, jsonData) +} + +func WssError(c *gin.Context, ws *websocket.Conn, openaiError dto.OpenAIError) { + errorObj := &dto.RealtimeEvent{ + Type: "error", + EventId: GetLocalRealtimeID(c), + Error: &openaiError, + } + _ = WssObject(c, ws, errorObj) +} + func GetResponseID(c *gin.Context) string { - logID := c.GetString("X-Oneapi-Request-Id") + logID := c.GetString(common.RequestIdKey) return fmt.Sprintf("chatcmpl-%s", logID) } +func GetLocalRealtimeID(c *gin.Context) string { + logID := c.GetString(common.RequestIdKey) + return fmt.Sprintf("evt_%s", logID) +} + func GenerateStopResponse(id string, createAt int64, model string, finishReason string) *dto.ChatCompletionsStreamResponse { return &dto.ChatCompletionsStreamResponse{ Id: id, From 33af069faeda9bfec67516b2e64c4dcdba16e633 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Fri, 4 Oct 2024 16:08:18 +0800 Subject: [PATCH 02/13] feat: realtime (cherry picked from commit a5529df3e1a4c08a120e8c05203a7d885b0fe8d8) --- common/model-ratio.go | 14 ++ controller/channel-test.go | 15 +- controller/relay.go | 22 ++- dto/realtime.go | 52 ++++-- middleware/auth.go | 19 +++ relay/channel/adapter.go | 6 +- relay/channel/ali/adaptor.go | 12 +- relay/channel/api_request.go | 36 +++- relay/channel/aws/adaptor.go | 6 +- relay/channel/baidu/adaptor.go | 8 +- relay/channel/claude/adaptor.go | 10 +- relay/channel/cloudflare/adaptor.go | 8 +- relay/channel/cohere/adaptor.go | 8 +- relay/channel/dify/adaptor.go | 8 +- relay/channel/gemini/adaptor.go | 8 +- relay/channel/jina/adaptor.go | 8 +- relay/channel/ollama/adaptor.go | 6 +- relay/channel/openai/adaptor.go | 26 ++- relay/channel/openai/relay-openai.go | 104 ++++++++++++ relay/channel/palm/adaptor.go | 8 +- relay/channel/perplexity/adaptor.go | 8 +- relay/channel/siliconflow/adaptor.go | 8 +- relay/channel/tencent/adaptor.go | 14 +- relay/channel/vertex/adaptor.go | 8 +- relay/channel/xunfei/adaptor.go | 6 +- relay/channel/zhipu/adaptor.go | 8 +- relay/channel/zhipu_4v/adaptor.go | 8 +- relay/common/relay_info.go | 9 + relay/relay-audio.go | 12 +- relay/relay-image.go | 12 +- relay/relay-text.go | 12 +- relay/relay_rerank.go | 11 +- relay/websocket.go | 242 +++++++++++++++++++++++++++ service/log.go | 11 ++ service/relay.go | 14 ++ service/token_counter.go | 39 +++++ web/src/components/LogsTable.js | 109 ++++++++---- 37 files changed, 759 insertions(+), 156 deletions(-) create mode 100644 relay/websocket.go diff --git a/common/model-ratio.go b/common/model-ratio.go index 54afbec..c037b8b 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -421,6 +421,20 @@ func GetCompletionRatio(name string) float64 { return 1 } +func GetAudioRatio(name string) float64 { + if strings.HasPrefix(name, "gpt-4o-realtime") { + return 20 + } + return 20 +} + +func GetAudioCompletionRatio(name string) float64 { + if strings.HasPrefix(name, "gpt-4o-realtime") { + return 10 + } + return 10 +} + func GetCompletionRatioMap() map[string]float64 { if CompletionRatio == nil { CompletionRatio = defaultCompletionRatio diff --git a/controller/channel-test.go b/controller/channel-test.go index ff66386..38e5dc7 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -102,17 +102,22 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr if err != nil { return err, nil } - if resp != nil && resp.StatusCode != http.StatusOK { - err := service.RelayErrorHandler(resp) - return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), err + var httpResp *http.Response + if resp != nil { + httpResp = resp.(*http.Response) + if httpResp.StatusCode != http.StatusOK { + err := service.RelayErrorHandler(httpResp) + return fmt.Errorf("status code %d: %s", httpResp.StatusCode, err.Error.Message), err + } } - usage, respErr := adaptor.DoResponse(c, resp, meta) + usageA, respErr := adaptor.DoResponse(c, httpResp, meta) if respErr != nil { return fmt.Errorf("%s", respErr.Error.Message), respErr } - if usage == nil { + if usageA == nil { return errors.New("usage is nil"), nil } + usage := usageA.(dto.Usage) result := w.Result() respBody, err := io.ReadAll(result.Body) if err != nil { diff --git a/controller/relay.go b/controller/relay.go index f891000..fe65d96 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -39,6 +39,15 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode return err } +func wsHandler(c *gin.Context, ws *websocket.Conn, relayMode int) *dto.OpenAIErrorWithStatusCode { + var err *dto.OpenAIErrorWithStatusCode + switch relayMode { + default: + err = relay.TextHelper(c) + } + return err +} + func Playground(c *gin.Context) { var openaiErr *dto.OpenAIErrorWithStatusCode @@ -143,12 +152,16 @@ var upgrader = websocket.Upgrader{ func WssRelay(c *gin.Context) { // 将 HTTP 连接升级为 WebSocket 连接 + ws, err := upgrader.Upgrade(c.Writer, c.Request, nil) + defer ws.Close() + if err != nil { openaiErr := service.OpenAIErrorWrapper(err, "get_channel_failed", http.StatusInternalServerError) service.WssError(c, ws, openaiErr.Error) return } + relayMode := constant.Path2RelayMode(c.Request.URL.Path) requestId := c.GetString(common.RequestIdKey) group := c.GetString("group") @@ -164,7 +177,7 @@ func WssRelay(c *gin.Context) { break } - openaiErr = relayRequest(c, relayMode, channel) + openaiErr = wssRequest(c, ws, relayMode, channel) if openaiErr == nil { return // 成功处理请求,直接返回 @@ -198,6 +211,13 @@ func relayRequest(c *gin.Context, relayMode int, channel *model.Channel) *dto.Op return relayHandler(c, relayMode) } +func wssRequest(c *gin.Context, ws *websocket.Conn, relayMode int, channel *model.Channel) *dto.OpenAIErrorWithStatusCode { + addUsedChannel(c, channel.Id) + requestBody, _ := common.GetRequestBody(c) + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + return relay.WssHelper(c, ws) +} + func addUsedChannel(c *gin.Context, channelId int) { useChannel := c.GetStringSlice("use_channel") useChannel = append(useChannel, fmt.Sprintf("%d", channelId)) diff --git a/dto/realtime.go b/dto/realtime.go index dbc3f40..c470730 100644 --- a/dto/realtime.go +++ b/dto/realtime.go @@ -7,13 +7,41 @@ const ( RealtimeEventTypeResponseCreate = "response.create" ) +const ( + RealtimeEventTypeResponseDone = "response.done" +) + type RealtimeEvent struct { EventId string `json:"event_id"` Type string `json:"type"` //PreviousItemId string `json:"previous_item_id"` - Session *RealtimeSession `json:"session,omitempty"` - Item *RealtimeItem `json:"item,omitempty"` - Error *OpenAIError `json:"error,omitempty"` + Session *RealtimeSession `json:"session,omitempty"` + Item *RealtimeItem `json:"item,omitempty"` + Error *OpenAIError `json:"error,omitempty"` + Response *RealtimeResponse `json:"response,omitempty"` +} + +type RealtimeResponse struct { + Usage *RealtimeUsage `json:"usage"` +} + +type RealtimeUsage struct { + TotalTokens int `json:"total_tokens"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + InputTokenDetails InputTokenDetails `json:"input_token_details"` + OutputTokenDetails OutputTokenDetails `json:"output_token_details"` +} + +type InputTokenDetails struct { + CachedTokens int `json:"cached_tokens"` + TextTokens int `json:"text_tokens"` + AudioTokens int `json:"audio_tokens"` +} + +type OutputTokenDetails struct { + TextTokens int `json:"text_tokens"` + AudioTokens int `json:"audio_tokens"` } type RealtimeSession struct { @@ -27,7 +55,7 @@ type RealtimeSession struct { Tools []RealTimeTool `json:"tools"` ToolChoice string `json:"tool_choice"` Temperature float64 `json:"temperature"` - MaxResponseOutputTokens int `json:"max_response_output_tokens"` + //MaxResponseOutputTokens int `json:"max_response_output_tokens"` } type InputAudioTranscription struct { @@ -42,14 +70,14 @@ type RealTimeTool struct { } type RealtimeItem struct { - Id string `json:"id"` - Type string `json:"type"` - Status string `json:"status"` - Role string `json:"role"` - Content RealtimeContent `json:"content"` - Name *string `json:"name,omitempty"` - ToolCalls any `json:"tool_calls,omitempty"` - CallId string `json:"call_id,omitempty"` + Id string `json:"id"` + Type string `json:"type"` + Status string `json:"status"` + Role string `json:"role"` + Content []RealtimeContent `json:"content"` + Name *string `json:"name,omitempty"` + ToolCalls any `json:"tool_calls,omitempty"` + CallId string `json:"call_id,omitempty"` } type RealtimeContent struct { Type string `json:"type"` diff --git a/middleware/auth.go b/middleware/auth.go index 76f2b6b..53c7079 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -155,8 +155,27 @@ func RootAuth() func(c *gin.Context) { } } +func WssAuth(c *gin.Context) { + +} + func TokenAuth() func(c *gin.Context) { return func(c *gin.Context) { + // 先检测是否为ws + if c.Request.Header.Get("Sec-WebSocket-Protocol") != "" { + // Sec-WebSocket-Protocol: realtime, openai-insecure-api-key.sk-xxx, openai-beta.realtime-v1 + // read sk from Sec-WebSocket-Protocol + key := c.Request.Header.Get("Sec-WebSocket-Protocol") + parts := strings.Split(key, ",") + for _, part := range parts { + part = strings.TrimSpace(part) + if strings.HasPrefix(part, "openai-insecure-api-key") { + key = strings.TrimPrefix(part, "openai-insecure-api-key.") + break + } + } + c.Request.Header.Set("Authorization", "Bearer "+key) + } key := c.Request.Header.Get("Authorization") parts := make([]string, 0) key = strings.TrimPrefix(key, "Bearer ") diff --git a/relay/channel/adapter.go b/relay/channel/adapter.go index 870b2b0..d72db6e 100644 --- a/relay/channel/adapter.go +++ b/relay/channel/adapter.go @@ -12,13 +12,13 @@ type Adaptor interface { // Init IsStream bool Init(info *relaycommon.RelayInfo) GetRequestURL(info *relaycommon.RelayInfo) (string, error) - SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error + SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) - DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) - DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) + DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) + DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) GetModelList() []string GetChannelName() string } diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index ff9d533..aa01ca6 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -32,14 +32,14 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { return fullRequestURL, nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, req) - req.Header.Set("Authorization", "Bearer "+info.ApiKey) + req.Set("Authorization", "Bearer "+info.ApiKey) if info.IsStream { - req.Header.Set("X-DashScope-SSE", "enable") + req.Set("X-DashScope-SSE", "enable") } if c.GetString("plugin") != "" { - req.Header.Set("X-DashScope-Plugin", c.GetString("plugin")) + req.Set("X-DashScope-Plugin", c.GetString("plugin")) } return nil } @@ -72,11 +72,11 @@ func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInf return nil, errors.New("not implemented") } -func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { switch info.RelayMode { case constant.RelayModeImagesGenerations: err, usage = aliImageHandler(c, resp, info) diff --git a/relay/channel/api_request.go b/relay/channel/api_request.go index 423a91d..8b51c53 100644 --- a/relay/channel/api_request.go +++ b/relay/channel/api_request.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" "io" "net/http" "one-api/relay/common" @@ -11,14 +12,16 @@ import ( "one-api/service" ) -func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Request) { +func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) { if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation { // multipart/form-data + } else if info.RelayMode == constant.RelayModeRealtime { + // websocket } else { - req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - req.Header.Set("Accept", c.Request.Header.Get("Accept")) + req.Set("Content-Type", c.Request.Header.Get("Content-Type")) + req.Set("Accept", c.Request.Header.Get("Accept")) if info.IsStream && c.Request.Header.Get("Accept") == "" { - req.Header.Set("Accept", "text/event-stream") + req.Set("Accept", "text/event-stream") } } } @@ -32,7 +35,7 @@ func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody if err != nil { return nil, fmt.Errorf("new request failed: %w", err) } - err = a.SetupRequestHeader(c, req, info) + err = a.SetupRequestHeader(c, &req.Header, info) if err != nil { return nil, fmt.Errorf("setup request header failed: %w", err) } @@ -55,7 +58,7 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod // set form data req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - err = a.SetupRequestHeader(c, req, info) + err = a.SetupRequestHeader(c, &req.Header, info) if err != nil { return nil, fmt.Errorf("setup request header failed: %w", err) } @@ -66,6 +69,27 @@ func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBod return resp, nil } +func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*websocket.Conn, error) { + fullRequestURL, err := a.GetRequestURL(info) + if err != nil { + return nil, fmt.Errorf("get request url failed: %w", err) + } + targetHeader := http.Header{} + err = a.SetupRequestHeader(c, &targetHeader, info) + if err != nil { + return nil, fmt.Errorf("setup request header failed: %w", err) + } + targetHeader.Set("Content-Type", c.Request.Header.Get("Content-Type")) + targetConn, _, err := websocket.DefaultDialer.Dial(fullRequestURL, targetHeader) + if err != nil { + return nil, fmt.Errorf("dial failed to %s: %w", fullRequestURL, err) + } + // send request body + //all, err := io.ReadAll(requestBody) + //err = service.WssString(c, targetConn, string(all)) + return targetConn, nil +} + func doRequest(c *gin.Context, req *http.Request) (*http.Response, error) { resp, err := service.GetHttpClient().Do(req) if err != nil { diff --git a/relay/channel/aws/adaptor.go b/relay/channel/aws/adaptor.go index 875d3dd..be72c04 100644 --- a/relay/channel/aws/adaptor.go +++ b/relay/channel/aws/adaptor.go @@ -37,7 +37,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { return "", nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { return nil } @@ -59,11 +59,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } -func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return nil, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { err, usage = awsStreamHandler(c, resp, info, a.RequestMode) } else { diff --git a/relay/channel/baidu/adaptor.go b/relay/channel/baidu/adaptor.go index cc0be56..3991a5e 100644 --- a/relay/channel/baidu/adaptor.go +++ b/relay/channel/baidu/adaptor.go @@ -98,9 +98,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { return fullRequestURL, nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, req) - req.Header.Set("Authorization", "Bearer "+info.ApiKey) + req.Set("Authorization", "Bearer "+info.ApiKey) return nil } @@ -122,11 +122,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } -func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { err, usage = baiduStreamHandler(c, resp) } else { diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index b9173af..488d87d 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -47,14 +47,14 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, req) - req.Header.Set("x-api-key", info.ApiKey) + req.Set("x-api-key", info.ApiKey) anthropicVersion := c.Request.Header.Get("anthropic-version") if anthropicVersion == "" { anthropicVersion = "2023-06-01" } - req.Header.Set("anthropic-version", anthropicVersion) + req.Set("anthropic-version", anthropicVersion) return nil } @@ -73,11 +73,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } -func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode) } else { diff --git a/relay/channel/cloudflare/adaptor.go b/relay/channel/cloudflare/adaptor.go index a518da8..fc0ec27 100644 --- a/relay/channel/cloudflare/adaptor.go +++ b/relay/channel/cloudflare/adaptor.go @@ -30,9 +30,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, req) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) + req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) return nil } @@ -48,7 +48,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re } } -func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } @@ -78,7 +78,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf return nil, errors.New("not implemented") } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { switch info.RelayMode { case constant.RelayModeEmbeddings: fallthrough diff --git a/relay/channel/cohere/adaptor.go b/relay/channel/cohere/adaptor.go index 3945774..f8b190e 100644 --- a/relay/channel/cohere/adaptor.go +++ b/relay/channel/cohere/adaptor.go @@ -36,9 +36,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, req) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) + req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) return nil } @@ -46,7 +46,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re return requestOpenAI2Cohere(*request), nil } -func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } @@ -54,7 +54,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return requestConvertRerank2Cohere(request), nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.RelayMode == constant.RelayModeRerank { err, usage = cohereRerankHandler(c, resp, info) } else { diff --git a/relay/channel/dify/adaptor.go b/relay/channel/dify/adaptor.go index b582da2..53ba26e 100644 --- a/relay/channel/dify/adaptor.go +++ b/relay/channel/dify/adaptor.go @@ -31,9 +31,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, req) - req.Header.Set("Authorization", "Bearer "+info.ApiKey) + req.Set("Authorization", "Bearer "+info.ApiKey) return nil } @@ -48,11 +48,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } -func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { err, usage = difyStreamHandler(c, resp, info) } else { diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 07cdcfa..437efcc 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -47,9 +47,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, req) - req.Header.Set("x-goog-api-key", info.ApiKey) + req.Set("x-goog-api-key", info.ApiKey) return nil } @@ -64,11 +64,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } -func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { err, usage = GeminiChatStreamHandler(c, resp, info) } else { diff --git a/relay/channel/jina/adaptor.go b/relay/channel/jina/adaptor.go index f296ed0..ad488f2 100644 --- a/relay/channel/jina/adaptor.go +++ b/relay/channel/jina/adaptor.go @@ -37,9 +37,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { return "", errors.New("invalid relay mode") } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, req) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) + req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) return nil } @@ -47,7 +47,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re return request, nil } -func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } @@ -55,7 +55,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return request, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.RelayMode == constant.RelayModeRerank { err, usage = jinaRerankHandler(c, resp) } else if info.RelayMode == constant.RelayModeEmbeddings { diff --git a/relay/channel/ollama/adaptor.go b/relay/channel/ollama/adaptor.go index 1478415..3079840 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/channel/ollama/adaptor.go @@ -37,7 +37,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, req) return nil } @@ -58,11 +58,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } -func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { err, usage = openai.OaiStreamHandler(c, resp, info) } else { diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 8e4cf78..5ac0306 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -31,6 +31,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { + if info.RelayMode == constant.RelayModeRealtime { + // trim https + baseUrl := strings.TrimPrefix(info.BaseUrl, "https://") + baseUrl = strings.TrimPrefix(baseUrl, "http://") + baseUrl = "wss://" + baseUrl + info.BaseUrl = baseUrl + } switch info.ChannelType { case common.ChannelTypeAzure: // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api @@ -54,16 +61,19 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, req) if info.ChannelType == common.ChannelTypeAzure { - req.Header.Set("api-key", info.ApiKey) + req.Set("api-key", info.ApiKey) return nil } if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization { - req.Header.Set("OpenAI-Organization", info.Organization) + req.Set("OpenAI-Organization", info.Organization) + } + req.Set("Authorization", "Bearer "+info.ApiKey) + if info.RelayMode == constant.RelayModeRealtime { + req.Set("openai-beta", "realtime=v1") } - req.Header.Set("Authorization", "Bearer "+info.ApiKey) //if info.ChannelType == common.ChannelTypeOpenRouter { // req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") // req.Header.Set("X-Title", "One API") @@ -131,16 +141,20 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf return request, nil } -func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation { return channel.DoFormRequest(a, c, info, requestBody) + } else if info.RelayMode == constant.RelayModeRealtime { + return channel.DoWssRequest(a, c, info, requestBody) } else { return channel.DoApiRequest(a, c, info, requestBody) } } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { switch info.RelayMode { + case constant.RelayModeRealtime: + err, usage = OpenaiRealtimeHandler(c, info) case constant.RelayModeAudioSpeech: err, usage = OpenaiTTSHandler(c, resp, info) case constant.RelayModeAudioTranslation: diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 87ad7d3..d2eccb3 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -7,6 +7,7 @@ import ( "fmt" "github.com/bytedance/gopkg/util/gopool" "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" "io" "net/http" "one-api/common" @@ -373,3 +374,106 @@ func getTextFromJSON(body []byte) (string, error) { } return whisperResponse.Text, nil } + +func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.RealtimeUsage) { + info.IsStream = true + clientConn := info.ClientWs + targetConn := info.TargetWs + + clientClosed := make(chan struct{}) + targetClosed := make(chan struct{}) + sendChan := make(chan []byte, 100) + receiveChan := make(chan []byte, 100) + errChan := make(chan error, 2) + + usage := &dto.RealtimeUsage{} + + go func() { + for { + select { + case <-c.Done(): + return + default: + _, message, err := clientConn.ReadMessage() + if err != nil { + if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + errChan <- fmt.Errorf("error reading from client: %v", err) + } + close(clientClosed) + return + } + + err = service.WssString(c, targetConn, string(message)) + if err != nil { + errChan <- fmt.Errorf("error writing to target: %v", err) + return + } + + select { + case sendChan <- message: + default: + } + } + } + }() + + go func() { + for { + select { + case <-c.Done(): + return + default: + _, message, err := targetConn.ReadMessage() + if err != nil { + if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + errChan <- fmt.Errorf("error reading from target: %v", err) + } + close(targetClosed) + return + } + info.SetFirstResponseTime() + realtimeEvent := &dto.RealtimeEvent{} + err = json.Unmarshal(message, realtimeEvent) + if err != nil { + errChan <- fmt.Errorf("error unmarshalling message: %v", err) + return + } + + if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone { + realtimeUsage := realtimeEvent.Response.Usage + if realtimeUsage != nil { + usage.TotalTokens += realtimeUsage.TotalTokens + usage.InputTokens += realtimeUsage.InputTokens + usage.OutputTokens += realtimeUsage.OutputTokens + usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens + usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens + usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens + usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens + usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens + } + } + + err = service.WssString(c, clientConn, string(message)) + if err != nil { + errChan <- fmt.Errorf("error writing to client: %v", err) + return + } + + select { + case receiveChan <- message: + default: + } + } + } + }() + + select { + case <-clientClosed: + case <-targetClosed: + case <-errChan: + //return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil + case <-c.Done(): + } + + return nil, usage +} diff --git a/relay/channel/palm/adaptor.go b/relay/channel/palm/adaptor.go index d8c4ffb..9127233 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/channel/palm/adaptor.go @@ -32,9 +32,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, req) - req.Header.Set("x-goog-api-key", info.ApiKey) + req.Set("x-goog-api-key", info.ApiKey) return nil } @@ -49,11 +49,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } -func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { var responseText string err, responseText = palmStreamHandler(c, resp) diff --git a/relay/channel/perplexity/adaptor.go b/relay/channel/perplexity/adaptor.go index e9d07fb..18b66a9 100644 --- a/relay/channel/perplexity/adaptor.go +++ b/relay/channel/perplexity/adaptor.go @@ -32,9 +32,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, req) - req.Header.Set("Authorization", "Bearer "+info.ApiKey) + req.Set("Authorization", "Bearer "+info.ApiKey) return nil } @@ -52,11 +52,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } -func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { err, usage = openai.OaiStreamHandler(c, resp, info) } else { diff --git a/relay/channel/siliconflow/adaptor.go b/relay/channel/siliconflow/adaptor.go index 6906fca..ac722b2 100644 --- a/relay/channel/siliconflow/adaptor.go +++ b/relay/channel/siliconflow/adaptor.go @@ -40,9 +40,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { return "", errors.New("invalid relay mode") } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, req) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) + req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) return nil } @@ -50,7 +50,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, re return request, nil } -func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } @@ -58,7 +58,7 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return request, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { switch info.RelayMode { case constant.RelayModeRerank: err, usage = siliconflowRerankHandler(c, resp) diff --git a/relay/channel/tencent/adaptor.go b/relay/channel/tencent/adaptor.go index 5811c87..d831cc8 100644 --- a/relay/channel/tencent/adaptor.go +++ b/relay/channel/tencent/adaptor.go @@ -43,12 +43,12 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { return fmt.Sprintf("%s/", info.BaseUrl), nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, req) - req.Header.Set("Authorization", a.Sign) - req.Header.Set("X-TC-Action", a.Action) - req.Header.Set("X-TC-Version", a.Version) - req.Header.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10)) + req.Set("Authorization", a.Sign) + req.Set("X-TC-Action", a.Action) + req.Set("X-TC-Version", a.Version) + req.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10)) return nil } @@ -73,11 +73,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } -func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { var responseText string err, responseText = tencentStreamHandler(c, resp) diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 4174d78..c9c9a30 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -107,13 +107,13 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { return "", errors.New("unsupported request mode") } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, req) accessToken, err := getAccessToken(a, info) if err != nil { return err } - req.Header.Set("Authorization", "Bearer "+accessToken) + req.Set("Authorization", "Bearer "+accessToken) return nil } @@ -148,11 +148,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } -func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { switch a.RequestMode { case RequestModeClaude: diff --git a/relay/channel/xunfei/adaptor.go b/relay/channel/xunfei/adaptor.go index f499bec..31d426a 100644 --- a/relay/channel/xunfei/adaptor.go +++ b/relay/channel/xunfei/adaptor.go @@ -33,7 +33,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { return "", nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, req) return nil } @@ -50,14 +50,14 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } -func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { // xunfei's request is not http request, so we don't need to do anything here dummyResp := &http.Response{} dummyResp.StatusCode = http.StatusOK return dummyResp, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { splits := strings.Split(info.ApiKey, "|") if len(splits) != 3 { return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go index f98581f..f0538ed 100644 --- a/relay/channel/zhipu/adaptor.go +++ b/relay/channel/zhipu/adaptor.go @@ -35,10 +35,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.BaseUrl, info.UpstreamModelName, method), nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, req) token := getZhipuToken(info.ApiKey) - req.Header.Set("Authorization", token) + req.Set("Authorization", token) return nil } @@ -56,11 +56,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } -func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { err, usage = zhipuStreamHandler(c, resp) } else { diff --git a/relay/channel/zhipu_4v/adaptor.go b/relay/channel/zhipu_4v/adaptor.go index 5e0906e..3d46b79 100644 --- a/relay/channel/zhipu_4v/adaptor.go +++ b/relay/channel/zhipu_4v/adaptor.go @@ -32,10 +32,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { return fmt.Sprintf("%s/api/paas/v4/chat/completions", info.BaseUrl), nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, req) token := getZhipuToken(info.ApiKey) - req.Header.Set("Authorization", token) + req.Set("Authorization", token) return nil } @@ -53,11 +53,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } -func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { err, usage = openai.OaiStreamHandler(c, resp, info) } else { diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 82b5373..bd29b9c 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -2,6 +2,7 @@ package common import ( "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" "one-api/common" "one-api/relay/constant" "strings" @@ -32,6 +33,14 @@ type RelayInfo struct { BaseUrl string SupportStreamOptions bool ShouldIncludeUsage bool + ClientWs *websocket.Conn + TargetWs *websocket.Conn +} + +func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo { + info := GenRelayInfo(c) + info.ClientWs = ws + return info } func GenRelayInfo(c *gin.Context) *RelayInfo { diff --git a/relay/relay-audio.go b/relay/relay-audio.go index 88455c6..e1d0a70 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -122,19 +122,21 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { if err != nil { return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } - statusCodeMappingStr := c.GetString("status_code_mapping") + + var httpResp *http.Response if resp != nil { - if resp.StatusCode != http.StatusOK { + httpResp = resp.(*http.Response) + if httpResp.StatusCode != http.StatusOK { returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - openaiErr := service.RelayErrorHandler(resp) + openaiErr := service.RelayErrorHandler(httpResp) // reset status code 重置状态码 service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } } - usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) + usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo) if openaiErr != nil { returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) // reset status code 重置状态码 @@ -142,7 +144,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { return openaiErr } - postConsumeQuota(c, relayInfo, audioRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "") + postConsumeQuota(c, relayInfo, audioRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, 0, false, "") return nil } diff --git a/relay/relay-image.go b/relay/relay-image.go index 411a2c3..c114de9 100644 --- a/relay/relay-image.go +++ b/relay/relay-image.go @@ -149,22 +149,24 @@ func ImageHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { requestBody = bytes.NewBuffer(jsonData) statusCodeMappingStr := c.GetString("status_code_mapping") + resp, err := adaptor.DoRequest(c, relayInfo, requestBody) if err != nil { return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } - + var httpResp *http.Response if resp != nil { - relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") - if resp.StatusCode != http.StatusOK { - openaiErr := service.RelayErrorHandler(resp) + httpResp = resp.(*http.Response) + relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") + if httpResp.StatusCode != http.StatusOK { + openaiErr := service.RelayErrorHandler(httpResp) // reset status code 重置状态码 service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } } - _, openaiErr := adaptor.DoResponse(c, resp, relayInfo) + _, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo) if openaiErr != nil { // reset status code 重置状态码 service.ResetStatusCode(openaiErr, statusCodeMappingStr) diff --git a/relay/relay-text.go b/relay/relay-text.go index d9db1ef..7bb7d99 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -180,30 +180,32 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { } statusCodeMappingStr := c.GetString("status_code_mapping") + var httpResp *http.Response resp, err := adaptor.DoRequest(c, relayInfo, requestBody) if err != nil { return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } if resp != nil { - relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") - if resp.StatusCode != http.StatusOK { + httpResp = resp.(*http.Response) + relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") + if httpResp.StatusCode != http.StatusOK { returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - openaiErr := service.RelayErrorHandler(resp) + openaiErr := service.RelayErrorHandler(httpResp) // reset status code 重置状态码 service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } } - usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) + usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo) if openaiErr != nil { returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) // reset status code 重置状态码 service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } - postConsumeQuota(c, relayInfo, textRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") + postConsumeQuota(c, relayInfo, textRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") return nil } diff --git a/relay/relay_rerank.go b/relay/relay_rerank.go index e8bf9d6..4cb1c98 100644 --- a/relay/relay_rerank.go +++ b/relay/relay_rerank.go @@ -99,23 +99,26 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode if err != nil { return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } + + var httpResp *http.Response if resp != nil { - if resp.StatusCode != http.StatusOK { + httpResp = resp.(*http.Response) + if httpResp.StatusCode != http.StatusOK { returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - openaiErr := service.RelayErrorHandler(resp) + openaiErr := service.RelayErrorHandler(httpResp) // reset status code 重置状态码 service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } } - usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) + usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo) if openaiErr != nil { returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) // reset status code 重置状态码 service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } - postConsumeQuota(c, relayInfo, rerankRequest.Model, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "") + postConsumeQuota(c, relayInfo, rerankRequest.Model, usage.(*dto.Usage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, success, "") return nil } diff --git a/relay/websocket.go b/relay/websocket.go new file mode 100644 index 0000000..5bd1e81 --- /dev/null +++ b/relay/websocket.go @@ -0,0 +1,242 @@ +package relay + +import ( + "encoding/json" + "fmt" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "math" + "net/http" + "one-api/common" + "one-api/dto" + "one-api/model" + relaycommon "one-api/relay/common" + "one-api/service" + "strings" + "time" +) + +//func getAndValidateWssRequest(c *gin.Context, ws *websocket.Conn) (*dto.RealtimeEvent, error) { +// _, p, err := ws.ReadMessage() +// if err != nil { +// return nil, err +// } +// realtimeEvent := &dto.RealtimeEvent{} +// err = json.Unmarshal(p, realtimeEvent) +// if err != nil { +// return nil, err +// } +// // save the original request +// if realtimeEvent.Session == nil { +// return nil, errors.New("session object is nil") +// } +// c.Set("first_wss_request", p) +// return realtimeEvent, nil +//} + +func WssHelper(c *gin.Context, ws *websocket.Conn) *dto.OpenAIErrorWithStatusCode { + relayInfo := relaycommon.GenRelayInfoWs(c, ws) + + // get & validate textRequest 获取并验证文本请求 + //realtimeEvent, err := getAndValidateWssRequest(c, ws) + //if err != nil { + // common.LogError(c, fmt.Sprintf("getAndValidateWssRequest failed: %s", err.Error())) + // return service.OpenAIErrorWrapperLocal(err, "invalid_text_request", http.StatusBadRequest) + //} + + // map model name + modelMapping := c.GetString("model_mapping") + //isModelMapped := false + if modelMapping != "" && modelMapping != "{}" { + modelMap := make(map[string]string) + err := json.Unmarshal([]byte(modelMapping), &modelMap) + if err != nil { + return service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) + } + if modelMap[relayInfo.OriginModelName] != "" { + relayInfo.UpstreamModelName = modelMap[relayInfo.OriginModelName] + // set upstream model name + //isModelMapped = true + } + } + //relayInfo.UpstreamModelName = textRequest.Model + modelPrice, getModelPriceSuccess := common.GetModelPrice(relayInfo.UpstreamModelName, false) + groupRatio := common.GetGroupRatio(relayInfo.Group) + + var preConsumedQuota int + var ratio float64 + var modelRatio float64 + //err := service.SensitiveWordsCheck(textRequest) + + //if constant.ShouldCheckPromptSensitive() { + // err = checkRequestSensitive(textRequest, relayInfo) + // if err != nil { + // return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest) + // } + //} + + //promptTokens, err := getWssPromptTokens(realtimeEvent, relayInfo) + //// count messages token error 计算promptTokens错误 + //if err != nil { + // return service.OpenAIErrorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError) + //} + // + if !getModelPriceSuccess { + preConsumedTokens := common.PreConsumedQuota + //if realtimeEvent.Session.MaxResponseOutputTokens != 0 { + // preConsumedTokens = promptTokens + int(realtimeEvent.Session.MaxResponseOutputTokens) + //} + modelRatio = common.GetModelRatio(relayInfo.UpstreamModelName) + ratio = modelRatio * groupRatio + preConsumedQuota = int(float64(preConsumedTokens) * ratio) + } else { + preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio) + } + + // pre-consume quota 预消耗配额 + preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, preConsumedQuota, relayInfo) + if openaiErr != nil { + return openaiErr + } + + adaptor := GetAdaptor(relayInfo.ApiType) + if adaptor == nil { + return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) + } + adaptor.Init(relayInfo) + //var requestBody io.Reader + //firstWssRequest, _ := c.Get("first_wss_request") + //requestBody = bytes.NewBuffer(firstWssRequest.([]byte)) + + statusCodeMappingStr := c.GetString("status_code_mapping") + resp, err := adaptor.DoRequest(c, relayInfo, nil) + if err != nil { + return service.OpenAIErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + } + + if resp != nil { + relayInfo.TargetWs = resp.(*websocket.Conn) + defer relayInfo.TargetWs.Close() + } + + usage, openaiErr := adaptor.DoResponse(c, nil, relayInfo) + if openaiErr != nil { + returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) + // reset status code 重置状态码 + service.ResetStatusCode(openaiErr, statusCodeMappingStr) + return openaiErr + } + postWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") + return nil +} + +func postWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, + usage *dto.RealtimeUsage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, + groupRatio float64, + modelPrice float64, usePrice bool, extraContent string) { + + useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() + textInputTokens := usage.InputTokenDetails.TextTokens + textOutTokens := usage.OutputTokenDetails.TextTokens + + audioInputTokens := usage.InputTokenDetails.AudioTokens + audioOutTokens := usage.OutputTokenDetails.AudioTokens + + tokenName := ctx.GetString("token_name") + completionRatio := common.GetCompletionRatio(modelName) + audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName) + audioCompletionRatio := common.GetAudioCompletionRatio(modelName) + + quota := 0 + if !usePrice { + quota = textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio)) + quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*completionRatio*audioCompletionRatio)) + + quota = int(math.Round(float64(quota) * ratio)) + if ratio != 0 && quota <= 0 { + quota = 1 + } + } else { + quota = int(modelPrice * common.QuotaPerUnit * groupRatio) + } + totalTokens := usage.TotalTokens + var logContent string + if !usePrice { + logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio) + } else { + logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio) + } + + // record all the consume log even if quota is 0 + if totalTokens == 0 { + // in this case, must be some error happened + // we cannot just return, because we may have to return the pre-consumed quota + quota = 0 + logContent += fmt.Sprintf("(可能是上游超时)") + common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ + "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota)) + } else { + //if sensitiveResp != nil { + // logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", ")) + //} + quotaDelta := quota - preConsumedQuota + if quotaDelta != 0 { + err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true) + if err != nil { + common.LogError(ctx, "error consuming token remain quota: "+err.Error()) + } + } + err := model.CacheUpdateUserQuota(relayInfo.UserId) + if err != nil { + common.LogError(ctx, "error update user quota cache: "+err.Error()) + } + model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) + model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) + } + + logModel := modelName + if strings.HasPrefix(logModel, "gpt-4-gizmo") { + logModel = "gpt-4-gizmo-*" + logContent += fmt.Sprintf(",模型 %s", modelName) + } + if strings.HasPrefix(logModel, "gpt-4o-gizmo") { + logModel = "gpt-4o-gizmo-*" + logContent += fmt.Sprintf(",模型 %s", modelName) + } + if extraContent != "" { + logContent += ", " + extraContent + } + other := service.GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, modelPrice) + model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel, + tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other) + + //if quota != 0 { + // + //} +} + +func getWssPromptTokens(textRequest *dto.RealtimeEvent, info *relaycommon.RelayInfo) (int, error) { + var promptTokens int + var err error + switch info.RelayMode { + default: + promptTokens, err = service.CountTokenRealtime(*textRequest, info.UpstreamModelName) + } + info.PromptTokens = promptTokens + return promptTokens, err +} + +//func checkWssRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error { +// var err error +// switch info.RelayMode { +// case relayconstant.RelayModeChatCompletions: +// err = service.CheckSensitiveMessages(textRequest.Messages) +// case relayconstant.RelayModeCompletions: +// err = service.CheckSensitiveInput(textRequest.Prompt) +// case relayconstant.RelayModeModerations: +// err = service.CheckSensitiveInput(textRequest.Input) +// case relayconstant.RelayModeEmbeddings: +// err = service.CheckSensitiveInput(textRequest.Input) +// } +// return err +//} diff --git a/service/log.go b/service/log.go index 506effb..e5354cd 100644 --- a/service/log.go +++ b/service/log.go @@ -2,6 +2,7 @@ package service import ( "github.com/gin-gonic/gin" + "one-api/dto" relaycommon "one-api/relay/common" ) @@ -17,3 +18,13 @@ func GenerateTextOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, m other["admin_info"] = adminInfo return other } + +func GenerateWssOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage, modelRatio, groupRatio, completionRatio, modelPrice float64) map[string]interface{} { + info := GenerateTextOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio, modelPrice) + info["ws"] = true + info["audio_input"] = usage.InputTokenDetails.AudioTokens + info["audio_output"] = usage.OutputTokenDetails.AudioTokens + info["text_input"] = usage.InputTokenDetails.TextTokens + info["text_output"] = usage.OutputTokenDetails.TextTokens + return info +} diff --git a/service/relay.go b/service/relay.go index 0aa2f13..4b5ed36 100644 --- a/service/relay.go +++ b/service/relay.go @@ -43,11 +43,25 @@ func Done(c *gin.Context) { _ = StringData(c, "[DONE]") } +func WssString(c *gin.Context, ws *websocket.Conn, str string) error { + if ws == nil { + common.LogError(c, "websocket connection is nil") + return errors.New("websocket connection is nil") + } + common.LogInfo(c, fmt.Sprintf("sending message: %s", str)) + return ws.WriteMessage(1, []byte(str)) +} + func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error { jsonData, err := json.Marshal(object) if err != nil { return fmt.Errorf("error marshalling object: %w", err) } + if ws == nil { + common.LogError(c, "websocket connection is nil") + return errors.New("websocket connection is nil") + } + common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData)) return ws.WriteMessage(1, jsonData) } diff --git a/service/token_counter.go b/service/token_counter.go index 97ade1f..e169a25 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -191,6 +191,45 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int, return tkm, nil } +func CountTokenRealtime(request dto.RealtimeEvent, model string) (int, error) { + tkm := 0 + ratio := 1 + if request.Session != nil { + msgTokens, err := CountTokenText(request.Session.Instructions, model) + if err != nil { + return 0, err + } + ratio = len(request.Session.Modalities) + tkm += msgTokens + if request.Session.Tools != nil { + toolsData, _ := json.Marshal(request.Session.Tools) + var openaiTools []dto.OpenAITools + err := json.Unmarshal(toolsData, &openaiTools) + if err != nil { + return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error())) + } + countStr := "" + for _, tool := range openaiTools { + countStr = tool.Function.Name + if tool.Function.Description != "" { + countStr += tool.Function.Description + } + if tool.Function.Parameters != nil { + countStr += fmt.Sprintf("%v", tool.Function.Parameters) + } + } + toolTokens, err := CountTokenInput(countStr, model) + if err != nil { + return 0, err + } + tkm += 8 + tkm += toolTokens + } + } + tkm *= ratio + return tkm, nil +} + func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) { //recover when panic tokenEncoder := getTokenEncoder(model) diff --git a/web/src/components/LogsTable.js b/web/src/components/LogsTable.js index 77d18b9..80bb3fe 100644 --- a/web/src/components/LogsTable.js +++ b/web/src/components/LogsTable.js @@ -11,7 +11,7 @@ import { import { Avatar, - Button, + Button, Descriptions, Form, Layout, Modal, @@ -20,7 +20,7 @@ import { Spin, Table, Tag, - Tooltip, + Tooltip } from '@douyinfe/semi-ui'; import { ITEMS_PER_PAGE } from '../constants'; import { @@ -336,33 +336,33 @@ const LogsTable = () => { ); }, }, - { - title: '重试', - dataIndex: 'retry', - className: isAdmin() ? 'tableShow' : 'tableHiddle', - render: (text, record, index) => { - let content = '渠道:' + record.channel; - if (record.other !== '') { - let other = JSON.parse(record.other); - if (other === null) { - return <>; - } - if (other.admin_info !== undefined) { - if ( - other.admin_info.use_channel !== null && - other.admin_info.use_channel !== undefined && - other.admin_info.use_channel !== '' - ) { - // channel id array - let useChannel = other.admin_info.use_channel; - let useChannelStr = useChannel.join('->'); - content = `渠道:${useChannelStr}`; - } - } - } - return isAdminUser ?
{content}
: <>; - }, - }, + // { + // title: '重试', + // dataIndex: 'retry', + // className: isAdmin() ? 'tableShow' : 'tableHiddle', + // render: (text, record, index) => { + // let content = '渠道:' + record.channel; + // if (record.other !== '') { + // let other = JSON.parse(record.other); + // if (other === null) { + // return <>; + // } + // if (other.admin_info !== undefined) { + // if ( + // other.admin_info.use_channel !== null && + // other.admin_info.use_channel !== undefined && + // other.admin_info.use_channel !== '' + // ) { + // // channel id array + // let useChannel = other.admin_info.use_channel; + // let useChannelStr = useChannel.join('->'); + // content = `渠道:${useChannelStr}`; + // } + // } + // } + // return isAdminUser ?
{content}
: <>; + // }, + // }, { title: '详情', dataIndex: 'content', @@ -409,6 +409,7 @@ const LogsTable = () => { ]; const [logs, setLogs] = useState([]); + const [expandData, setExpandData] = useState({}); const [showStat, setShowStat] = useState(false); const [loading, setLoading] = useState(false); const [loadingStat, setLoadingStat] = useState(false); @@ -512,10 +513,54 @@ const LogsTable = () => { }; const setLogsFormat = (logs) => { + let expandDatesLocal = {}; for (let i = 0; i < logs.length; i++) { logs[i].timestamp2string = timestamp2string(logs[i].created_at); logs[i].key = '' + logs[i].id; + let other = getLogOther(logs[i].other); + let expandDataLocal = []; + if (isAdmin()) { + let content = '渠道:' + logs[i].channel; + if (other.admin_info !== undefined) { + if ( + other.admin_info.use_channel !== null && + other.admin_info.use_channel !== undefined && + other.admin_info.use_channel !== '' + ) { + // channel id array + let useChannel = other.admin_info.use_channel; + let useChannelStr = useChannel.join('->'); + content = `渠道:${useChannelStr}`; + } + } + expandDataLocal.push({ + key: '重试', + value: content, + }) + } + if (other.ws) { + expandDataLocal.push({ + key: '语音输入', + value: other.audio_input, + }); + expandDataLocal.push({ + key: '语音输出', + value: other.audio_output, + }); + expandDataLocal.push({ + key: '文字输入', + value: other.text_input, + }); + expandDataLocal.push({ + key: '文字输出', + value: other.text_output, + }); + } + expandDatesLocal[logs[i].key] = expandDataLocal; } + console.log(expandDatesLocal); + setExpandData(expandDatesLocal); + setLogs(logs); }; @@ -588,6 +633,10 @@ const LogsTable = () => { handleEyeClick(); }, []); + const expandRowRender = (record, index) => { + return ; + }; + return ( <> @@ -686,7 +735,9 @@ const LogsTable = () => { Date: Sun, 6 Oct 2024 14:13:41 +0800 Subject: [PATCH 03/13] feat: realtime (cherry picked from commit d4966246e68dbdcdab45ec5c5141362834d74425) --- common/model-ratio.go | 16 ++- dto/realtime.go | 12 +- relay/channel/claude/relay-claude.go | 2 +- relay/channel/cloudflare/relay_cloudflare.go | 2 +- relay/channel/dify/relay-dify.go | 2 +- relay/channel/openai/adaptor.go | 4 +- relay/channel/openai/relay-openai.go | 63 +++++++++- relay/channel/palm/relay-palm.go | 2 +- relay/common/relay_info.go | 8 ++ relay/relay-audio.go | 2 +- relay/websocket.go | 22 ++-- service/audio.go | 31 +++++ service/relay.go | 4 +- service/token_counter.go | 117 +++++++++++++------ service/usage_helpr.go | 2 +- 15 files changed, 227 insertions(+), 62 deletions(-) create mode 100644 service/audio.go diff --git a/common/model-ratio.go b/common/model-ratio.go index c037b8b..6eab850 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -432,9 +432,23 @@ func GetAudioCompletionRatio(name string) float64 { if strings.HasPrefix(name, "gpt-4o-realtime") { return 10 } - return 10 + return 2 } +//func GetAudioPricePerMinute(name string) float64 { +// if strings.HasPrefix(name, "gpt-4o-realtime") { +// return 0.06 +// } +// return 0.06 +//} +// +//func GetAudioCompletionPricePerMinute(name string) float64 { +// if strings.HasPrefix(name, "gpt-4o-realtime") { +// return 0.24 +// } +// return 0.24 +//} + func GetCompletionRatioMap() map[string]float64 { if CompletionRatio == nil { CompletionRatio = defaultCompletionRatio diff --git a/dto/realtime.go b/dto/realtime.go index c470730..cca99f3 100644 --- a/dto/realtime.go +++ b/dto/realtime.go @@ -5,10 +5,18 @@ const ( RealtimeEventTypeSessionUpdate = "session.update" RealtimeEventTypeConversationCreate = "conversation.item.create" RealtimeEventTypeResponseCreate = "response.create" + RealtimeEventInputAudioBufferAppend = "input_audio_buffer.append" ) const ( - RealtimeEventTypeResponseDone = "response.done" + RealtimeEventTypeResponseDone = "response.done" + RealtimeEventTypeSessionUpdated = "session.updated" + RealtimeEventTypeSessionCreated = "session.created" + RealtimeEventResponseAudioDelta = "response.audio.delta" + RealtimeEventResponseAudioTranscriptionDelta = "response.audio_transcript.delta" + RealtimeEventResponseFunctionCallArgumentsDelta = "response.function_call_arguments.delta" + RealtimeEventResponseFunctionCallArgumentsDone = "response.function_call_arguments.done" + RealtimeEventConversationItemCreated = "conversation.item.created" ) type RealtimeEvent struct { @@ -19,6 +27,8 @@ type RealtimeEvent struct { Item *RealtimeItem `json:"item,omitempty"` Error *OpenAIError `json:"error,omitempty"` Response *RealtimeResponse `json:"response,omitempty"` + Delta string `json:"delta,omitempty"` + Audio string `json:"audio,omitempty"` } type RealtimeResponse struct { diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 781b9a7..4c7f188 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -509,7 +509,7 @@ func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *r }, nil } fullTextResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse) - completionTokens, err := service.CountTokenText(claudeResponse.Completion, info.OriginModelName) + completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName) if err != nil { return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError), nil } diff --git a/relay/channel/cloudflare/relay_cloudflare.go b/relay/channel/cloudflare/relay_cloudflare.go index 69d6b85..d21e524 100644 --- a/relay/channel/cloudflare/relay_cloudflare.go +++ b/relay/channel/cloudflare/relay_cloudflare.go @@ -149,7 +149,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn usage := &dto.Usage{} usage.PromptTokens = info.PromptTokens - usage.CompletionTokens, _ = service.CountTokenText(cfResp.Result.Text, info.UpstreamModelName) + usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens return nil, usage diff --git a/relay/channel/dify/relay-dify.go b/relay/channel/dify/relay-dify.go index 66ba839..5df34d3 100644 --- a/relay/channel/dify/relay-dify.go +++ b/relay/channel/dify/relay-dify.go @@ -108,7 +108,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re } if usage.TotalTokens == 0 { usage.PromptTokens = info.PromptTokens - usage.CompletionTokens, _ = service.CountTokenText("gpt-3.5-turbo", responseText) + usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens } return nil, usage diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index 5ac0306..a663d15 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -47,8 +47,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { model_ := info.UpstreamModelName model_ = strings.Replace(model_, ".", "", -1) // https://github.com/songquanpeng/one-api/issues/67 - requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) + if info.RelayMode == constant.RelayModeRealtime { + requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, info.ApiVersion) + } return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil case common.ChannelTypeMiniMax: return minimax.GetRequestURL(info) diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index d2eccb3..1aef14e 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -9,6 +9,7 @@ import ( "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "io" + "log" "net/http" "one-api/common" "one-api/constant" @@ -232,7 +233,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) { completionTokens := 0 for _, choice := range simpleResponse.Choices { - ctkm, _ := service.CountTokenText(string(choice.Message.Content), model) + ctkm, _ := service.CountTextToken(string(choice.Message.Content), model) completionTokens += ctkm } simpleResponse.Usage = dto.Usage{ @@ -325,7 +326,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel usage := &dto.Usage{} usage.PromptTokens = info.PromptTokens - usage.CompletionTokens, _ = service.CountTokenText(text, info.UpstreamModelName) + usage.CompletionTokens, _ = service.CountTextToken(text, info.UpstreamModelName) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens return nil, usage } @@ -387,6 +388,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op errChan := make(chan error, 2) usage := &dto.RealtimeUsage{} + localUsage := &dto.RealtimeUsage{} go func() { for { @@ -403,6 +405,32 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op return } + realtimeEvent := &dto.RealtimeEvent{} + err = json.Unmarshal(message, realtimeEvent) + if err != nil { + errChan <- fmt.Errorf("error unmarshalling message: %v", err) + return + } + + if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate { + if realtimeEvent.Session != nil { + if realtimeEvent.Session.Tools != nil { + info.RealtimeTools = realtimeEvent.Session.Tools + } + } + } + + textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName) + if err != nil { + errChan <- fmt.Errorf("error counting text token: %v", err) + return + } + log.Printf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken) + localUsage.TotalTokens += textToken + audioToken + localUsage.InputTokens += textToken + localUsage.InputTokenDetails.TextTokens += textToken + localUsage.InputTokenDetails.AudioTokens += audioToken + err = service.WssString(c, targetConn, string(message)) if err != nil { errChan <- fmt.Errorf("error writing to target: %v", err) @@ -451,6 +479,32 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens } + } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated { + realtimeSession := realtimeEvent.Session + if realtimeSession != nil { + // update audio format + info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat) + info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat) + } + } else { + textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName) + if err != nil { + errChan <- fmt.Errorf("error counting text token: %v", err) + return + } + log.Printf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken) + localUsage.TotalTokens += textToken + audioToken + + if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone { + info.IsFirstRequest = false + localUsage.InputTokens += textToken + audioToken + localUsage.InputTokenDetails.TextTokens += textToken + localUsage.InputTokenDetails.AudioTokens += audioToken + } else { + localUsage.OutputTokens += textToken + audioToken + localUsage.OutputTokenDetails.TextTokens += textToken + localUsage.OutputTokenDetails.AudioTokens += audioToken + } } err = service.WssString(c, clientConn, string(message)) @@ -475,5 +529,10 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op case <-c.Done(): } + // check usage total tokens, if 0, use local usage + + if usage.TotalTokens == 0 { + usage = localUsage + } return nil, usage } diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index 47588a2..dfde59f 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -156,7 +156,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st }, nil } fullTextResponse := responsePaLM2OpenAI(&palmResponse) - completionTokens, _ := service.CountTokenText(palmResponse.Candidates[0].Content, model) + completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model) usage := dto.Usage{ PromptTokens: promptTokens, CompletionTokens: completionTokens, diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index bd29b9c..b43f917 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -4,6 +4,7 @@ import ( "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "one-api/common" + "one-api/dto" "one-api/relay/constant" "strings" "time" @@ -35,11 +36,18 @@ type RelayInfo struct { ShouldIncludeUsage bool ClientWs *websocket.Conn TargetWs *websocket.Conn + InputAudioFormat string + OutputAudioFormat string + RealtimeTools []dto.RealTimeTool + IsFirstRequest bool } func GenRelayInfoWs(c *gin.Context, ws *websocket.Conn) *RelayInfo { info := GenRelayInfo(c) info.ClientWs = ws + info.InputAudioFormat = "pcm16" + info.OutputAudioFormat = "pcm16" + info.IsFirstRequest = true return info } diff --git a/relay/relay-audio.go b/relay/relay-audio.go index e1d0a70..b65f612 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -58,7 +58,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { promptTokens := 0 preConsumedTokens := common.PreConsumedQuota if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech { - promptTokens, err = service.CountAudioToken(audioRequest.Input, audioRequest.Model) + promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model) if err != nil { return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError) } diff --git a/relay/websocket.go b/relay/websocket.go index 5bd1e81..089805d 100644 --- a/relay/websocket.go +++ b/relay/websocket.go @@ -150,7 +150,7 @@ func postWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod quota := 0 if !usePrice { quota = textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio)) - quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*completionRatio*audioCompletionRatio)) + quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*audioRatio*audioCompletionRatio)) quota = int(math.Round(float64(quota) * ratio)) if ratio != 0 && quota <= 0 { @@ -215,16 +215,16 @@ func postWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod //} } -func getWssPromptTokens(textRequest *dto.RealtimeEvent, info *relaycommon.RelayInfo) (int, error) { - var promptTokens int - var err error - switch info.RelayMode { - default: - promptTokens, err = service.CountTokenRealtime(*textRequest, info.UpstreamModelName) - } - info.PromptTokens = promptTokens - return promptTokens, err -} +//func getWssPromptTokens(textRequest *dto.RealtimeEvent, info *relaycommon.RelayInfo) (int, error) { +// var promptTokens int +// var err error +// switch info.RelayMode { +// default: +// promptTokens, err = service.CountTokenRealtime(*textRequest, info.UpstreamModelName) +// } +// info.PromptTokens = promptTokens +// return promptTokens, err +//} //func checkWssRequestSensitive(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) error { // var err error diff --git a/service/audio.go b/service/audio.go new file mode 100644 index 0000000..d558e96 --- /dev/null +++ b/service/audio.go @@ -0,0 +1,31 @@ +package service + +import ( + "encoding/base64" + "fmt" +) + +func parseAudio(audioBase64 string, format string) (duration float64, err error) { + audioData, err := base64.StdEncoding.DecodeString(audioBase64) + if err != nil { + return 0, fmt.Errorf("base64 decode error: %v", err) + } + + var samplesCount int + var sampleRate int + + switch format { + case "pcm16": + samplesCount = len(audioData) / 2 // 16位 = 2字节每样本 + sampleRate = 24000 // 24kHz + case "g711_ulaw", "g711_alaw": + samplesCount = len(audioData) // 8位 = 1字节每样本 + sampleRate = 8000 // 8kHz + default: + samplesCount = len(audioData) // 8位 = 1字节每样本 + sampleRate = 8000 // 8kHz + } + + duration = float64(samplesCount) / float64(sampleRate) + return duration, nil +} diff --git a/service/relay.go b/service/relay.go index 4b5ed36..6ffed1e 100644 --- a/service/relay.go +++ b/service/relay.go @@ -48,7 +48,7 @@ func WssString(c *gin.Context, ws *websocket.Conn, str string) error { common.LogError(c, "websocket connection is nil") return errors.New("websocket connection is nil") } - common.LogInfo(c, fmt.Sprintf("sending message: %s", str)) + //common.LogInfo(c, fmt.Sprintf("sending message: %s", str)) return ws.WriteMessage(1, []byte(str)) } @@ -61,7 +61,7 @@ func WssObject(c *gin.Context, ws *websocket.Conn, object interface{}) error { common.LogError(c, "websocket connection is nil") return errors.New("websocket connection is nil") } - common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData)) + //common.LogInfo(c, fmt.Sprintf("sending message: %s", jsonData)) return ws.WriteMessage(1, jsonData) } diff --git a/service/token_counter.go b/service/token_counter.go index e169a25..63eb712 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -11,6 +11,7 @@ import ( "one-api/common" "one-api/constant" "one-api/dto" + relaycommon "one-api/relay/common" "strings" "unicode/utf8" ) @@ -191,43 +192,55 @@ func CountTokenChatRequest(request dto.GeneralOpenAIRequest, model string) (int, return tkm, nil } -func CountTokenRealtime(request dto.RealtimeEvent, model string) (int, error) { - tkm := 0 - ratio := 1 - if request.Session != nil { - msgTokens, err := CountTokenText(request.Session.Instructions, model) - if err != nil { - return 0, err +func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, model string) (int, int, error) { + audioToken := 0 + textToken := 0 + switch request.Type { + case dto.RealtimeEventTypeSessionUpdate: + if request.Session != nil { + msgTokens, err := CountTextToken(request.Session.Instructions, model) + if err != nil { + return 0, 0, err + } + textToken += msgTokens } - ratio = len(request.Session.Modalities) - tkm += msgTokens - if request.Session.Tools != nil { - toolsData, _ := json.Marshal(request.Session.Tools) - var openaiTools []dto.OpenAITools - err := json.Unmarshal(toolsData, &openaiTools) - if err != nil { - return 0, errors.New(fmt.Sprintf("count_tools_token_fail: %s", err.Error())) - } - countStr := "" - for _, tool := range openaiTools { - countStr = tool.Function.Name - if tool.Function.Description != "" { - countStr += tool.Function.Description - } - if tool.Function.Parameters != nil { - countStr += fmt.Sprintf("%v", tool.Function.Parameters) + case dto.RealtimeEventResponseAudioDelta: + // count audio token + atk, err := CountAudioTokenOutput(request.Delta, info.OutputAudioFormat) + if err != nil { + return 0, 0, fmt.Errorf("error counting audio token: %v", err) + } + audioToken += atk + case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta: + // count text token + tkm, err := CountTextToken(request.Delta, model) + if err != nil { + return 0, 0, fmt.Errorf("error counting text token: %v", err) + } + textToken += tkm + case dto.RealtimeEventInputAudioBufferAppend: + // count audio token + atk, err := CountAudioTokenInput(request.Audio, info.InputAudioFormat) + if err != nil { + return 0, 0, fmt.Errorf("error counting audio token: %v", err) + } + audioToken += atk + case dto.RealtimeEventTypeResponseDone: + // count tools token + if !info.IsFirstRequest { + if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 { + for _, tool := range info.RealtimeTools { + toolTokens, err := CountTokenInput(tool, model) + if err != nil { + return 0, 0, err + } + textToken += 8 + textToken += toolTokens } } - toolTokens, err := CountTokenInput(countStr, model) - if err != nil { - return 0, err - } - tkm += 8 - tkm += toolTokens } } - tkm *= ratio - return tkm, nil + return textToken, audioToken, nil } func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, error) { @@ -287,13 +300,13 @@ func CountTokenMessages(messages []dto.Message, model string, stream bool) (int, func CountTokenInput(input any, model string) (int, error) { switch v := input.(type) { case string: - return CountTokenText(v, model) + return CountTextToken(v, model) case []string: text := "" for _, s := range v { text += s } - return CountTokenText(text, model) + return CountTextToken(text, model) } return CountTokenInput(fmt.Sprintf("%v", input), model) } @@ -315,16 +328,44 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, return tokens } -func CountAudioToken(text string, model string) (int, error) { +func CountTTSToken(text string, model string) (int, error) { if strings.HasPrefix(model, "tts") { return utf8.RuneCountInString(text), nil } else { - return CountTokenText(text, model) + return CountTextToken(text, model) } } -// CountTokenText 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量 -func CountTokenText(text string, model string) (int, error) { +func CountAudioTokenInput(audioBase64 string, audioFormat string) (int, error) { + if audioBase64 == "" { + return 0, nil + } + duration, err := parseAudio(audioBase64, audioFormat) + if err != nil { + return 0, err + } + return int(duration / 60 * 100 / 0.06), nil +} + +func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error) { + if audioBase64 == "" { + return 0, nil + } + duration, err := parseAudio(audioBase64, audioFormat) + if err != nil { + return 0, err + } + return int(duration / 60 * 200 / 0.24), nil +} + +//func CountAudioToken(sec float64, audioType string) { +// if audioType == "input" { +// +// } +//} + +// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量 +func CountTextToken(text string, model string) (int, error) { var err error tokenEncoder := getTokenEncoder(model) return getTokenNum(tokenEncoder, text), err diff --git a/service/usage_helpr.go b/service/usage_helpr.go index d2fa102..c52e1e1 100644 --- a/service/usage_helpr.go +++ b/service/usage_helpr.go @@ -19,7 +19,7 @@ import ( func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) { usage := &dto.Usage{} usage.PromptTokens = promptTokens - ctkm, err := CountTokenText(responseText, modeName) + ctkm, err := CountTextToken(responseText, modeName) usage.CompletionTokens = ctkm usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens return usage, err From 8de79382f023ac91a810e32e5d9ebaf4b92a8318 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Mon, 7 Oct 2024 17:18:11 +0800 Subject: [PATCH 04/13] feat: azure realtime (cherry picked from commit 75ff3d98f06103dc2df1f8817bd3fcbf433e0f20) --- relay/channel/openai/relay-openai.go | 26 +++++++++++++++----------- service/token_counter.go | 15 +++++++++++++++ 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 1aef14e..6b11aac 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -478,6 +478,18 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens + } else { + textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName) + if err != nil { + errChan <- fmt.Errorf("error counting text token: %v", err) + return + } + log.Printf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken) + localUsage.TotalTokens += textToken + audioToken + info.IsFirstRequest = false + localUsage.InputTokens += textToken + audioToken + localUsage.InputTokenDetails.TextTokens += textToken + localUsage.InputTokenDetails.AudioTokens += audioToken } } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated { realtimeSession := realtimeEvent.Session @@ -494,17 +506,9 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op } log.Printf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken) localUsage.TotalTokens += textToken + audioToken - - if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone { - info.IsFirstRequest = false - localUsage.InputTokens += textToken + audioToken - localUsage.InputTokenDetails.TextTokens += textToken - localUsage.InputTokenDetails.AudioTokens += audioToken - } else { - localUsage.OutputTokens += textToken + audioToken - localUsage.OutputTokenDetails.TextTokens += textToken - localUsage.OutputTokenDetails.AudioTokens += audioToken - } + localUsage.OutputTokens += textToken + audioToken + localUsage.OutputTokenDetails.TextTokens += textToken + localUsage.OutputTokenDetails.AudioTokens += audioToken } err = service.WssString(c, clientConn, string(message)) diff --git a/service/token_counter.go b/service/token_counter.go index 63eb712..17fbe0a 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -225,6 +225,21 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent, return 0, 0, fmt.Errorf("error counting audio token: %v", err) } audioToken += atk + case dto.RealtimeEventConversationItemCreated: + if request.Item != nil { + switch request.Item.Type { + case "message": + for _, content := range request.Item.Content { + if content.Type == "input_text" { + tokens, err := CountTextToken(content.Text, model) + if err != nil { + return 0, 0, err + } + textToken += tokens + } + } + } + } case dto.RealtimeEventTypeResponseDone: // count tools token if !info.IsFirstRequest { From 24b3ed50d75af7b2c243ddb2ea15ae4b93fa8f62 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Mon, 7 Oct 2024 19:08:20 +0800 Subject: [PATCH 05/13] feat: realtime pre consume (cherry picked from commit d87917f8f6eb9d2e144a9f840d6d91767ea2eb69) --- relay/channel/openai/relay-openai.go | 51 ++++++++++- relay/common/relay_info.go | 1 + relay/websocket.go | 92 +------------------ service/quota.go | 132 +++++++++++++++++++++++++++ 4 files changed, 181 insertions(+), 95 deletions(-) create mode 100644 service/quota.go diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 6b11aac..60d09a0 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -389,6 +389,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op usage := &dto.RealtimeUsage{} localUsage := &dto.RealtimeUsage{} + sumUsage := &dto.RealtimeUsage{} go func() { for { @@ -478,6 +479,12 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens + err := preConsumeUsage(c, info, usage, sumUsage) + if err != nil { + errChan <- fmt.Errorf("error consume usage: %v", err) + return + } + usage = &dto.RealtimeUsage{} } else { textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName) if err != nil { @@ -490,7 +497,18 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op localUsage.InputTokens += textToken + audioToken localUsage.InputTokenDetails.TextTokens += textToken localUsage.InputTokenDetails.AudioTokens += audioToken + err = preConsumeUsage(c, info, localUsage, sumUsage) + if err != nil { + errChan <- fmt.Errorf("error consume usage: %v", err) + return + } + localUsage = &dto.RealtimeUsage{} + // print now usage } + common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage)) + common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) + common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) + } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated { realtimeSession := realtimeEvent.Session if realtimeSession != nil { @@ -528,15 +546,38 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op select { case <-clientClosed: case <-targetClosed: - case <-errChan: + case err := <-errChan: //return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil + common.LogError(c, "realtime error: "+err.Error()) case <-c.Done(): } + if usage.TotalTokens != 0 { + _ = preConsumeUsage(c, info, usage, sumUsage) + } + + if localUsage.TotalTokens != 0 { + _ = preConsumeUsage(c, info, localUsage, sumUsage) + } + // check usage total tokens, if 0, use local usage - if usage.TotalTokens == 0 { - usage = localUsage - } - return nil, usage + return nil, sumUsage +} + +func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error { + totalUsage.TotalTokens += usage.TotalTokens + totalUsage.InputTokens += usage.InputTokens + totalUsage.OutputTokens += usage.OutputTokens + totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens + totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens + totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens + totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens + totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens + // clear usage + err := service.PreWssConsumeQuota(ctx, info, usage) + if err == nil { + common.LogInfo(ctx, "realtime streaming consume usage success") + } + return err } diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index b43f917..21e3691 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -23,6 +23,7 @@ type RelayInfo struct { ApiType int IsStream bool IsPlayground bool + UsePrice bool RelayMode int UpstreamModelName string OriginModelName string diff --git a/relay/websocket.go b/relay/websocket.go index 089805d..09d8298 100644 --- a/relay/websocket.go +++ b/relay/websocket.go @@ -5,15 +5,11 @@ import ( "fmt" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" - "math" "net/http" "one-api/common" "one-api/dto" - "one-api/model" relaycommon "one-api/relay/common" "one-api/service" - "strings" - "time" ) //func getAndValidateWssRequest(c *gin.Context, ws *websocket.Conn) (*dto.RealtimeEvent, error) { @@ -91,6 +87,7 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) *dto.OpenAIErrorWithStatusCod preConsumedQuota = int(float64(preConsumedTokens) * ratio) } else { preConsumedQuota = int(modelPrice * common.QuotaPerUnit * groupRatio) + relayInfo.UsePrice = true } // pre-consume quota 预消耗配额 @@ -126,95 +123,10 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) *dto.OpenAIErrorWithStatusCod service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr } - postWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") + service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, getModelPriceSuccess, "") return nil } -func postWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, - usage *dto.RealtimeUsage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, - groupRatio float64, - modelPrice float64, usePrice bool, extraContent string) { - - useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() - textInputTokens := usage.InputTokenDetails.TextTokens - textOutTokens := usage.OutputTokenDetails.TextTokens - - audioInputTokens := usage.InputTokenDetails.AudioTokens - audioOutTokens := usage.OutputTokenDetails.AudioTokens - - tokenName := ctx.GetString("token_name") - completionRatio := common.GetCompletionRatio(modelName) - audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName) - audioCompletionRatio := common.GetAudioCompletionRatio(modelName) - - quota := 0 - if !usePrice { - quota = textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio)) - quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*audioRatio*audioCompletionRatio)) - - quota = int(math.Round(float64(quota) * ratio)) - if ratio != 0 && quota <= 0 { - quota = 1 - } - } else { - quota = int(modelPrice * common.QuotaPerUnit * groupRatio) - } - totalTokens := usage.TotalTokens - var logContent string - if !usePrice { - logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio) - } else { - logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio) - } - - // record all the consume log even if quota is 0 - if totalTokens == 0 { - // in this case, must be some error happened - // we cannot just return, because we may have to return the pre-consumed quota - quota = 0 - logContent += fmt.Sprintf("(可能是上游超时)") - common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ - "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota)) - } else { - //if sensitiveResp != nil { - // logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", ")) - //} - quotaDelta := quota - preConsumedQuota - if quotaDelta != 0 { - err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true) - if err != nil { - common.LogError(ctx, "error consuming token remain quota: "+err.Error()) - } - } - err := model.CacheUpdateUserQuota(relayInfo.UserId) - if err != nil { - common.LogError(ctx, "error update user quota cache: "+err.Error()) - } - model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) - model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) - } - - logModel := modelName - if strings.HasPrefix(logModel, "gpt-4-gizmo") { - logModel = "gpt-4-gizmo-*" - logContent += fmt.Sprintf(",模型 %s", modelName) - } - if strings.HasPrefix(logModel, "gpt-4o-gizmo") { - logModel = "gpt-4o-gizmo-*" - logContent += fmt.Sprintf(",模型 %s", modelName) - } - if extraContent != "" { - logContent += ", " + extraContent - } - other := service.GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, modelPrice) - model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel, - tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other) - - //if quota != 0 { - // - //} -} - //func getWssPromptTokens(textRequest *dto.RealtimeEvent, info *relaycommon.RelayInfo) (int, error) { // var promptTokens int // var err error diff --git a/service/quota.go b/service/quota.go new file mode 100644 index 0000000..09c2fd5 --- /dev/null +++ b/service/quota.go @@ -0,0 +1,132 @@ +package service + +import ( + "fmt" + "github.com/gin-gonic/gin" + "math" + "one-api/common" + "one-api/dto" + "one-api/model" + relaycommon "one-api/relay/common" + "strings" + "time" +) + +func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.RealtimeUsage) error { + if relayInfo.UsePrice { + return nil + } + modelName := relayInfo.UpstreamModelName + textInputTokens := usage.InputTokenDetails.TextTokens + textOutTokens := usage.OutputTokenDetails.TextTokens + audioInputTokens := usage.InputTokenDetails.AudioTokens + audioOutTokens := usage.OutputTokenDetails.AudioTokens + + completionRatio := common.GetCompletionRatio(modelName) + audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName) + audioCompletionRatio := common.GetAudioCompletionRatio(modelName) + groupRatio := common.GetGroupRatio(relayInfo.Group) + modelRatio := common.GetModelRatio(modelName) + + ratio := groupRatio * modelRatio + + quota := textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio)) + quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*audioRatio*audioCompletionRatio)) + + quota = int(math.Round(float64(quota) * ratio)) + if ratio != 0 && quota <= 0 { + quota = 1 + } + + err := model.PostConsumeTokenQuota(relayInfo, 0, quota, 0, false) + if err != nil { + return err + } + err = model.CacheUpdateUserQuota(relayInfo.UserId) + if err != nil { + return err + } + return nil +} + +func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string, + usage *dto.RealtimeUsage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, + groupRatio float64, + modelPrice float64, usePrice bool, extraContent string) { + + useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() + textInputTokens := usage.InputTokenDetails.TextTokens + textOutTokens := usage.OutputTokenDetails.TextTokens + + audioInputTokens := usage.InputTokenDetails.AudioTokens + audioOutTokens := usage.OutputTokenDetails.AudioTokens + + tokenName := ctx.GetString("token_name") + completionRatio := common.GetCompletionRatio(modelName) + audioRatio := common.GetAudioRatio(relayInfo.UpstreamModelName) + audioCompletionRatio := common.GetAudioCompletionRatio(modelName) + + quota := 0 + if !usePrice { + quota = textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio)) + quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*audioRatio*audioCompletionRatio)) + + quota = int(math.Round(float64(quota) * ratio)) + if ratio != 0 && quota <= 0 { + quota = 1 + } + } else { + quota = int(modelPrice * common.QuotaPerUnit * groupRatio) + } + totalTokens := usage.TotalTokens + var logContent string + if !usePrice { + logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,音频倍率 %.2f,音频补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, audioRatio, audioCompletionRatio, groupRatio) + } else { + logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio) + } + + // record all the consume log even if quota is 0 + if totalTokens == 0 { + // in this case, must be some error happened + // we cannot just return, because we may have to return the pre-consumed quota + quota = 0 + logContent += fmt.Sprintf("(可能是上游超时)") + common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+ + "tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota)) + } else { + //if sensitiveResp != nil { + // logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", ")) + //} + //quotaDelta := quota - preConsumedQuota + //if quotaDelta != 0 { + // err := model.PostConsumeTokenQuota(relayInfo, userQuota, quotaDelta, preConsumedQuota, true) + // if err != nil { + // common.LogError(ctx, "error consuming token remain quota: "+err.Error()) + // } + //} + + //err := model.CacheUpdateUserQuota(relayInfo.UserId) + //if err != nil { + // common.LogError(ctx, "error update user quota cache: "+err.Error()) + //} + model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota) + model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota) + } + + logModel := modelName + if strings.HasPrefix(logModel, "gpt-4-gizmo") { + logModel = "gpt-4-gizmo-*" + logContent += fmt.Sprintf(",模型 %s", modelName) + } + if strings.HasPrefix(logModel, "gpt-4o-gizmo") { + logModel = "gpt-4o-gizmo-*" + logContent += fmt.Sprintf(",模型 %s", modelName) + } + if extraContent != "" { + logContent += ", " + extraContent + } + other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio, completionRatio, modelPrice) + model.RecordConsumeLog(ctx, relayInfo.UserId, relayInfo.ChannelId, usage.InputTokens, usage.OutputTokens, logModel, + tokenName, quota, logContent, relayInfo.TokenId, userQuota, int(useTimeSeconds), relayInfo.IsStream, other) +} From e5c05d77b7bf0e342444cbf5c08accaa0cca36b4 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Mon, 7 Oct 2024 19:15:14 +0800 Subject: [PATCH 06/13] feat: realtime pre consume (cherry picked from commit 273d154e1640bae26b7caedddf1685e9ff21ab74) --- relay/channel/openai/relay-openai.go | 3 --- service/quota.go | 12 +++++++++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 60d09a0..bb19684 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -576,8 +576,5 @@ func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.R totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens // clear usage err := service.PreWssConsumeQuota(ctx, info, usage) - if err == nil { - common.LogInfo(ctx, "realtime streaming consume usage success") - } return err } diff --git a/service/quota.go b/service/quota.go index 09c2fd5..9a3c542 100644 --- a/service/quota.go +++ b/service/quota.go @@ -1,6 +1,7 @@ package service import ( + "errors" "fmt" "github.com/gin-gonic/gin" "math" @@ -16,6 +17,10 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag if relayInfo.UsePrice { return nil } + userQuota, err := model.GetUserQuota(relayInfo.UserId) + if err != nil { + return err + } modelName := relayInfo.UpstreamModelName textInputTokens := usage.InputTokenDetails.TextTokens textOutTokens := usage.OutputTokenDetails.TextTokens @@ -38,10 +43,15 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag quota = 1 } - err := model.PostConsumeTokenQuota(relayInfo, 0, quota, 0, false) + if userQuota < quota { + return errors.New(fmt.Sprintf("用户额度不足,剩余额度为 %d", userQuota)) + } + + err = model.PostConsumeTokenQuota(relayInfo, 0, quota, 0, false) if err != nil { return err } + common.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota)) err = model.CacheUpdateUserQuota(relayInfo.UserId) if err != nil { return err From f0907bf60ae4350ce404e16413f77dbe1e946356 Mon Sep 17 00:00:00 2001 From: Xyfacai Date: Mon, 7 Oct 2024 20:35:33 +0800 Subject: [PATCH 07/13] =?UTF-8?q?fix:=20=E9=83=A8=E5=88=86=E6=83=85?= =?UTF-8?q?=E5=86=B5=E7=BC=BA=E5=B0=91=E8=BF=94=E5=9B=9E=E9=A2=84=E6=89=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (cherry picked from commit 96373455521a38095706bd81c57f9a18557d9c2e) --- relay/channel/openai/relay-openai.go | 8 ++++---- relay/relay-audio.go | 10 +++++++--- relay/relay-text.go | 12 +++++++----- relay/relay_rerank.go | 12 ++++++++---- relay/websocket.go | 9 +++++++-- 5 files changed, 33 insertions(+), 18 deletions(-) diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index bb19684..2b237ea 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -391,7 +391,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op localUsage := &dto.RealtimeUsage{} sumUsage := &dto.RealtimeUsage{} - go func() { + gopool.Go(func() { for { select { case <-c.Done(): @@ -444,9 +444,9 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op } } } - }() + }) - go func() { + gopool.Go(func() { for { select { case <-c.Done(): @@ -541,7 +541,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op } } } - }() + }) select { case <-clientClosed: diff --git a/relay/relay-audio.go b/relay/relay-audio.go index b65f612..4bf916b 100644 --- a/relay/relay-audio.go +++ b/relay/relay-audio.go @@ -46,7 +46,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto. return audioRequest, nil } -func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { +func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { relayInfo := relaycommon.GenRelayInfo(c) audioRequest, err := getAndValidAudioRequest(c, relayInfo) @@ -92,6 +92,11 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { return service.OpenAIErrorWrapperLocal(err, "pre_consume_token_quota_failed", http.StatusForbidden) } } + defer func() { + if openaiErr != nil { + returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) + } + }() // map model name modelMapping := c.GetString("model_mapping") @@ -128,8 +133,7 @@ func AudioHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { if resp != nil { httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - openaiErr := service.RelayErrorHandler(httpResp) + openaiErr = service.RelayErrorHandler(httpResp) // reset status code 重置状态码 service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr diff --git a/relay/relay-text.go b/relay/relay-text.go index 7bb7d99..463947f 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -64,7 +64,7 @@ func getAndValidateTextRequest(c *gin.Context, relayInfo *relaycommon.RelayInfo) return textRequest, nil } -func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { +func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) { relayInfo := relaycommon.GenRelayInfo(c) @@ -131,7 +131,11 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { if openaiErr != nil { return openaiErr } - + defer func() { + if openaiErr != nil { + returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) + } + }() includeUsage := false // 判断用户是否需要返回使用情况 if textRequest.StreamOptions != nil && textRequest.StreamOptions.IncludeUsage { @@ -190,8 +194,7 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { httpResp = resp.(*http.Response) relayInfo.IsStream = relayInfo.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - openaiErr := service.RelayErrorHandler(httpResp) + openaiErr = service.RelayErrorHandler(httpResp) // reset status code 重置状态码 service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr @@ -200,7 +203,6 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode { usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo) if openaiErr != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) // reset status code 重置状态码 service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr diff --git a/relay/relay_rerank.go b/relay/relay_rerank.go index 4cb1c98..a627b78 100644 --- a/relay/relay_rerank.go +++ b/relay/relay_rerank.go @@ -23,7 +23,7 @@ func getRerankPromptToken(rerankRequest dto.RerankRequest) int { return token } -func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode { +func RerankHelper(c *gin.Context, relayMode int) (openaiErr *dto.OpenAIErrorWithStatusCode) { relayInfo := relaycommon.GenRelayInfo(c) var rerankRequest *dto.RerankRequest @@ -79,6 +79,12 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode if openaiErr != nil { return openaiErr } + defer func() { + if openaiErr != nil { + returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) + } + }() + adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) @@ -104,8 +110,7 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode if resp != nil { httpResp = resp.(*http.Response) if httpResp.StatusCode != http.StatusOK { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) - openaiErr := service.RelayErrorHandler(httpResp) + openaiErr = service.RelayErrorHandler(httpResp) // reset status code 重置状态码 service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr @@ -114,7 +119,6 @@ func RerankHelper(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode usage, openaiErr := adaptor.DoResponse(c, httpResp, relayInfo) if openaiErr != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) // reset status code 重置状态码 service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr diff --git a/relay/websocket.go b/relay/websocket.go index 09d8298..247169e 100644 --- a/relay/websocket.go +++ b/relay/websocket.go @@ -30,7 +30,7 @@ import ( // return realtimeEvent, nil //} -func WssHelper(c *gin.Context, ws *websocket.Conn) *dto.OpenAIErrorWithStatusCode { +func WssHelper(c *gin.Context, ws *websocket.Conn) (openaiErr *dto.OpenAIErrorWithStatusCode) { relayInfo := relaycommon.GenRelayInfoWs(c, ws) // get & validate textRequest 获取并验证文本请求 @@ -96,6 +96,12 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) *dto.OpenAIErrorWithStatusCod return openaiErr } + defer func() { + if openaiErr != nil { + returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) + } + }() + adaptor := GetAdaptor(relayInfo.ApiType) if adaptor == nil { return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest) @@ -118,7 +124,6 @@ func WssHelper(c *gin.Context, ws *websocket.Conn) *dto.OpenAIErrorWithStatusCod usage, openaiErr := adaptor.DoResponse(c, nil, relayInfo) if openaiErr != nil { - returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota) // reset status code 重置状态码 service.ResetStatusCode(openaiErr, statusCodeMappingStr) return openaiErr From d596699250d6d445c64399e4b8a12eddbdd48e16 Mon Sep 17 00:00:00 2001 From: Xyfacai Date: Mon, 7 Oct 2024 20:46:13 +0800 Subject: [PATCH 08/13] refactor: realtime log (cherry picked from commit fd24dc467bfc360008b313220e607f0176ee7aa3) --- relay/channel/openai/relay-openai.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 2b237ea..400e111 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -9,7 +9,6 @@ import ( "github.com/gin-gonic/gin" "github.com/gorilla/websocket" "io" - "log" "net/http" "one-api/common" "one-api/constant" @@ -426,7 +425,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op errChan <- fmt.Errorf("error counting text token: %v", err) return } - log.Printf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken) + common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) localUsage.TotalTokens += textToken + audioToken localUsage.InputTokens += textToken localUsage.InputTokenDetails.TextTokens += textToken @@ -491,7 +490,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op errChan <- fmt.Errorf("error counting text token: %v", err) return } - log.Printf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken) + common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) localUsage.TotalTokens += textToken + audioToken info.IsFirstRequest = false localUsage.InputTokens += textToken + audioToken @@ -522,7 +521,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op errChan <- fmt.Errorf("error counting text token: %v", err) return } - log.Printf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken) + common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) localUsage.TotalTokens += textToken + audioToken localUsage.OutputTokens += textToken + audioToken localUsage.OutputTokenDetails.TextTokens += textToken From be64408a250f60bca8d5e49171be71ed5a943300 Mon Sep 17 00:00:00 2001 From: Xyfacai Date: Thu, 10 Oct 2024 00:15:27 +0800 Subject: [PATCH 09/13] =?UTF-8?q?fix(realtime):=20=E4=BF=AE=E5=A4=8Dws=20?= =?UTF-8?q?=E6=8F=A1=E6=89=8B=E5=A4=B1=E8=B4=A5=E3=80=81=E8=AE=A1=E8=B4=B9?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (cherry picked from commit 618dffc43fd5a5f4065944db87761f9ee18e44d3) --- controller/relay.go | 1 + relay/channel/openai/adaptor.go | 27 +++++++++++++++++++++------ relay/channel/openai/relay-openai.go | 4 ++++ service/quota.go | 6 ++---- 4 files changed, 28 insertions(+), 10 deletions(-) diff --git a/controller/relay.go b/controller/relay.go index fe65d96..c2e7523 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -145,6 +145,7 @@ func Relay(c *gin.Context) { } var upgrader = websocket.Upgrader{ + Subprotocols: []string{"realtime"}, // WS 握手支持的协议,如果有使用 Sec-WebSocket-Protocol,则必须在此声明对应的 Protocol TODO add other protocol CheckOrigin: func(r *http.Request) bool { return true // 允许跨域 }, diff --git a/relay/channel/openai/adaptor.go b/relay/channel/openai/adaptor.go index a663d15..def0850 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/channel/openai/adaptor.go @@ -63,18 +63,33 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { } } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { - channel.SetupApiRequestHeader(info, c, req) +func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error { + channel.SetupApiRequestHeader(info, c, header) if info.ChannelType == common.ChannelTypeAzure { - req.Set("api-key", info.ApiKey) + header.Set("api-key", info.ApiKey) return nil } if info.ChannelType == common.ChannelTypeOpenAI && "" != info.Organization { - req.Set("OpenAI-Organization", info.Organization) + header.Set("OpenAI-Organization", info.Organization) } - req.Set("Authorization", "Bearer "+info.ApiKey) if info.RelayMode == constant.RelayModeRealtime { - req.Set("openai-beta", "realtime=v1") + swp := c.Request.Header.Get("Sec-WebSocket-Protocol") + if swp != "" { + items := []string{ + "realtime", + "openai-insecure-api-key." + info.ApiKey, + "openai-beta.realtime-v1", + } + header.Set("Sec-WebSocket-Protocol", strings.Join(items, ",")) + //req.Header.Set("Sec-WebSocket-Key", c.Request.Header.Get("Sec-WebSocket-Key")) + //req.Header.Set("Sec-Websocket-Extensions", c.Request.Header.Get("Sec-Websocket-Extensions")) + //req.Header.Set("Sec-Websocket-Version", c.Request.Header.Get("Sec-Websocket-Version")) + } else { + header.Set("openai-beta", "realtime=v1") + header.Set("Authorization", "Bearer "+info.ApiKey) + } + } else { + header.Set("Authorization", "Bearer "+info.ApiKey) } //if info.ChannelType == common.ChannelTypeOpenRouter { // req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 400e111..d172525 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -483,7 +483,10 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op errChan <- fmt.Errorf("error consume usage: %v", err) return } + // 本次计费完成,清除 usage = &dto.RealtimeUsage{} + + localUsage = &dto.RealtimeUsage{} } else { textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName) if err != nil { @@ -501,6 +504,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op errChan <- fmt.Errorf("error consume usage: %v", err) return } + // 本次计费完成,清除 localUsage = &dto.RealtimeUsage{} // print now usage } diff --git a/service/quota.go b/service/quota.go index 9a3c542..7dd49b9 100644 --- a/service/quota.go +++ b/service/quota.go @@ -78,10 +78,8 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod quota := 0 if !usePrice { - quota = textInputTokens + int(math.Round(float64(textOutTokens)*completionRatio)) - quota += int(math.Round(float64(audioInputTokens)*audioRatio)) + int(math.Round(float64(audioOutTokens)*audioRatio*audioCompletionRatio)) - - quota = int(math.Round(float64(quota) * ratio)) + quota = int(math.Round(float64(textInputTokens)*ratio + float64(textOutTokens)*ratio*completionRatio)) + quota += int(math.Round(float64(audioInputTokens)*ratio*audioRatio + float64(audioOutTokens)*ratio*audioRatio*audioCompletionRatio)) if ratio != 0 && quota <= 0 { quota = 1 } From f08f7ae9402da7a42377bbd54e08448a6da273e9 Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Sat, 12 Oct 2024 14:13:11 +0800 Subject: [PATCH 10/13] fix: channel test (cherry picked from commit 052bdab1c45b3a4ba5f079afc763f54e751b1cd7) --- controller/channel-test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index 38e5dc7..5e69f70 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -117,7 +117,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr if usageA == nil { return errors.New("usage is nil"), nil } - usage := usageA.(dto.Usage) + usage := usageA.(*dto.Usage) result := w.Result() respBody, err := io.ReadAll(result.Body) if err != nil { From 4e0c522cd00d2b975175e462a509d32de4faacaa Mon Sep 17 00:00:00 2001 From: "1808837298@qq.com" <1808837298@qq.com> Date: Mon, 14 Oct 2024 15:40:34 +0800 Subject: [PATCH 11/13] =?UTF-8?q?fix:=20realtime=E8=AE=A1=E8=B4=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (cherry picked from commit fdfea8726c6d86d3844af1ac18d7b3df908f26a7) --- relay/channel/openai/relay-openai.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index d172525..9e0036e 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -427,7 +427,7 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op } common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken)) localUsage.TotalTokens += textToken + audioToken - localUsage.InputTokens += textToken + localUsage.InputTokens += textToken + audioToken localUsage.InputTokenDetails.TextTokens += textToken localUsage.InputTokenDetails.AudioTokens += audioToken From 7b1ff41e4cf3a38d7295519ea46bc3fc892d0fa1 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Tue, 5 Nov 2024 19:32:51 +0800 Subject: [PATCH 12/13] fix: mistral adaptor --- relay/channel/mistral/adaptor.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/relay/channel/mistral/adaptor.go b/relay/channel/mistral/adaptor.go index 5ca095d..4ab1a35 100644 --- a/relay/channel/mistral/adaptor.go +++ b/relay/channel/mistral/adaptor.go @@ -31,9 +31,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { channel.SetupApiRequestHeader(info, c, req) - req.Header.Set("Authorization", "Bearer "+info.ApiKey) + req.Set("Authorization", "Bearer "+info.ApiKey) return nil } @@ -50,11 +50,11 @@ func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dt return nil, nil } -func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { return channel.DoApiRequest(a, c, info, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) { if info.IsStream { err, usage = openai.OaiStreamHandler(c, resp, info) } else { From 0a80231e18acd991d159707603cb8cafd70d9676 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Tue, 5 Nov 2024 19:41:38 +0800 Subject: [PATCH 13/13] =?UTF-8?q?chore:=20=E5=88=A0=E9=99=A4=E6=97=A0?= =?UTF-8?q?=E7=94=A8=E6=97=A5=E5=BF=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- relay/channel/openai/relay-openai.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/relay/channel/openai/relay-openai.go b/relay/channel/openai/relay-openai.go index 9e0036e..2a087a6 100644 --- a/relay/channel/openai/relay-openai.go +++ b/relay/channel/openai/relay-openai.go @@ -508,9 +508,9 @@ func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*dto.Op localUsage = &dto.RealtimeUsage{} // print now usage } - common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage)) - common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) - common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) + //common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage)) + //common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) + //common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage)) } else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated { realtimeSession := realtimeEvent.Session