From cf3d894195880ea9e7b9a79fb92047b1966b3170 Mon Sep 17 00:00:00 2001 From: CalciumIon <1808837298@qq.com> Date: Fri, 5 Jul 2024 20:51:25 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E8=AE=B0=E5=BD=95=E6=B8=A0=E9=81=93?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E7=9A=84=E6=B6=88=E8=B4=B9=E6=97=A5=E5=BF=97?= =?UTF-8?q?=20(close=20#334)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- controller/channel-test.go | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/controller/channel-test.go b/controller/channel-test.go index ea82578..1beb5e1 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "io" + "math" "net/http" "net/http/httptest" "net/url" @@ -24,6 +25,7 @@ import ( ) func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) { + tik := time.Now() if channel.Type == common.ChannelTypeMidjourney { return errors.New("midjourney channel test is not supported"), nil } @@ -120,6 +122,25 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr if err != nil { return err, nil } + modelPrice, usePrice := common.GetModelPrice(testModel, false) + modelRatio := common.GetModelRatio(testModel) + completionRatio := common.GetCompletionRatio(testModel) + ratio := modelRatio + quota := 0 + if !usePrice { + quota = usage.PromptTokens + int(math.Round(float64(usage.CompletionTokens)*completionRatio)) + quota = int(math.Round(float64(quota) * ratio)) + if ratio != 0 && quota <= 0 { + quota = 1 + } + } else { + quota = int(modelPrice * common.QuotaPerUnit) + } + tok := time.Now() + milliseconds := tok.Sub(tik).Milliseconds() + consumedTime := float64(milliseconds) / 1000.0 + other := service.GenerateTextOtherInfo(c, meta, modelRatio, 1, completionRatio, modelPrice) + model.RecordConsumeLog(c, 1, channel.Id, usage.PromptTokens, usage.CompletionTokens, testModel, "模型测试", quota, "模型测试", 0, quota, int(consumedTime), false, other) common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody))) return nil, nil } @@ -140,7 +161,7 @@ func buildTestRequest() *dto.GeneralOpenAIRequest { } func TestChannel(c *gin.Context) { - id, err := strconv.Atoi(c.Param("id")) + channelId, err := strconv.Atoi(c.Param("id")) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -148,7 +169,7 @@ func TestChannel(c *gin.Context) { }) return } - channel, err := model.GetChannelById(id, true) + channel, err := model.GetChannelById(channelId, true) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false,