fix: select channel

This commit is contained in:
CaIon 2024-04-07 22:08:11 +08:00
parent 2d1d1b4631
commit 34bf8f8945

View File

@ -24,6 +24,9 @@ func Distribute() func(c *gin.Context) {
userId := c.GetInt("id") userId := c.GetInt("id")
var channel *model.Channel var channel *model.Channel
channelId, ok := c.Get("specific_channel_id") channelId, ok := c.Get("specific_channel_id")
modelRequest, shouldSelectChannel, err := getModelRequest(c)
userGroup, _ := model.CacheGetUserGroup(userId)
c.Set("group", userGroup)
if ok { if ok {
id, err := strconv.Atoi(channelId.(string)) id, err := strconv.Atoi(channelId.(string))
if err != nil { if err != nil {
@ -40,9 +43,56 @@ func Distribute() func(c *gin.Context) {
return return
} }
} else { } else {
shouldSelectChannel := true
// Select a channel for the user // Select a channel for the user
// check token model mapping
modelLimitEnable := c.GetBool("token_model_limit_enabled")
if modelLimitEnable {
s, ok := c.Get("token_model_limit")
var tokenModelLimit map[string]bool
if ok {
tokenModelLimit = s.(map[string]bool)
} else {
tokenModelLimit = map[string]bool{}
}
if tokenModelLimit != nil {
if _, ok := tokenModelLimit[modelRequest.Model]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
return
}
} else {
// token model limit is empty, all models are not allowed
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
return
}
}
if shouldSelectChannel {
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
// 如果错误,但是渠道不为空,说明是数据库一致性问题
if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"
}
// 如果错误,而且渠道为空,说明是没有可用渠道
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message)
return
}
if channel == nil {
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
return
}
}
}
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
c.Next()
}
}
func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
var modelRequest ModelRequest var modelRequest ModelRequest
shouldSelectChannel := true
var err error var err error
if strings.Contains(c.Request.URL.Path, "/mj/") { if strings.Contains(c.Request.URL.Path, "/mj/") {
relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path) relayMode := relayconstant.Path2RelayModeMidjourney(c.Request.URL.Path)
@ -56,17 +106,17 @@ func Distribute() func(c *gin.Context) {
err = common.UnmarshalBodyReusable(c, &midjourneyRequest) err = common.UnmarshalBodyReusable(c, &midjourneyRequest)
if err != nil { if err != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error()) abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, "+err.Error())
return return nil, false, err
} }
midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest) midjourneyModel, mjErr, success := service.GetMjRequestModel(relayMode, &midjourneyRequest)
if mjErr != nil { if mjErr != nil {
abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description) abortWithMidjourneyMessage(c, http.StatusBadRequest, mjErr.Code, mjErr.Description)
return return nil, false, fmt.Errorf(mjErr.Description)
} }
if midjourneyModel == "" { if midjourneyModel == "" {
if !success { if !success {
abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型") abortWithMidjourneyMessage(c, http.StatusBadRequest, constant.MjErrorUnknown, "无效的请求, 无法解析模型")
return return nil, false, fmt.Errorf("无效的请求, 无法解析模型")
} else { } else {
// task fetch, task fetch by condition, notify // task fetch, task fetch by condition, notify
shouldSelectChannel = false shouldSelectChannel = false
@ -80,7 +130,7 @@ func Distribute() func(c *gin.Context) {
} }
if err != nil { if err != nil {
abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error()) abortWithOpenAiMessage(c, http.StatusBadRequest, "无效的请求, "+err.Error())
return return nil, false, err
} }
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" { if modelRequest.Model == "" {
@ -106,52 +156,7 @@ func Distribute() func(c *gin.Context) {
} }
} }
} }
// check token model mapping return &modelRequest, shouldSelectChannel, nil
modelLimitEnable := c.GetBool("token_model_limit_enabled")
if modelLimitEnable {
s, ok := c.Get("token_model_limit")
var tokenModelLimit map[string]bool
if ok {
tokenModelLimit = s.(map[string]bool)
} else {
tokenModelLimit = map[string]bool{}
}
if tokenModelLimit != nil {
if _, ok := tokenModelLimit[modelRequest.Model]; !ok {
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问模型 "+modelRequest.Model)
return
}
} else {
// token model limit is empty, all models are not allowed
abortWithOpenAiMessage(c, http.StatusForbidden, "该令牌无权访问任何模型")
return
}
}
userGroup, _ := model.CacheGetUserGroup(userId)
c.Set("group", userGroup)
if shouldSelectChannel {
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, 0)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
// 如果错误,但是渠道不为空,说明是数据库一致性问题
if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"
}
// 如果错误,而且渠道为空,说明是没有可用渠道
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, message)
return
}
if channel == nil {
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
return
}
SetupContextForSelectedChannel(c, channel, modelRequest.Model)
}
}
c.Next()
}
} }
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) { func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {