diff --git a/common/redis.go b/common/redis.go index 55d4931c..317e224e 100644 --- a/common/redis.go +++ b/common/redis.go @@ -7,6 +7,7 @@ import ( "time" "github.com/go-redis/redis/v8" + "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" ) @@ -20,7 +21,7 @@ func InitRedisClient() (err error) { logger.SysLog("REDIS_CONN_STRING not set, Redis is not enabled") return nil } - if os.Getenv("SYNC_FREQUENCY") == "" { + if config.SyncFrequency == 0 { RedisEnabled = false logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled") return nil diff --git a/controller/model.go b/controller/model.go index e0502060..4803e673 100644 --- a/controller/model.go +++ b/controller/model.go @@ -152,7 +152,7 @@ func ListModels(c *gin.Context) { } // Get available models with their channel names - availableAbilities, err := model.GetGroupModelsV2(c.Request.Context(), userGroup) + availableAbilities, err := model.CacheGetGroupModelsV2(c.Request.Context(), userGroup) if err != nil { middleware.AbortWithError(c, http.StatusBadRequest, err) return @@ -217,7 +217,8 @@ func GetUserAvailableModels(c *gin.Context) { }) return } - models, err := model.CacheGetGroupModels(ctx, userGroup) + + models, err := model.CacheGetGroupModelsV2(ctx, userGroup) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -225,10 +226,20 @@ func GetUserAvailableModels(c *gin.Context) { }) return } + + var modelNames []string + modelsMap := map[string]bool{} + for _, model := range models { + modelsMap[model.Model] = true + } + for modelName := range modelsMap { + modelNames = append(modelNames, modelName) + } + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", - "data": models, + "data": modelNames, }) return } diff --git a/model/cache.go b/model/cache.go index ddf9fcfa..cdbbbd53 100644 --- a/model/cache.go +++ b/model/cache.go @@ -150,7 +150,10 @@ func CacheIsUserEnabled(userId int) (bool, error) { return userEnabled, err } -func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) { +// CacheGetGroupModels returns models of a group +// +// Deprecated: use CacheGetGroupModelsV2 instead +func CacheGetGroupModels(ctx context.Context, group string) (models []string, err error) { if !common.RedisEnabled { return GetGroupModels(ctx, group) } @@ -158,7 +161,7 @@ func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) { if err == nil { return strings.Split(modelsStr, ","), nil } - models, err := GetGroupModels(ctx, group) + models, err = GetGroupModels(ctx, group) if err != nil { return nil, err } @@ -169,6 +172,42 @@ func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) { return models, nil } +// CacheGetGroupModelsV2 is a version of CacheGetGroupModels that returns EnabledAbility instead of string +func CacheGetGroupModelsV2(ctx context.Context, group string) (models []EnabledAbility, err error) { + if !common.RedisEnabled { + return GetGroupModelsV2(ctx, group) + } + modelsStr, err := common.RedisGet(fmt.Sprintf("group_models_v2:%s", group)) + if err != nil { + logger.Warnf(ctx, "Redis get group models error: %+v", err) + } else { + if err = json.Unmarshal([]byte(modelsStr), &models); err != nil { + logger.Warnf(ctx, "Redis get group models error: %+v", err) + } else { + return models, nil + } + } + + models, err = GetGroupModelsV2(ctx, group) + if err != nil { + return nil, errors.Wrap(err, "get group models") + } + + cachePayload, err := json.Marshal(models) + if err != nil { + logger.SysError("Redis set group models error: " + err.Error()) + return models, nil + } + + err = common.RedisSet(fmt.Sprintf("group_models:%s", group), string(cachePayload), + time.Duration(GroupModelsCacheSeconds)*time.Second) + if err != nil { + logger.SysError("Redis set group models error: " + err.Error()) + } + + return models, nil +} + var group2model2channels map[string]map[string][]*Channel var channelSyncLock sync.RWMutex