feat: 初步兼容生成内容检查

This commit is contained in:
CaIon 2024-03-20 19:00:51 +08:00
parent 7a663d26ec
commit 64b9d3b58c
21 changed files with 141 additions and 63 deletions

View File

@ -36,3 +36,15 @@ func SundaySearch(text string, pattern string) bool {
} }
return false // 如果没有找到匹配,返回-1 return false // 如果没有找到匹配,返回-1
} }
func RemoveDuplicate(s []string) []string {
result := make([]string, 0, len(s))
temp := map[string]struct{}{}
for _, item := range s {
if _, ok := temp[item]; !ok {
temp[item] = struct{}{}
result = append(result, item)
}
}
return result
}

View File

@ -87,7 +87,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
err := relaycommon.RelayErrorHandler(resp) err := relaycommon.RelayErrorHandler(resp)
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
} }
usage, respErr := adaptor.DoResponse(c, resp, meta) usage, respErr, _ := adaptor.DoResponse(c, resp, meta)
if respErr != nil { if respErr != nil {
return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
} }

6
dto/sensitive.go Normal file
View File

@ -0,0 +1,6 @@
package dto
type SensitiveResponse struct {
SensitiveWords []string `json:"sensitive_words"`
Content string `json:"content"`
}

View File

@ -1,9 +1,9 @@
package dto package dto
type TextResponse struct { type TextResponse struct {
Choices []OpenAITextResponseChoice `json:"choices"` Choices []*OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"` Usage `json:"usage"`
Error OpenAIError `json:"error"` Error *OpenAIError `json:"error,omitempty"`
} }
type OpenAITextResponseChoice struct { type OpenAITextResponseChoice struct {

View File

@ -15,7 +15,7 @@ type Adaptor interface {
SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error) ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse)
GetModelList() []string GetModelList() []string
GetChannelName() string GetChannelName() string
} }

View File

@ -57,7 +57,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
if info.IsStream { if info.IsStream {
err, usage = aliStreamHandler(c, resp) err, usage = aliStreamHandler(c, resp)
} else { } else {

View File

@ -69,7 +69,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
if info.IsStream { if info.IsStream {
err, usage = baiduStreamHandler(c, resp) err, usage = baiduStreamHandler(c, resp)
} else { } else {

View File

@ -63,7 +63,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
if info.IsStream { if info.IsStream {
err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp) err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp)
} else { } else {

View File

@ -47,7 +47,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
if info.IsStream { if info.IsStream {
var responseText string var responseText string
err, responseText = geminiChatStreamHandler(c, resp) err, responseText = geminiChatStreamHandler(c, resp)

View File

@ -39,13 +39,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
if info.IsStream { if info.IsStream {
var responseText string var responseText string
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else { } else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
} }
return return
} }

View File

@ -71,13 +71,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
if info.IsStream { if info.IsStream {
var responseText string var responseText string
err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode) err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else { } else {
err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage, sensitiveResp = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
} }
return return
} }

View File

@ -4,8 +4,11 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"io" "io"
"log"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
@ -18,6 +21,7 @@ import (
) )
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) { func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) {
checkSensitive := constant.ShouldCheckCompletionSensitive()
var responseTextBuilder strings.Builder var responseTextBuilder strings.Builder
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
@ -37,11 +41,10 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
defer close(stopChan) defer close(stopChan)
defer close(dataChan) defer close(dataChan)
var wg sync.WaitGroup var wg sync.WaitGroup
go func() { go func() {
wg.Add(1) wg.Add(1)
defer wg.Done() defer wg.Done()
var streamItems []string var streamItems []string // store stream items
for scanner.Scan() { for scanner.Scan() {
data := scanner.Text() data := scanner.Text()
if len(data) < 6 { // ignore blank line or wrong format if len(data) < 6 { // ignore blank line or wrong format
@ -50,11 +53,20 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
if data[:6] != "data: " && data[:6] != "[DONE]" { if data[:6] != "data: " && data[:6] != "[DONE]" {
continue continue
} }
sensitive := false
if checkSensitive {
// check sensitive
sensitive, _, data = service.SensitiveWordReplace(data, constant.StopOnSensitiveEnabled)
}
dataChan <- data dataChan <- data
data = data[6:] data = data[6:]
if !strings.HasPrefix(data, "[DONE]") { if !strings.HasPrefix(data, "[DONE]") {
streamItems = append(streamItems, data) streamItems = append(streamItems, data)
} }
if sensitive && constant.StopOnSensitiveEnabled {
dataChan <- "data: [DONE]"
break
}
} }
streamResp := "[" + strings.Join(streamItems, ",") + "]" streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch relayMode { switch relayMode {
@ -112,50 +124,48 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
return nil, responseTextBuilder.String() return nil, responseTextBuilder.String()
} }
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) { func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) {
var textResponse dto.TextResponse var textResponse dto.TextResponse
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil
} }
err = resp.Body.Close() err = resp.Body.Close()
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
} }
err = json.Unmarshal(responseBody, &textResponse) err = json.Unmarshal(responseBody, &textResponse)
if err != nil { if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil
} }
if textResponse.Error.Type != "" { log.Printf("textResponse: %+v", textResponse)
if textResponse.Error != nil {
return &dto.OpenAIErrorWithStatusCode{ return &dto.OpenAIErrorWithStatusCode{
Error: textResponse.Error, Error: *textResponse.Error,
StatusCode: resp.StatusCode, StatusCode: resp.StatusCode,
}, nil }, nil, nil
}
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
} }
if textResponse.Usage.TotalTokens == 0 { checkSensitive := constant.ShouldCheckCompletionSensitive()
sensitiveWords := make([]string, 0)
triggerSensitive := false
if textResponse.Usage.TotalTokens == 0 || checkSensitive {
completionTokens := 0 completionTokens := 0
for _, choice := range textResponse.Choices { for _, choice := range textResponse.Choices {
ctkm, _ := service.CountTokenText(string(choice.Message.Content), model, constant.ShouldCheckCompletionSensitive()) stringContent := string(choice.Message.Content)
ctkm, _ := service.CountTokenText(stringContent, model, false)
completionTokens += ctkm completionTokens += ctkm
if checkSensitive {
sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false)
if sensitive {
triggerSensitive = true
msg := choice.Message
msg.Content = common.StringToByteSlice(stringContent)
choice.Message = msg
sensitiveWords = append(sensitiveWords, words...)
}
}
} }
textResponse.Usage = dto.Usage{ textResponse.Usage = dto.Usage{
PromptTokens: promptTokens, PromptTokens: promptTokens,
@ -163,5 +173,36 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
TotalTokens: promptTokens + completionTokens, TotalTokens: promptTokens + completionTokens,
} }
} }
return nil, &textResponse.Usage
if constant.StopOnSensitiveEnabled {
} else {
responseBody, err = json.Marshal(textResponse)
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil, nil
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
}
}
if checkSensitive && triggerSensitive {
sensitiveWords = common.RemoveDuplicate(sensitiveWords)
return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected: %s", strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest), &textResponse.Usage, &dto.SensitiveResponse{
SensitiveWords: sensitiveWords,
}
}
return nil, &textResponse.Usage, nil
} }

View File

@ -39,7 +39,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
if info.IsStream { if info.IsStream {
var responseText string var responseText string
err, responseText = palmStreamHandler(c, resp) err, responseText = palmStreamHandler(c, resp)

View File

@ -43,13 +43,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
if info.IsStream { if info.IsStream {
var responseText string var responseText string
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else { } else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
} }
return return
} }

View File

@ -53,7 +53,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
if info.IsStream { if info.IsStream {
var responseText string var responseText string
err, responseText = tencentStreamHandler(c, resp) err, responseText = tencentStreamHandler(c, resp)

View File

@ -43,13 +43,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return dummyResp, nil return dummyResp, nil
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
splits := strings.Split(info.ApiKey, "|") splits := strings.Split(info.ApiKey, "|")
if len(splits) != 3 { if len(splits) != 3 {
return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest), nil
} }
if a.request == nil { if a.request == nil {
return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest) return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest), nil
} }
if info.IsStream { if info.IsStream {
err, usage = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2]) err, usage = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2])

View File

@ -46,7 +46,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
if info.IsStream { if info.IsStream {
err, usage = zhipuStreamHandler(c, resp) err, usage = zhipuStreamHandler(c, resp)
} else { } else {

View File

@ -44,13 +44,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
return channel.DoApiRequest(a, c, info, requestBody) return channel.DoApiRequest(a, c, info, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
if info.IsStream { if info.IsStream {
var responseText string var responseText string
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode) err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens) usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else { } else {
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName) err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
} }
return return
} }

View File

@ -40,7 +40,7 @@ func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.Open
if err != nil { if err != nil {
return return
} }
OpenAIErrorWithStatusCode.Error = textResponse.Error OpenAIErrorWithStatusCode.Error = *textResponse.Error
return return
} }

View File

@ -162,12 +162,21 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
return service.RelayErrorHandler(resp) return service.RelayErrorHandler(resp)
} }
usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo) usage, openaiErr, sensitiveResp := adaptor.DoResponse(c, resp, relayInfo)
if openaiErr != nil { if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota) if sensitiveResp == nil { // 如果没有敏感词检查结果
return openaiErr returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
return openaiErr
} else {
// 如果有敏感词检查结果,不返回预消耗配额,继续消耗配额
postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, sensitiveResp)
if constant.StopOnSensitiveEnabled { // 是否直接返回错误
return openaiErr
}
return nil
}
} }
postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice) postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, nil)
return nil return nil
} }
@ -243,7 +252,10 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu
} }
} }
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest, usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, modelPrice float64) { func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest,
usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
modelPrice float64, sensitiveResp *dto.SensitiveResponse) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix() useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens completionTokens := usage.CompletionTokens
@ -277,6 +289,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
logContent += fmt.Sprintf("(可能是上游超时)") logContent += fmt.Sprintf("(可能是上游超时)")
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, textRequest.Model, preConsumedQuota)) common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, textRequest.Model, preConsumedQuota))
} else { } else {
if sensitiveResp != nil {
logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
}
quotaDelta := quota - preConsumedQuota quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quotaDelta, preConsumedQuota, true) err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quotaDelta, preConsumedQuota, true)
if err != nil { if err != nil {

View File

@ -24,18 +24,21 @@ func SensitiveWordContains(text string) (bool, []string) {
} }
// SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本 // SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本
func SensitiveWordReplace(text string) (bool, string) { func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string, string) {
text = strings.ToLower(text)
m := initAc() m := initAc()
hits := m.MultiPatternSearch([]rune(text), false) hits := m.MultiPatternSearch([]rune(text), returnImmediately)
if len(hits) > 0 { if len(hits) > 0 {
words := make([]string, 0)
for _, hit := range hits { for _, hit := range hits {
pos := hit.Pos pos := hit.Pos
word := string(hit.Word) word := string(hit.Word)
text = text[:pos] + strings.Repeat("*", len(word)) + text[pos+len(word):] text = text[:pos] + " *###* " + text[pos+len(word):]
words = append(words, word)
} }
return true, text return true, words, text
} }
return false, text return false, nil, text
} }
func initAc() *goahocorasick.Machine { func initAc() *goahocorasick.Machine {
@ -52,6 +55,7 @@ func readRunes() [][]rune {
var dict [][]rune var dict [][]rune
for _, word := range constant.SensitiveWords { for _, word := range constant.SensitiveWords {
word = strings.ToLower(word)
l := bytes.TrimSpace([]byte(word)) l := bytes.TrimSpace([]byte(word))
dict = append(dict, bytes.Runes(l)) dict = append(dict, bytes.Runes(l))
} }