mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-10 02:23:43 +08:00
fix: improve error handling in rate limiter and remove unnecessary logging
This commit is contained in:
@@ -1,13 +1,12 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
gmw "github.com/Laisky/gin-middlewares/v6"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/pkg/errors"
|
||||||
"github.com/songquanpeng/one-api/common"
|
"github.com/songquanpeng/one-api/common"
|
||||||
"github.com/songquanpeng/one-api/common/config"
|
"github.com/songquanpeng/one-api/common/config"
|
||||||
)
|
)
|
||||||
@@ -17,43 +16,44 @@ var timeFormat = "2006-01-02T15:04:05.000Z"
|
|||||||
var inMemoryRateLimiter common.InMemoryRateLimiter
|
var inMemoryRateLimiter common.InMemoryRateLimiter
|
||||||
|
|
||||||
func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) {
|
func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) {
|
||||||
ctx := context.Background()
|
ctx := gmw.Ctx(c)
|
||||||
|
|
||||||
rdb := common.RDB
|
rdb := common.RDB
|
||||||
key := "rateLimit:" + mark + c.ClientIP()
|
key := "rateLimit:" + mark + c.ClientIP()
|
||||||
listLength, err := rdb.LLen(ctx, key).Result()
|
listLength, err := rdb.LLen(ctx, key).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err.Error())
|
AbortWithError(c, http.StatusInternalServerError, errors.Wrap(err, "failed to get list length"))
|
||||||
c.Status(http.StatusInternalServerError)
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if listLength < int64(maxRequestNum) {
|
if listLength < int64(maxRequestNum) {
|
||||||
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
|
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
|
||||||
rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
|
rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
|
||||||
} else {
|
} else {
|
||||||
oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result()
|
oldTimeStr, err := rdb.LIndex(ctx, key, -1).Result()
|
||||||
oldTime, err := time.Parse(timeFormat, oldTimeStr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
AbortWithError(c, http.StatusInternalServerError, errors.Wrap(err, "failed to get old time"))
|
||||||
c.Status(http.StatusInternalServerError)
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
oldTime, err := time.Parse(timeFormat, oldTimeStr)
|
||||||
|
if err != nil {
|
||||||
|
AbortWithError(c, http.StatusInternalServerError, errors.Wrap(err, "failed to parse old time"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
nowTimeStr := time.Now().Format(timeFormat)
|
nowTimeStr := time.Now().Format(timeFormat)
|
||||||
nowTime, err := time.Parse(timeFormat, nowTimeStr)
|
nowTime, err := time.Parse(timeFormat, nowTimeStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
AbortWithError(c, http.StatusInternalServerError, errors.Wrap(err, "failed to parse now time"))
|
||||||
c.Status(http.StatusInternalServerError)
|
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// time.Since will return negative number!
|
// time.Since will return negative number!
|
||||||
// See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows
|
// See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows
|
||||||
if int64(nowTime.Sub(oldTime).Seconds()) < duration {
|
if int64(nowTime.Sub(oldTime).Seconds()) < duration {
|
||||||
rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
|
rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration)
|
||||||
c.Status(http.StatusTooManyRequests)
|
AbortWithError(c, http.StatusTooManyRequests, errors.New("rate limit exceeded"))
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
} else {
|
} else {
|
||||||
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
|
rdb.LPush(ctx, key, time.Now().Format(timeFormat))
|
||||||
rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1))
|
rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1))
|
||||||
@@ -65,8 +65,7 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st
|
|||||||
func memoryRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) {
|
func memoryRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark string) {
|
||||||
key := mark + c.ClientIP()
|
key := mark + c.ClientIP()
|
||||||
if !inMemoryRateLimiter.Request(key, maxRequestNum, duration) {
|
if !inMemoryRateLimiter.Request(key, maxRequestNum, duration) {
|
||||||
c.Status(http.StatusTooManyRequests)
|
AbortWithError(c, http.StatusTooManyRequests, errors.New("rate limit exceeded"))
|
||||||
c.Abort()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -424,8 +424,6 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
|
|||||||
data = strings.TrimPrefix(data, "data: ")
|
data = strings.TrimPrefix(data, "data: ")
|
||||||
data = strings.TrimSuffix(data, "\"")
|
data = strings.TrimSuffix(data, "\"")
|
||||||
|
|
||||||
fmt.Printf(">> gemini response: %s\n", data)
|
|
||||||
|
|
||||||
var geminiResponse ChatResponse
|
var geminiResponse ChatResponse
|
||||||
err := json.Unmarshal([]byte(data), &geminiResponse)
|
err := json.Unmarshal([]byte(data), &geminiResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -77,6 +77,7 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
|
|||||||
return len(tokenEncoder.Encode(text, nil, nil))
|
return len(tokenEncoder.Encode(text, nil, nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CountTokenMessages counts the number of tokens in a list of messages.
|
||||||
func CountTokenMessages(ctx context.Context,
|
func CountTokenMessages(ctx context.Context,
|
||||||
messages []model.Message, actualModel string) int {
|
messages []model.Message, actualModel string) int {
|
||||||
tokenEncoder := getTokenEncoder(actualModel)
|
tokenEncoder := getTokenEncoder(actualModel)
|
||||||
|
|||||||
Reference in New Issue
Block a user