mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			255 lines
		
	
	
		
			7.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			255 lines
		
	
	
		
			7.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package controller
 | 
						||
 | 
						||
import (
 | 
						||
	"bytes"
 | 
						||
	"encoding/json"
 | 
						||
	"errors"
 | 
						||
	"fmt"
 | 
						||
	"io"
 | 
						||
	"net/http"
 | 
						||
	"net/http/httptest"
 | 
						||
	"net/url"
 | 
						||
	"strconv"
 | 
						||
	"strings"
 | 
						||
	"sync"
 | 
						||
	"time"
 | 
						||
 | 
						||
	"github.com/gin-gonic/gin"
 | 
						||
	"github.com/songquanpeng/one-api/common/config"
 | 
						||
	"github.com/songquanpeng/one-api/common/ctxkey"
 | 
						||
	"github.com/songquanpeng/one-api/common/logger"
 | 
						||
	"github.com/songquanpeng/one-api/common/message"
 | 
						||
	"github.com/songquanpeng/one-api/middleware"
 | 
						||
	"github.com/songquanpeng/one-api/model"
 | 
						||
	"github.com/songquanpeng/one-api/monitor"
 | 
						||
	relay "github.com/songquanpeng/one-api/relay"
 | 
						||
	"github.com/songquanpeng/one-api/relay/channeltype"
 | 
						||
	"github.com/songquanpeng/one-api/relay/controller"
 | 
						||
	"github.com/songquanpeng/one-api/relay/meta"
 | 
						||
	relaymodel "github.com/songquanpeng/one-api/relay/model"
 | 
						||
	"github.com/songquanpeng/one-api/relay/relaymode"
 | 
						||
)
 | 
						||
 | 
						||
func buildTestRequest(model string) *relaymodel.GeneralOpenAIRequest {
 | 
						||
	if model == "" {
 | 
						||
		model = "gpt-3.5-turbo"
 | 
						||
	}
 | 
						||
	testRequest := &relaymodel.GeneralOpenAIRequest{
 | 
						||
		MaxTokens: 2,
 | 
						||
		Model:     model,
 | 
						||
	}
 | 
						||
	testMessage := relaymodel.Message{
 | 
						||
		Role:    "user",
 | 
						||
		Content: "hi",
 | 
						||
	}
 | 
						||
	testRequest.Messages = append(testRequest.Messages, testMessage)
 | 
						||
	return testRequest
 | 
						||
}
 | 
						||
 | 
						||
func testChannel(channel *model.Channel, request *relaymodel.GeneralOpenAIRequest) (err error, openaiErr *relaymodel.Error) {
 | 
						||
	w := httptest.NewRecorder()
 | 
						||
	c, _ := gin.CreateTestContext(w)
 | 
						||
	c.Request = &http.Request{
 | 
						||
		Method: "POST",
 | 
						||
		URL:    &url.URL{Path: "/v1/chat/completions"},
 | 
						||
		Body:   nil,
 | 
						||
		Header: make(http.Header),
 | 
						||
	}
 | 
						||
	c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
 | 
						||
	c.Request.Header.Set("Content-Type", "application/json")
 | 
						||
	c.Set(ctxkey.Channel, channel.Type)
 | 
						||
	c.Set(ctxkey.BaseURL, channel.GetBaseURL())
 | 
						||
	cfg, _ := channel.LoadConfig()
 | 
						||
	c.Set(ctxkey.Config, cfg)
 | 
						||
	middleware.SetupContextForSelectedChannel(c, channel, "")
 | 
						||
	meta := meta.GetByContext(c)
 | 
						||
	apiType := channeltype.ToAPIType(channel.Type)
 | 
						||
	adaptor := relay.GetAdaptor(apiType)
 | 
						||
	if adaptor == nil {
 | 
						||
		return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
 | 
						||
	}
 | 
						||
	adaptor.Init(meta)
 | 
						||
	modelName := request.Model
 | 
						||
	modelMap := channel.GetModelMapping()
 | 
						||
	if modelName == "" || !strings.Contains(channel.Models, modelName) {
 | 
						||
		modelNames := strings.Split(channel.Models, ",")
 | 
						||
		if len(modelNames) > 0 {
 | 
						||
			modelName = modelNames[0]
 | 
						||
		}
 | 
						||
		if modelMap != nil && modelMap[modelName] != "" {
 | 
						||
			modelName = modelMap[modelName]
 | 
						||
		}
 | 
						||
	}
 | 
						||
	meta.OriginModelName, meta.ActualModelName = request.Model, modelName
 | 
						||
	request.Model = modelName
 | 
						||
	convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request)
 | 
						||
	if err != nil {
 | 
						||
		return err, nil
 | 
						||
	}
 | 
						||
	jsonData, err := json.Marshal(convertedRequest)
 | 
						||
	if err != nil {
 | 
						||
		return err, nil
 | 
						||
	}
 | 
						||
	logger.SysLog(string(jsonData))
 | 
						||
	requestBody := bytes.NewBuffer(jsonData)
 | 
						||
	c.Request.Body = io.NopCloser(requestBody)
 | 
						||
	resp, err := adaptor.DoRequest(c, meta, requestBody)
 | 
						||
	if err != nil {
 | 
						||
		return err, nil
 | 
						||
	}
 | 
						||
	if resp != nil && resp.StatusCode != http.StatusOK {
 | 
						||
		err := controller.RelayErrorHandler(resp)
 | 
						||
		return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
 | 
						||
	}
 | 
						||
	usage, respErr := adaptor.DoResponse(c, resp, meta)
 | 
						||
	if respErr != nil {
 | 
						||
		return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
 | 
						||
	}
 | 
						||
	if usage == nil {
 | 
						||
		return errors.New("usage is nil"), nil
 | 
						||
	}
 | 
						||
	result := w.Result()
 | 
						||
	// print result.Body
 | 
						||
	respBody, err := io.ReadAll(result.Body)
 | 
						||
	if err != nil {
 | 
						||
		return err, nil
 | 
						||
	}
 | 
						||
	logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
 | 
						||
	return nil, nil
 | 
						||
}
 | 
						||
 | 
						||
func TestChannel(c *gin.Context) {
 | 
						||
	id, err := strconv.Atoi(c.Param("id"))
 | 
						||
	if err != nil {
 | 
						||
		c.JSON(http.StatusOK, gin.H{
 | 
						||
			"success": false,
 | 
						||
			"message": err.Error(),
 | 
						||
		})
 | 
						||
		return
 | 
						||
	}
 | 
						||
	channel, err := model.GetChannelById(id, true)
 | 
						||
	if err != nil {
 | 
						||
		c.JSON(http.StatusOK, gin.H{
 | 
						||
			"success": false,
 | 
						||
			"message": err.Error(),
 | 
						||
		})
 | 
						||
		return
 | 
						||
	}
 | 
						||
	model := c.Query("model")
 | 
						||
	testRequest := buildTestRequest(model)
 | 
						||
	tik := time.Now()
 | 
						||
	err, _ = testChannel(channel, testRequest)
 | 
						||
	tok := time.Now()
 | 
						||
	milliseconds := tok.Sub(tik).Milliseconds()
 | 
						||
	if err != nil {
 | 
						||
		milliseconds = 0
 | 
						||
	}
 | 
						||
	go channel.UpdateResponseTime(milliseconds)
 | 
						||
	consumedTime := float64(milliseconds) / 1000.0
 | 
						||
	if err != nil {
 | 
						||
		c.JSON(http.StatusOK, gin.H{
 | 
						||
			"success": false,
 | 
						||
			"message": err.Error(),
 | 
						||
			"time":    consumedTime,
 | 
						||
			"model":   model,
 | 
						||
		})
 | 
						||
		return
 | 
						||
	}
 | 
						||
	c.JSON(http.StatusOK, gin.H{
 | 
						||
		"success": true,
 | 
						||
		"message": "",
 | 
						||
		"time":    consumedTime,
 | 
						||
		"model":   model,
 | 
						||
	})
 | 
						||
	return
 | 
						||
}
 | 
						||
 | 
						||
var testAllChannelsLock sync.Mutex
 | 
						||
var testAllChannelsRunning bool = false
 | 
						||
 | 
						||
func testChannels(notify bool, scope string) error {
 | 
						||
	if config.RootUserEmail == "" {
 | 
						||
		config.RootUserEmail = model.GetRootUserEmail()
 | 
						||
	}
 | 
						||
	testAllChannelsLock.Lock()
 | 
						||
	if testAllChannelsRunning {
 | 
						||
		testAllChannelsLock.Unlock()
 | 
						||
		return errors.New("测试已在运行中")
 | 
						||
	}
 | 
						||
	testAllChannelsRunning = true
 | 
						||
	testAllChannelsLock.Unlock()
 | 
						||
	channels, err := model.GetAllChannels(0, 0, scope)
 | 
						||
	if err != nil {
 | 
						||
		return err
 | 
						||
	}
 | 
						||
	var disableThreshold = int64(config.ChannelDisableThreshold * 1000)
 | 
						||
	if disableThreshold == 0 {
 | 
						||
		disableThreshold = 10000000 // a impossible value
 | 
						||
	}
 | 
						||
	go func() {
 | 
						||
		for _, channel := range channels {
 | 
						||
			isChannelEnabled := channel.Status == model.ChannelStatusEnabled
 | 
						||
			tik := time.Now()
 | 
						||
			testRequest := buildTestRequest("")
 | 
						||
			err, openaiErr := testChannel(channel, testRequest)
 | 
						||
			tok := time.Now()
 | 
						||
			milliseconds := tok.Sub(tik).Milliseconds()
 | 
						||
			if isChannelEnabled && milliseconds > disableThreshold {
 | 
						||
				err = fmt.Errorf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)
 | 
						||
				if config.AutomaticDisableChannelEnabled {
 | 
						||
					monitor.DisableChannel(channel.Id, channel.Name, err.Error())
 | 
						||
				} else {
 | 
						||
					_ = message.Notify(message.ByAll, fmt.Sprintf("渠道 %s (%d)测试超时", channel.Name, channel.Id), "", err.Error())
 | 
						||
				}
 | 
						||
			}
 | 
						||
			if isChannelEnabled && monitor.ShouldDisableChannel(openaiErr, -1) {
 | 
						||
				monitor.DisableChannel(channel.Id, channel.Name, err.Error())
 | 
						||
			}
 | 
						||
			if !isChannelEnabled && monitor.ShouldEnableChannel(err, openaiErr) {
 | 
						||
				monitor.EnableChannel(channel.Id, channel.Name)
 | 
						||
			}
 | 
						||
			channel.UpdateResponseTime(milliseconds)
 | 
						||
			time.Sleep(config.RequestInterval)
 | 
						||
		}
 | 
						||
		testAllChannelsLock.Lock()
 | 
						||
		testAllChannelsRunning = false
 | 
						||
		testAllChannelsLock.Unlock()
 | 
						||
		if notify {
 | 
						||
			err := message.Notify(message.ByAll, "渠道测试完成", "", "渠道测试完成,如果没有收到禁用通知,说明所有渠道都正常")
 | 
						||
			if err != nil {
 | 
						||
				logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
 | 
						||
			}
 | 
						||
		}
 | 
						||
	}()
 | 
						||
	return nil
 | 
						||
}
 | 
						||
 | 
						||
func TestChannels(c *gin.Context) {
 | 
						||
	scope := c.Query("scope")
 | 
						||
	if scope == "" {
 | 
						||
		scope = "all"
 | 
						||
	}
 | 
						||
	err := testChannels(true, scope)
 | 
						||
	if err != nil {
 | 
						||
		c.JSON(http.StatusOK, gin.H{
 | 
						||
			"success": false,
 | 
						||
			"message": err.Error(),
 | 
						||
		})
 | 
						||
		return
 | 
						||
	}
 | 
						||
	c.JSON(http.StatusOK, gin.H{
 | 
						||
		"success": true,
 | 
						||
		"message": "",
 | 
						||
	})
 | 
						||
	return
 | 
						||
}
 | 
						||
 | 
						||
func AutomaticallyTestChannels(frequency int) {
 | 
						||
	for {
 | 
						||
		time.Sleep(time.Duration(frequency) * time.Minute)
 | 
						||
		logger.SysLog("testing all channels")
 | 
						||
		_ = testChannels(false, "all")
 | 
						||
		logger.SysLog("channel test finished")
 | 
						||
	}
 | 
						||
}
 |