mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-09-30 23:26:39 +08:00
81 lines
2.2 KiB
Go
81 lines
2.2 KiB
Go
package middleware
|
|
|
|
import (
|
|
"bytes"
|
|
"github.com/gin-gonic/gin"
|
|
"io"
|
|
"net/http"
|
|
"one-api/common"
|
|
"strconv"
|
|
"time"
|
|
)
|
|
|
|
func RetryHandler(group *gin.RouterGroup) gin.HandlerFunc {
|
|
var retryHandler gin.HandlerFunc
|
|
// 获取RetryHandler在当前HandlersChain的位置
|
|
index := len(group.Handlers) + 1
|
|
retryHandler = func(c *gin.Context) {
|
|
// Backup request
|
|
hasBody := c.Request.ContentLength > 0
|
|
backupHeader := c.Request.Header.Clone()
|
|
var backupBody []byte
|
|
var err error
|
|
if hasBody {
|
|
backupBody, err = io.ReadAll(c.Request.Body)
|
|
if err != nil {
|
|
abortWithMessage(c, http.StatusBadRequest, "Invalid request")
|
|
return
|
|
}
|
|
_ = c.Request.Body.Close()
|
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(backupBody))
|
|
}
|
|
|
|
// 获取 retryHandler 后续的中间件
|
|
// Get next handlers
|
|
nextHandlers := group.Handlers[index:]
|
|
|
|
// 加入Relay处理函数 c.Handler() => c.handlers.Last() => controller.Relay
|
|
// Add Relay handler
|
|
nextHandlers = append(nextHandlers, c.Handler())
|
|
|
|
// Retry
|
|
maxRetryStr := c.Query("retry")
|
|
maxRetry, err := strconv.Atoi(maxRetryStr)
|
|
if err != nil || maxRetryStr == "" || maxRetry < 0 || maxRetry > common.RetryTimes {
|
|
maxRetry = common.RetryTimes
|
|
}
|
|
retryDelay := time.Duration(common.RetryInterval) * time.Millisecond
|
|
for i := maxRetry; i >= 0; i-- {
|
|
c.Set("retry", i)
|
|
|
|
if i == maxRetry {
|
|
// 第一次请求, 直接执行使用c.Next()调用后续中间件, 防止直接使用handler 内部调用c.Next() 导致重复执行
|
|
// First request, execute next middleware
|
|
c.Next()
|
|
} else {
|
|
// 重试, 恢复请求头和请求体, 并执行后续中间件
|
|
// Retry, restore request and execute next middleware
|
|
c.Request.Header = backupHeader.Clone()
|
|
if hasBody {
|
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(backupBody))
|
|
}
|
|
for _, handler := range nextHandlers {
|
|
handler(c)
|
|
}
|
|
}
|
|
|
|
// If no errors, return
|
|
if len(c.Errors) == 0 {
|
|
return
|
|
}
|
|
// c.index 指向 AbortIndex 可以防止出错时重复执行后续中间件
|
|
c.Abort()
|
|
// If errors, retry after delay
|
|
time.Sleep(retryDelay)
|
|
// Clear errors to avoid confusion in next middleware
|
|
c.Errors = c.Errors[:0]
|
|
}
|
|
}
|
|
return retryHandler
|
|
}
|