Merge remote-tracking branch 'origin/upstream/main'

This commit is contained in:
Laisky.Cai
2024-03-05 01:02:35 +00:00
28 changed files with 536 additions and 149 deletions

View File

@@ -8,6 +8,7 @@ import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/middleware"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/helper"
@@ -18,6 +19,7 @@ import (
"net/http/httptest"
"net/url"
"strconv"
"strings"
"sync"
"time"
@@ -51,6 +53,7 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
c.Request.Header.Set("Content-Type", "application/json")
c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL())
middleware.SetupContextForSelectedChannel(c, channel, "")
meta := util.GetRelayMeta(c)
apiType := constant.ChannelType2APIType(channel.Type)
adaptor := helper.GetAdaptor(apiType)
@@ -59,6 +62,12 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
}
adaptor.Init(meta)
modelName := adaptor.GetModelList()[0]
if !strings.Contains(channel.Models, modelName) {
modelNames := strings.Split(channel.Models, ",")
if len(modelNames) > 0 {
modelName = modelNames[0]
}
}
request := buildTestRequest()
request.Model = modelName
meta.OriginModelName, meta.ActualModelName = modelName, modelName

View File

@@ -4,6 +4,9 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel/ai360"
"github.com/songquanpeng/one-api/relay/channel/baichuan"
"github.com/songquanpeng/one-api/relay/channel/minimax"
"github.com/songquanpeng/one-api/relay/channel/mistral"
"github.com/songquanpeng/one-api/relay/channel/moonshot"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/helper"
@@ -102,6 +105,39 @@ func init() {
Parent: nil,
})
}
for _, modelName := range baichuan.ModelList {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: "baichuan",
Permission: permission,
Root: modelName,
Parent: nil,
})
}
for _, modelName := range minimax.ModelList {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: "minimax",
Permission: permission,
Root: modelName,
Parent: nil,
})
}
for _, modelName := range mistral.ModelList {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: "mistralai",
Permission: permission,
Root: modelName,
Parent: nil,
})
}
openAIModelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels {
openAIModelsMap[model.Id] = model

View File

@@ -41,6 +41,10 @@ func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
func Relay(c *gin.Context) {
ctx := c.Request.Context()
relayMode := constant.Path2RelayMode(c.Request.URL.Path)
if config.DebugEnabled {
requestBody, _ := common.GetRequestBody(c)
logger.Debugf(ctx, "request body: %s", string(requestBody))
}
bizErr := relay(c, relayMode)
if bizErr == nil {
return
@@ -58,7 +62,7 @@ func Relay(c *gin.Context) {
retryTimes = 0
}
for i := retryTimes; i > 0; i-- {
channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel)
channel, err := dbmodel.CacheGetRandomSatisfiedChannel(group, originalModel, i != retryTimes)
if err != nil {
logger.Errorf(ctx, "CacheGetRandomSatisfiedChannel failed: %+v", err)
break