mirror of
				https://github.com/linux-do/new-api.git
				synced 2025-11-04 13:23:42 +08:00 
			
		
		
		
	feat: 初步兼容生成内容检查
This commit is contained in:
		@@ -36,3 +36,15 @@ func SundaySearch(text string, pattern string) bool {
 | 
			
		||||
	}
 | 
			
		||||
	return false // 如果没有找到匹配,返回-1
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func RemoveDuplicate(s []string) []string {
 | 
			
		||||
	result := make([]string, 0, len(s))
 | 
			
		||||
	temp := map[string]struct{}{}
 | 
			
		||||
	for _, item := range s {
 | 
			
		||||
		if _, ok := temp[item]; !ok {
 | 
			
		||||
			temp[item] = struct{}{}
 | 
			
		||||
			result = append(result, item)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return result
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -87,7 +87,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
 | 
			
		||||
		err := relaycommon.RelayErrorHandler(resp)
 | 
			
		||||
		return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
 | 
			
		||||
	}
 | 
			
		||||
	usage, respErr := adaptor.DoResponse(c, resp, meta)
 | 
			
		||||
	usage, respErr, _ := adaptor.DoResponse(c, resp, meta)
 | 
			
		||||
	if respErr != nil {
 | 
			
		||||
		return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										6
									
								
								dto/sensitive.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								dto/sensitive.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,6 @@
 | 
			
		||||
package dto
 | 
			
		||||
 | 
			
		||||
type SensitiveResponse struct {
 | 
			
		||||
	SensitiveWords []string `json:"sensitive_words"`
 | 
			
		||||
	Content        string   `json:"content"`
 | 
			
		||||
}
 | 
			
		||||
@@ -1,9 +1,9 @@
 | 
			
		||||
package dto
 | 
			
		||||
 | 
			
		||||
type TextResponse struct {
 | 
			
		||||
	Choices []OpenAITextResponseChoice `json:"choices"`
 | 
			
		||||
	Choices []*OpenAITextResponseChoice `json:"choices"`
 | 
			
		||||
	Usage   `json:"usage"`
 | 
			
		||||
	Error   OpenAIError `json:"error"`
 | 
			
		||||
	Error   *OpenAIError `json:"error,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type OpenAITextResponseChoice struct {
 | 
			
		||||
 
 | 
			
		||||
@@ -15,7 +15,7 @@ type Adaptor interface {
 | 
			
		||||
	SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error
 | 
			
		||||
	ConvertRequest(c *gin.Context, relayMode int, request *dto.GeneralOpenAIRequest) (any, error)
 | 
			
		||||
	DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error)
 | 
			
		||||
	DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode)
 | 
			
		||||
	DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse)
 | 
			
		||||
	GetModelList() []string
 | 
			
		||||
	GetChannelName() string
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -57,7 +57,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		err, usage = aliStreamHandler(c, resp)
 | 
			
		||||
	} else {
 | 
			
		||||
 
 | 
			
		||||
@@ -69,7 +69,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		err, usage = baiduStreamHandler(c, resp)
 | 
			
		||||
	} else {
 | 
			
		||||
 
 | 
			
		||||
@@ -63,7 +63,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		err, usage = claudeStreamHandler(a.RequestMode, info.UpstreamModelName, info.PromptTokens, c, resp)
 | 
			
		||||
	} else {
 | 
			
		||||
 
 | 
			
		||||
@@ -47,7 +47,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		var responseText string
 | 
			
		||||
		err, responseText = geminiChatStreamHandler(c, resp)
 | 
			
		||||
 
 | 
			
		||||
@@ -39,13 +39,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		var responseText string
 | 
			
		||||
		err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
 | 
			
		||||
		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 | 
			
		||||
	} else {
 | 
			
		||||
		err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 | 
			
		||||
		err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -71,13 +71,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		var responseText string
 | 
			
		||||
		err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode)
 | 
			
		||||
		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 | 
			
		||||
	} else {
 | 
			
		||||
		err, usage = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 | 
			
		||||
		err, usage, sensitiveResp = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -4,8 +4,11 @@ import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"io"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"one-api/common"
 | 
			
		||||
	"one-api/constant"
 | 
			
		||||
@@ -18,6 +21,7 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*dto.OpenAIErrorWithStatusCode, string) {
 | 
			
		||||
	checkSensitive := constant.ShouldCheckCompletionSensitive()
 | 
			
		||||
	var responseTextBuilder strings.Builder
 | 
			
		||||
	scanner := bufio.NewScanner(resp.Body)
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
@@ -37,11 +41,10 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
 | 
			
		||||
	defer close(stopChan)
 | 
			
		||||
	defer close(dataChan)
 | 
			
		||||
	var wg sync.WaitGroup
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
		wg.Add(1)
 | 
			
		||||
		defer wg.Done()
 | 
			
		||||
		var streamItems []string
 | 
			
		||||
		var streamItems []string // store stream items
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			data := scanner.Text()
 | 
			
		||||
			if len(data) < 6 { // ignore blank line or wrong format
 | 
			
		||||
@@ -50,11 +53,20 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
 | 
			
		||||
			if data[:6] != "data: " && data[:6] != "[DONE]" {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			sensitive := false
 | 
			
		||||
			if checkSensitive {
 | 
			
		||||
				// check sensitive
 | 
			
		||||
				sensitive, _, data = service.SensitiveWordReplace(data, constant.StopOnSensitiveEnabled)
 | 
			
		||||
			}
 | 
			
		||||
			dataChan <- data
 | 
			
		||||
			data = data[6:]
 | 
			
		||||
			if !strings.HasPrefix(data, "[DONE]") {
 | 
			
		||||
				streamItems = append(streamItems, data)
 | 
			
		||||
			}
 | 
			
		||||
			if sensitive && constant.StopOnSensitiveEnabled {
 | 
			
		||||
				dataChan <- "data: [DONE]"
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		streamResp := "[" + strings.Join(streamItems, ",") + "]"
 | 
			
		||||
		switch relayMode {
 | 
			
		||||
@@ -112,50 +124,48 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
 | 
			
		||||
	return nil, responseTextBuilder.String()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
 | 
			
		||||
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) {
 | 
			
		||||
	var textResponse dto.TextResponse
 | 
			
		||||
	responseBody, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
		return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
 | 
			
		||||
	}
 | 
			
		||||
	err = json.Unmarshal(responseBody, &textResponse)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
		return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil
 | 
			
		||||
	}
 | 
			
		||||
	if textResponse.Error.Type != "" {
 | 
			
		||||
	log.Printf("textResponse: %+v", textResponse)
 | 
			
		||||
	if textResponse.Error != nil {
 | 
			
		||||
		return &dto.OpenAIErrorWithStatusCode{
 | 
			
		||||
			Error:      textResponse.Error,
 | 
			
		||||
			Error:      *textResponse.Error,
 | 
			
		||||
			StatusCode: resp.StatusCode,
 | 
			
		||||
		}, nil
 | 
			
		||||
	}
 | 
			
		||||
	// Reset response body
 | 
			
		||||
	resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 | 
			
		||||
	// We shouldn't set the header before we parse the response body, because the parse part may fail.
 | 
			
		||||
	// And then we will have to send an error response, but in this case, the header has already been set.
 | 
			
		||||
	// So the httpClient will be confused by the response.
 | 
			
		||||
	// For example, Postman will report error, and we cannot check the response at all.
 | 
			
		||||
	for k, v := range resp.Header {
 | 
			
		||||
		c.Writer.Header().Set(k, v[0])
 | 
			
		||||
	}
 | 
			
		||||
	c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
	_, err = io.Copy(c.Writer, resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
	}
 | 
			
		||||
	err = resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
 | 
			
		||||
		}, nil, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if textResponse.Usage.TotalTokens == 0 {
 | 
			
		||||
	checkSensitive := constant.ShouldCheckCompletionSensitive()
 | 
			
		||||
	sensitiveWords := make([]string, 0)
 | 
			
		||||
	triggerSensitive := false
 | 
			
		||||
 | 
			
		||||
	if textResponse.Usage.TotalTokens == 0 || checkSensitive {
 | 
			
		||||
		completionTokens := 0
 | 
			
		||||
		for _, choice := range textResponse.Choices {
 | 
			
		||||
			ctkm, _ := service.CountTokenText(string(choice.Message.Content), model, constant.ShouldCheckCompletionSensitive())
 | 
			
		||||
			stringContent := string(choice.Message.Content)
 | 
			
		||||
			ctkm, _ := service.CountTokenText(stringContent, model, false)
 | 
			
		||||
			completionTokens += ctkm
 | 
			
		||||
			if checkSensitive {
 | 
			
		||||
				sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false)
 | 
			
		||||
				if sensitive {
 | 
			
		||||
					triggerSensitive = true
 | 
			
		||||
					msg := choice.Message
 | 
			
		||||
					msg.Content = common.StringToByteSlice(stringContent)
 | 
			
		||||
					choice.Message = msg
 | 
			
		||||
					sensitiveWords = append(sensitiveWords, words...)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		textResponse.Usage = dto.Usage{
 | 
			
		||||
			PromptTokens:     promptTokens,
 | 
			
		||||
@@ -163,5 +173,36 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
 | 
			
		||||
			TotalTokens:      promptTokens + completionTokens,
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil, &textResponse.Usage
 | 
			
		||||
 | 
			
		||||
	if constant.StopOnSensitiveEnabled {
 | 
			
		||||
 | 
			
		||||
	} else {
 | 
			
		||||
		responseBody, err = json.Marshal(textResponse)
 | 
			
		||||
		// Reset response body
 | 
			
		||||
		resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
 | 
			
		||||
		// We shouldn't set the header before we parse the response body, because the parse part may fail.
 | 
			
		||||
		// And then we will have to send an error response, but in this case, the header has already been set.
 | 
			
		||||
		// So the httpClient will be confused by the response.
 | 
			
		||||
		// For example, Postman will report error, and we cannot check the response at all.
 | 
			
		||||
		for k, v := range resp.Header {
 | 
			
		||||
			c.Writer.Header().Set(k, v[0])
 | 
			
		||||
		}
 | 
			
		||||
		c.Writer.WriteHeader(resp.StatusCode)
 | 
			
		||||
		_, err = io.Copy(c.Writer, resp.Body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return service.OpenAIErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil, nil
 | 
			
		||||
		}
 | 
			
		||||
		err = resp.Body.Close()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if checkSensitive && triggerSensitive {
 | 
			
		||||
		sensitiveWords = common.RemoveDuplicate(sensitiveWords)
 | 
			
		||||
		return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected: %s", strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest), &textResponse.Usage, &dto.SensitiveResponse{
 | 
			
		||||
			SensitiveWords: sensitiveWords,
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil, &textResponse.Usage, nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -39,7 +39,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		var responseText string
 | 
			
		||||
		err, responseText = palmStreamHandler(c, resp)
 | 
			
		||||
 
 | 
			
		||||
@@ -43,13 +43,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		var responseText string
 | 
			
		||||
		err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
 | 
			
		||||
		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 | 
			
		||||
	} else {
 | 
			
		||||
		err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 | 
			
		||||
		err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -53,7 +53,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		var responseText string
 | 
			
		||||
		err, responseText = tencentStreamHandler(c, resp)
 | 
			
		||||
 
 | 
			
		||||
@@ -43,13 +43,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 | 
			
		||||
	return dummyResp, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
 | 
			
		||||
	splits := strings.Split(info.ApiKey, "|")
 | 
			
		||||
	if len(splits) != 3 {
 | 
			
		||||
		return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)
 | 
			
		||||
		return nil, service.OpenAIErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest), nil
 | 
			
		||||
	}
 | 
			
		||||
	if a.request == nil {
 | 
			
		||||
		return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest)
 | 
			
		||||
		return nil, service.OpenAIErrorWrapper(errors.New("request is nil"), "request_is_nil", http.StatusBadRequest), nil
 | 
			
		||||
	}
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		err, usage = xunfeiStreamHandler(c, *a.request, splits[0], splits[1], splits[2])
 | 
			
		||||
 
 | 
			
		||||
@@ -46,7 +46,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		err, usage = zhipuStreamHandler(c, resp)
 | 
			
		||||
	} else {
 | 
			
		||||
 
 | 
			
		||||
@@ -44,13 +44,13 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
 | 
			
		||||
	return channel.DoApiRequest(a, c, info, requestBody)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
 | 
			
		||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode, sensitiveResp *dto.SensitiveResponse) {
 | 
			
		||||
	if info.IsStream {
 | 
			
		||||
		var responseText string
 | 
			
		||||
		err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
 | 
			
		||||
		usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
 | 
			
		||||
	} else {
 | 
			
		||||
		err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 | 
			
		||||
		err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -40,7 +40,7 @@ func RelayErrorHandler(resp *http.Response) (OpenAIErrorWithStatusCode *dto.Open
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	OpenAIErrorWithStatusCode.Error = textResponse.Error
 | 
			
		||||
	OpenAIErrorWithStatusCode.Error = *textResponse.Error
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -162,12 +162,21 @@ func TextHelper(c *gin.Context) *dto.OpenAIErrorWithStatusCode {
 | 
			
		||||
		return service.RelayErrorHandler(resp)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	usage, openaiErr := adaptor.DoResponse(c, resp, relayInfo)
 | 
			
		||||
	usage, openaiErr, sensitiveResp := adaptor.DoResponse(c, resp, relayInfo)
 | 
			
		||||
	if openaiErr != nil {
 | 
			
		||||
		returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
 | 
			
		||||
		return openaiErr
 | 
			
		||||
		if sensitiveResp == nil { // 如果没有敏感词检查结果
 | 
			
		||||
			returnPreConsumedQuota(c, relayInfo.TokenId, userQuota, preConsumedQuota)
 | 
			
		||||
			return openaiErr
 | 
			
		||||
		} else {
 | 
			
		||||
			// 如果有敏感词检查结果,不返回预消耗配额,继续消耗配额
 | 
			
		||||
			postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, sensitiveResp)
 | 
			
		||||
			if constant.StopOnSensitiveEnabled { // 是否直接返回错误
 | 
			
		||||
				return openaiErr
 | 
			
		||||
			}
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice)
 | 
			
		||||
	postConsumeQuota(c, relayInfo, *textRequest, usage, ratio, preConsumedQuota, userQuota, modelRatio, groupRatio, modelPrice, nil)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -243,7 +252,10 @@ func returnPreConsumedQuota(c *gin.Context, tokenId int, userQuota int, preConsu
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest, usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64, modelPrice float64) {
 | 
			
		||||
func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRequest dto.GeneralOpenAIRequest,
 | 
			
		||||
	usage *dto.Usage, ratio float64, preConsumedQuota int, userQuota int, modelRatio float64, groupRatio float64,
 | 
			
		||||
	modelPrice float64, sensitiveResp *dto.SensitiveResponse) {
 | 
			
		||||
 | 
			
		||||
	useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
 | 
			
		||||
	promptTokens := usage.PromptTokens
 | 
			
		||||
	completionTokens := usage.CompletionTokens
 | 
			
		||||
@@ -277,6 +289,9 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, textRe
 | 
			
		||||
		logContent += fmt.Sprintf("(可能是上游超时)")
 | 
			
		||||
		common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, textRequest.Model, preConsumedQuota))
 | 
			
		||||
	} else {
 | 
			
		||||
		if sensitiveResp != nil {
 | 
			
		||||
			logContent += fmt.Sprintf(",敏感词:%s", strings.Join(sensitiveResp.SensitiveWords, ", "))
 | 
			
		||||
		}
 | 
			
		||||
		quotaDelta := quota - preConsumedQuota
 | 
			
		||||
		err := model.PostConsumeTokenQuota(relayInfo.TokenId, userQuota, quotaDelta, preConsumedQuota, true)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -24,18 +24,21 @@ func SensitiveWordContains(text string) (bool, []string) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本
 | 
			
		||||
func SensitiveWordReplace(text string) (bool, string) {
 | 
			
		||||
func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string, string) {
 | 
			
		||||
	text = strings.ToLower(text)
 | 
			
		||||
	m := initAc()
 | 
			
		||||
	hits := m.MultiPatternSearch([]rune(text), false)
 | 
			
		||||
	hits := m.MultiPatternSearch([]rune(text), returnImmediately)
 | 
			
		||||
	if len(hits) > 0 {
 | 
			
		||||
		words := make([]string, 0)
 | 
			
		||||
		for _, hit := range hits {
 | 
			
		||||
			pos := hit.Pos
 | 
			
		||||
			word := string(hit.Word)
 | 
			
		||||
			text = text[:pos] + strings.Repeat("*", len(word)) + text[pos+len(word):]
 | 
			
		||||
			text = text[:pos] + " *###* " + text[pos+len(word):]
 | 
			
		||||
			words = append(words, word)
 | 
			
		||||
		}
 | 
			
		||||
		return true, text
 | 
			
		||||
		return true, words, text
 | 
			
		||||
	}
 | 
			
		||||
	return false, text
 | 
			
		||||
	return false, nil, text
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func initAc() *goahocorasick.Machine {
 | 
			
		||||
@@ -52,6 +55,7 @@ func readRunes() [][]rune {
 | 
			
		||||
	var dict [][]rune
 | 
			
		||||
 | 
			
		||||
	for _, word := range constant.SensitiveWords {
 | 
			
		||||
		word = strings.ToLower(word)
 | 
			
		||||
		l := bytes.TrimSpace([]byte(word))
 | 
			
		||||
		dict = append(dict, bytes.Runes(l))
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user