diff --git a/dto/text_request.go b/dto/text_request.go index a73297c..59c7a18 100644 --- a/dto/text_request.go +++ b/dto/text_request.go @@ -33,7 +33,7 @@ type GeneralOpenAIRequest struct { ToolChoice any `json:"tool_choice,omitempty"` User string `json:"user,omitempty"` LogitBias any `json:"logit_bias"` - LogProbs bool `json:"logprobs,omitempty"` + LogProbs any `json:"logprobs,omitempty"` TopLogProbs int `json:"top_logprobs,omitempty"` } diff --git a/relay/relay-text.go b/relay/relay-text.go index 8e2ab5f..b34705c 100644 --- a/relay/relay-text.go +++ b/relay/relay-text.go @@ -191,7 +191,15 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re case relayconstant.RelayModeChatCompletions: promptTokens, err, sensitiveTrigger = service.CountTokenChatRequest(*textRequest, textRequest.Model, checkSensitive) case relayconstant.RelayModeCompletions: - promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Prompt, textRequest.Model, checkSensitive) + prompts := textRequest.Prompt + switch v := prompts.(type) { + case string: + prompts = v + textRequest.Suffix + case []string: + prompts = append(v, textRequest.Suffix) + } + + promptTokens, err, sensitiveTrigger = service.CountTokenInput(prompts, textRequest.Model, checkSensitive) case relayconstant.RelayModeModerations: promptTokens, err, sensitiveTrigger = service.CountTokenInput(textRequest.Input, textRequest.Model, checkSensitive) case relayconstant.RelayModeEmbeddings: