package middleware import ( "context" "fmt" "net/http" "one-api/common" "time" "github.com/gin-gonic/gin" ) var timeFormat = "2006-01-02T15:04:05.000Z" var inMemoryRateLimiter common.InMemoryRateLimiter // All duration's unit is seconds // Shouldn't larger then RateLimitKeyExpirationDuration var ( GlobalApiRateLimitNum = 180 GlobalApiRateLimitDuration int64 = 3 * 60 GlobalWebRateLimitNum = 100 GlobalWebRateLimitDuration int64 = 3 * 60 UploadRateLimitNum = 10 UploadRateLimitDuration int64 = 60 DownloadRateLimitNum = 10 DownloadRateLimitDuration int64 = 60 CriticalRateLimitNum = 20 CriticalRateLimitDuration int64 = 20 * 60 ) func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) { ctx := context.Background() rdb := common.RDB key := "rateLimit:" + mark + c.ClientIP() listLength, err := rdb.LLen(ctx, key).Result() if err != nil { fmt.Println(err.Error()) c.Status(http.StatusInternalServerError) c.Abort() return } if listLength < int64(maxRequestNum) { rdb.LPush(ctx, key, time.Now().Format(timeFormat)) rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) } else { oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() oldTime, err := time.Parse(timeFormat, oldTimeStr) if err != nil { fmt.Println(err) c.Status(http.StatusInternalServerError) c.Abort() return } nowTimeStr := time.Now().Format(timeFormat) nowTime, err := time.Parse(timeFormat, nowTimeStr) if err != nil { fmt.Println(err) c.Status(http.StatusInternalServerError) c.Abort() return } // time.Since will return negative number! // See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows if int64(nowTime.Sub(oldTime).Seconds()) < duration { rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) c.Status(http.StatusTooManyRequests) c.Abort() return } else { rdb.LPush(ctx, key, time.Now().Format(timeFormat)) rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1)) rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) } } } func memoryRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) { key := mark + c.ClientIP() if !inMemoryRateLimiter.Request(key, maxRequestNum, duration) { c.Status(http.StatusTooManyRequests) c.Abort() return } } func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gin.Context) { if common.RedisEnabled { return func(c *gin.Context) { redisRateLimiter(c, maxRequestNum, duration, mark) } } else { // It's safe to call multi times. inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration) return func(c *gin.Context) { memoryRateLimiter(c, maxRequestNum, duration, mark) } } } func GlobalWebRateLimit() func(c *gin.Context) { return rateLimitFactory(common.GetOrDefault("GLOBAL_WEB_RATE_LIMIT", GlobalWebRateLimitNum), GlobalWebRateLimitDuration, "GW") } func GlobalAPIRateLimit() func(c *gin.Context) { return rateLimitFactory(common.GetOrDefault("GLOBAL_API_RATE_LIMIT", GlobalApiRateLimitNum), GlobalApiRateLimitDuration, "GA") } func CriticalRateLimit() func(c *gin.Context) { return rateLimitFactory(CriticalRateLimitNum, CriticalRateLimitDuration, "CT") } func DownloadRateLimit() func(c *gin.Context) { return rateLimitFactory(DownloadRateLimitNum, DownloadRateLimitDuration, "DW") } func UploadRateLimit() func(c *gin.Context) { return rateLimitFactory(UploadRateLimitNum, UploadRateLimitDuration, "UP") }