From 54635ca4fe6b27e7860a3fb0292c222d6855d274 Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Wed, 26 Feb 2025 11:22:03 +0000 Subject: [PATCH] fix: models api return models in deactivate channels - Enhance logging functionality by adding context support and improving debugging options. - Standardize function naming conventions across middleware to ensure consistency. - Optimize data retrieval and handling in the model controller, including caching and error management. - Simplify the bug report template to streamline the issue reporting process. --- .github/ISSUE_TEMPLATE/bug_report.md | 21 +++--- controller/model.go | 102 +++++++++++++++++---------- middleware/auth.go | 21 +++--- middleware/distributor.go | 10 +-- middleware/utils.go | 15 +++- model/ability.go | 35 ++++++++- model/cache.go | 54 ++++++++++++-- model/main.go | 17 +++-- 8 files changed, 194 insertions(+), 81 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index dd688493..b1633905 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -1,20 +1,20 @@ --- name: 报告问题 about: 使用简练详细的语言描述你遇到的问题 -title: '' +title: "" labels: bug -assignees: '' - +assignees: "" --- **例行检查** -[//]: # (方框内删除已有的空格,填 x 号) -+ [ ] 我已确认目前没有类似 issue -+ [ ] 我已确认我已升级到最新版本 -+ [ ] 我已完整查看过项目 README,尤其是常见问题部分 -+ [ ] 我理解并愿意跟进此 issue,协助测试和提供反馈 -+ [ ] 我理解并认可上述内容,并理解项目维护者精力有限,**不遵循规则的 issue 可能会被无视或直接关闭** +[//]: # "方框内删除已有的空格,填 x 号" + +- [ ] 我已确认目前没有类似 issue +- [ ] 我已确认我已升级到最新版本 +- [ ] 我已完整查看过项目 README,尤其是常见问题部分 +- [ ] 我理解并愿意跟进此 issue,协助测试和提供反馈 +- [ ] 我理解并认可上述内容,并理解项目维护者精力有限,**不遵循规则的 issue 可能会被无视或直接关闭** **问题描述** @@ -23,4 +23,5 @@ assignees: '' **预期结果** **相关截图** -如果没有的话,请删除此节。 \ No newline at end of file + +如果没有的话,请删除此节。 diff --git a/controller/model.go b/controller/model.go index dcbe709e..8b42cd2a 100644 --- a/controller/model.go +++ b/controller/model.go @@ -2,8 +2,12 @@ package controller import ( "fmt" + "net/http" + "sort" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/middleware" "github.com/songquanpeng/one-api/model" relay "github.com/songquanpeng/one-api/relay" "github.com/songquanpeng/one-api/relay/adaptor/openai" @@ -11,8 +15,6 @@ import ( "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" - "net/http" - "strings" ) // https://platform.openai.com/docs/api-reference/models/list @@ -33,16 +35,20 @@ type OpenAIModelPermission struct { } type OpenAIModels struct { - Id string `json:"id"` - Object string `json:"object"` - Created int `json:"created"` + // Id model's name + // + // BUG: Different channels may have the same model name + Id string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + // OwnedBy is the channel's adaptor name OwnedBy string `json:"owned_by"` Permission []OpenAIModelPermission `json:"permission"` Root string `json:"root"` Parent *string `json:"parent"` } -var models []OpenAIModels +var allModels []OpenAIModels var modelsMap map[string]OpenAIModels var channelId2Models map[int][]string @@ -71,7 +77,7 @@ func init() { channelName := adaptor.GetChannelName() modelNames := adaptor.GetModelList() for _, modelName := range modelNames { - models = append(models, OpenAIModels{ + allModels = append(allModels, OpenAIModels{ Id: modelName, Object: "model", Created: 1626777600, @@ -88,7 +94,7 @@ func init() { } channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType) for _, modelName := range channelModelList { - models = append(models, OpenAIModels{ + allModels = append(allModels, OpenAIModels{ Id: modelName, Object: "model", Created: 1626777600, @@ -100,7 +106,7 @@ func init() { } } modelsMap = make(map[string]OpenAIModels) - for _, model := range models { + for _, model := range allModels { modelsMap[model.Id] = model } channelId2Models = make(map[int][]string) @@ -125,49 +131,56 @@ func DashboardListModels(c *gin.Context) { func ListAllModels(c *gin.Context) { c.JSON(200, gin.H{ "object": "list", - "data": models, + "data": allModels, }) } func ListModels(c *gin.Context) { - ctx := c.Request.Context() - var availableModels []string - if c.GetString(ctxkey.AvailableModels) != "" { - availableModels = strings.Split(c.GetString(ctxkey.AvailableModels), ",") - } else { - userId := c.GetInt(ctxkey.Id) - userGroup, _ := model.CacheGetUserGroup(userId) - availableModels, _ = model.CacheGetGroupModels(ctx, userGroup) + userId := c.GetInt(ctxkey.Id) + userGroup, err := model.CacheGetUserGroup(userId) + if err != nil { + middleware.AbortWithError(c, http.StatusBadRequest, err) + return } - modelSet := make(map[string]bool) - for _, availableModel := range availableModels { - modelSet[availableModel] = true + + // Get available models with their channel names + availableAbilities, err := model.CacheGetGroupModelsV2(c.Request.Context(), userGroup) + if err != nil { + middleware.AbortWithError(c, http.StatusBadRequest, err) + return } + + // Create a map for quick lookup of enabled model+channel combinations + // Only store the exact model:channel combinations from abilities + abilityMap := make(map[string]bool) + for _, ability := range availableAbilities { + // Store as "modelName:channelName" for exact matching + adaptor := relay.GetAdaptor(channeltype.ToAPIType(ability.ChannelType)) + key := ability.Model + ":" + adaptor.GetChannelName() + abilityMap[key] = true + } + + // Filter models that match user's abilities with EXACT model+channel matches availableOpenAIModels := make([]OpenAIModels, 0) - for _, model := range models { - if _, ok := modelSet[model.Id]; ok { - modelSet[model.Id] = false + + // Only include models that have a matching model+channel combination + for _, model := range allModels { + key := model.Id + ":" + model.OwnedBy + if abilityMap[key] { availableOpenAIModels = append(availableOpenAIModels, model) } } - for modelName, ok := range modelSet { - if ok { - availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{ - Id: modelName, - Object: "model", - Created: 1626777600, - OwnedBy: "custom", - Root: modelName, - Parent: nil, - }) - } - } + + // Sort models alphabetically for consistent presentation + sort.Slice(availableOpenAIModels, func(i, j int) bool { + return availableOpenAIModels[i].Id < availableOpenAIModels[j].Id + }) + c.JSON(200, gin.H{ "object": "list", "data": availableOpenAIModels, }) } - func RetrieveModel(c *gin.Context) { modelId := c.Param("model") if model, ok := modelsMap[modelId]; ok { @@ -196,7 +209,8 @@ func GetUserAvailableModels(c *gin.Context) { }) return } - models, err := model.CacheGetGroupModels(ctx, userGroup) + + models, err := model.CacheGetGroupModelsV2(ctx, userGroup) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -204,10 +218,20 @@ func GetUserAvailableModels(c *gin.Context) { }) return } + + var modelNames []string + modelsMap := map[string]bool{} + for _, model := range models { + modelsMap[model.Model] = true + } + for modelName := range modelsMap { + modelNames = append(modelNames, modelName) + } + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", - "data": models, + "data": modelNames, }) return } diff --git a/middleware/auth.go b/middleware/auth.go index e0019838..6853232b 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -1,15 +1,16 @@ package middleware import ( - "fmt" + "net/http" + "strings" + "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common/blacklist" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/network" "github.com/songquanpeng/one-api/model" - "net/http" - "strings" ) func authHelper(c *gin.Context, minRole int) { @@ -98,34 +99,34 @@ func TokenAuth() func(c *gin.Context) { key = parts[0] token, err := model.ValidateUserToken(key) if err != nil { - abortWithMessage(c, http.StatusUnauthorized, err.Error()) + AbortWithError(c, http.StatusUnauthorized, err) return } if token.Subnet != nil && *token.Subnet != "" { if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) { - abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌只能在指定网段使用:%s,当前 ip:%s", *token.Subnet, c.ClientIP())) + AbortWithError(c, http.StatusForbidden, errors.Errorf("该令牌只能在指定网段使用:%s,当前 ip:%s", *token.Subnet, c.ClientIP())) return } } userEnabled, err := model.CacheIsUserEnabled(token.UserId) if err != nil { - abortWithMessage(c, http.StatusInternalServerError, err.Error()) + AbortWithError(c, http.StatusInternalServerError, err) return } if !userEnabled || blacklist.IsUserBanned(token.UserId) { - abortWithMessage(c, http.StatusForbidden, "用户已被封禁") + AbortWithError(c, http.StatusForbidden, errors.New("用户已被封禁")) return } requestModel, err := getRequestModel(c) if err != nil && shouldCheckModel(c) { - abortWithMessage(c, http.StatusBadRequest, err.Error()) + AbortWithError(c, http.StatusBadRequest, err) return } c.Set(ctxkey.RequestModel, requestModel) if token.Models != nil && *token.Models != "" { c.Set(ctxkey.AvailableModels, *token.Models) if requestModel != "" && !isModelInList(requestModel, *token.Models) { - abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel)) + AbortWithError(c, http.StatusForbidden, errors.Errorf("该令牌无权使用模型: %s", requestModel)) return } } @@ -136,7 +137,7 @@ func TokenAuth() func(c *gin.Context) { if model.IsAdmin(token.UserId) { c.Set(ctxkey.SpecificChannelId, parts[1]) } else { - abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") + AbortWithError(c, http.StatusForbidden, errors.New("普通用户不支持指定渠道")) return } } diff --git a/middleware/distributor.go b/middleware/distributor.go index 58ae0556..4015504d 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -6,7 +6,7 @@ import ( "strconv" "github.com/gin-gonic/gin" - + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" @@ -29,16 +29,16 @@ func Distribute() func(c *gin.Context) { if ok { id, err := strconv.Atoi(channelId.(string)) if err != nil { - abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") + AbortWithError(c, http.StatusBadRequest, errors.New("无效的渠道 Id")) return } channel, err = model.GetChannelById(id, true) if err != nil { - abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") + AbortWithError(c, http.StatusBadRequest, errors.New("无效的渠道 Id")) return } if channel.Status != model.ChannelStatusEnabled { - abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用") + AbortWithError(c, http.StatusForbidden, errors.New("该渠道已被禁用")) return } } else { @@ -51,7 +51,7 @@ func Distribute() func(c *gin.Context) { logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) message = "数据库一致性已被破坏,请联系管理员" } - abortWithMessage(c, http.StatusServiceUnavailable, message) + AbortWithError(c, http.StatusServiceUnavailable, errors.New(message)) return } } diff --git a/middleware/utils.go b/middleware/utils.go index 4d2f8092..ad5646a9 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -2,11 +2,12 @@ package middleware import ( "fmt" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" - "strings" ) func abortWithMessage(c *gin.Context, statusCode int, message string) { @@ -20,6 +21,18 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) { logger.Error(c.Request.Context(), message) } +// AbortWithError aborts the request with an error message +func AbortWithError(c *gin.Context, statusCode int, err error) { + logger.Errorf(c, "server abort: %+v", err) + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "message": helper.MessageWithRequestId(err.Error(), c.GetString(helper.RequestIdKey)), + "type": "one_api_error", + }, + }) + c.Abort() +} + func getRequestModel(c *gin.Context) (string, error) { var modelRequest ModelRequest err := common.UnmarshalBodyReusable(c, &modelRequest) diff --git a/model/ability.go b/model/ability.go index 5cfb9949..22d7f423 100644 --- a/model/ability.go +++ b/model/ability.go @@ -5,10 +5,10 @@ import ( "sort" "strings" - "gorm.io/gorm" - + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/utils" + "gorm.io/gorm" ) type Ability struct { @@ -110,3 +110,34 @@ func GetGroupModels(ctx context.Context, group string) ([]string, error) { sort.Strings(models) return models, err } + +type EnabledAbility struct { + Model string `json:"model" gorm:"model"` + ChannelType int `json:"channel_type" gorm:"channel_type"` +} + +// GetGroupModelsV2 returns all enabled models for this group with their channel names. +func GetGroupModelsV2(ctx context.Context, group string) ([]EnabledAbility, error) { + // prepare query based on database type + groupCol := "`group`" + trueVal := "1" + if common.UsingPostgreSQL { + groupCol = `"group"` + trueVal = "true" + } + + // query with JOIN to get model and channel name in a single query + var models []EnabledAbility + query := DB.Model(&Ability{}). + Select("abilities.model AS model, channels.type AS channel_type"). + Joins("JOIN channels ON abilities.channel_id = channels.id"). + Where("abilities."+groupCol+" = ? AND abilities.enabled = "+trueVal, group). + Order("abilities.priority DESC") + + err := query.Find(&models).Error + if err != nil { + return nil, errors.Wrap(err, "get group models") + } + + return models, nil +} diff --git a/model/cache.go b/model/cache.go index cfb0f8a4..a92e6d66 100644 --- a/model/cache.go +++ b/model/cache.go @@ -3,18 +3,19 @@ package model import ( "context" "encoding/json" - "errors" "fmt" - "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/common/config" - "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/common/random" "math/rand" "sort" "strconv" "strings" "sync" "time" + + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/random" ) var ( @@ -148,7 +149,10 @@ func CacheIsUserEnabled(userId int) (bool, error) { return userEnabled, err } -func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) { +// CacheGetGroupModels returns models of a group +// +// Deprecated: use CacheGetGroupModelsV2 instead +func CacheGetGroupModels(ctx context.Context, group string) (models []string, err error) { if !common.RedisEnabled { return GetGroupModels(ctx, group) } @@ -156,7 +160,7 @@ func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) { if err == nil { return strings.Split(modelsStr, ","), nil } - models, err := GetGroupModels(ctx, group) + models, err = GetGroupModels(ctx, group) if err != nil { return nil, err } @@ -167,6 +171,42 @@ func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) { return models, nil } +// CacheGetGroupModelsV2 is a version of CacheGetGroupModels that returns EnabledAbility instead of string +func CacheGetGroupModelsV2(ctx context.Context, group string) (models []EnabledAbility, err error) { + if !common.RedisEnabled { + return GetGroupModelsV2(ctx, group) + } + modelsStr, err := common.RedisGet(fmt.Sprintf("group_models_v2:%s", group)) + if err != nil { + logger.Warnf(ctx, "Redis get group models error: %+v", err) + } else { + if err = json.Unmarshal([]byte(modelsStr), &models); err != nil { + logger.Warnf(ctx, "Redis get group models error: %+v", err) + } else { + return models, nil + } + } + + models, err = GetGroupModelsV2(ctx, group) + if err != nil { + return nil, errors.Wrap(err, "get group models") + } + + cachePayload, err := json.Marshal(models) + if err != nil { + logger.SysError("Redis set group models error: " + err.Error()) + return models, nil + } + + err = common.RedisSet(fmt.Sprintf("group_models:%s", group), string(cachePayload), + time.Duration(GroupModelsCacheSeconds)*time.Second) + if err != nil { + logger.SysError("Redis set group models error: " + err.Error()) + } + + return models, nil +} + var group2model2channels map[string]map[string][]*Channel var channelSyncLock sync.RWMutex diff --git a/model/main.go b/model/main.go index 72e271a0..c8aefdb2 100644 --- a/model/main.go +++ b/model/main.go @@ -1,8 +1,13 @@ package model import ( + "context" "database/sql" "fmt" + "os" + "strings" + "time" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/env" @@ -13,9 +18,6 @@ import ( "gorm.io/driver/postgres" "gorm.io/driver/sqlite" "gorm.io/gorm" - "os" - "strings" - "time" ) var DB *gorm.DB @@ -116,6 +118,11 @@ func InitDB() { return } + if config.DebugSQLEnabled { + logger.Debug(context.TODO(), "debug sql enabled") + DB = DB.Debug() + } + sqlDB := setDBConns(DB) if !config.IsMasterNode { @@ -201,10 +208,6 @@ func migrateLOGDB() error { } func setDBConns(db *gorm.DB) *sql.DB { - if config.DebugSQLEnabled { - db = db.Debug() - } - sqlDB, err := db.DB() if err != nil { logger.FatalLog("failed to connect database: " + err.Error())