diff --git a/relay/controller/audio.go b/relay/controller/audio.go index e3d57b1e..2f792ce3 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "io" + "mime/multipart" "net/http" "strings" @@ -30,8 +31,7 @@ import ( func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { ctx := c.Request.Context() meta := meta.GetByContext(c) - audioModel := "whisper-1" - + audioModel := "gpt-4o-transcribe" tokenId := c.GetInt(ctxkey.TokenId) channelType := c.GetInt(ctxkey.Channel) channelId := c.GetInt(ctxkey.ChannelId) @@ -124,12 +124,13 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus fullRequestURL := openai.GetFullRequestURL(baseURL, requestURL, channelType) if channelType == channeltype.Azure { apiVersion := meta.Config.APIVersion + deploymentName := c.GetString(ctxkey.ChannelName) if relayMode == relaymode.AudioTranscription { // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api - fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) + fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, deploymentName, apiVersion) } else if relayMode == relaymode.AudioSpeech { // https://learn.microsoft.com/en-us/azure/ai-services/openai/text-to-speech-quickstart?tabs=command-line#rest-api - fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", baseURL, audioModel, apiVersion) + fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", baseURL, deploymentName, apiVersion) } } @@ -138,8 +139,73 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus if err != nil { return openai.ErrorWrapper(err, "new_request_body_failed", http.StatusInternalServerError) } - c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes())) - responseFormat := c.DefaultPostForm("response_format", "json") + + // 处理表单数据 + contentType := c.Request.Header.Get("Content-Type") + responseFormat := "json" + var contentTypeWithBoundary string + + if strings.Contains(contentType, "multipart/form-data") { + originalBody := requestBody.Bytes() + c.Request.Body = io.NopCloser(bytes.NewBuffer(originalBody)) + err = c.Request.ParseMultipartForm(32 << 20) // 32MB 最大内存 + if err != nil { + return openai.ErrorWrapper(err, "parse_multipart_form_failed", http.StatusInternalServerError) + } + + // 获取响应格式 + if format := c.Request.FormValue("response_format"); format != "" { + responseFormat = format + } + + requestBody = &bytes.Buffer{} + writer := multipart.NewWriter(requestBody) + + // 复制表单字段 + for key, values := range c.Request.MultipartForm.Value { + for _, value := range values { + err = writer.WriteField(key, value) + if err != nil { + return openai.ErrorWrapper(err, "write_field_failed", http.StatusInternalServerError) + } + } + } + + // 复制文件 + for key, fileHeaders := range c.Request.MultipartForm.File { + for _, fileHeader := range fileHeaders { + file, err := fileHeader.Open() + if err != nil { + return openai.ErrorWrapper(err, "open_file_failed", http.StatusInternalServerError) + } + + part, err := writer.CreateFormFile(key, fileHeader.Filename) + if err != nil { + file.Close() + return openai.ErrorWrapper(err, "create_form_file_failed", http.StatusInternalServerError) + } + + _, err = io.Copy(part, file) + file.Close() + if err != nil { + return openai.ErrorWrapper(err, "copy_file_failed", http.StatusInternalServerError) + } + } + } + + // 完成multipart写入 + err = writer.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_writer_failed", http.StatusInternalServerError) + } + + // 更新Content-Type + contentTypeWithBoundary = writer.FormDataContentType() + c.Request.Header.Set("Content-Type", contentTypeWithBoundary) + } else { + // 对于非表单请求,直接重置请求体 + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes())) + } req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { @@ -151,11 +217,26 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") req.Header.Set("api-key", apiKey) - req.ContentLength = c.Request.ContentLength + // 确保请求体大小与Content-Length一致 + req.ContentLength = int64(requestBody.Len()) } else { req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) + // 确保请求体大小与Content-Length一致 + req.ContentLength = int64(requestBody.Len()) + } + + // 确保Content-Type正确传递 + if strings.Contains(contentType, "multipart/form-data") && c.Request.MultipartForm != nil { + // 对于multipart请求,使用我们重建时生成的Content-Type + // 注意:此处必须使用writer生成的boundary + if contentTypeWithBoundary != "" { + req.Header.Set("Content-Type", contentTypeWithBoundary) + } else { + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) + } + } else { + req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) } - req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) resp, err := client.HTTPClient.Do(req)