add transcriptions api

This commit is contained in:
Martial BE
2023-12-01 10:54:07 +08:00
parent 7c6dee7390
commit a013b1a166
18 changed files with 304 additions and 24 deletions

View File

@@ -3,6 +3,7 @@ package controller
import (
"context"
"errors"
"fmt"
"net/http"
"one-api/common"
"one-api/model"
@@ -58,6 +59,8 @@ func relayHelper(c *gin.Context, relayMode int) *types.OpenAIErrorWithStatusCode
usage, openAIErrorWithStatusCode = handleModerations(c, provider, modelMap, quotaInfo, group)
case common.RelayModeAudioSpeech:
usage, openAIErrorWithStatusCode = handleSpeech(c, provider, modelMap, quotaInfo, group)
case common.RelayModeAudioTranscription:
usage, openAIErrorWithStatusCode = handleTranscriptions(c, provider, modelMap, quotaInfo, group)
default:
return types.ErrorWrapper(errors.New("invalid relay mode"), "invalid_relay_mode", http.StatusBadRequest)
}
@@ -257,3 +260,37 @@ func handleSpeech(c *gin.Context, provider providers_base.ProviderInterface, mod
}
return speechProvider.SpeechAction(&speechRequest, isModelMapped, promptTokens)
}
func handleTranscriptions(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) {
var audioRequest types.AudioRequest
isModelMapped := false
speechProvider, ok := provider.(providers_base.TranscriptionsInterface)
if !ok {
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
}
err := common.UnmarshalBodyReusable(c, &audioRequest)
if err != nil {
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
}
if audioRequest.File == nil {
fmt.Println(audioRequest)
return nil, types.ErrorWrapper(errors.New("field file is required"), "required_field_missing", http.StatusBadRequest)
}
if modelMap != nil && modelMap[audioRequest.Model] != "" {
audioRequest.Model = modelMap[audioRequest.Model]
isModelMapped = true
}
promptTokens := 0
quotaInfo.modelName = audioRequest.Model
quotaInfo.promptTokens = promptTokens
quotaInfo.initQuotaInfo(group)
quota_err := quotaInfo.preQuotaConsumption()
if quota_err != nil {
return nil, quota_err
}
return speechProvider.TranscriptionsAction(&audioRequest, isModelMapped, promptTokens)
}

View File

@@ -222,6 +222,7 @@ type CompletionsStreamResponse struct {
}
func Relay(c *gin.Context) {
defer c.Request.Body.Close()
var err *types.OpenAIErrorWithStatusCode
relayMode := common.RelayModeUnknown
@@ -237,13 +238,14 @@ func Relay(c *gin.Context) {
relayMode = common.RelayModeModerations
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") {
relayMode = common.RelayModeAudioSpeech
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
relayMode = common.RelayModeAudioTranscription
}
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
// relayMode = RelayModeImagesGenerations
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
// relayMode = RelayModeEdits
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") {
// relayMode = RelayModeAudioTranscription
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
// relayMode = RelayModeAudioTranslation
// }