diff --git a/controller/channel-test.go b/controller/channel-test.go index 18f1e9b..c0ac9a6 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -86,9 +86,10 @@ func buildTestRequest() *ChatRequest { Model: "", // this will be set later MaxTokens: 1, } + content, _ := json.Marshal("hi") testMessage := Message{ Role: "user", - Content: "hi", + Content: content, } testRequest.Messages = append(testRequest.Messages, testMessage) return testRequest @@ -186,6 +187,10 @@ func testAllChannels(notify bool) error { err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) ban = true } + if openaiErr != nil { + err = errors.New(fmt.Sprintf("type %s, code %v, message %s", openaiErr.Type, openaiErr.Code, openaiErr.Message)) + ban = true + } // parse *int to bool if channel.AutoBan != nil && *channel.AutoBan == 0 { ban = false diff --git a/controller/relay-aiproxy.go b/controller/relay-aiproxy.go index d0159ce..7dbf679 100644 --- a/controller/relay-aiproxy.go +++ b/controller/relay-aiproxy.go @@ -48,7 +48,7 @@ type AIProxyLibraryStreamResponse struct { func requestOpenAI2AIProxyLibrary(request GeneralOpenAIRequest) *AIProxyLibraryRequest { query := "" if len(request.Messages) != 0 { - query = request.Messages[len(request.Messages)-1].Content + query = string(request.Messages[len(request.Messages)-1].Content) } return &AIProxyLibraryRequest{ Model: request.Model, @@ -69,7 +69,7 @@ func aiProxyDocuments2Markdown(documents []AIProxyLibraryDocument) string { } func responseAIProxyLibrary2OpenAI(response *AIProxyLibraryResponse) *OpenAITextResponse { - content := response.Answer + aiProxyDocuments2Markdown(response.Documents) + content, _ := json.Marshal(response.Answer + aiProxyDocuments2Markdown(response.Documents)) choice := OpenAITextResponseChoice{ Index: 0, Message: Message{ diff --git a/controller/relay-ali.go b/controller/relay-ali.go index 50dc743..6a79d2b 100644 --- a/controller/relay-ali.go +++ b/controller/relay-ali.go @@ -88,18 +88,18 @@ func requestOpenAI2Ali(request GeneralOpenAIRequest) *AliChatRequest { message := request.Messages[i] if message.Role == "system" { messages = append(messages, AliMessage{ - User: message.Content, + User: string(message.Content), Bot: "Okay", }) continue } else { if i == len(request.Messages)-1 { - prompt = message.Content + prompt = string(message.Content) break } messages = append(messages, AliMessage{ - User: message.Content, - Bot: request.Messages[i+1].Content, + User: string(message.Content), + Bot: string(request.Messages[i+1].Content), }) i++ } @@ -184,11 +184,12 @@ func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse) *OpenAIEmbeddin } func responseAli2OpenAI(response *AliChatResponse) *OpenAITextResponse { + content, _ := json.Marshal(response.Output.Text) choice := OpenAITextResponseChoice{ Index: 0, Message: Message{ Role: "assistant", - Content: response.Output.Text, + Content: content, }, FinishReason: response.Output.FinishReason, } diff --git a/controller/relay-baidu.go b/controller/relay-baidu.go index ed08ac0..05bbad0 100644 --- a/controller/relay-baidu.go +++ b/controller/relay-baidu.go @@ -89,7 +89,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { if message.Role == "system" { messages = append(messages, BaiduMessage{ Role: "user", - Content: message.Content, + Content: string(message.Content), }) messages = append(messages, BaiduMessage{ Role: "assistant", @@ -98,7 +98,7 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { } else { messages = append(messages, BaiduMessage{ Role: message.Role, - Content: message.Content, + Content: string(message.Content), }) } } @@ -109,11 +109,12 @@ func requestOpenAI2Baidu(request GeneralOpenAIRequest) *BaiduChatRequest { } func responseBaidu2OpenAI(response *BaiduChatResponse) *OpenAITextResponse { + content, _ := json.Marshal(response.Result) choice := OpenAITextResponseChoice{ Index: 0, Message: Message{ Role: "assistant", - Content: response.Result, + Content: content, }, FinishReason: "stop", } diff --git a/controller/relay-claude.go b/controller/relay-claude.go index 1f4a3e7..e131263 100644 --- a/controller/relay-claude.go +++ b/controller/relay-claude.go @@ -93,11 +93,12 @@ func streamResponseClaude2OpenAI(claudeResponse *ClaudeResponse) *ChatCompletion } func responseClaude2OpenAI(claudeResponse *ClaudeResponse) *OpenAITextResponse { + content, _ := json.Marshal(strings.TrimPrefix(claudeResponse.Completion, " ")) choice := OpenAITextResponseChoice{ Index: 0, Message: Message{ Role: "assistant", - Content: strings.TrimPrefix(claudeResponse.Completion, " "), + Content: content, Name: nil, }, FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), diff --git a/controller/relay-openai.go b/controller/relay-openai.go index 6bdfbc0..9b08f85 100644 --- a/controller/relay-openai.go +++ b/controller/relay-openai.go @@ -132,7 +132,7 @@ func openaiHandler(c *gin.Context, resp *http.Response, consumeQuota bool, promp if textResponse.Usage.TotalTokens == 0 { completionTokens := 0 for _, choice := range textResponse.Choices { - completionTokens += countTokenText(choice.Message.Content, model) + completionTokens += countTokenText(string(choice.Message.Content), model) } textResponse.Usage = Usage{ PromptTokens: promptTokens, diff --git a/controller/relay-palm.go b/controller/relay-palm.go index a705b31..a7b0c1f 100644 --- a/controller/relay-palm.go +++ b/controller/relay-palm.go @@ -59,7 +59,7 @@ func requestOpenAI2PaLM(textRequest GeneralOpenAIRequest) *PaLMChatRequest { } for _, message := range textRequest.Messages { palmMessage := PaLMChatMessage{ - Content: message.Content, + Content: string(message.Content), } if message.Role == "user" { palmMessage.Author = "0" @@ -76,11 +76,12 @@ func responsePaLM2OpenAI(response *PaLMChatResponse) *OpenAITextResponse { Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), } for i, candidate := range response.Candidates { + content, _ := json.Marshal(candidate.Content) choice := OpenAITextResponseChoice{ Index: i, Message: Message{ Role: "assistant", - Content: candidate.Content, + Content: content, }, FinishReason: "stop", } diff --git a/controller/relay-tencent.go b/controller/relay-tencent.go index 024468b..c96e6d4 100644 --- a/controller/relay-tencent.go +++ b/controller/relay-tencent.go @@ -84,7 +84,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { if message.Role == "system" { messages = append(messages, TencentMessage{ Role: "user", - Content: message.Content, + Content: string(message.Content), }) messages = append(messages, TencentMessage{ Role: "assistant", @@ -93,7 +93,7 @@ func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { continue } messages = append(messages, TencentMessage{ - Content: message.Content, + Content: string(message.Content), Role: message.Role, }) } @@ -119,11 +119,12 @@ func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse { Usage: response.Usage, } if len(response.Choices) > 0 { + content, _ := json.Marshal(response.Choices[0].Messages.Content) choice := OpenAITextResponseChoice{ Index: 0, Message: Message{ Role: "assistant", - Content: response.Choices[0].Messages.Content, + Content: content, }, FinishReason: response.Choices[0].FinishReason, } diff --git a/controller/relay-text.go b/controller/relay-text.go index 2729650..a009267 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -199,9 +199,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } var promptTokens int var completionTokens int + var err error switch relayMode { case RelayModeChatCompletions: - promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model) + promptTokens, err = countTokenMessages(textRequest.Messages, textRequest.Model) + if err != nil { + return errorWrapper(err, "count_token_messages_failed", http.StatusInternalServerError) + } case RelayModeCompletions: promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model) case RelayModeModerations: diff --git a/controller/relay-utils.go b/controller/relay-utils.go index 40aa547..177d853 100644 --- a/controller/relay-utils.go +++ b/controller/relay-utils.go @@ -63,7 +63,8 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { return len(tokenEncoder.Encode(text, nil, nil)) } -func countTokenMessages(messages []Message, model string) int { +func countTokenMessages(messages []Message, model string) (int, error) { + //recover when panic tokenEncoder := getTokenEncoder(model) // Reference: // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb @@ -82,15 +83,33 @@ func countTokenMessages(messages []Message, model string) int { tokenNum := 0 for _, message := range messages { tokenNum += tokensPerMessage - tokenNum += getTokenNum(tokenEncoder, message.Content) tokenNum += getTokenNum(tokenEncoder, message.Role) - if message.Name != nil { - tokenNum += tokensPerName - tokenNum += getTokenNum(tokenEncoder, *message.Name) + var arrayContent []MediaMessage + if err := json.Unmarshal(message.Content, &arrayContent); err != nil { + + var stringContent string + if err := json.Unmarshal(message.Content, &stringContent); err != nil { + return 0, err + } else { + tokenNum += getTokenNum(tokenEncoder, stringContent) + if message.Name != nil { + tokenNum += tokensPerName + tokenNum += getTokenNum(tokenEncoder, *message.Name) + } + } + } else { + for _, m := range arrayContent { + if m.Type == "image_url" { + //TODO: getImageToken + tokenNum += 1000 + } else { + tokenNum += getTokenNum(tokenEncoder, m.Text) + } + } } } tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|> - return tokenNum + return tokenNum, nil } func countTokenInput(input any, model string) int { diff --git a/controller/relay-xunfei.go b/controller/relay-xunfei.go index 91fb604..33383d8 100644 --- a/controller/relay-xunfei.go +++ b/controller/relay-xunfei.go @@ -81,7 +81,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma if message.Role == "system" { messages = append(messages, XunfeiMessage{ Role: "user", - Content: message.Content, + Content: string(message.Content), }) messages = append(messages, XunfeiMessage{ Role: "assistant", @@ -90,7 +90,7 @@ func requestOpenAI2Xunfei(request GeneralOpenAIRequest, xunfeiAppId string, doma } else { messages = append(messages, XunfeiMessage{ Role: message.Role, - Content: message.Content, + Content: string(message.Content), }) } } @@ -112,11 +112,12 @@ func responseXunfei2OpenAI(response *XunfeiChatResponse) *OpenAITextResponse { }, } } + content, _ := json.Marshal(response.Payload.Choices.Text[0].Content) choice := OpenAITextResponseChoice{ Index: 0, Message: Message{ Role: "assistant", - Content: response.Payload.Choices.Text[0].Content, + Content: content, }, FinishReason: stopFinishReason, } diff --git a/controller/relay-zhipu.go b/controller/relay-zhipu.go index 7a4a582..5ad4151 100644 --- a/controller/relay-zhipu.go +++ b/controller/relay-zhipu.go @@ -114,7 +114,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { if message.Role == "system" { messages = append(messages, ZhipuMessage{ Role: "system", - Content: message.Content, + Content: string(message.Content), }) messages = append(messages, ZhipuMessage{ Role: "user", @@ -123,7 +123,7 @@ func requestOpenAI2Zhipu(request GeneralOpenAIRequest) *ZhipuRequest { } else { messages = append(messages, ZhipuMessage{ Role: message.Role, - Content: message.Content, + Content: string(message.Content), }) } } @@ -144,11 +144,12 @@ func responseZhipu2OpenAI(response *ZhipuResponse) *OpenAITextResponse { Usage: response.Data.Usage, } for i, choice := range response.Data.Choices { + content, _ := json.Marshal(strings.Trim(choice.Content, "\"")) openaiChoice := OpenAITextResponseChoice{ Index: i, Message: Message{ Role: choice.Role, - Content: strings.Trim(choice.Content, "\""), + Content: content, }, FinishReason: "", } diff --git a/controller/relay.go b/controller/relay.go index 06ef341..714910c 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -1,6 +1,7 @@ package controller import ( + "encoding/json" "fmt" "log" "net/http" @@ -12,9 +13,20 @@ import ( ) type Message struct { - Role string `json:"role"` - Content string `json:"content"` - Name *string `json:"name,omitempty"` + Role string `json:"role"` + Content json.RawMessage `json:"content"` + Name *string `json:"name,omitempty"` +} + +type MediaMessage struct { + Type string `json:"type"` + Text string `json:"text"` + ImageUrl MessageImageUrl `json:"image_url,omitempty"` +} + +type MessageImageUrl struct { + Url string `json:"url"` + Detail string `json:"detail"` } const (