feat: able to query test log

This commit is contained in:
JustSong
2025-01-31 21:23:12 +08:00
parent 4f68f3e1b3
commit fa2a772731
6 changed files with 377 additions and 184 deletions

View File

@@ -2,6 +2,7 @@ package controller
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@@ -15,14 +16,17 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/message"
"github.com/songquanpeng/one-api/middleware"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor"
relay "github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/controller"
"github.com/songquanpeng/one-api/relay/meta"
@@ -35,18 +39,34 @@ func buildTestRequest(model string) *relaymodel.GeneralOpenAIRequest {
model = "gpt-3.5-turbo"
}
testRequest := &relaymodel.GeneralOpenAIRequest{
MaxTokens: 2,
Model: model,
Model: model,
}
testMessage := relaymodel.Message{
Role: "user",
Content: "hi",
Content: config.TestPrompt,
}
testRequest.Messages = append(testRequest.Messages, testMessage)
return testRequest
}
func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (err error, openaiErr *relaymodel.Error) {
func parseTestResponse(resp string) (*openai.TextResponse, string, error) {
var response openai.TextResponse
err := json.Unmarshal([]byte(resp), &response)
if err != nil {
return nil, "", err
}
if len(response.Choices) == 0 {
return nil, "", errors.New("response has no choices")
}
stringContent, ok := response.Choices[0].Content.(string)
if !ok {
return nil, "", errors.New("response content is not string")
}
return &response, stringContent, nil
}
func testChannel(ctx context.Context, channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (responseMessage string, err error, openaiErr *relaymodel.Error) {
startTime := time.Now()
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = &http.Request{
@@ -66,7 +86,7 @@ func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIReques
apiType := channeltype.ToAPIType(channel.Type)
adaptor := relay.GetAdaptor(apiType)
if adaptor == nil {
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
return "", fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
}
adaptor.Init(meta)
modelName := request.Model
@@ -84,41 +104,69 @@ func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIReques
request.Model = modelName
convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request)
if err != nil {
return err, nil
return "", err, nil
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return err, nil
return "", err, nil
}
defer func() {
logContent := fmt.Sprintf("渠道 %s 测试成功,响应:%s", channel.Name, responseMessage)
if err != nil || openaiErr != nil {
errorMessage := ""
if err != nil {
errorMessage = err.Error()
} else {
errorMessage = openaiErr.Message
}
logContent = fmt.Sprintf("渠道 %s 测试失败,错误:%s", channel.Name, errorMessage)
}
go model.RecordTestLog(ctx, &model.Log{
ChannelId: channel.Id,
ModelName: modelName,
Content: logContent,
ElapsedTime: helper.CalcElapsedTime(startTime),
})
}()
logger.SysLog(string(jsonData))
requestBody := bytes.NewBuffer(jsonData)
c.Request.Body = io.NopCloser(requestBody)
resp, err := adaptor.DoRequest(c, meta, requestBody)
if err != nil {
return err, nil
return "", err, nil
}
if resp != nil && resp.StatusCode != http.StatusOK {
err := controller.RelayErrorHandler(resp)
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
errorMessage := err.Error.Message
if errorMessage != "" {
errorMessage = ", error message: " + errorMessage
}
return "", fmt.Errorf("http status code: %d%s", resp.StatusCode, errorMessage), &err.Error
}
usage, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
return "", fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
}
if usage == nil {
return errors.New("usage is nil"), nil
return "", errors.New("usage is nil"), nil
}
rawResponse := w.Body.String()
_, responseMessage, err = parseTestResponse(rawResponse)
if err != nil {
return "", err, nil
}
result := w.Result()
// print result.Body
respBody, err := io.ReadAll(result.Body)
if err != nil {
return err, nil
return "", err, nil
}
logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
return nil, nil
return responseMessage, nil, nil
}
func TestChannel(c *gin.Context) {
ctx := c.Request.Context()
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -135,10 +183,10 @@ func TestChannel(c *gin.Context) {
})
return
}
model := c.Query("model")
testRequest := buildTestRequest(model)
modelName := c.Query("model")
testRequest := buildTestRequest(modelName)
tik := time.Now()
err, _ = testChannel(channel, testRequest)
responseMessage, err, _ := testChannel(ctx, channel, testRequest)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
if err != nil {
@@ -148,18 +196,18 @@ func TestChannel(c *gin.Context) {
consumedTime := float64(milliseconds) / 1000.0
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
"time": consumedTime,
"model": model,
"success": false,
"message": err.Error(),
"time": consumedTime,
"modelName": modelName,
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"time": consumedTime,
"model": model,
"success": true,
"message": responseMessage,
"time": consumedTime,
"modelName": modelName,
})
return
}
@@ -167,7 +215,7 @@ func TestChannel(c *gin.Context) {
var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
func testChannels(notify bool, scope string) error {
func testChannels(ctx context.Context, notify bool, scope string) error {
if config.RootUserEmail == "" {
config.RootUserEmail = model.GetRootUserEmail()
}
@@ -191,7 +239,7 @@ func testChannels(notify bool, scope string) error {
isChannelEnabled := channel.Status == model.ChannelStatusEnabled
tik := time.Now()
testRequest := buildTestRequest("")
err, openaiErr := testChannel(channel, testRequest)
_, err, openaiErr := testChannel(ctx, channel, testRequest)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
if isChannelEnabled && milliseconds > disableThreshold {
@@ -225,11 +273,12 @@ func testChannels(notify bool, scope string) error {
}
func TestChannels(c *gin.Context) {
ctx := c.Request.Context()
scope := c.Query("scope")
if scope == "" {
scope = "all"
}
err := testChannels(true, scope)
err := testChannels(ctx, true, scope)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -245,10 +294,11 @@ func TestChannels(c *gin.Context) {
}
func AutomaticallyTestChannels(frequency int) {
ctx := context.Background()
for {
time.Sleep(time.Duration(frequency) * time.Minute)
logger.SysLog("testing all channels")
_ = testChannels(false, "all")
_ = testChannels(ctx, false, "all")
logger.SysLog("channel test finished")
}
}