one-api/middleware/retry.go

99 lines
2.7 KiB
Go

package middleware
import (
"bytes"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"strconv"
"time"
)
type OpenAIErrorWithStatusCode struct {
OpenAIError
StatusCode int `json:"status_code"`
}
type OpenAIError struct {
Message string `json:"message"`
Type string `json:"type"`
Param string `json:"param"`
Code any `json:"code"`
}
func RetryHandler(group *gin.RouterGroup) gin.HandlerFunc {
var retryHandler gin.HandlerFunc
// 获取RetryHandler在当前HandlersChain的位置
index := len(group.Handlers) + 1
retryHandler = func(c *gin.Context) {
// Backup request
hasBody := c.Request.ContentLength > 0
backupHeader := c.Request.Header.Clone()
var backupBody []byte
var err error
if hasBody {
backupBody, err = io.ReadAll(c.Request.Body)
if err != nil {
abortWithMessage(c, http.StatusBadRequest, "Invalid request")
return
}
_ = c.Request.Body.Close()
c.Request.Body = io.NopCloser(bytes.NewBuffer(backupBody))
}
// 获取 retryHandler 后续的中间件
// Get next handlers
nextHandlers := group.Handlers[index:]
// 加入Relay处理函数 c.Handler() => c.handlers.Last() => controller.Relay
// Add Relay handler
nextHandlers = append(nextHandlers, c.Handler())
// Retry
maxRetryStr := c.Query("retry")
maxRetry, err := strconv.Atoi(maxRetryStr)
if err != nil || maxRetryStr == "" || maxRetry < 0 || maxRetry > common.RetryTimes {
maxRetry = common.RetryTimes
}
retryDelay := time.Duration(common.RetryInterval) * time.Millisecond
var openaiErr *OpenAIErrorWithStatusCode
for i := 0; i < maxRetry; i++ {
if i == 0 {
// 第一次请求, 直接执行使用c.Next()调用后续中间件, 防止直接使用handler 内部调用c.Next() 导致重复执行
// First request, execute next middleware
c.Next()
fmt.Println("c.Next()")
} else {
// Clear errors to avoid confusion in next middleware
c.Errors = c.Errors[:0]
// 重试, 恢复请求头和请求体, 并执行后续中间件
// Retry, restore request and execute next middleware
c.Request.Header = backupHeader.Clone()
if hasBody {
c.Request.Body = io.NopCloser(bytes.NewBuffer(backupBody))
}
for _, handler := range nextHandlers {
handler(c)
}
}
// If no errors, return
if len(c.Errors) == 0 {
return
}
// c.index 指向 AbortIndex 可以防止出错时重复执行后续中间件
c.Abort()
// If errors, retry after delay
time.Sleep(retryDelay)
}
_ = json.Unmarshal([]byte(c.Errors.Last().Error()), &openaiErr)
c.JSON(openaiErr.StatusCode, gin.H{
"error": openaiErr.OpenAIError,
})
}
return retryHandler
}