mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-12-26 01:35:58 +08:00
refactor: use adaptor to do relay & test
This commit is contained in:
@@ -9,10 +9,14 @@ import (
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/relay/channel/openai"
|
||||
"github.com/songquanpeng/one-api/relay/constant"
|
||||
"github.com/songquanpeng/one-api/relay/helper"
|
||||
relaymodel "github.com/songquanpeng/one-api/relay/model"
|
||||
"github.com/songquanpeng/one-api/relay/util"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -20,87 +24,13 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func testChannel(channel *model.Channel, request openai.ChatRequest) (err error, openaiErr *openai.Error) {
|
||||
switch channel.Type {
|
||||
case common.ChannelTypePaLM:
|
||||
fallthrough
|
||||
case common.ChannelTypeGemini:
|
||||
fallthrough
|
||||
case common.ChannelTypeAnthropic:
|
||||
fallthrough
|
||||
case common.ChannelTypeBaidu:
|
||||
fallthrough
|
||||
case common.ChannelTypeZhipu:
|
||||
fallthrough
|
||||
case common.ChannelTypeAli:
|
||||
fallthrough
|
||||
case common.ChannelType360:
|
||||
fallthrough
|
||||
case common.ChannelTypeXunfei:
|
||||
return errors.New("该渠道类型当前版本不支持测试,请手动测试"), nil
|
||||
case common.ChannelTypeAzure:
|
||||
request.Model = "gpt-35-turbo"
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = errors.New("请确保已在 Azure 上创建了 gpt-35-turbo 模型,并且 apiVersion 已正确填写!")
|
||||
}
|
||||
}()
|
||||
default:
|
||||
request.Model = "gpt-3.5-turbo"
|
||||
}
|
||||
requestURL := common.ChannelBaseURLs[channel.Type]
|
||||
if channel.Type == common.ChannelTypeAzure {
|
||||
requestURL = util.GetFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type)
|
||||
} else {
|
||||
if baseURL := channel.GetBaseURL(); len(baseURL) > 0 {
|
||||
requestURL = baseURL
|
||||
}
|
||||
|
||||
requestURL = util.GetFullRequestURL(requestURL, "/v1/chat/completions", channel.Type)
|
||||
}
|
||||
jsonData, err := json.Marshal(request)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
req, err := http.NewRequest("POST", requestURL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
if channel.Type == common.ChannelTypeAzure {
|
||||
req.Header.Set("api-key", channel.Key)
|
||||
} else {
|
||||
req.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err := util.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
var response openai.SlimTextResponse
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error: %s\nResp body: %s", err, body), nil
|
||||
}
|
||||
if response.Usage.CompletionTokens == 0 {
|
||||
if response.Error.Message == "" {
|
||||
response.Error.Message = "补全 tokens 非预期返回 0"
|
||||
}
|
||||
return fmt.Errorf("type %s, code %v, message %s", response.Error.Type, response.Error.Code, response.Error.Message), &response.Error
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func buildTestRequest() *openai.ChatRequest {
|
||||
testRequest := &openai.ChatRequest{
|
||||
Model: "", // this will be set later
|
||||
func buildTestRequest() *relaymodel.GeneralOpenAIRequest {
|
||||
testRequest := &relaymodel.GeneralOpenAIRequest{
|
||||
MaxTokens: 1,
|
||||
Stream: false,
|
||||
Model: "gpt-3.5-turbo",
|
||||
}
|
||||
testMessage := openai.Message{
|
||||
testMessage := relaymodel.Message{
|
||||
Role: "user",
|
||||
Content: "hi",
|
||||
}
|
||||
@@ -108,6 +38,64 @@ func buildTestRequest() *openai.ChatRequest {
|
||||
return testRequest
|
||||
}
|
||||
|
||||
func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error) {
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = &http.Request{
|
||||
Method: "POST",
|
||||
URL: &url.URL{Path: "/v1/chat/completions"},
|
||||
Body: nil,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
c.Request.Header.Set("Authorization", "Bearer "+channel.Key)
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
meta := util.GetRelayMeta(c)
|
||||
apiType := constant.ChannelType2APIType(channel.Type)
|
||||
adaptor := helper.GetAdaptor(apiType)
|
||||
if adaptor == nil {
|
||||
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
|
||||
}
|
||||
modelName := adaptor.GetModelList()[0]
|
||||
request := buildTestRequest()
|
||||
request.Model = modelName
|
||||
meta.OriginModelName, meta.ActualModelName = modelName, modelName
|
||||
convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
jsonData, err := json.Marshal(convertedRequest)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
requestBody := bytes.NewBuffer(jsonData)
|
||||
c.Request.Body = io.NopCloser(requestBody)
|
||||
resp, err := adaptor.DoRequest(c, meta, requestBody)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
err := util.RelayErrorHandler(resp)
|
||||
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
|
||||
}
|
||||
usage, respErr := adaptor.DoResponse(c, resp, meta)
|
||||
if respErr != nil {
|
||||
return fmt.Errorf("%s", respErr.Error.Message), &respErr.Error
|
||||
}
|
||||
if usage == nil {
|
||||
return errors.New("usage is nil"), nil
|
||||
}
|
||||
result := w.Result()
|
||||
// print result.Body
|
||||
respBody, err := io.ReadAll(result.Body)
|
||||
if err != nil {
|
||||
return err, nil
|
||||
}
|
||||
logger.SysLog(fmt.Sprintf("testing channel #%d, response: \n%s", channel.Id, string(respBody)))
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestChannel(c *gin.Context) {
|
||||
id, err := strconv.Atoi(c.Param("id"))
|
||||
if err != nil {
|
||||
@@ -125,9 +113,8 @@ func TestChannel(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
testRequest := buildTestRequest()
|
||||
tik := time.Now()
|
||||
err, _ = testChannel(channel, *testRequest)
|
||||
err, _ = testChannel(channel)
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
go channel.UpdateResponseTime(milliseconds)
|
||||
@@ -192,7 +179,6 @@ func testAllChannels(notify bool) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
testRequest := buildTestRequest()
|
||||
var disableThreshold = int64(config.ChannelDisableThreshold * 1000)
|
||||
if disableThreshold == 0 {
|
||||
disableThreshold = 10000000 // a impossible value
|
||||
@@ -201,7 +187,7 @@ func testAllChannels(notify bool) error {
|
||||
for _, channel := range channels {
|
||||
isChannelEnabled := channel.Status == common.ChannelStatusEnabled
|
||||
tik := time.Now()
|
||||
err, openaiErr := testChannel(channel, *testRequest)
|
||||
err, openaiErr := testChannel(channel)
|
||||
tok := time.Now()
|
||||
milliseconds := tok.Sub(tik).Milliseconds()
|
||||
if isChannelEnabled && milliseconds > disableThreshold {
|
||||
|
||||
Reference in New Issue
Block a user