Refactor codebase, introduce relaymode package, update constants and improve consistency

- Refactor constant definitions and organization
- Clean up package level variables and functions
- Introduce new `relaymode` and `apitype` packages for constant definitions
- Refactor and simplify code in several packages including `openai`, `relay/channel/baidu`, `relay/util`, `relay/controller`, `relay/channeltype`
- Add helper functions in `relay/channeltype` package to convert channel type constants to corresponding API type constants
- Remove deprecated functions such as `ResponseText2Usage` from `relay/channel/openai/helper.go`
- Modify code in `relay/util/validation.go` and related files to use new `validator.ValidateTextRequest` function
- Rename `util` package to `relaymode` and update related imports in several packages
This commit is contained in:
Laisky.Cai
2024-04-06 05:18:04 +00:00
parent 7c59a97a6b
commit 50bab08496
157 changed files with 3217 additions and 19160 deletions

View File

@@ -1,15 +1,15 @@
package middleware
import (
"net/http"
"strings"
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/blacklist"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/network"
"github.com/songquanpeng/one-api/model"
"net/http"
"strings"
)
func authHelper(c *gin.Context, minRole int) {
@@ -47,7 +47,7 @@ func authHelper(c *gin.Context, minRole int) {
return
}
}
if status.(int) == common.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) {
if status.(int) == model.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已被封禁",
@@ -74,24 +74,25 @@ func authHelper(c *gin.Context, minRole int) {
func UserAuth() func(c *gin.Context) {
return func(c *gin.Context) {
authHelper(c, common.RoleCommonUser)
authHelper(c, model.RoleCommonUser)
}
}
func AdminAuth() func(c *gin.Context) {
return func(c *gin.Context) {
authHelper(c, common.RoleAdminUser)
authHelper(c, model.RoleAdminUser)
}
}
func RootAuth() func(c *gin.Context) {
return func(c *gin.Context) {
authHelper(c, common.RoleRootUser)
authHelper(c, model.RoleRootUser)
}
}
func TokenAuth() func(c *gin.Context) {
return func(c *gin.Context) {
ctx := c.Request.Context()
key := c.Request.Header.Get("Authorization")
key = strings.TrimPrefix(key, "Bearer ")
key = strings.TrimPrefix(strings.TrimPrefix(key, "sk-"), "laisky-")
@@ -102,6 +103,12 @@ func TokenAuth() func(c *gin.Context) {
abortWithMessage(c, http.StatusUnauthorized, err.Error())
return
}
if token.Subnet != nil && *token.Subnet != "" {
if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) {
abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌只能在指定网段使用:%s当前 ip%s", *token.Subnet, c.ClientIP()))
return
}
}
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
if err != nil {
abortWithMessage(c, http.StatusInternalServerError, err.Error())
@@ -111,6 +118,19 @@ func TokenAuth() func(c *gin.Context) {
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
return
}
requestModel, err := getRequestModel(c)
if err != nil && !strings.HasPrefix(c.Request.URL.Path, "/v1/models") {
abortWithMessage(c, http.StatusBadRequest, err.Error())
return
}
c.Set("request_model", requestModel)
if token.Models != nil && *token.Models != "" {
c.Set("available_models", *token.Models)
if requestModel != "" && !isModelInList(requestModel, *token.Models) {
abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel))
return
}
}
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_name", token.Name)

View File

@@ -2,14 +2,16 @@ package middleware
import (
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/channeltype"
)
type ModelRequest struct {
@@ -35,42 +37,16 @@ func Distribute() func(c *gin.Context) {
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
return
}
if channel.Status != common.ChannelStatusEnabled {
if channel.Status != model.ChannelStatusEnabled {
abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
return
}
} else {
// Select a channel for the user
var modelRequest ModelRequest
err := common.UnmarshalBodyReusable(c, &modelRequest)
requestModel := c.GetString("request_model")
var err error
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false)
if err != nil {
abortWithMessage(c, http.StatusBadRequest, fmt.Sprintf("无效的请求: %+v", err))
return
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable"
}
}
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
if modelRequest.Model == "" {
modelRequest.Model = c.Param("model")
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" {
modelRequest.Model = "dall-e-2"
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
if modelRequest.Model == "" {
modelRequest.Model = "whisper-1"
}
}
requestModel = modelRequest.Model
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, requestModel)
if channel != nil {
logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"
@@ -85,17 +61,18 @@ func Distribute() func(c *gin.Context) {
}
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
// one channel could relates to multiple groups,
// and each groud has individual ratio,
// set minimal group ratio as channel_ratio
var minimalRatio float64 = -1
for _, grp := range strings.Split(channel.Group, ",") {
v := common.GetGroupRatio(grp)
v := ratio.GetGroupRatio(grp)
if minimalRatio < 0 || v < minimalRatio {
minimalRatio = v
}
}
logger.Info(c.Request.Context(), fmt.Sprintf("set channel %s ratio to %f", channel.Name, minimalRatio))
c.Set("channel_ratio", minimalRatio)
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
@@ -105,19 +82,19 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("base_url", channel.GetBaseURL())
// this is for backward compatibility
switch channel.Type {
case common.ChannelTypeAzure:
c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeXunfei:
c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeGemini:
c.Set(common.ConfigKeyAPIVersion, channel.Other)
case common.ChannelTypeAIProxyLibrary:
c.Set(common.ConfigKeyLibraryID, channel.Other)
case common.ChannelTypeAli:
c.Set(common.ConfigKeyPlugin, channel.Other)
case channeltype.Azure:
c.Set(config.KeyAPIVersion, channel.Other)
case channeltype.Xunfei:
c.Set(config.KeyAPIVersion, channel.Other)
case channeltype.Gemini:
c.Set(config.KeyAPIVersion, channel.Other)
case channeltype.AIProxyLibrary:
c.Set(config.KeyLibraryID, channel.Other)
case channeltype.Ali:
c.Set(config.KeyPlugin, channel.Other)
}
cfg, _ := channel.LoadConfig()
for k, v := range cfg {
c.Set(common.ConfigKeyPrefix+k, v)
c.Set(config.KeyPrefix+k, v)
}
}

View File

@@ -1,9 +1,12 @@
package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"strings"
)
func abortWithMessage(c *gin.Context, statusCode int, message string) {
@@ -16,3 +19,42 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) {
c.Abort()
logger.Error(c.Request.Context(), message)
}
func getRequestModel(c *gin.Context) (string, error) {
var modelRequest ModelRequest
err := common.UnmarshalBodyReusable(c, &modelRequest)
if err != nil {
return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err)
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable"
}
}
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
if modelRequest.Model == "" {
modelRequest.Model = c.Param("model")
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" {
modelRequest.Model = "dall-e-2"
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
if modelRequest.Model == "" {
modelRequest.Model = "whisper-1"
}
}
return modelRequest.Model, nil
}
func isModelInList(modelName string, models string) bool {
modelList := strings.Split(models, ",")
for _, model := range modelList {
if modelName == model {
return true
}
}
return false
}