mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-10-31 22:03:41 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			272 lines
		
	
	
		
			6.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			272 lines
		
	
	
		
			6.8 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package controller
 | ||
| 
 | ||
| import (
 | ||
| 	"bufio"
 | ||
| 	"bytes"
 | ||
| 	"encoding/json"
 | ||
| 	"errors"
 | ||
| 	"fmt"
 | ||
| 	"io/ioutil"
 | ||
| 	"net/http"
 | ||
| 	"one-api/common"
 | ||
| 	"one-api/model"
 | ||
| 	"strconv"
 | ||
| 	"strings"
 | ||
| 	"sync"
 | ||
| 	"time"
 | ||
| 
 | ||
| 	"github.com/gin-gonic/gin"
 | ||
| )
 | ||
| 
 | ||
| func testChannel(channel *model.Channel, request ChatRequest) error {
 | ||
| 	switch channel.Type {
 | ||
| 	case common.ChannelTypeAzure:
 | ||
| 		request.Model = "gpt-35-turbo"
 | ||
| 	default:
 | ||
| 		request.Model = "gpt-3.5-turbo"
 | ||
| 	}
 | ||
| 	requestURL := common.ChannelBaseURLs[channel.Type]
 | ||
| 	if channel.Type == common.ChannelTypeAzure {
 | ||
| 		requestURL = fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", channel.BaseURL, request.Model)
 | ||
| 	} else {
 | ||
| 		if channel.BaseURL != "" {
 | ||
| 			requestURL = channel.BaseURL
 | ||
| 		}
 | ||
| 		requestURL += "/v1/chat/completions"
 | ||
| 	}
 | ||
| 
 | ||
| 	jsonData, err := json.Marshal(request)
 | ||
| 	if err != nil {
 | ||
| 		return err
 | ||
| 	}
 | ||
| 	req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
 | ||
| 	if err != nil {
 | ||
| 		return err
 | ||
| 	}
 | ||
| 	if channel.Type == common.ChannelTypeAzure {
 | ||
| 		req.Header.Set("api-key", channel.Key)
 | ||
| 	} else {
 | ||
| 		req.Header.Set("Authorization", "Bearer "+channel.Key)
 | ||
| 	}
 | ||
| 	req.Header.Set("Content-Type", "application/json")
 | ||
| 	client := &http.Client{}
 | ||
| 	resp, err := client.Do(req)
 | ||
| 	if err != nil {
 | ||
| 		return err
 | ||
| 	}
 | ||
| 
 | ||
| 	if resp.StatusCode != http.StatusOK {
 | ||
| 		return errors.New("invalid status code: " + strconv.Itoa(resp.StatusCode))
 | ||
| 	}
 | ||
| 
 | ||
| 	var response TextResponse
 | ||
| 	isStream := strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
 | ||
| 	var streamResponseText string
 | ||
| 
 | ||
| 	if isStream {
 | ||
| 		scanner := bufio.NewScanner(resp.Body)
 | ||
| 		scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | ||
| 			if atEOF && len(data) == 0 {
 | ||
| 				return 0, nil, nil
 | ||
| 			}
 | ||
| 
 | ||
| 			if i := strings.Index(string(data), "\n\n"); i >= 0 {
 | ||
| 				return i + 2, data[0:i], nil
 | ||
| 			}
 | ||
| 
 | ||
| 			if atEOF {
 | ||
| 				return len(data), data, nil
 | ||
| 			}
 | ||
| 
 | ||
| 			return 0, nil, nil
 | ||
| 		})
 | ||
| 		for scanner.Scan() {
 | ||
| 			data := scanner.Text()
 | ||
| 			if len(data) < 6 { // must be something wrong!
 | ||
| 				common.SysError("invalid stream response: " + data)
 | ||
| 				continue
 | ||
| 			}
 | ||
| 			data = data[6:]
 | ||
| 			if !strings.HasPrefix(data, "[DONE]") {
 | ||
| 				var streamResponse ChatCompletionsStreamResponse
 | ||
| 				err = json.Unmarshal([]byte(data), &streamResponse)
 | ||
| 				if err != nil {
 | ||
| 					common.SysError("error unmarshalling stream response: " + err.Error())
 | ||
| 					return err
 | ||
| 				}
 | ||
| 				for _, choice := range streamResponse.Choices {
 | ||
| 					streamResponseText += choice.Delta.Content
 | ||
| 				}
 | ||
| 			}
 | ||
| 		}
 | ||
| 
 | ||
| 		if streamResponseText == "" {
 | ||
| 			return errors.New("empty stream response")
 | ||
| 		}
 | ||
| 	} else {
 | ||
| 		body, err := ioutil.ReadAll(resp.Body)
 | ||
| 		if err != nil {
 | ||
| 			return err
 | ||
| 		}
 | ||
| 		err = json.Unmarshal(body, &response)
 | ||
| 		if err != nil {
 | ||
| 			return err
 | ||
| 		}
 | ||
| 
 | ||
| 		// channel.BaseURL starts with https://api.openai.com
 | ||
| 		if response.Usage.CompletionTokens == 0 && strings.HasPrefix(channel.BaseURL, "https://api.openai.com") {
 | ||
| 			return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message))
 | ||
| 		}
 | ||
| 	}
 | ||
| 
 | ||
| 	defer resp.Body.Close()
 | ||
| 
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| func buildTestRequest() *ChatRequest {
 | ||
| 	testRequest := &ChatRequest{
 | ||
| 		Model:     "", // this will be set later
 | ||
| 		MaxTokens: 1,
 | ||
| 	}
 | ||
| 	testMessage := Message{
 | ||
| 		Role:    "user",
 | ||
| 		Content: "hi",
 | ||
| 	}
 | ||
| 	testRequest.Messages = append(testRequest.Messages, testMessage)
 | ||
| 	return testRequest
 | ||
| }
 | ||
| 
 | ||
| func TestChannel(c *gin.Context) {
 | ||
| 	id, err := strconv.Atoi(c.Param("id"))
 | ||
| 	if err != nil {
 | ||
| 		c.JSON(http.StatusOK, gin.H{
 | ||
| 			"success": false,
 | ||
| 			"message": err.Error(),
 | ||
| 		})
 | ||
| 		return
 | ||
| 	}
 | ||
| 	channel, err := model.GetChannelById(id, true)
 | ||
| 	if err != nil {
 | ||
| 		c.JSON(http.StatusOK, gin.H{
 | ||
| 			"success": false,
 | ||
| 			"message": err.Error(),
 | ||
| 		})
 | ||
| 		return
 | ||
| 	}
 | ||
| 	testRequest := buildTestRequest()
 | ||
| 	tik := time.Now()
 | ||
| 	err = testChannel(channel, *testRequest)
 | ||
| 	tok := time.Now()
 | ||
| 	milliseconds := tok.Sub(tik).Milliseconds()
 | ||
| 	go channel.UpdateResponseTime(milliseconds)
 | ||
| 	consumedTime := float64(milliseconds) / 1000.0
 | ||
| 	if err != nil {
 | ||
| 		c.JSON(http.StatusOK, gin.H{
 | ||
| 			"success": false,
 | ||
| 			"message": err.Error(),
 | ||
| 			"time":    consumedTime,
 | ||
| 		})
 | ||
| 		return
 | ||
| 	}
 | ||
| 	c.JSON(http.StatusOK, gin.H{
 | ||
| 		"success": true,
 | ||
| 		"message": "",
 | ||
| 		"time":    consumedTime,
 | ||
| 	})
 | ||
| 	return
 | ||
| }
 | ||
| 
 | ||
| var testAllChannelsLock sync.Mutex
 | ||
| var testAllChannelsRunning bool = false
 | ||
| 
 | ||
| // disable & notify
 | ||
| func disableChannel(channelId int, channelName string, reason string) {
 | ||
| 	if common.RootUserEmail == "" {
 | ||
| 		common.RootUserEmail = model.GetRootUserEmail()
 | ||
| 	}
 | ||
| 	model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled)
 | ||
| 	subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
 | ||
| 	content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
 | ||
| 	err := common.SendEmail(subject, common.RootUserEmail, content)
 | ||
| 	if err != nil {
 | ||
| 		common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| func testAllChannels(notify bool) error {
 | ||
| 	if common.RootUserEmail == "" {
 | ||
| 		common.RootUserEmail = model.GetRootUserEmail()
 | ||
| 	}
 | ||
| 	testAllChannelsLock.Lock()
 | ||
| 	if testAllChannelsRunning {
 | ||
| 		testAllChannelsLock.Unlock()
 | ||
| 		return errors.New("测试已在运行中")
 | ||
| 	}
 | ||
| 	testAllChannelsRunning = true
 | ||
| 	testAllChannelsLock.Unlock()
 | ||
| 	channels, err := model.GetAllChannels(0, 0, true)
 | ||
| 	if err != nil {
 | ||
| 		return err
 | ||
| 	}
 | ||
| 	testRequest := buildTestRequest()
 | ||
| 	var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
 | ||
| 	if disableThreshold == 0 {
 | ||
| 		disableThreshold = 10000000 // a impossible value
 | ||
| 	}
 | ||
| 	go func() {
 | ||
| 		for _, channel := range channels {
 | ||
| 			if channel.Status != common.ChannelStatusEnabled {
 | ||
| 				continue
 | ||
| 			}
 | ||
| 			tik := time.Now()
 | ||
| 			err := testChannel(channel, *testRequest)
 | ||
| 			tok := time.Now()
 | ||
| 			milliseconds := tok.Sub(tik).Milliseconds()
 | ||
| 			if err != nil || milliseconds > disableThreshold {
 | ||
| 				if milliseconds > disableThreshold {
 | ||
| 					err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0))
 | ||
| 				}
 | ||
| 				disableChannel(channel.Id, channel.Name, err.Error())
 | ||
| 			}
 | ||
| 			channel.UpdateResponseTime(milliseconds)
 | ||
| 			time.Sleep(common.RequestInterval)
 | ||
| 		}
 | ||
| 		testAllChannelsLock.Lock()
 | ||
| 		testAllChannelsRunning = false
 | ||
| 		testAllChannelsLock.Unlock()
 | ||
| 		if notify {
 | ||
| 			err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常")
 | ||
| 			if err != nil {
 | ||
| 				common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
 | ||
| 			}
 | ||
| 		}
 | ||
| 	}()
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| func TestAllChannels(c *gin.Context) {
 | ||
| 	err := testAllChannels(true)
 | ||
| 	if err != nil {
 | ||
| 		c.JSON(http.StatusOK, gin.H{
 | ||
| 			"success": false,
 | ||
| 			"message": err.Error(),
 | ||
| 		})
 | ||
| 		return
 | ||
| 	}
 | ||
| 	c.JSON(http.StatusOK, gin.H{
 | ||
| 		"success": true,
 | ||
| 		"message": "",
 | ||
| 	})
 | ||
| 	return
 | ||
| }
 | ||
| 
 | ||
| func AutomaticallyTestChannels(frequency int) {
 | ||
| 	for {
 | ||
| 		time.Sleep(time.Duration(frequency) * time.Minute)
 | ||
| 		common.SysLog("testing all channels")
 | ||
| 		_ = testAllChannels(false)
 | ||
| 		common.SysLog("channel test finished")
 | ||
| 	}
 | ||
| }
 |