one-api/middleware/retry.go

81 lines
2.2 KiB
Go

package middleware
import (
"bytes"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"strconv"
"time"
)
func RetryHandler(group *gin.RouterGroup) gin.HandlerFunc {
var retryHandler gin.HandlerFunc
// 获取RetryHandler在当前HandlersChain的位置
index := len(group.Handlers) + 1
retryHandler = func(c *gin.Context) {
// Backup request
hasBody := c.Request.ContentLength > 0
backupHeader := c.Request.Header.Clone()
var backupBody []byte
var err error
if hasBody {
backupBody, err = io.ReadAll(c.Request.Body)
if err != nil {
abortWithMessage(c, http.StatusBadRequest, "Invalid request")
return
}
_ = c.Request.Body.Close()
c.Request.Body = io.NopCloser(bytes.NewBuffer(backupBody))
}
// 获取 retryHandler 后续的中间件
// Get next handlers
nextHandlers := group.Handlers[index:]
// 加入Relay处理函数 c.Handler() => c.handlers.Last() => controller.Relay
// Add Relay handler
nextHandlers = append(nextHandlers, c.Handler())
// Retry
maxRetryStr := c.Query("retry")
maxRetry, err := strconv.Atoi(maxRetryStr)
if err != nil || maxRetryStr == "" || maxRetry < 0 || maxRetry > common.RetryTimes {
maxRetry = common.RetryTimes
}
retryDelay := time.Duration(common.RetryInterval) * time.Millisecond
for i := maxRetry; i >= 0; i-- {
c.Set("retry", i)
if i == maxRetry {
// 第一次请求, 直接执行使用c.Next()调用后续中间件, 防止直接使用handler 内部调用c.Next() 导致重复执行
// First request, execute next middleware
c.Next()
} else {
// 重试, 恢复请求头和请求体, 并执行后续中间件
// Retry, restore request and execute next middleware
c.Request.Header = backupHeader.Clone()
if hasBody {
c.Request.Body = io.NopCloser(bytes.NewBuffer(backupBody))
}
for _, handler := range nextHandlers {
handler(c)
}
}
// If no errors, return
if len(c.Errors) == 0 {
return
}
// c.index 指向 AbortIndex 可以防止出错时重复执行后续中间件
c.Abort()
// If errors, retry after delay
time.Sleep(retryDelay)
// Clear errors to avoid confusion in next middleware
c.Errors = c.Errors[:0]
}
}
return retryHandler
}