merge upstream

Signed-off-by: wozulong <>
This commit is contained in:
wozulong
2024-05-16 19:00:58 +08:00
21 changed files with 218 additions and 71 deletions

View File

@@ -64,7 +64,21 @@ func testChannel(channel *model.Channel, testModel string) (err error, openaiErr
} else {
testModel = adaptor.GetModelList()[0]
}
} else {
modelMapping := *channel.ModelMapping
if modelMapping != "" && modelMapping != "{}" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
openaiErr := service.OpenAIErrorWrapperLocal(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError).Error
return err, &openaiErr
}
if modelMap[testModel] != "" {
testModel = modelMap[testModel]
}
}
}
request := buildTestRequest()
request.Model = testModel
meta.UpstreamModelName = testModel

View File

@@ -108,8 +108,8 @@ func init() {
})
}
openAIModelsMap = make(map[string]dto.OpenAIModels)
for _, model := range openAIModels {
openAIModelsMap[model.Id] = model
for _, aiModel := range openAIModels {
openAIModelsMap[aiModel.Id] = aiModel
}
channelId2Models = make(map[int][]string)
for i := 1; i <= common.ChannelTypeDummy; i++ {
@@ -174,8 +174,8 @@ func DashboardListModels(c *gin.Context) {
func RetrieveModel(c *gin.Context) {
modelId := c.Param("model")
if model, ok := openAIModelsMap[modelId]; ok {
c.JSON(200, model)
if aiModel, ok := openAIModelsMap[modelId]; ok {
c.JSON(200, aiModel)
} else {
openAIError := dto.OpenAIError{
Message: fmt.Sprintf("The model '%s' does not exist", modelId),
@@ -191,12 +191,12 @@ func RetrieveModel(c *gin.Context) {
func GetPricing(c *gin.Context) {
userId := c.GetInt("id")
user, _ := model.GetUserById(userId, true)
group, err := model.CacheGetUserGroup(userId)
groupRatio := common.GetGroupRatio("default")
if user != nil {
groupRatio = common.GetGroupRatio(user.Group)
if err != nil {
groupRatio = common.GetGroupRatio(group)
}
pricing := model.GetPricing(user, openAIModels)
pricing := model.GetPricing(group)
c.JSON(200, gin.H{
"success": true,
"data": pricing,

View File

@@ -43,7 +43,7 @@ func Relay(c *gin.Context) {
group := c.GetString("group")
originalModel := c.GetString("original_model")
openaiErr := relayHandler(c, relayMode)
useChannel := []int{channelId}
c.Set("use_channel", []string{fmt.Sprintf("%d", channelId)})
if openaiErr != nil {
go processChannelError(c, channelId, openaiErr)
} else {
@@ -56,7 +56,9 @@ func Relay(c *gin.Context) {
break
}
channelId = channel.Id
useChannel = append(useChannel, channelId)
useChannel := c.GetStringSlice("use_channel")
useChannel = append(useChannel, fmt.Sprintf("%d", channelId))
c.Set("use_channel", useChannel)
common.LogInfo(c.Request.Context(), fmt.Sprintf("using channel #%d to retry (remain times %d)", channel.Id, i))
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
@@ -67,6 +69,7 @@ func Relay(c *gin.Context) {
go processChannelError(c, channelId, openaiErr)
}
}
useChannel := c.GetStringSlice("use_channel")
if len(useChannel) > 1 {
retryLogStr := fmt.Sprintf("重试:%s", strings.Trim(strings.Join(strings.Fields(fmt.Sprint(useChannel)), "->"), "[]"))
common.LogInfo(c.Request.Context(), retryLogStr)