mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-14 20:23:46 +08:00
♻️ refactor: 重构moderation接口
This commit is contained in:
@@ -24,7 +24,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *types.OpenAIErrorWithStatus
|
||||
// 获取 Provider
|
||||
provider := providers.GetProvider(channelType, c)
|
||||
if provider == nil {
|
||||
return types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
|
||||
return types.ErrorWrapper(errors.New("channel not found"), "channel_not_found", http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
if !provider.SupportAPI(relayMode) {
|
||||
return types.ErrorWrapper(errors.New("channel does not support this API"), "channel_not_support_api", http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
modelMap, err := parseModelMapping(c.GetString("model_mapping"))
|
||||
@@ -45,12 +49,14 @@ func relayTextHelper(c *gin.Context, relayMode int) *types.OpenAIErrorWithStatus
|
||||
var openAIErrorWithStatusCode *types.OpenAIErrorWithStatusCode
|
||||
|
||||
switch relayMode {
|
||||
case RelayModeChatCompletions:
|
||||
case common.RelayModeChatCompletions:
|
||||
usage, openAIErrorWithStatusCode = handleChatCompletions(c, provider, modelMap, quotaInfo, group)
|
||||
case RelayModeCompletions:
|
||||
case common.RelayModeCompletions:
|
||||
usage, openAIErrorWithStatusCode = handleCompletions(c, provider, modelMap, quotaInfo, group)
|
||||
case RelayModeEmbeddings:
|
||||
case common.RelayModeEmbeddings:
|
||||
usage, openAIErrorWithStatusCode = handleEmbeddings(c, provider, modelMap, quotaInfo, group)
|
||||
case common.RelayModeModerations:
|
||||
usage, openAIErrorWithStatusCode = handleModerations(c, provider, modelMap, quotaInfo, group)
|
||||
default:
|
||||
return types.ErrorWrapper(errors.New("invalid relay mode"), "invalid_relay_mode", http.StatusBadRequest)
|
||||
}
|
||||
@@ -84,14 +90,21 @@ func relayTextHelper(c *gin.Context, relayMode int) *types.OpenAIErrorWithStatus
|
||||
func handleChatCompletions(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) {
|
||||
var chatRequest types.ChatCompletionRequest
|
||||
isModelMapped := false
|
||||
|
||||
chatProvider, ok := provider.(providers_base.ChatInterface)
|
||||
if !ok {
|
||||
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
err := common.UnmarshalBodyReusable(c, &chatRequest)
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if chatRequest.Messages == nil || len(chatRequest.Messages) == 0 {
|
||||
return nil, types.ErrorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if modelMap != nil && modelMap[chatRequest.Model] != "" {
|
||||
chatRequest.Model = modelMap[chatRequest.Model]
|
||||
isModelMapped = true
|
||||
@@ -114,10 +127,16 @@ func handleCompletions(c *gin.Context, provider providers_base.ProviderInterface
|
||||
if !ok {
|
||||
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
err := common.UnmarshalBodyReusable(c, &completionRequest)
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if completionRequest.Prompt == "" {
|
||||
return nil, types.ErrorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if modelMap != nil && modelMap[completionRequest.Model] != "" {
|
||||
completionRequest.Model = modelMap[completionRequest.Model]
|
||||
isModelMapped = true
|
||||
@@ -140,10 +159,16 @@ func handleEmbeddings(c *gin.Context, provider providers_base.ProviderInterface,
|
||||
if !ok {
|
||||
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
err := common.UnmarshalBodyReusable(c, &embeddingsRequest)
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if embeddingsRequest.Input == "" {
|
||||
return nil, types.ErrorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if modelMap != nil && modelMap[embeddingsRequest.Model] != "" {
|
||||
embeddingsRequest.Model = modelMap[embeddingsRequest.Model]
|
||||
isModelMapped = true
|
||||
@@ -158,3 +183,39 @@ func handleEmbeddings(c *gin.Context, provider providers_base.ProviderInterface,
|
||||
}
|
||||
return embeddingsProvider.EmbeddingsAction(&embeddingsRequest, isModelMapped, promptTokens)
|
||||
}
|
||||
|
||||
func handleModerations(c *gin.Context, provider providers_base.ProviderInterface, modelMap map[string]string, quotaInfo *QuotaInfo, group string) (*types.Usage, *types.OpenAIErrorWithStatusCode) {
|
||||
var moderationRequest types.ModerationRequest
|
||||
isModelMapped := false
|
||||
moderationProvider, ok := provider.(providers_base.ModerationInterface)
|
||||
if !ok {
|
||||
return nil, types.ErrorWrapper(errors.New("channel not implemented"), "channel_not_implemented", http.StatusNotImplemented)
|
||||
}
|
||||
|
||||
err := common.UnmarshalBodyReusable(c, &moderationRequest)
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "bind_request_body_failed", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if moderationRequest.Input == "" {
|
||||
return nil, types.ErrorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if moderationRequest.Model == "" {
|
||||
moderationRequest.Model = "text-moderation-latest"
|
||||
}
|
||||
|
||||
if modelMap != nil && modelMap[moderationRequest.Model] != "" {
|
||||
moderationRequest.Model = modelMap[moderationRequest.Model]
|
||||
isModelMapped = true
|
||||
}
|
||||
promptTokens := common.CountTokenInput(moderationRequest.Input, moderationRequest.Model)
|
||||
|
||||
quotaInfo.modelName = moderationRequest.Model
|
||||
quotaInfo.initQuotaInfo(group)
|
||||
quota_err := quotaInfo.preQuotaConsumption()
|
||||
if quota_err != nil {
|
||||
return nil, quota_err
|
||||
}
|
||||
return moderationProvider.ModerationAction(&moderationRequest, isModelMapped, promptTokens)
|
||||
}
|
||||
|
||||
@@ -56,19 +56,6 @@ func (m Message) StringContent() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
const (
|
||||
RelayModeUnknown = iota
|
||||
RelayModeChatCompletions
|
||||
RelayModeCompletions
|
||||
RelayModeEmbeddings
|
||||
RelayModeModerations
|
||||
RelayModeImagesGenerations
|
||||
RelayModeEdits
|
||||
RelayModeAudioSpeech
|
||||
RelayModeAudioTranscription
|
||||
RelayModeAudioTranslation
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/chat
|
||||
|
||||
type ResponseFormat struct {
|
||||
@@ -237,21 +224,18 @@ type CompletionsStreamResponse struct {
|
||||
func Relay(c *gin.Context) {
|
||||
var err *types.OpenAIErrorWithStatusCode
|
||||
|
||||
relayMode := RelayModeUnknown
|
||||
relayMode := common.RelayModeUnknown
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
|
||||
// err = relayChatHelper(c)
|
||||
relayMode = RelayModeChatCompletions
|
||||
relayMode = common.RelayModeChatCompletions
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
|
||||
// err = relayCompletionHelper(c)
|
||||
relayMode = RelayModeCompletions
|
||||
relayMode = common.RelayModeCompletions
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") {
|
||||
// err = relayEmbeddingsHelper(c)
|
||||
relayMode = RelayModeEmbeddings
|
||||
relayMode = common.RelayModeEmbeddings
|
||||
} else if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
||||
relayMode = RelayModeEmbeddings
|
||||
relayMode = common.RelayModeEmbeddings
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||
relayMode = common.RelayModeModerations
|
||||
}
|
||||
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||
// relayMode = RelayModeModerations
|
||||
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||
// relayMode = RelayModeImagesGenerations
|
||||
// } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") {
|
||||
|
||||
Reference in New Issue
Block a user