From a25bcaa58fb627318da1e3fb1298b78f17efb0c5 Mon Sep 17 00:00:00 2001 From: wozulong <> Date: Wed, 22 May 2024 11:30:38 +0800 Subject: [PATCH] fix the type of the logprobs Signed-off-by: wozulong <> --- dto/text_request.go | 2 +- relay/relay-text.go | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/dto/text_request.go b/dto/text_request.go index a73297c..59c7a18 100644 --- a/dto/text_request.go +++ b/dto/text_request.go @@ -33,7 +33,7 @@ type GeneralOpenAIRequest struct { ToolChoice any `json:"tool_choice,omitempty"` User string `json:"user,omitempty"` LogitBias any `json:"logit_bias"` - LogProbs bool `json:"logprobs,omitempty"` + LogProbs any `json:"logprobs,omitempty"` TopLogProbs int `json:"top_logprobs,omitempty"` } diff --git a/relay/relay-text.go b/relay/relay-text.go index 8e2ab5f..b34705c 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -191,7 +191,15 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re case relayconstant.RelayModeChatCompletions: promptTokens, err, sensitiveTrigger = service.CountTokenChatRequest(*textRequest, textRequest.Model, checkSensitive) case relayconstant.RelayModeCompletions: - promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive) + prompts := textRequest.Prompt + switch v := prompts.(type) { + case string: + prompts = v + textRequest.Suffix + case []string: + prompts = append(v, textRequest.Suffix) + } + + promptTokens, err, sensitiveTrigger = service.CountTokenInput(prompts, textRequest.Model, checkSensitive) case relayconstant.RelayModeModerations: promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive) case relayconstant.RelayModeEmbeddings: