From fa2a7727316c8d8843e046fb82a1049d33578325 Mon Sep 17 00:00:00 2001 From: JustSong Date: Fri, 31 Jan 2025 21:23:12 +0800 Subject: [PATCH] feat: able to query test log --- README.md | 1 + common/config/config.go | 4 +- controller/channel-test.go | 108 +++++--- model/log.go | 7 + web/default/src/components/ChannelsTable.js | 269 +++++++++++++------- web/default/src/components/LogsTable.js | 172 ++++++++----- 6 files changed, 377 insertions(+), 184 deletions(-) diff --git a/README.md b/README.md index 853ec067..37781eb7 100644 --- a/README.md +++ b/README.md @@ -410,6 +410,7 @@ graph LR 27. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。 28. `INITIAL_ROOT_ACCESS_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量的 root 用户创建系统管理令牌。 29. `ENFORCE_INCLUDE_USAGE`:是否强制在 stream 模型下返回 usage,默认不开启,可选值为 `true` 和 `false`。 +30. `TEST_PROMPT`:测试模型时的用户 prompt,默认为 `Print your model name exactly and do not output without any other text.`。 ### 命令行参数 1. `--port `: 指定服务器监听的端口号,默认为 `3000`。 diff --git a/common/config/config.go b/common/config/config.go index 2eb894ef..2cd14363 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -1,13 +1,14 @@ package config import ( - "github.com/songquanpeng/one-api/common/env" "os" "strconv" "strings" "sync" "time" + "github.com/songquanpeng/one-api/common/env" + "github.com/google/uuid" ) @@ -162,3 +163,4 @@ var UserContentRequestProxy = env.String("USER_CONTENT_REQUEST_PROXY", "") var UserContentRequestTimeout = env.Int("USER_CONTENT_REQUEST_TIMEOUT", 30) var EnforceIncludeUsage = env.Bool("ENFORCE_INCLUDE_USAGE", false) +var TestPrompt = env.String("TEST_PROMPT", "Print your model name exactly and do not output without any other text.") diff --git a/controller/channel-test.go b/controller/channel-test.go index 971f5382..c24ad971 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -2,6 +2,7 @@ package controller import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -15,14 +16,17 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/message" "github.com/songquanpeng/one-api/middleware" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/monitor" - relay "github.com/songquanpeng/one-api/relay" + "github.com/songquanpeng/one-api/relay" + "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/controller" "github.com/songquanpeng/one-api/relay/meta" @@ -35,18 +39,34 @@ func buildTestRequest(model string) *relaymodel.GeneralOpenAIRequest { model = "gpt-3.5-turbo" } testRequest := &relaymodel.GeneralOpenAIRequest{ - MaxTokens: 2, - Model: model, + Model: model, } testMessage := relaymodel.Message{ Role: "user", - Content: "hi", + Content: config.TestPrompt, } testRequest.Messages = append(testRequest.Messages, testMessage) return testRequest } -func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (err error, openaiErr *relaymodel.Error) { +func parseTestResponse(resp string) (*openai.TextResponse, string, error) { + var response openai.TextResponse + err := json.Unmarshal([]byte(resp), &response) + if err != nil { + return nil, "", err + } + if len(response.Choices) == 0 { + return nil, "", errors.New("response has no choices") + } + stringContent, ok := response.Choices[0].Content.(string) + if !ok { + return nil, "", errors.New("response content is not string") + } + return &response, stringContent, nil +} + +func testChannel(ctx context.Context, channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (responseMessage string, err error, openaiErr *relaymodel.Error) { + startTime := time.Now() w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.Request = &http.Request{ @@ -66,7 +86,7 @@ func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIReques apiType := channeltype.ToAPIType(channel.Type) adaptor := relay.GetAdaptor(apiType) if adaptor == nil { - return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil + return "", fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil } adaptor.Init(meta) modelName := request.Model @@ -84,41 +104,69 @@ func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIReques request.Model = modelName convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request) if err != nil { - return err, nil + return "", err, nil } jsonData, err := json.Marshal(convertedRequest) if err != nil { - return err, nil + return "", err, nil } + defer func() { + logContent := fmt.Sprintf("渠道 %s 测试成功,响应:%s", channel.Name, responseMessage) + if err != nil || openaiErr != nil { + errorMessage := "" + if err != nil { + errorMessage = err.Error() + } else { + errorMessage = openaiErr.Message + } + logContent = fmt.Sprintf("渠道 %s 测试失败,错误:%s", channel.Name, errorMessage) + } + go model.RecordTestLog(ctx, &model.Log{ + ChannelId: channel.Id, + ModelName: modelName, + Content: logContent, + ElapsedTime: helper.CalcElapsedTime(startTime), + }) + }() logger.SysLog(string(jsonData)) requestBody := bytes.NewBuffer(jsonData) c.Request.Body = io.NopCloser(requestBody) resp, err := adaptor.DoRequest(c, meta, requestBody) if err != nil { - return err, nil + return "", err, nil } if resp != nil && resp.StatusCode != http.StatusOK { err := controller.RelayErrorHandler(resp) - return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error + errorMessage := err.Error.Message + if errorMessage != "" { + errorMessage = ", error message: " + errorMessage + } + return "", fmt.Errorf("http status code: %d%s", resp.StatusCode, errorMessage), &err.Error } usage, respErr := adaptor.DoResponse(c, resp, meta) if respErr != nil { - return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error + return "", fmt.Errorf("%s", respErr.Error.Message), &respErr.Error } if usage == nil { - return errors.New("usage is nil"), nil + return "", errors.New("usage is nil"), nil + } + rawResponse := w.Body.String() + _, responseMessage, err = parseTestResponse(rawResponse) + if err != nil { + return "", err, nil } result := w.Result() // print result.Body respBody, err := io.ReadAll(result.Body) if err != nil { - return err, nil + return "", err, nil } logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) - return nil, nil + return responseMessage, nil, nil } func TestChannel(c *gin.Context) { + ctx := c.Request.Context() id, err := strconv.Atoi(c.Param("id")) if err != nil { c.JSON(http.StatusOK, gin.H{ @@ -135,10 +183,10 @@ func TestChannel(c *gin.Context) { }) return } - model := c.Query("model") - testRequest := buildTestRequest(model) + modelName := c.Query("model") + testRequest := buildTestRequest(modelName) tik := time.Now() - err, _ = testChannel(channel, testRequest) + responseMessage, err, _ := testChannel(ctx, channel, testRequest) tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() if err != nil { @@ -148,18 +196,18 @@ func TestChannel(c *gin.Context) { consumedTime := float64(milliseconds) / 1000.0 if err != nil { c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - "time": consumedTime, - "model": model, + "success": false, + "message": err.Error(), + "time": consumedTime, + "modelName": modelName, }) return } c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "", - "time": consumedTime, - "model": model, + "success": true, + "message": responseMessage, + "time": consumedTime, + "modelName": modelName, }) return } @@ -167,7 +215,7 @@ func TestChannel(c *gin.Context) { var testAllChannelsLock sync.Mutex var testAllChannelsRunning bool = false -func testChannels(notify bool, scope string) error { +func testChannels(ctx context.Context, notify bool, scope string) error { if config.RootUserEmail == "" { config.RootUserEmail = model.GetRootUserEmail() } @@ -191,7 +239,7 @@ func testChannels(notify bool, scope string) error { isChannelEnabled := channel.Status == model.ChannelStatusEnabled tik := time.Now() testRequest := buildTestRequest("") - err, openaiErr := testChannel(channel, testRequest) + _, err, openaiErr := testChannel(ctx, channel, testRequest) tok := time.Now() milliseconds := tok.Sub(tik).Milliseconds() if isChannelEnabled && milliseconds > disableThreshold { @@ -225,11 +273,12 @@ func testChannels(notify bool, scope string) error { } func TestChannels(c *gin.Context) { + ctx := c.Request.Context() scope := c.Query("scope") if scope == "" { scope = "all" } - err := testChannels(true, scope) + err := testChannels(ctx, true, scope) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -245,10 +294,11 @@ func TestChannels(c *gin.Context) { } func AutomaticallyTestChannels(frequency int) { + ctx := context.Background() for { time.Sleep(time.Duration(frequency) * time.Minute) logger.SysLog("testing all channels") - _ = testChannels(false, "all") + _ = testChannels(ctx, false, "all") logger.SysLog("channel test finished") } } diff --git a/model/log.go b/model/log.go index 17525500..2c920652 100644 --- a/model/log.go +++ b/model/log.go @@ -37,6 +37,7 @@ const ( LogTypeConsume LogTypeManage LogTypeSystem + LogTypeTest ) func recordLogHelper(ctx context.Context, log *Log) { @@ -86,6 +87,12 @@ func RecordConsumeLog(ctx context.Context, log *Log) { recordLogHelper(ctx, log) } +func RecordTestLog(ctx context.Context, log *Log) { + log.CreatedAt = helper.GetTimestamp() + log.Type = LogTypeTest + recordLogHelper(ctx, log) +} + func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) { var tx *gorm.DB if logType == LogTypeUnknown { diff --git a/web/default/src/components/ChannelsTable.js b/web/default/src/components/ChannelsTable.js index e745814b..1d7f41b2 100644 --- a/web/default/src/components/ChannelsTable.js +++ b/web/default/src/components/ChannelsTable.js @@ -1,5 +1,15 @@ import React, { useEffect, useState } from 'react'; -import { Button, Dropdown, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react'; +import { + Button, + Dropdown, + Form, + Input, + Label, + Message, + Pagination, + Popup, + Table, +} from 'semantic-ui-react'; import { Link } from 'react-router-dom'; import { API, @@ -9,31 +19,31 @@ import { showError, showInfo, showSuccess, - timestamp2string + timestamp2string, } from '../helpers'; import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; import { renderGroup, renderNumber } from '../helpers/render'; function renderTimestamp(timestamp) { - return ( - <> - {timestamp2string(timestamp)} - - ); + return <>{timestamp2string(timestamp)}; } let type2label = undefined; function renderType(type) { if (!type2label) { - type2label = new Map; + type2label = new Map(); for (let i = 0; i < CHANNEL_OPTIONS.length; i++) { type2label[CHANNEL_OPTIONS[i].value] = CHANNEL_OPTIONS[i]; } type2label[0] = { value: 0, text: '未知类型', color: 'grey' }; } - return ; + return ( + + ); } function renderBalance(type, balance) { @@ -62,10 +72,10 @@ function renderBalance(type, balance) { } function isShowDetail() { - return localStorage.getItem("show_detail") === "true"; + return localStorage.getItem('show_detail') === 'true'; } -const promptID = "detail" +const promptID = 'detail'; const ChannelsTable = () => { const [channels, setChannels] = useState([]); @@ -81,33 +91,37 @@ const ChannelsTable = () => { const res = await API.get(`/api/channel/?p=${startIdx}`); const { success, message, data } = res.data; if (success) { - let localChannels = data.map((channel) => { - if (channel.models === '') { - channel.models = []; - channel.test_model = ""; - } else { - channel.models = channel.models.split(','); - if (channel.models.length > 0) { - channel.test_model = channel.models[0]; - } - channel.model_options = channel.models.map((model) => { - return { - key: model, - text: model, - value: model, - } - }) - console.log('channel', channel) - } - return channel; - }); - if (startIdx === 0) { - setChannels(localChannels); + let localChannels = data.map((channel) => { + if (channel.models === '') { + channel.models = []; + channel.test_model = ''; } else { - let newChannels = [...channels]; - newChannels.splice(startIdx * ITEMS_PER_PAGE, data.length, ...localChannels); - setChannels(newChannels); + channel.models = channel.models.split(','); + if (channel.models.length > 0) { + channel.test_model = channel.models[0]; + } + channel.model_options = channel.models.map((model) => { + return { + key: model, + text: model, + value: model, + }; + }); + console.log('channel', channel); } + return channel; + }); + if (startIdx === 0) { + setChannels(localChannels); + } else { + let newChannels = [...channels]; + newChannels.splice( + startIdx * ITEMS_PER_PAGE, + data.length, + ...localChannels + ); + setChannels(newChannels); + } } else { showError(message); } @@ -131,8 +145,8 @@ const ChannelsTable = () => { const toggleShowDetail = () => { setShowDetail(!showDetail); - localStorage.setItem("show_detail", (!showDetail).toString()); - } + localStorage.setItem('show_detail', (!showDetail).toString()); + }; useEffect(() => { loadChannels(0) @@ -196,13 +210,19 @@ const ChannelsTable = () => { const renderStatus = (status) => { switch (status) { case 1: - return ; + return ( + + ); case 2: return ( - 已禁用 - } + trigger={ + + } content='本渠道被手动禁用' basic /> @@ -210,9 +230,11 @@ const ChannelsTable = () => { case 3: return ( - 已禁用 - } + trigger={ + + } content='本渠道被程序自动禁用' basic /> @@ -230,15 +252,35 @@ const ChannelsTable = () => { let time = responseTime / 1000; time = time.toFixed(2) + ' 秒'; if (responseTime === 0) { - return ; + return ( + + ); } else if (responseTime <= 1000) { - return ; + return ( + + ); } else if (responseTime <= 3000) { - return ; + return ( + + ); } else if (responseTime <= 5000) { - return ; + return ( + + ); } else { - return ; + return ( + + ); } }; @@ -277,7 +319,11 @@ const ChannelsTable = () => { newChannels[realIdx].response_time = time * 1000; newChannels[realIdx].test_time = Date.now() / 1000; setChannels(newChannels); - showInfo(`渠道 ${name} 测试成功,模型 ${model},耗时 ${time.toFixed(2)} 秒。`); + showInfo( + `渠道 ${name} 测试成功,模型 ${model},耗时 ${time.toFixed( + 2 + )} 秒,模型输出:${message}` + ); } else { showError(message); } @@ -360,7 +406,6 @@ const ChannelsTable = () => { setLoading(false); }; - return ( <>
@@ -374,20 +419,22 @@ const ChannelsTable = () => { onChange={handleKeywordChange} />
- { - showPrompt && ( - { + {showPrompt && ( + { setShowPrompt(false); setPromptShown(promptID); - }}> - OpenAI 渠道已经不再支持通过 key 获取余额,因此余额显示为 0。对于支持的渠道类型,请点击余额进行刷新。 -
- 渠道测试仅支持 chat 模型,优先使用 gpt-3.5-turbo,如果该模型不可用则使用你所配置的模型列表中的第一个模型。 -
- 点击下方详情按钮可以显示余额以及设置额外的测试模型。 -
- ) - } + }} + > + OpenAI 渠道已经不再支持通过 key 获取余额,因此余额显示为 + 0。对于支持的渠道类型,请点击余额进行刷新。 +
+ 渠道测试仅支持 chat 模型,优先使用 + gpt-3.5-turbo,如果该模型不可用则使用你所配置的模型列表中的第一个模型。 +
+ 点击下方详情按钮可以显示余额以及设置额外的测试模型。 +
+ )} @@ -478,7 +525,11 @@ const ChannelsTable = () => { {renderStatus(channel.status)} { { - manageChannel( - channel.id, - 'priority', - idx, - event.target.value - ); - }}> - - } + trigger={ + { + manageChannel( + channel.id, + 'priority', + idx, + event.target.value + ); + }} + > + + + } content='渠道选择优先级,越高越优先' basic /> @@ -528,7 +590,12 @@ const ChannelsTable = () => { size={'small'} positive onClick={() => { - testChannel(channel.id, channel.name, idx, channel.test_model); + testChannel( + channel.id, + channel.name, + idx, + channel.test_model + ); }} > 测试 @@ -590,14 +657,31 @@ const ChannelsTable = () => { - - - - {/* @@ -627,8 +716,12 @@ const ChannelsTable = () => { (channels.length % ITEMS_PER_PAGE === 0 ? 1 : 0) } /> - - + + diff --git a/web/default/src/components/LogsTable.js b/web/default/src/components/LogsTable.js index 12b8dc60..1ae9fd6e 100644 --- a/web/default/src/components/LogsTable.js +++ b/web/default/src/components/LogsTable.js @@ -21,6 +21,7 @@ import { import { ITEMS_PER_PAGE } from '../constants'; import { renderColorLabel, renderQuota } from '../helpers/render'; +import { Link } from 'react-router-dom'; function renderTimestamp(timestamp, request_id) { return ( @@ -50,6 +51,7 @@ const LOG_OPTIONS = [ { key: '2', text: '消费', value: 2 }, { key: '3', text: '管理', value: 3 }, { key: '4', text: '系统', value: 4 }, + { key: '5', text: '测试', value: 5 }, ]; function renderType(type) { @@ -78,6 +80,12 @@ function renderType(type) { 系统 ); + case 5: + return ( + + ); default: return (