From 660b9b3c9907bcb08e0a887f22d9493a11a14723 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Thu, 4 Apr 2024 17:28:56 +0800 Subject: [PATCH] feat: able to set default test model (#138) --- controller/channel-test.go | 9 ++++++--- middleware/distributor.go | 2 +- model/channel.go | 1 + web/src/pages/Channel/EditChannel.js | 12 ++++++++++++ 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index a4dcfe9..e407193 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -27,7 +27,6 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr if channel.Type == common.ChannelTypeMidjourney { return errors.New("midjourney channel test is not supported"), nil } - common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel)) w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = &http.Request{ @@ -60,12 +59,16 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil } if testModel == "" { - testModel = adaptor.GetModelList()[0] - meta.UpstreamModelName = testModel + if channel.TestModel != nil && *channel.TestModel != "" { + testModel = *channel.TestModel + } else { + testModel = adaptor.GetModelList()[0] + } } request := buildTestRequest() request.Model = testModel meta.UpstreamModelName = testModel + common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel)) adaptor.Init(meta, *request) diff --git a/middleware/distributor.go b/middleware/distributor.go index 4db683f..35cb6df 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -163,7 +163,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode if channel.AutoBan != nil && *channel.AutoBan == 0 { ban = false } - if nil != channel.OpenAIOrganization { + if nil != channel.OpenAIOrganization && "" != *channel.OpenAIOrganization { c.Set("channel_organization", *channel.OpenAIOrganization) } c.Set("auto_ban", ban) diff --git a/model/channel.go b/model/channel.go index b06f578..3e30ad4 100644 --- a/model/channel.go +++ b/model/channel.go @@ -10,6 +10,7 @@ type Channel struct { Type int `json:"type" gorm:"default:0"` Key string `json:"key" gorm:"not null"` OpenAIOrganization *string `json:"openai_organization"` + TestModel *string `json:"test_model"` Status int `json:"status" gorm:"default:1"` Name string `json:"name" gorm:"index"` Weight *uint `json:"weight" gorm:"default:0"` diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 9b98de2..0fe6e2b 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -63,6 +63,7 @@ const EditChannel = (props) => { model_mapping: '', models: [], auto_ban: 1, + test_model: '', groups: ['default'], }; const [batch, setBatch] = useState(false); @@ -669,6 +670,17 @@ const EditChannel = (props) => { }} value={inputs.openai_organization} /> +
+ 默认测试模型: +
+ { + handleInputChange('test_model', value); + }} + value={inputs.test_model} + />