package controller import ( "bufio" "bytes" "context" "encoding/json" "errors" "fmt" "io" "mime/multipart" "net/http" "strings" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/client" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/billing" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/relaymode" ) func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { ctx := c.Request.Context() meta := meta.GetByContext(c) audioModel := "gpt-4o-transcribe" tokenId := c.GetInt(ctxkey.TokenId) channelType := c.GetInt(ctxkey.Channel) channelId := c.GetInt(ctxkey.ChannelId) userId := c.GetInt(ctxkey.Id) group := c.GetString(ctxkey.Group) tokenName := c.GetString(ctxkey.TokenName) var ttsRequest openai.TextToSpeechRequest if relayMode == relaymode.AudioSpeech { // Read JSON err := common.UnmarshalBodyReusable(c, &ttsRequest) // Check if JSON is valid if err != nil { return openai.ErrorWrapper(err, "invalid_json", http.StatusBadRequest) } audioModel = ttsRequest.Model // Check if text is too long 4096 if len(ttsRequest.Input) > 4096 { return openai.ErrorWrapper(errors.New("input is too long (over 4096 characters)"), "text_too_long", http.StatusBadRequest) } } modelRatio := billingratio.GetModelRatio(audioModel, channelType) groupRatio := billingratio.GetGroupRatio(group) ratio := modelRatio * groupRatio var quota int64 var preConsumedQuota int64 switch relayMode { case relaymode.AudioSpeech: preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio) quota = preConsumedQuota default: preConsumedQuota = int64(float64(config.PreConsumedQuota) * ratio) } userQuota, err := model.CacheGetUserQuota(ctx, userId) if err != nil { return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) } // Check if user quota is enough if userQuota-preConsumedQuota < 0 { return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) if err != nil { return openai.ErrorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) } if userQuota > 100*preConsumedQuota { // in this case, we do not pre-consume quota // because the user has enough quota preConsumedQuota = 0 } if preConsumedQuota > 0 { err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) if err != nil { return openai.ErrorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) } } succeed := false defer func() { if succeed { return } if preConsumedQuota > 0 { // we need to roll back the pre-consumed quota defer func(ctx context.Context) { go func() { // negative means add quota back for token & user err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) if err != nil { logger.Error(ctx, fmt.Sprintf("error rollback pre-consumed quota: %s", err.Error())) } }() }(c.Request.Context()) } }() // map model name modelMapping := c.GetStringMapString(ctxkey.ModelMapping) if modelMapping != nil && modelMapping[audioModel] != "" { audioModel = modelMapping[audioModel] } baseURL := channeltype.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() if c.GetString(ctxkey.BaseURL) != "" { baseURL = c.GetString(ctxkey.BaseURL) } fullRequestURL := openai.GetFullRequestURL(baseURL, requestURL, channelType) if channelType == channeltype.Azure { apiVersion := meta.Config.APIVersion deploymentName := c.GetString(ctxkey.ChannelName) if relayMode == relaymode.AudioTranscription { // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, deploymentName, apiVersion) } else if relayMode == relaymode.AudioSpeech { // https://learn.microsoft.com/en-us/azure/ai-services/openai/text-to-speech-quickstart?tabs=command-line#rest-api fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", baseURL, deploymentName, apiVersion) } } requestBody := &bytes.Buffer{} _, err = io.Copy(requestBody, c.Request.Body) if err != nil { return openai.ErrorWrapper(err, "new_request_body_failed", http.StatusInternalServerError) } // 处理表单数据 contentType := c.Request.Header.Get("Content-Type") responseFormat := "json" var contentTypeWithBoundary string if strings.Contains(contentType, "multipart/form-data") { originalBody := requestBody.Bytes() c.Request.Body = io.NopCloser(bytes.NewBuffer(originalBody)) err = c.Request.ParseMultipartForm(32 << 20) // 32MB 最大内存 if err != nil { return openai.ErrorWrapper(err, "parse_multipart_form_failed", http.StatusInternalServerError) } // 获取响应格式 if format := c.Request.FormValue("response_format"); format != "" { responseFormat = format } requestBody = &bytes.Buffer{} writer := multipart.NewWriter(requestBody) // 复制表单字段 for key, values := range c.Request.MultipartForm.Value { for _, value := range values { err = writer.WriteField(key, value) if err != nil { return openai.ErrorWrapper(err, "write_field_failed", http.StatusInternalServerError) } } } // 复制文件 for key, fileHeaders := range c.Request.MultipartForm.File { for _, fileHeader := range fileHeaders { file, err := fileHeader.Open() if err != nil { return openai.ErrorWrapper(err, "open_file_failed", http.StatusInternalServerError) } part, err := writer.CreateFormFile(key, fileHeader.Filename) if err != nil { file.Close() return openai.ErrorWrapper(err, "create_form_file_failed", http.StatusInternalServerError) } _, err = io.Copy(part, file) file.Close() if err != nil { return openai.ErrorWrapper(err, "copy_file_failed", http.StatusInternalServerError) } } } // 完成multipart写入 err = writer.Close() if err != nil { return openai.ErrorWrapper(err, "close_writer_failed", http.StatusInternalServerError) } // 更新Content-Type contentTypeWithBoundary = writer.FormDataContentType() c.Request.Header.Set("Content-Type", contentTypeWithBoundary) } else { // 对于非表单请求,直接重置请求体 c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes())) } req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) } if (relayMode == relaymode.AudioTranscription || relayMode == relaymode.AudioSpeech) && channelType == channeltype.Azure { // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") req.Header.Set("api-key", apiKey) // 确保请求体大小与Content-Length一致 req.ContentLength = int64(requestBody.Len()) } else { req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) // 确保请求体大小与Content-Length一致 req.ContentLength = int64(requestBody.Len()) } // 确保Content-Type正确传递 if strings.Contains(contentType, "multipart/form-data") && c.Request.MultipartForm != nil { // 对于multipart请求,使用我们重建时生成的Content-Type // 注意:此处必须使用writer生成的boundary if contentTypeWithBoundary != "" { req.Header.Set("Content-Type", contentTypeWithBoundary) } else { req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) } } else { req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) } req.Header.Set("Accept", c.Request.Header.Get("Accept")) resp, err := client.HTTPClient.Do(req) if err != nil { return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } err = req.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } err = c.Request.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } if relayMode != relaymode.AudioSpeech { responseBody, err := io.ReadAll(resp.Body) if err != nil { return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) } err = resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } var openAIErr openai.SlimTextResponse if err = json.Unmarshal(responseBody, &openAIErr); err == nil { if openAIErr.Error.Message != "" { return openai.ErrorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError) } } var text string switch responseFormat { case "json": text, err = getTextFromJSON(responseBody) case "text": text, err = getTextFromText(responseBody) case "srt": text, err = getTextFromSRT(responseBody) case "verbose_json": text, err = getTextFromVerboseJSON(responseBody) case "vtt": text, err = getTextFromVTT(responseBody) default: return openai.ErrorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError) } if err != nil { return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) } quota = int64(openai.CountTokenText(text, audioModel)) resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) } if resp.StatusCode != http.StatusOK { return RelayErrorHandler(resp) } succeed = true quotaDelta := quota - preConsumedQuota defer func(ctx context.Context) { go billing.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) }(c.Request.Context()) for k, v := range resp.Header { c.Writer.Header().Set(k, v[0]) } c.Writer.WriteHeader(resp.StatusCode) _, err = io.Copy(c.Writer, resp.Body) if err != nil { return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) } err = resp.Body.Close() if err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) } return nil } func getTextFromVTT(body []byte) (string, error) { return getTextFromSRT(body) } func getTextFromVerboseJSON(body []byte) (string, error) { var whisperResponse openai.WhisperVerboseJSONResponse if err := json.Unmarshal(body, &whisperResponse); err != nil { return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) } return whisperResponse.Text, nil } func getTextFromSRT(body []byte) (string, error) { scanner := bufio.NewScanner(strings.NewReader(string(body))) var builder strings.Builder var textLine bool for scanner.Scan() { line := scanner.Text() if textLine { builder.WriteString(line) textLine = false continue } else if strings.Contains(line, "-->") { textLine = true continue } } if err := scanner.Err(); err != nil { return "", err } return builder.String(), nil } func getTextFromText(body []byte) (string, error) { return strings.TrimSuffix(string(body), "\n"), nil } func getTextFromJSON(body []byte) (string, error) { var whisperResponse openai.WhisperJSONResponse if err := json.Unmarshal(body, &whisperResponse); err != nil { return "", fmt.Errorf("unmarshal_response_body_failed err :%w", err) } return whisperResponse.Text, nil }