one-api/controller/channel-test.go
2025-02-16 13:18:50 +00:00

310 lines
9.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package controller
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/message"
"github.com/songquanpeng/one-api/middleware"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/controller"
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
)
func buildTestRequest(model string) *relaymodel.GeneralOpenAIRequest {
if model == "" {
model = "gpt-3.5-turbo"
}
testRequest := &relaymodel.GeneralOpenAIRequest{
Model: model,
}
testMessage := relaymodel.Message{
Role: "user",
Content: config.TestPrompt,
}
testRequest.Messages = append(testRequest.Messages, testMessage)
return testRequest
}
func parseTestResponse(resp string) (*openai.TextResponse, string, error) {
var response openai.TextResponse
err := json.Unmarshal([]byte(resp), &response)
if err != nil {
return nil, "", err
}
if len(response.Choices) == 0 {
return nil, "", errors.New("response has no choices")
}
stringContent, ok := response.Choices[0].Content.(string)
if !ok {
return nil, "", errors.New("response content is not string")
}
return &response, stringContent, nil
}
func testChannel(ctx context.Context, channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (responseMessage string, err error, openaiErr *relaymodel.Error) {
startTime := time.Now()
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = &http.Request{
Method: http.MethodPost,
URL: &url.URL{Path: "/v1/chat/completions"},
Body: nil,
Header: make(http.Header),
}
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
c.Request.Header.Set("Content-Type", "application/json")
c.Set(ctxkey.Channel, channel.Type)
c.Set(ctxkey.BaseURL, channel.GetBaseURL())
cfg, _ := channel.LoadConfig()
c.Set(ctxkey.Config, cfg)
middleware.SetupContextForSelectedChannel(c, channel, "")
meta := meta.GetByContext(c)
apiType := channeltype.ToAPIType(channel.Type)
adaptor := relay.GetAdaptor(apiType)
if adaptor == nil {
return "", fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
}
adaptor.Init(meta)
modelName := request.Model
modelMap := channel.GetModelMapping()
if modelName == "" || !strings.Contains(channel.Models, modelName) {
modelNames := strings.Split(channel.Models, ",")
if len(modelNames) > 0 {
modelName = modelNames[0]
}
}
if modelMap != nil && modelMap[modelName] != "" {
modelName = modelMap[modelName]
}
meta.OriginModelName, meta.ActualModelName = request.Model, modelName
request.Model = modelName
convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request)
if err != nil {
return "", err, nil
}
jsonData, err := json.Marshal(convertedRequest)
if err != nil {
return "", err, nil
}
defer func() {
logContent := fmt.Sprintf("渠道 %s 测试成功,响应:%s", channel.Name, responseMessage)
if err != nil || openaiErr != nil {
errorMessage := ""
if err != nil {
errorMessage = err.Error()
} else {
errorMessage = openaiErr.Message
}
logContent = fmt.Sprintf("渠道 %s 测试失败,错误:%s", channel.Name, errorMessage)
}
go model.RecordTestLog(ctx, &model.Log{
ChannelId: channel.Id,
ModelName: modelName,
Content: logContent,
ElapsedTime: helper.CalcElapsedTime(startTime),
})
}()
logger.SysLog(string(jsonData))
requestBody := bytes.NewBuffer(jsonData)
c.Request.Body = io.NopCloser(requestBody)
var resp *http.Response
resp, err = adaptor.DoRequest(c, meta, requestBody)
if err != nil {
return "", err, nil
}
if resp != nil && resp.StatusCode != http.StatusOK {
wrappedErr := controller.RelayErrorHandler(resp)
errorMessage := wrappedErr.Error.Message
if errorMessage != "" {
errorMessage = ", error message: " + errorMessage
}
err = fmt.Errorf("http status code: %d%s", resp.StatusCode, errorMessage)
return "", err, &wrappedErr.Error
}
usage, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
err = fmt.Errorf("%s", respErr.Error.Message)
return "", err, &respErr.Error
}
if usage == nil {
err = errors.New("usage is nil")
return "", err, nil
}
rawResponse := w.Body.String()
_, responseMessage, err = parseTestResponse(rawResponse)
if err != nil {
logger.SysError(fmt.Sprintf("failed to parse error: %s, \nresponse: %s", err.Error(), rawResponse))
return "", err, nil
}
result := w.Result()
// print result.Body
var respBody []byte
respBody, err = io.ReadAll(result.Body)
if err != nil {
return "", err, nil
}
logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
return responseMessage, nil, nil
}
func TestChannel(c *gin.Context) {
ctx := c.Request.Context()
id, err := strconv.Atoi(c.Param("id"))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
channel, err := model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
modelName := c.Query("model")
testRequest := buildTestRequest(modelName)
tik := time.Now()
responseMessage, err, _ := testChannel(ctx, channel, testRequest)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
if err != nil {
milliseconds = 0
}
go channel.UpdateResponseTime(milliseconds)
consumedTime := float64(milliseconds) / 1000.0
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
"time": consumedTime,
"modelName": modelName,
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": responseMessage,
"time": consumedTime,
"modelName": modelName,
})
return
}
var testAllChannelsLock sync.Mutex
var testAllChannelsRunning bool = false
func testChannels(ctx context.Context, notify bool, scope string) error {
if config.RootUserEmail == "" {
config.RootUserEmail = model.GetRootUserEmail()
}
testAllChannelsLock.Lock()
if testAllChannelsRunning {
testAllChannelsLock.Unlock()
return errors.New("Test is already running")
}
testAllChannelsRunning = true
testAllChannelsLock.Unlock()
channels, err := model.GetAllChannels(0, 0, scope)
if err != nil {
return err
}
var disableThreshold = int64(config.ChannelDisableThreshold * 1000)
if disableThreshold == 0 {
disableThreshold = 10000000 // a impossible value
}
go func() {
for _, channel := range channels {
isChannelEnabled := channel.Status == model.ChannelStatusEnabled
tik := time.Now()
testRequest := buildTestRequest("")
_, err, openaiErr := testChannel(ctx, channel, testRequest)
tok := time.Now()
milliseconds := tok.Sub(tik).Milliseconds()
if isChannelEnabled && milliseconds > disableThreshold {
err = fmt.Errorf("Response time %.2fs exceeds threshold %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)
if config.AutomaticDisableChannelEnabled {
monitor.DisableChannel(channel.Id, channel.Name, err.Error())
} else {
_ = message.Notify(message.ByAll, fmt.Sprintf("Channel %s %dTest超时", channel.Name, channel.Id), "", err.Error())
}
}
if isChannelEnabled && monitor.ShouldDisableChannel(openaiErr, -1) {
monitor.DisableChannel(channel.Id, channel.Name, err.Error())
}
if !isChannelEnabled && monitor.ShouldEnableChannel(err, openaiErr) {
monitor.EnableChannel(channel.Id, channel.Name)
}
channel.UpdateResponseTime(milliseconds)
time.Sleep(config.RequestInterval)
}
testAllChannelsLock.Lock()
testAllChannelsRunning = false
testAllChannelsLock.Unlock()
if notify {
err := message.Notify(message.ByAll, "Channel test completed", "", "Channel test completed, if you have not received the disable notification, it means that all channels are normal")
if err != nil {
logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}
}
}()
return nil
}
func TestChannels(c *gin.Context) {
ctx := c.Request.Context()
scope := c.Query("scope")
if scope == "" {
scope = "all"
}
err := testChannels(ctx, true, scope)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}
func AutomaticallyTestChannels(frequency int) {
ctx := context.Background()
for {
time.Sleep(time.Duration(frequency) * time.Minute)
logger.SysLog("testing all channels")
_ = testChannels(ctx, false, "all")
logger.SysLog("channel test finished")
}
}