mirror of
				https://github.com/linux-do/new-api.git
				synced 2025-11-04 05:13:41 +08:00 
			
		
		
		
	feat: support o1 channel test
This commit is contained in:
		@@ -20,6 +20,7 @@ import (
 | 
			
		||||
	"one-api/relay/constant"
 | 
			
		||||
	"one-api/service"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
@@ -81,8 +82,7 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 | 
			
		||||
		return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	request := buildTestRequest()
 | 
			
		||||
	request.Model = testModel
 | 
			
		||||
	request := buildTestRequest(testModel)
 | 
			
		||||
	meta.UpstreamModelName = testModel
 | 
			
		||||
	common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
 | 
			
		||||
 | 
			
		||||
@@ -141,17 +141,22 @@ func testChannel(channel *model.Channel, testModel string) (err error, openAIErr
 | 
			
		||||
	return nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func buildTestRequest() *dto.GeneralOpenAIRequest {
 | 
			
		||||
func buildTestRequest(model string) *dto.GeneralOpenAIRequest {
 | 
			
		||||
	testRequest := &dto.GeneralOpenAIRequest{
 | 
			
		||||
		Model:     "", // this will be set later
 | 
			
		||||
		MaxTokens: 1,
 | 
			
		||||
		Stream:    false,
 | 
			
		||||
		Model:  "", // this will be set later
 | 
			
		||||
		Stream: false,
 | 
			
		||||
	}
 | 
			
		||||
	if strings.HasPrefix(model, "o1-") {
 | 
			
		||||
		testRequest.MaxCompletionTokens = 1
 | 
			
		||||
	} else {
 | 
			
		||||
		testRequest.MaxTokens = 1
 | 
			
		||||
	}
 | 
			
		||||
	content, _ := json.Marshal("hi")
 | 
			
		||||
	testMessage := dto.Message{
 | 
			
		||||
		Role:    "user",
 | 
			
		||||
		Content: content,
 | 
			
		||||
	}
 | 
			
		||||
	testRequest.Model = model
 | 
			
		||||
	testRequest.Messages = append(testRequest.Messages, testMessage)
 | 
			
		||||
	return testRequest
 | 
			
		||||
}
 | 
			
		||||
@@ -226,26 +231,22 @@ func testAllChannels(notify bool) error {
 | 
			
		||||
			tok := time.Now()
 | 
			
		||||
			milliseconds := tok.Sub(tik).Milliseconds()
 | 
			
		||||
 | 
			
		||||
			ban := false
 | 
			
		||||
			if milliseconds > disableThreshold {
 | 
			
		||||
				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
 | 
			
		||||
				ban = true
 | 
			
		||||
			}
 | 
			
		||||
			shouldBanChannel := false
 | 
			
		||||
 | 
			
		||||
			// request error disables the channel
 | 
			
		||||
			if openaiWithStatusErr != nil {
 | 
			
		||||
				oaiErr := openaiWithStatusErr.Error
 | 
			
		||||
				err = errors.New(fmt.Sprintf("type %s, httpCode %d, code %v, message %s", oaiErr.Type, openaiWithStatusErr.StatusCode, oaiErr.Code, oaiErr.Message))
 | 
			
		||||
				ban = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
 | 
			
		||||
				shouldBanChannel = service.ShouldDisableChannel(channel.Type, openaiWithStatusErr)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// parse *int to bool
 | 
			
		||||
			if !channel.GetAutoBan() {
 | 
			
		||||
				ban = false
 | 
			
		||||
			if milliseconds > disableThreshold {
 | 
			
		||||
				err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
 | 
			
		||||
				shouldBanChannel = true
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// disable channel
 | 
			
		||||
			if ban && isChannelEnabled {
 | 
			
		||||
			if isChannelEnabled && shouldBanChannel && channel.GetAutoBan() {
 | 
			
		||||
				service.DisableChannel(channel.Id, channel.Name, err.Error())
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user