add custom test model

This commit is contained in:
Martial BE
2023-12-29 16:23:25 +08:00
parent 61c47a3b08
commit 9c0a49b97a
4 changed files with 63 additions and 31 deletions

View File

@@ -18,6 +18,10 @@ import (
)
func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (err error, openaiErr *types.OpenAIError) {
if channel.TestModel == "" {
return errors.New("请填写测速模型后再试"), nil
}
// 创建一个 http.Request
req, err := http.NewRequest("POST", "/v1/chat/completions", nil)
if err != nil {
@@ -28,26 +32,7 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = req
// 创建映射
channelTypeToModel := map[int]string{
common.ChannelTypePaLM: "PaLM-2",
common.ChannelTypeAnthropic: "claude-2",
common.ChannelTypeBaidu: "ERNIE-Bot",
common.ChannelTypeZhipu: "chatglm_lite",
common.ChannelTypeAli: "qwen-turbo",
common.ChannelType360: "360GPT_S2_V9",
common.ChannelTypeXunfei: "SparkDesk",
common.ChannelTypeTencent: "hunyuan",
common.ChannelTypeAzure: "gpt-3.5-turbo",
}
// 从映射中获取模型名称
model, ok := channelTypeToModel[channel.Type]
if !ok {
model = "gpt-3.5-turbo" // 默认值
}
request.Model = model
request.Model = channel.TestModel
provider := providers.GetProvider(channel, c)
if provider == nil {
@@ -69,13 +54,15 @@ func testChannel(channel *model.Channel, request types.ChatCompletionRequest) (e
promptTokens := common.CountTokenMessages(request.Messages, request.Model)
Usage, openAIErrorWithStatusCode := chatProvider.ChatAction(&request, true, promptTokens)
if openAIErrorWithStatusCode != nil {
return nil, &openAIErrorWithStatusCode.OpenAIError
return errors.New(openAIErrorWithStatusCode.Message), &openAIErrorWithStatusCode.OpenAIError
}
if Usage.CompletionTokens == 0 {
return fmt.Errorf("channel %s, message 补全 tokens 非预期返回 0", channel.Name), nil
}
common.SysLog(fmt.Sprintf("测试模型 %s 返回内容为:%s", channel.Name, w.Body.String()))
return nil, nil
}