♻️ refactor: 重构moderation接口

This commit is contained in:
Martial BE
2023-11-29 16:54:37 +08:00
parent 455269c145
commit 1c7c2d40bb
10 changed files with 183 additions and 28 deletions

View File

@@ -21,6 +21,7 @@ type BaseProvider struct {
ChatCompletions string
Embeddings string
AudioSpeech string
Moderation string
AudioTranscriptions string
AudioTranslations string
Proxy string
@@ -125,3 +126,24 @@ func (p *BaseProvider) HandleErrorResp(resp *http.Response) (openAIErrorWithStat
}
return
}
func (p *BaseProvider) SupportAPI(relayMode int) bool {
switch relayMode {
case common.RelayModeChatCompletions:
return p.ChatCompletions != ""
case common.RelayModeCompletions:
return p.Completions != ""
case common.RelayModeEmbeddings:
return p.Embeddings != ""
case common.RelayModeAudioSpeech:
return p.AudioSpeech != ""
case common.RelayModeAudioTranscription:
return p.AudioTranscriptions != ""
case common.RelayModeAudioTranslation:
return p.AudioTranslations != ""
case common.RelayModeModerations:
return p.Moderation != ""
default:
return false
}
}

View File

@@ -11,6 +11,7 @@ type ProviderInterface interface {
GetBaseURL() string
GetFullRequestURL(requestURL string, modelName string) string
GetRequestHeaders() (headers map[string]string)
SupportAPI(relayMode int) bool
}
// 完成接口
@@ -31,6 +32,12 @@ type EmbeddingsInterface interface {
EmbeddingsAction(request *types.EmbeddingRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
}
// 审查接口
type ModerationInterface interface {
ProviderInterface
ModerationAction(request *types.ModerationRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
}
// 余额接口
type BalanceInterface interface {
BalanceAction(channel *model.Channel) (float64, error)

View File

@@ -34,6 +34,7 @@ func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider {
Completions: "/v1/completions",
ChatCompletions: "/v1/chat/completions",
Embeddings: "/v1/embeddings",
Moderation: "/v1/moderations",
AudioSpeech: "/v1/audio/speech",
AudioTranscriptions: "/v1/audio/transcriptions",
AudioTranslations: "/v1/audio/translations",

View File

@@ -0,0 +1,49 @@
package openai
import (
"net/http"
"one-api/common"
"one-api/types"
)
func (c *OpenAIProviderModerationResponse) responseHandler(resp *http.Response) (errWithCode *types.OpenAIErrorWithStatusCode) {
if c.Error.Type != "" {
errWithCode = &types.OpenAIErrorWithStatusCode{
OpenAIError: c.Error,
StatusCode: resp.StatusCode,
}
return
}
return nil
}
func (p *OpenAIProvider) ModerationAction(request *types.ModerationRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
requestBody, err := p.getRequestBody(&request, isModelMapped)
if err != nil {
return nil, types.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
fullRequestURL := p.GetFullRequestURL(p.Moderation, request.Model)
headers := p.GetRequestHeaders()
client := common.NewClient()
req, err := client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(requestBody), common.WithHeader(headers))
if err != nil {
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
openAIProviderModerationResponse := &OpenAIProviderModerationResponse{}
errWithCode = p.sendRequest(req, openAIProviderModerationResponse)
if errWithCode != nil {
return
}
usage = &types.Usage{
PromptTokens: promptTokens,
CompletionTokens: 0,
TotalTokens: promptTokens,
}
return
}

View File

@@ -21,3 +21,8 @@ type OpenAIProviderEmbeddingsResponse struct {
types.EmbeddingResponse
types.OpenAIErrorResponse
}
type OpenAIProviderModerationResponse struct {
types.ModerationResponse
types.OpenAIErrorResponse
}