feat:修复transcribe bug

This commit is contained in:
suziheng 2025-04-21 15:59:44 +08:00
parent 9746803a2f
commit c2bd301e0a

View File

@ -8,6 +8,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"mime/multipart"
"net/http" "net/http"
"strings" "strings"
@ -30,8 +31,7 @@ import (
func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
ctx := c.Request.Context() ctx := c.Request.Context()
meta := meta.GetByContext(c) meta := meta.GetByContext(c)
audioModel := "whisper-1" audioModel := "gpt-4o-transcribe"
tokenId := c.GetInt(ctxkey.TokenId) tokenId := c.GetInt(ctxkey.TokenId)
channelType := c.GetInt(ctxkey.Channel) channelType := c.GetInt(ctxkey.Channel)
channelId := c.GetInt(ctxkey.ChannelId) channelId := c.GetInt(ctxkey.ChannelId)
@ -124,12 +124,13 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
fullRequestURL := openai.GetFullRequestURL(baseURL, requestURL, channelType) fullRequestURL := openai.GetFullRequestURL(baseURL, requestURL, channelType)
if channelType == channeltype.Azure { if channelType == channeltype.Azure {
apiVersion := meta.Config.APIVersion apiVersion := meta.Config.APIVersion
deploymentName := c.GetString(ctxkey.ChannelName)
if relayMode == relaymode.AudioTranscription { if relayMode == relaymode.AudioTranscription {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, deploymentName, apiVersion)
} else if relayMode == relaymode.AudioSpeech { } else if relayMode == relaymode.AudioSpeech {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/text-to-speech-quickstart?tabs=command-line#rest-api // https://learn.microsoft.com/en-us/azure/ai-services/openai/text-to-speech-quickstart?tabs=command-line#rest-api
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", baseURL, audioModel, apiVersion) fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", baseURL, deploymentName, apiVersion)
} }
} }
@ -138,8 +139,73 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "new_request_body_failed", http.StatusInternalServerError) return openai.ErrorWrapper(err, "new_request_body_failed", http.StatusInternalServerError)
} }
// 处理表单数据
contentType := c.Request.Header.Get("Content-Type")
responseFormat := "json"
var contentTypeWithBoundary string
if strings.Contains(contentType, "multipart/form-data") {
originalBody := requestBody.Bytes()
c.Request.Body = io.NopCloser(bytes.NewBuffer(originalBody))
err = c.Request.ParseMultipartForm(32 << 20) // 32MB 最大内存
if err != nil {
return openai.ErrorWrapper(err, "parse_multipart_form_failed", http.StatusInternalServerError)
}
// 获取响应格式
if format := c.Request.FormValue("response_format"); format != "" {
responseFormat = format
}
requestBody = &bytes.Buffer{}
writer := multipart.NewWriter(requestBody)
// 复制表单字段
for key, values := range c.Request.MultipartForm.Value {
for _, value := range values {
err = writer.WriteField(key, value)
if err != nil {
return openai.ErrorWrapper(err, "write_field_failed", http.StatusInternalServerError)
}
}
}
// 复制文件
for key, fileHeaders := range c.Request.MultipartForm.File {
for _, fileHeader := range fileHeaders {
file, err := fileHeader.Open()
if err != nil {
return openai.ErrorWrapper(err, "open_file_failed", http.StatusInternalServerError)
}
part, err := writer.CreateFormFile(key, fileHeader.Filename)
if err != nil {
file.Close()
return openai.ErrorWrapper(err, "create_form_file_failed", http.StatusInternalServerError)
}
_, err = io.Copy(part, file)
file.Close()
if err != nil {
return openai.ErrorWrapper(err, "copy_file_failed", http.StatusInternalServerError)
}
}
}
// 完成multipart写入
err = writer.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_writer_failed", http.StatusInternalServerError)
}
// 更新Content-Type
contentTypeWithBoundary = writer.FormDataContentType()
c.Request.Header.Set("Content-Type", contentTypeWithBoundary)
} else {
// 对于非表单请求,直接重置请求体
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes())) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes()))
responseFormat := c.DefaultPostForm("response_format", "json") }
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil { if err != nil {
@ -151,11 +217,26 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
apiKey := c.Request.Header.Get("Authorization") apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ") apiKey = strings.TrimPrefix(apiKey, "Bearer ")
req.Header.Set("api-key", apiKey) req.Header.Set("api-key", apiKey)
req.ContentLength = c.Request.ContentLength // 确保请求体大小与Content-Length一致
req.ContentLength = int64(requestBody.Len())
} else { } else {
req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
// 确保请求体大小与Content-Length一致
req.ContentLength = int64(requestBody.Len())
} }
// 确保Content-Type正确传递
if strings.Contains(contentType, "multipart/form-data") && c.Request.MultipartForm != nil {
// 对于multipart请求使用我们重建时生成的Content-Type
// 注意此处必须使用writer生成的boundary
if contentTypeWithBoundary != "" {
req.Header.Set("Content-Type", contentTypeWithBoundary)
} else {
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
}
} else {
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
}
req.Header.Set("Accept", c.Request.Header.Get("Accept")) req.Header.Set("Accept", c.Request.Header.Get("Accept"))
resp, err := client.HTTPClient.Do(req) resp, err := client.HTTPClient.Do(req)