mirror of
https://github.com/linux-do/new-api.git
synced 2025-11-08 23:23:42 +08:00
feat: 初步重构
This commit is contained in:
@@ -5,9 +5,17 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/model"
|
||||
relaychannel "one-api/relay/channel"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/constant"
|
||||
"one-api/service"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -15,89 +23,77 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) {
|
||||
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, request.Model))
|
||||
switch channel.Type {
|
||||
case common.ChannelTypePaLM:
|
||||
fallthrough
|
||||
case common.ChannelTypeAnthropic:
|
||||
fallthrough
|
||||
case common.ChannelTypeBaidu:
|
||||
fallthrough
|
||||
case common.ChannelTypeZhipu:
|
||||
fallthrough
|
||||
case common.ChannelTypeAli:
|
||||
fallthrough
|
||||
case common.ChannelType360:
|
||||
fallthrough
|
||||
case common.ChannelTypeGemini:
|
||||
fallthrough
|
||||
case common.ChannelTypeXunfei:
|
||||
return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
|
||||
case common.ChannelTypeAzure:
|
||||
if request.Model == "" {
|
||||
request.Model = "gpt-35-turbo"
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!")
|
||||
}
|
||||
}()
|
||||
default:
|
||||
if request.Model == "" {
|
||||
request.Model = "gpt-3.5-turbo"
|
||||
}
|
||||
func testChannel(channel *model.Channel, testModel string) (err error, openaiErr *dto.OpenAIError) {
|
||||
common.SysLog(fmt.Sprintf("testing channel %d with model %s", channel.Id, testModel))
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = &http.Request{
|
||||
Method: "POST",
|
||||
URL: &url.URL{Path: "/v1/chat/completions"},
|
||||
Body: nil,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
baseUrl := common.ChannelBaseURLs[channel.Type]
|
||||
if channel.GetBaseURL() != "" {
|
||||
baseUrl = channel.GetBaseURL()
|
||||
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
meta := relaycommon.GenRelayInfo(c)
|
||||
apiType := constant.ChannelType2APIType(channel.Type)
|
||||
adaptor := relaychannel.GetAdaptor(apiType)
|
||||
if adaptor == nil {
|
||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||
}
|
||||
requestURL := getFullRequestURL(baseUrl, "/v1/chat/completions", channel.Type)
|
||||
if testModel == "" {
|
||||
testModel = adaptor.GetModelList()[0]
|
||||
}
|
||||
request := buildTestRequest()
|
||||
|
||||
if channel.Type == common.ChannelTypeAzure {
|
||||
requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type)
|
||||
}
|
||||
adaptor.Init(meta, *request)
|
||||
|
||||
jsonData, err := json.Marshal(request)
|
||||
request.Model = testModel
|
||||
meta.UpstreamModelName = testModel
|
||||
convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
if channel.Type == common.ChannelTypeAzure {
|
||||
req.Header.Set("api-key", channel.Key)
|
||||
} else {
|
||||
req.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := httpClient.Do(req)
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
c.Request.Body = io.NopCloser(requestBody)
|
||||
resp, err := adaptor.DoRequest(c, meta, requestBody)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
var response TextResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&response)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
err := relaycommon.RelayErrorHandler(resp)
|
||||
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.OpenAIError.Message), &err.OpenAIError
|
||||
}
|
||||
usage, respErr := adaptor.DoResponse(c, resp, meta)
|
||||
if respErr != nil {
|
||||
return fmt.Errorf("%s", respErr.OpenAIError.Message), &respErr.OpenAIError
|
||||
}
|
||||
if usage == nil {
|
||||
return errors.New("usage is nil"), nil
|
||||
}
|
||||
result := w.Result()
|
||||
// print result.Body
|
||||
respBody, err := io.ReadAll(result.Body)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
if response.Usage.CompletionTokens == 0 {
|
||||
if response.Error.Message == "" {
|
||||
response.Error.Message = "补全 tokens 非预期返回 0"
|
||||
}
|
||||
return errors.New(fmt.Sprintf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message)), &response.Error
|
||||
}
|
||||
common.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func buildTestRequest() *ChatRequest {
|
||||
testRequest := &ChatRequest{
|
||||
func buildTestRequest() *dto.GeneralOpenAIRequest {
|
||||
testRequest := &dto.GeneralOpenAIRequest{
|
||||
Model: "", // this will be set later
|
||||
MaxTokens: 1,
|
||||
}
|
||||
content, _ := json.Marshal("hi")
|
||||
testMessage := Message{
|
||||
testMessage := dto.Message{
|
||||
Role: "user",
|
||||
Content: content,
|
||||
}
|
||||
@@ -114,7 +110,6 @@ func TestChannel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
testModel := c.Query("model")
|
||||
channel, err := model.GetChannelById(id, true)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -123,12 +118,9 @@ func TestChannel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
testRequest := buildTestRequest()
|
||||
if testModel != "" {
|
||||
testRequest.Model = testModel
|
||||
}
|
||||
testModel := c.Query("model")
|
||||
tik := time.Now()
|
||||
err, _ = testChannel(channel, *testRequest)
|
||||
err, _ = testChannel(channel, testModel)
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
go channel.UpdateResponseTime(milliseconds)
|
||||
@@ -152,31 +144,6 @@ func TestChannel(c *gin.Context) {
|
||||
var testAllChannelsLock sync.Mutex
|
||||
var testAllChannelsRunning bool = false
|
||||
|
||||
// disable & notify
|
||||
func disableChannel(channelId int, channelName string, reason string) {
|
||||
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
|
||||
subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId)
|
||||
content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
|
||||
notifyRootUser(subject, content)
|
||||
}
|
||||
|
||||
func enableChannel(channelId int, channelName string) {
|
||||
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
|
||||
subject := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
||||
content := fmt.Sprintf("通道「%s」(#%d)已被启用", channelName, channelId)
|
||||
notifyRootUser(subject, content)
|
||||
}
|
||||
|
||||
func notifyRootUser(subject string, content string) {
|
||||
if common.RootUserEmail == "" {
|
||||
common.RootUserEmail = model.GetRootUserEmail()
|
||||
}
|
||||
err := common.SendEmail(subject, common.RootUserEmail, content)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
func testAllChannels(notify bool) error {
|
||||
if common.RootUserEmail == "" {
|
||||
common.RootUserEmail = model.GetRootUserEmail()
|
||||
@@ -192,7 +159,6 @@ func testAllChannels(notify bool) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
testRequest := buildTestRequest()
|
||||
var disableThreshold = int64(common.ChannelDisableThreshold * 1000)
|
||||
if disableThreshold == 0 {
|
||||
disableThreshold = 10000000 // a impossible value
|
||||
@@ -201,7 +167,7 @@ func testAllChannels(notify bool) error {
|
||||
for _, channel := range channels {
|
||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||
tik := time.Now()
|
||||
err, openaiErr := testChannel(channel, *testRequest)
|
||||
err, openaiErr := testChannel(channel, "")
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
|
||||
@@ -218,11 +184,11 @@ func testAllChannels(notify bool) error {
|
||||
if channel.AutoBan != nil && *channel.AutoBan == 0 {
|
||||
ban = false
|
||||
}
|
||||
if isChannelEnabled && shouldDisableChannel(openaiErr, -1) && ban {
|
||||
disableChannel(channel.Id, channel.Name, err.Error())
|
||||
if isChannelEnabled && service.ShouldDisableChannel(openaiErr, -1) && ban {
|
||||
service.DisableChannel(channel.Id, channel.Name, err.Error())
|
||||
}
|
||||
if !isChannelEnabled && shouldEnableChannel(err, openaiErr) {
|
||||
enableChannel(channel.Id, channel.Name)
|
||||
if !isChannelEnabled && service.ShouldEnableChannel(err, openaiErr) {
|
||||
service.EnableChannel(channel.Id, channel.Name)
|
||||
}
|
||||
channel.UpdateResponseTime(milliseconds)
|
||||
time.Sleep(common.RequestInterval)
|
||||
|
||||
Reference in New Issue
Block a user