feat: 初步兼容生成内容检查

This commit is contained in:
CaIon
2024-03-20 19:00:51 +08:00
parent 7a663d26ec
commit 64b9d3b58c
21 changed files with 141 additions and 63 deletions

View File

@@ -162,12 +162,21 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return service.RelayErrorHandler(resp)
}
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
usage, openaiErr, sensitiveResp := adaptor.DoResponse(c, resp, relayInfo)
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
return openaiErr
if sensitiveResp == nil { // 如果没有敏感词检查结果
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
return openaiErr
} else {
// 如果有敏感词检查结果,不返回预消耗配额,继续消耗配额
postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, sensitiveResp)
if constant.StopOnSensitiveEnabled { // 是否直接返回错误
return openaiErr
}
return nil
}
}
postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice)
postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, nil)
return nil
}
@@ -243,7 +252,10 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu
}
}
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest, usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, modelPrice float64) {
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest,
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
modelPrice float64, sensitiveResp *dto.SensitiveResponse) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
@@ -277,6 +289,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
logContent += fmt.Sprintf("(可能是上游超时)")
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, textRequest.Model, preConsumedQuota))
} else {
if sensitiveResp != nil {
logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quotaDelta, preConsumedQuota, true)
if err != nil {