fix: fix embedding

This commit is contained in:
CaIon 2024-03-21 17:39:05 +08:00
parent 031957714a
commit c4b3d3a975
6 changed files with 60 additions and 37 deletions

View File

@ -1,12 +1,16 @@
package dto
type TextResponseWithError struct {
Choices []OpenAITextResponseChoice `json:"choices"`
Choices []OpenAITextResponseChoice `json:"choices"`
Object string `json:"object"`
Data []OpenAIEmbeddingResponseItem `json:"data"`
Model string `json:"model"`
Usage `json:"usage"`
Error OpenAIError `json:"error"`
}
type TextResponse struct {
Model string `json:"model"`
Choices []OpenAITextResponseChoice `json:"choices"`
Usage `json:"usage"`
}

View File

@ -45,7 +45,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
}
return
}

View File

@ -77,7 +77,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
err, responseText = OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage, sensitiveResp = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
err, usage, sensitiveResp = OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
}
return
}

View File

@ -124,8 +124,8 @@ func OpenaiStreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*d
return nil, responseTextBuilder.String()
}
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) {
var textResponseWithError dto.TextResponseWithError
func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model string, relayMode int) (*dto.OpenAIErrorWithStatusCode, *dto.Usage, *dto.SensitiveResponse) {
var responseWithError dto.TextResponseWithError
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil, nil
@ -134,62 +134,81 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
}
err = json.Unmarshal(responseBody, &textResponseWithError)
err = json.Unmarshal(responseBody, &responseWithError)
if err != nil {
log.Printf("unmarshal_response_body_failed: body: %s, err: %v", string(responseBody), err)
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil, nil
}
if textResponseWithError.Error.Type != "" {
if responseWithError.Error.Type != "" {
return &dto.OpenAIErrorWithStatusCode{
Error: textResponseWithError.Error,
Error: responseWithError.Error,
StatusCode: resp.StatusCode,
}, nil, nil
}
textResponse := &dto.TextResponse{
Choices: textResponseWithError.Choices,
Usage: textResponseWithError.Usage,
}
checkSensitive := constant.ShouldCheckCompletionSensitive()
sensitiveWords := make([]string, 0)
triggerSensitive := false
if textResponse.Usage.TotalTokens == 0 || checkSensitive {
completionTokens := 0
for i, choice := range textResponse.Choices {
stringContent := string(choice.Message.Content)
ctkm, _, _ := service.CountTokenText(stringContent, model, false)
completionTokens += ctkm
if checkSensitive {
sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false)
if sensitive {
triggerSensitive = true
msg := choice.Message
msg.Content = common.StringToByteSlice(stringContent)
textResponse.Choices[i].Message = msg
sensitiveWords = append(sensitiveWords, words...)
usage := &responseWithError.Usage
//textResponse := &dto.TextResponse{
// Choices: responseWithError.Choices,
// Usage: responseWithError.Usage,
//}
var doResponseBody []byte
switch relayMode {
case relayconstant.RelayModeEmbeddings:
embeddingResponse := &dto.OpenAIEmbeddingResponse{
Object: responseWithError.Object,
Data: responseWithError.Data,
Model: responseWithError.Model,
Usage: *usage,
}
doResponseBody, err = json.Marshal(embeddingResponse)
default:
if responseWithError.Usage.TotalTokens == 0 || checkSensitive {
completionTokens := 0
for i, choice := range responseWithError.Choices {
stringContent := string(choice.Message.Content)
ctkm, _, _ := service.CountTokenText(stringContent, model, false)
completionTokens += ctkm
if checkSensitive {
sensitive, words, stringContent := service.SensitiveWordReplace(stringContent, false)
if sensitive {
triggerSensitive = true
msg := choice.Message
msg.Content = common.StringToByteSlice(stringContent)
responseWithError.Choices[i].Message = msg
sensitiveWords = append(sensitiveWords, words...)
}
}
}
responseWithError.Usage = dto.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
}
}
textResponse.Usage = dto.Usage{
PromptTokens: promptTokens,
CompletionTokens: completionTokens,
TotalTokens: promptTokens + completionTokens,
textResponse := &dto.TextResponse{
Choices: responseWithError.Choices,
Model: responseWithError.Model,
Usage: *usage,
}
doResponseBody, err = json.Marshal(textResponse)
}
if checkSensitive && triggerSensitive && constant.StopOnSensitiveEnabled {
sensitiveWords = common.RemoveDuplicate(sensitiveWords)
return service.OpenAIErrorWrapper(errors.New(fmt.Sprintf("sensitive words detected on response: %s",
strings.Join(sensitiveWords, ", "))), "sensitive_words_detected", http.StatusBadRequest),
&textResponse.Usage, &dto.SensitiveResponse{
usage, &dto.SensitiveResponse{
SensitiveWords: sensitiveWords,
}
} else {
responseBody, err = json.Marshal(textResponse)
// Reset response body
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
resp.Body = io.NopCloser(bytes.NewBuffer(doResponseBody))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
@ -207,5 +226,5 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, promptTokens int, model
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil, nil
}
}
return nil, &textResponse.Usage, nil
return nil, usage, nil
}

View File

@ -49,7 +49,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
}
return
}

View File

@ -50,7 +50,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
err, responseText = openai.OpenaiStreamHandler(c, resp, info.RelayMode)
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
err, usage, sensitiveResp = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName, info.RelayMode)
}
return
}