diff --git a/controller/channel-test.go b/controller/channel-test.go index b20bbeb..1409e41 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -34,14 +34,18 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai case common.ChannelTypeXunfei: return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil case common.ChannelTypeAzure: - request.Model = "gpt-35-turbo" + if request.Model == "" { + request.Model = "gpt-35-turbo" + } defer func() { if err != nil { err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!") } }() default: - request.Model = "gpt-3.5-turbo" + if request.Model == "" { + request.Model = "gpt-3.5-turbo" + } } requestURL := getFullRequestURL(channel.GetBaseURL(), "/v1/chat/completions", channel.Type) @@ -102,6 +106,7 @@ func TestChannel(c *gin.Context) { }) return } + testModel := c.Param("model") channel, err := model.GetChannelById(id, true) if err != nil { c.JSON(http.StatusOK, gin.H{ @@ -111,6 +116,9 @@ func TestChannel(c *gin.Context) { return } testRequest := buildTestRequest() + if testModel != "" { + testRequest.Model = testModel + } tik := time.Now() err, _ = testChannel(channel, *testRequest) tok := time.Now() diff --git a/web/src/components/ChannelsTable.js b/web/src/components/ChannelsTable.js index 5611743..47b5131 100644 --- a/web/src/components/ChannelsTable.js +++ b/web/src/components/ChannelsTable.js @@ -23,9 +23,10 @@ import { Space, Tooltip, Switch, - Typography, InputNumber + Typography, InputNumber, Dropdown, SplitButtonGroup } from "@douyinfe/semi-ui"; import EditChannel from "../pages/Channel/EditChannel"; +import {IconTreeTriangleDown} from "@douyinfe/semi-icons"; function renderTimestamp(timestamp) { return ( @@ -195,7 +196,14 @@ const ChannelsTable = () => { dataIndex: 'operate', render: (text, record, index) => (
- + + + + + + + {/**/} { const setChannelFormat = (channels) => { for (let i = 0; i < channels.length; i++) { channels[i].key = '' + channels[i].id; + let test_models = [] + channels[i].models.split(',').forEach((item, index) => { + test_models.push({ + node: 'item', + name: item, + onClick: () => { + testChannel(channels[i], item) + } + }) + }) + channels[i].test_models = test_models } // data.key = '' + data.id setChannels(channels); @@ -440,14 +459,15 @@ const ChannelsTable = () => { setSearching(false); }; - const testChannel = async (record) => { - const res = await API.get(`/api/channel/test/${record.id}/`); + const testChannel = async (record, model) => { + const res = await API.get(`/api/channel/test/${record.id}?model=${model}`); const {success, message, time} = res.data; if (success) { let newChannels = [...channels]; record.response_time = time * 1000; record.test_time = Date.now() / 1000; - setChannels(newChannels); + + setChannelFormat(newChannels) showInfo(`通道 ${record.name} 测试成功,耗时 ${time.toFixed(2)} 秒。`); } else { showError(message);