diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 0ddd31a8..20753f9f 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -12,9 +12,8 @@ assignees: "" - [ ] I have confirmed there are no similar issues - [ ] I have confirmed I am using the latest version -- [ ] I have thoroughly read the project README, especially the FAQ section +- [ ] I have thoroughly read the project README - [ ] I understand and am willing to follow up on this issue, assist with testing and provide feedback -- [ ] I understand and agree to the above, and understand that maintainers have limited time - **issues not following guidelines may be ignored or closed** ## Issue Description diff --git a/controller/model.go b/controller/model.go index e4822f89..e0502060 100644 --- a/controller/model.go +++ b/controller/model.go @@ -3,10 +3,11 @@ package controller import ( "fmt" "net/http" - "strings" + "sort" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/middleware" "github.com/songquanpeng/one-api/model" relay "github.com/songquanpeng/one-api/relay" "github.com/songquanpeng/one-api/relay/adaptor/openai" @@ -34,16 +35,20 @@ type OpenAIModelPermission struct { } type OpenAIModels struct { - Id string `json:"id"` - Object string `json:"object"` - Created int `json:"created"` + // Id model's name + // + // BUG: Different channels may have the same model name + Id string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + // OwnedBy is the channel's adaptor name OwnedBy string `json:"owned_by"` Permission []OpenAIModelPermission `json:"permission"` Root string `json:"root"` Parent *string `json:"parent"` } -var models []OpenAIModels +var allModels []OpenAIModels var modelsMap map[string]OpenAIModels var channelId2Models map[int][]string @@ -76,7 +81,7 @@ func init() { channelName := adaptor.GetChannelName() modelNames := adaptor.GetModelList() for _, modelName := range modelNames { - models = append(models, OpenAIModels{ + allModels = append(allModels, OpenAIModels{ Id: modelName, Object: "model", Created: 1626777600, @@ -93,7 +98,7 @@ func init() { } channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType) for _, modelName := range channelModelList { - models = append(models, OpenAIModels{ + allModels = append(allModels, OpenAIModels{ Id: modelName, Object: "model", Created: 1626777600, @@ -105,7 +110,7 @@ func init() { } } modelsMap = make(map[string]OpenAIModels) - for _, model := range models { + for _, model := range allModels { modelsMap[model.Id] = model } channelId2Models = make(map[int][]string) @@ -134,49 +139,56 @@ func DashboardListModels(c *gin.Context) { func ListAllModels(c *gin.Context) { c.JSON(200, gin.H{ "object": "list", - "data": models, + "data": allModels, }) } func ListModels(c *gin.Context) { - ctx := c.Request.Context() - var availableModels []string - if c.GetString(ctxkey.AvailableModels) != "" { - availableModels = strings.Split(c.GetString(ctxkey.AvailableModels), ",") - } else { - userId := c.GetInt(ctxkey.Id) - userGroup, _ := model.CacheGetUserGroup(userId) - availableModels, _ = model.CacheGetGroupModels(ctx, userGroup) + userId := c.GetInt(ctxkey.Id) + userGroup, err := model.CacheGetUserGroup(userId) + if err != nil { + middleware.AbortWithError(c, http.StatusBadRequest, err) + return } - modelSet := make(map[string]bool) - for _, availableModel := range availableModels { - modelSet[availableModel] = true + + // Get available models with their channel names + availableAbilities, err := model.GetGroupModelsV2(c.Request.Context(), userGroup) + if err != nil { + middleware.AbortWithError(c, http.StatusBadRequest, err) + return } + + // Create a map for quick lookup of enabled model+channel combinations + // Only store the exact model:channel combinations from abilities + abilityMap := make(map[string]bool) + for _, ability := range availableAbilities { + // Store as "modelName:channelName" for exact matching + adaptor := relay.GetAdaptor(channeltype.ToAPIType(ability.ChannelType)) + key := ability.Model + ":" + adaptor.GetChannelName() + abilityMap[key] = true + } + + // Filter models that match user's abilities with EXACT model+channel matches availableOpenAIModels := make([]OpenAIModels, 0) - for _, model := range models { - if _, ok := modelSet[model.Id]; ok { - modelSet[model.Id] = false + + // Only include models that have a matching model+channel combination + for _, model := range allModels { + key := model.Id + ":" + model.OwnedBy + if abilityMap[key] { availableOpenAIModels = append(availableOpenAIModels, model) } } - for modelName, ok := range modelSet { - if ok { - availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{ - Id: modelName, - Object: "model", - Created: 1626777600, - OwnedBy: "custom", - Root: modelName, - Parent: nil, - }) - } - } + + // Sort models alphabetically for consistent presentation + sort.Slice(availableOpenAIModels, func(i, j int) bool { + return availableOpenAIModels[i].Id < availableOpenAIModels[j].Id + }) + c.JSON(200, gin.H{ "object": "list", "data": availableOpenAIModels, }) } - func RetrieveModel(c *gin.Context) { modelId := c.Param("model") if model, ok := modelsMap[modelId]; ok { diff --git a/middleware/auth.go b/middleware/auth.go index f3af72aa..f5b43bc5 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -102,34 +102,34 @@ func TokenAuth() func(c *gin.Context) { key = parts[0] token, err := model.ValidateUserToken(key) if err != nil { - abortWithError(c, http.StatusUnauthorized, err) + AbortWithError(c, http.StatusUnauthorized, err) return } if token.Subnet != nil && *token.Subnet != "" { if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) { - abortWithError(c, http.StatusForbidden, errors.Errorf("This API key can only be used in the specified subnet: %s, current IP: %s", *token.Subnet, c.ClientIP())) + AbortWithError(c, http.StatusForbidden, errors.Errorf("This API key can only be used in the specified subnet: %s, current IP: %s", *token.Subnet, c.ClientIP())) return } } userEnabled, err := model.CacheIsUserEnabled(token.UserId) if err != nil { - abortWithError(c, http.StatusInternalServerError, err) + AbortWithError(c, http.StatusInternalServerError, err) return } if !userEnabled || blacklist.IsUserBanned(token.UserId) { - abortWithError(c, http.StatusForbidden, errors.New("User has been banned")) + AbortWithError(c, http.StatusForbidden, errors.New("User has been banned")) return } requestModel, err := getRequestModel(c) if err != nil && shouldCheckModel(c) { - abortWithError(c, http.StatusBadRequest, err) + AbortWithError(c, http.StatusBadRequest, err) return } c.Set(ctxkey.RequestModel, requestModel) if token.Models != nil && *token.Models != "" { c.Set(ctxkey.AvailableModels, *token.Models) if requestModel != "" && !isModelInList(requestModel, *token.Models) { - abortWithError(c, http.StatusForbidden, errors.Errorf("This API key does not have permission to use the model: %s", requestModel)) + AbortWithError(c, http.StatusForbidden, errors.Errorf("This API key does not have permission to use the model: %s", requestModel)) return } } @@ -144,7 +144,7 @@ func TokenAuth() func(c *gin.Context) { if model.IsAdmin(token.UserId) { c.Set(ctxkey.SpecificChannelId, parts[1]) } else { - abortWithError(c, http.StatusForbidden, errors.New("Ordinary users do not support specifying channels")) + AbortWithError(c, http.StatusForbidden, errors.New("Ordinary users do not support specifying channels")) return } } diff --git a/middleware/distributor.go b/middleware/distributor.go index 3d663557..cc540726 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -32,16 +32,16 @@ func Distribute() func(c *gin.Context) { if ok { id, err := strconv.Atoi(channelId.(string)) if err != nil { - abortWithError(c, http.StatusBadRequest, errors.New("Invalid Channel Id")) + AbortWithError(c, http.StatusBadRequest, errors.New("Invalid Channel Id")) return } channel, err = model.GetChannelById(id, true) if err != nil { - abortWithError(c, http.StatusBadRequest, errors.New("Invalid Channel Id")) + AbortWithError(c, http.StatusBadRequest, errors.New("Invalid Channel Id")) return } if channel.Status != model.ChannelStatusEnabled { - abortWithError(c, http.StatusForbidden, errors.New("The channel has been disabled")) + AbortWithError(c, http.StatusForbidden, errors.New("The channel has been disabled")) return } } else { @@ -54,7 +54,7 @@ func Distribute() func(c *gin.Context) { logger.SysError(fmt.Sprintf("Channel does not exist: %d", channel.Id)) message = "Database consistency has been broken, please contact the administrator" } - abortWithError(c, http.StatusServiceUnavailable, errors.New(message)) + AbortWithError(c, http.StatusServiceUnavailable, errors.New(message)) return } } diff --git a/middleware/utils.go b/middleware/utils.go index 7e445b34..7353af98 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -23,7 +23,8 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) { logger.Error(c.Request.Context(), message) } -func abortWithError(c *gin.Context, statusCode int, err error) { +// AbortWithError aborts the request with an error message +func AbortWithError(c *gin.Context, statusCode int, err error) { logger := gmw.GetLogger(c) logger.Error("server abort", zap.Error(err)) c.JSON(statusCode, gin.H{ diff --git a/model/ability.go b/model/ability.go index 5cfb9949..bc11a689 100644 --- a/model/ability.go +++ b/model/ability.go @@ -4,11 +4,13 @@ import ( "context" "sort" "strings" + "time" - "gorm.io/gorm" - + gutils "github.com/Laisky/go-utils/v5" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/utils" + "gorm.io/gorm" ) type Ability struct { @@ -110,3 +112,44 @@ func GetGroupModels(ctx context.Context, group string) ([]string, error) { sort.Strings(models) return models, err } + +var getGroupModelsV2Cache = gutils.NewExpCache[[]EnabledAbility](context.Background(), time.Second*10) + +type EnabledAbility struct { + Model string `json:"model" gorm:"model"` + ChannelType int `json:"channel_type" gorm:"channel_type"` +} + +// GetGroupModelsV2 returns all enabled models for this group with their channel names. +func GetGroupModelsV2(ctx context.Context, group string) ([]EnabledAbility, error) { + // get from cache first + if models, ok := getGroupModelsV2Cache.Load(group); ok { + return models, nil + } + + // prepare query based on database type + groupCol := "`group`" + trueVal := "1" + if common.UsingPostgreSQL { + groupCol = `"group"` + trueVal = "true" + } + + // query with JOIN to get model and channel name in a single query + var models []EnabledAbility + query := DB.Model(&Ability{}). + Select("abilities.model AS model, channels.type AS channel_type"). + Joins("JOIN channels ON abilities.channel_id = channels.id"). + Where("abilities."+groupCol+" = ? AND abilities.enabled = "+trueVal, group). + Order("abilities.priority DESC") + + err := query.Find(&models).Error + if err != nil { + return nil, errors.Wrap(err, "get group models") + } + + // store in cache + getGroupModelsV2Cache.Store(group, models) + + return models, nil +} diff --git a/model/cache.go b/model/cache.go index 997d8361..ddf9fcfa 100644 --- a/model/cache.go +++ b/model/cache.go @@ -4,17 +4,18 @@ import ( "context" "encoding/json" "fmt" - "github.com/pkg/errors" - "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/common/config" - "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/common/random" "math/rand" "sort" "strconv" "strings" "sync" "time" + + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/random" ) var ( diff --git a/model/main.go b/model/main.go index fcf2b3ae..264f4053 100644 --- a/model/main.go +++ b/model/main.go @@ -1,6 +1,7 @@ package model import ( + "context" "database/sql" "fmt" "os" @@ -118,6 +119,11 @@ func InitDB() { return } + if config.DebugSQLEnabled { + logger.Debug(context.TODO(), "debug sql enabled") + DB = DB.Debug() + } + sqlDB := setDBConns(DB) if !config.IsMasterNode { @@ -203,10 +209,6 @@ func migrateLOGDB() error { } func setDBConns(db *gorm.DB) *sql.DB { - if config.DebugSQLEnabled { - db = db.Debug() - } - sqlDB, err := db.DB() if err != nil { logger.FatalLog("failed to connect database: " + err.Error())