diff --git a/middleware/rate-limit.go b/middleware/rate-limit.go index 63d7d549..38a2f97c 100644 --- a/middleware/rate-limit.go +++ b/middleware/rate-limit.go @@ -1,13 +1,12 @@ package middleware import ( - "context" - "fmt" "net/http" "time" + gmw "github.com/Laisky/gin-middlewares/v6" "github.com/gin-gonic/gin" - + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" ) @@ -17,43 +16,44 @@ var timeFormat = "2006-01-02T15:04:05.000Z" var inMemoryRateLimiter common.InMemoryRateLimiter func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) { - ctx := context.Background() + ctx := gmw.Ctx(c) + rdb := common.RDB key := "rateLimit:" + mark + c.ClientIP() listLength, err := rdb.LLen(ctx, key).Result() if err != nil { - fmt.Println(err.Error()) - c.Status(http.StatusInternalServerError) - c.Abort() + AbortWithError(c, http.StatusInternalServerError, errors.Wrap(err, "failed to get list length")) return } + if listLength < int64(maxRequestNum) { rdb.LPush(ctx, key, time.Now().Format(timeFormat)) rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration) } else { - oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() - oldTime, err := time.Parse(timeFormat, oldTimeStr) + oldTimeStr, err := rdb.LIndex(ctx, key, -1).Result() if err != nil { - fmt.Println(err) - c.Status(http.StatusInternalServerError) - c.Abort() + AbortWithError(c, http.StatusInternalServerError, errors.Wrap(err, "failed to get old time")) return } + + oldTime, err := time.Parse(timeFormat, oldTimeStr) + if err != nil { + AbortWithError(c, http.StatusInternalServerError, errors.Wrap(err, "failed to parse old time")) + return + } + nowTimeStr := time.Now().Format(timeFormat) nowTime, err := time.Parse(timeFormat, nowTimeStr) if err != nil { - fmt.Println(err) - c.Status(http.StatusInternalServerError) - c.Abort() + AbortWithError(c, http.StatusInternalServerError, errors.Wrap(err, "failed to parse now time")) return } + // time.Since will return negative number! // See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows if int64(nowTime.Sub(oldTime).Seconds()) < duration { rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration) - c.Status(http.StatusTooManyRequests) - c.Abort() - return + AbortWithError(c, http.StatusTooManyRequests, errors.New("rate limit exceeded")) } else { rdb.LPush(ctx, key, time.Now().Format(timeFormat)) rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1)) @@ -65,8 +65,7 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st func memoryRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) { key := mark + c.ClientIP() if !inMemoryRateLimiter.Request(key, maxRequestNum, duration) { - c.Status(http.StatusTooManyRequests) - c.Abort() + AbortWithError(c, http.StatusTooManyRequests, errors.New("rate limit exceeded")) return } } diff --git a/relay/adaptor/gemini/main.go b/relay/adaptor/gemini/main.go index 498741fc..ce9c42d0 100644 --- a/relay/adaptor/gemini/main.go +++ b/relay/adaptor/gemini/main.go @@ -424,8 +424,6 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC data = strings.TrimPrefix(data, "data: ") data = strings.TrimSuffix(data, "\"") - fmt.Printf(">> gemini response: %s\n", data) - var geminiResponse ChatResponse err := json.Unmarshal([]byte(data), &geminiResponse) if err != nil { diff --git a/relay/adaptor/openai/token.go b/relay/adaptor/openai/token.go index fba3f198..af4ba0e5 100644 --- a/relay/adaptor/openai/token.go +++ b/relay/adaptor/openai/token.go @@ -77,6 +77,7 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { return len(tokenEncoder.Encode(text, nil, nil)) } +// CountTokenMessages counts the number of tokens in a list of messages. func CountTokenMessages(ctx context.Context, messages []model.Message, actualModel string) int { tokenEncoder := getTokenEncoder(actualModel)