mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-12 11:23:42 +08:00
Refactor codebase, introduce relaymode package, update constants and improve consistency
- Refactor constant definitions and organization - Clean up package level variables and functions - Introduce new `relaymode` and `apitype` packages for constant definitions - Refactor and simplify code in several packages including `openai`, `relay/channel/baidu`, `relay/util`, `relay/controller`, `relay/channeltype` - Add helper functions in `relay/channeltype` package to convert channel type constants to corresponding API type constants - Remove deprecated functions such as `ResponseText2Usage` from `relay/channel/openai/helper.go` - Modify code in `relay/util/validation.go` and related files to use new `validator.ValidateTextRequest` function - Rename `util` package to `relaymode` and update related imports in several packages
This commit is contained in:
@@ -1,15 +1,15 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"fmt"
|
||||
"github.com/gin-contrib/sessions"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/blacklist"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/common/network"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func authHelper(c *gin.Context, minRole int) {
|
||||
@@ -47,7 +47,7 @@ func authHelper(c *gin.Context, minRole int) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if status.(int) == common.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) {
|
||||
if status.(int) == model.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": "用户已被封禁",
|
||||
@@ -74,24 +74,25 @@ func authHelper(c *gin.Context, minRole int) {
|
||||
|
||||
func UserAuth() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
authHelper(c, common.RoleCommonUser)
|
||||
authHelper(c, model.RoleCommonUser)
|
||||
}
|
||||
}
|
||||
|
||||
func AdminAuth() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
authHelper(c, common.RoleAdminUser)
|
||||
authHelper(c, model.RoleAdminUser)
|
||||
}
|
||||
}
|
||||
|
||||
func RootAuth() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
authHelper(c, common.RoleRootUser)
|
||||
authHelper(c, model.RoleRootUser)
|
||||
}
|
||||
}
|
||||
|
||||
func TokenAuth() func(c *gin.Context) {
|
||||
return func(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
key := c.Request.Header.Get("Authorization")
|
||||
key = strings.TrimPrefix(key, "Bearer ")
|
||||
key = strings.TrimPrefix(strings.TrimPrefix(key, "sk-"), "laisky-")
|
||||
@@ -102,6 +103,12 @@ func TokenAuth() func(c *gin.Context) {
|
||||
abortWithMessage(c, http.StatusUnauthorized, err.Error())
|
||||
return
|
||||
}
|
||||
if token.Subnet != nil && *token.Subnet != "" {
|
||||
if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) {
|
||||
abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌只能在指定网段使用:%s,当前 ip:%s", *token.Subnet, c.ClientIP()))
|
||||
return
|
||||
}
|
||||
}
|
||||
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusInternalServerError, err.Error())
|
||||
@@ -111,6 +118,19 @@ func TokenAuth() func(c *gin.Context) {
|
||||
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
|
||||
return
|
||||
}
|
||||
requestModel, err := getRequestModel(c)
|
||||
if err != nil && !strings.HasPrefix(c.Request.URL.Path, "/v1/models") {
|
||||
abortWithMessage(c, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
c.Set("request_model", requestModel)
|
||||
if token.Models != nil && *token.Models != "" {
|
||||
c.Set("available_models", *token.Models)
|
||||
if requestModel != "" && !isModelInList(requestModel, *token.Models) {
|
||||
abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel))
|
||||
return
|
||||
}
|
||||
}
|
||||
c.Set("id", token.UserId)
|
||||
c.Set("token_id", token.Id)
|
||||
c.Set("token_name", token.Name)
|
||||
|
||||
@@ -2,14 +2,16 @@ package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common/config"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"github.com/songquanpeng/one-api/model"
|
||||
"github.com/songquanpeng/one-api/relay/billing/ratio"
|
||||
"github.com/songquanpeng/one-api/relay/channeltype"
|
||||
)
|
||||
|
||||
type ModelRequest struct {
|
||||
@@ -35,42 +37,16 @@ func Distribute() func(c *gin.Context) {
|
||||
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
|
||||
return
|
||||
}
|
||||
if channel.Status != common.ChannelStatusEnabled {
|
||||
if channel.Status != model.ChannelStatusEnabled {
|
||||
abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// Select a channel for the user
|
||||
var modelRequest ModelRequest
|
||||
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
requestModel := c.GetString("request_model")
|
||||
var err error
|
||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false)
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusBadRequest, fmt.Sprintf("无效的请求: %+v", err))
|
||||
return
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "text-moderation-stable"
|
||||
}
|
||||
}
|
||||
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = c.Param("model")
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "dall-e-2"
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "whisper-1"
|
||||
}
|
||||
}
|
||||
requestModel = modelRequest.Model
|
||||
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false)
|
||||
if err != nil {
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
|
||||
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, requestModel)
|
||||
if channel != nil {
|
||||
logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
|
||||
message = "数据库一致性已被破坏,请联系管理员"
|
||||
@@ -85,17 +61,18 @@ func Distribute() func(c *gin.Context) {
|
||||
}
|
||||
|
||||
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
|
||||
// one channel could relates to multiple groups,
|
||||
// and each groud has individual ratio,
|
||||
// set minimal group ratio as channel_ratio
|
||||
var minimalRatio float64 = -1
|
||||
for _, grp := range strings.Split(channel.Group, ",") {
|
||||
v := common.GetGroupRatio(grp)
|
||||
v := ratio.GetGroupRatio(grp)
|
||||
if minimalRatio < 0 || v < minimalRatio {
|
||||
minimalRatio = v
|
||||
}
|
||||
}
|
||||
logger.Info(c.Request.Context(), fmt.Sprintf("set channel %s ratio to %f", channel.Name, minimalRatio))
|
||||
c.Set("channel_ratio", minimalRatio)
|
||||
|
||||
c.Set("channel", channel.Type)
|
||||
c.Set("channel_id", channel.Id)
|
||||
c.Set("channel_name", channel.Name)
|
||||
@@ -105,19 +82,19 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
||||
c.Set("base_url", channel.GetBaseURL())
|
||||
// this is for backward compatibility
|
||||
switch channel.Type {
|
||||
case common.ChannelTypeAzure:
|
||||
c.Set(common.ConfigKeyAPIVersion, channel.Other)
|
||||
case common.ChannelTypeXunfei:
|
||||
c.Set(common.ConfigKeyAPIVersion, channel.Other)
|
||||
case common.ChannelTypeGemini:
|
||||
c.Set(common.ConfigKeyAPIVersion, channel.Other)
|
||||
case common.ChannelTypeAIProxyLibrary:
|
||||
c.Set(common.ConfigKeyLibraryID, channel.Other)
|
||||
case common.ChannelTypeAli:
|
||||
c.Set(common.ConfigKeyPlugin, channel.Other)
|
||||
case channeltype.Azure:
|
||||
c.Set(config.KeyAPIVersion, channel.Other)
|
||||
case channeltype.Xunfei:
|
||||
c.Set(config.KeyAPIVersion, channel.Other)
|
||||
case channeltype.Gemini:
|
||||
c.Set(config.KeyAPIVersion, channel.Other)
|
||||
case channeltype.AIProxyLibrary:
|
||||
c.Set(config.KeyLibraryID, channel.Other)
|
||||
case channeltype.Ali:
|
||||
c.Set(config.KeyPlugin, channel.Other)
|
||||
}
|
||||
cfg, _ := channel.LoadConfig()
|
||||
for k, v := range cfg {
|
||||
c.Set(common.ConfigKeyPrefix+k, v)
|
||||
c.Set(config.KeyPrefix+k, v)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/songquanpeng/one-api/common"
|
||||
"github.com/songquanpeng/one-api/common/helper"
|
||||
"github.com/songquanpeng/one-api/common/logger"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
||||
@@ -16,3 +19,42 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) {
|
||||
c.Abort()
|
||||
logger.Error(c.Request.Context(), message)
|
||||
}
|
||||
|
||||
func getRequestModel(c *gin.Context) (string, error) {
|
||||
var modelRequest ModelRequest
|
||||
err := common.UnmarshalBodyReusable(c, &modelRequest)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err)
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "text-moderation-stable"
|
||||
}
|
||||
}
|
||||
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = c.Param("model")
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "dall-e-2"
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
|
||||
if modelRequest.Model == "" {
|
||||
modelRequest.Model = "whisper-1"
|
||||
}
|
||||
}
|
||||
return modelRequest.Model, nil
|
||||
}
|
||||
|
||||
func isModelInList(modelName string, models string) bool {
|
||||
modelList := strings.Split(models, ",")
|
||||
for _, model := range modelList {
|
||||
if modelName == model {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user