new-api/middleware/distributor.go
2023-08-14 22:16:32 +08:00

157 lines
4.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"fmt"
"log"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
type ModelRequest struct {
Model string `json:"model"`
}
func Distribute() func(c *gin.Context) {
return func(c *gin.Context) {
userId := c.GetInt("id")
userGroup, _ := model.CacheGetUserGroup(userId)
c.Set("group", userGroup)
var channel *model.Channel
var err error
channelId, ok := c.Get("channelId")
if ok {
id, err := strconv.Atoi(channelId.(string))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "无效的渠道 ID",
"type": "one_api_error",
},
})
c.Abort()
return
}
channel, err = model.GetChannelById(id, true)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "无效的渠道 ID",
"type": "one_api_error",
},
})
c.Abort()
return
}
if channel.Status != common.ChannelStatusEnabled {
c.JSON(http.StatusForbidden, gin.H{
"error": gin.H{
"message": "该渠道已被禁用",
"type": "one_api_error",
},
})
c.Abort()
return
}
} else {
// Select a channel for the user
var modelRequest ModelRequest
if strings.HasPrefix(c.Request.URL.Path, "/mj") {
if modelRequest.Model == "" {
modelRequest.Model = "midjourney"
}
} else {
err := common.UnmarshalBodyReusable(c, &modelRequest)
if err != nil {
log.Println(err)
c.JSON(http.StatusBadRequest, gin.H{
"error": gin.H{
"message": "无效的请求",
"type": "one_api_error",
},
})
c.Abort()
return
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable"
}
}
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
if modelRequest.Model == "" {
modelRequest.Model = c.Param("model")
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" {
modelRequest.Model = "dall-e"
}
}
isStable := false
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
c.Set("stable", false)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
if strings.HasPrefix(modelRequest.Model, "gpt-4") {
common.SysLog("GPT-4低价渠道宕机正在尝试转换")
nowUser, err := model.GetUserById(userId, false)
if err == nil {
if nowUser.StableMode {
userGroup = "svip"
//stableRatio = (common.StablePrice / common.BasePrice) * modelRatio
userMaxPrice, _ := strconv.ParseFloat(nowUser.MaxPrice, 64)
if userMaxPrice < common.StablePrice {
message = "当前低价通道不可用,稳定渠道价格为" + strconv.FormatFloat(common.StablePrice, 'f', -1, 64) + "R/刀"
} else {
//common.SysLog(fmt.Sprintf("用户 %s 使用稳定渠道", nowUser.Username))
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model)
if err != nil {
message = "稳定渠道已经宕机,请联系管理员"
}
isStable = true
common.SysLog(fmt.Sprintf("用户 %s 使用稳定渠道 %v", nowUser.Username, channel))
c.Set("stable", true)
}
} else {
message = "当前低价通道不可用,请稍后再试,或者在后台开启稳定渠道模式"
}
}
}
//if channel == nil {
// common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
// message = "数据库一致性已被破坏,请联系管理员"
//}
if !isStable {
c.JSON(http.StatusInternalServerError, gin.H{
"error": gin.H{
"message": message,
"type": "one_api_error",
},
})
c.Abort()
return
}
}
}
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
c.Set("model_mapping", channel.ModelMapping)
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
c.Set("base_url", channel.BaseURL)
if channel.Type == common.ChannelTypeAzure {
c.Set("api_version", channel.Other)
}
c.Next()
}
}