mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-21 15:36:49 +08:00
fix: Change the implementation of retry from 307 redirect to Retry middleware
This commit is contained in:
98
middleware/retry.go
Normal file
98
middleware/retry.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
type OpenAIErrorWithStatusCode struct {
|
||||
OpenAIError
|
||||
StatusCode int `json:"status_code"`
|
||||
}
|
||||
|
||||
type OpenAIError struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Param string `json:"param"`
|
||||
Code any `json:"code"`
|
||||
}
|
||||
|
||||
func RetryHandler(group *gin.RouterGroup) gin.HandlerFunc {
|
||||
var retryHandler gin.HandlerFunc
|
||||
// 获取RetryHandler在当前HandlersChain的位置
|
||||
index := len(group.Handlers) + 1
|
||||
retryHandler = func(c *gin.Context) {
|
||||
// Backup request
|
||||
hasBody := c.Request.ContentLength > 0
|
||||
backupHeader := c.Request.Header.Clone()
|
||||
var backupBody []byte
|
||||
var err error
|
||||
if hasBody {
|
||||
backupBody, err = io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
abortWithMessage(c, http.StatusBadRequest, "Invalid request")
|
||||
return
|
||||
}
|
||||
_ = c.Request.Body.Close()
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(backupBody))
|
||||
}
|
||||
|
||||
// 获取 retryHandler 后续的中间件
|
||||
// Get next handlers
|
||||
nextHandlers := group.Handlers[index:]
|
||||
|
||||
// 加入Relay处理函数 c.Handler() => c.handlers.Last() => controller.Relay
|
||||
// Add Relay handler
|
||||
nextHandlers = append(nextHandlers, c.Handler())
|
||||
|
||||
// Retry
|
||||
maxRetryStr := c.Query("retry")
|
||||
maxRetry, err := strconv.Atoi(maxRetryStr)
|
||||
if err != nil || maxRetryStr == "" || maxRetry < 0 || maxRetry > common.RetryTimes {
|
||||
maxRetry = common.RetryTimes
|
||||
}
|
||||
retryDelay := time.Duration(common.RetryInterval) * time.Millisecond
|
||||
var openaiErr *OpenAIErrorWithStatusCode
|
||||
for i := 0; i < maxRetry; i++ {
|
||||
if i == 0 {
|
||||
// 第一次请求, 直接执行使用c.Next()调用后续中间件, 防止直接使用handler 内部调用c.Next() 导致重复执行
|
||||
// First request, execute next middleware
|
||||
c.Next()
|
||||
fmt.Println("c.Next()")
|
||||
} else {
|
||||
// Clear errors to avoid confusion in next middleware
|
||||
c.Errors = c.Errors[:0]
|
||||
// 重试, 恢复请求头和请求体, 并执行后续中间件
|
||||
// Retry, restore request and execute next middleware
|
||||
c.Request.Header = backupHeader.Clone()
|
||||
if hasBody {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(backupBody))
|
||||
}
|
||||
for _, handler := range nextHandlers {
|
||||
handler(c)
|
||||
}
|
||||
}
|
||||
|
||||
// If no errors, return
|
||||
if len(c.Errors) == 0 {
|
||||
return
|
||||
}
|
||||
// c.index 指向 AbortIndex 可以防止出错时重复执行后续中间件
|
||||
c.Abort()
|
||||
// If errors, retry after delay
|
||||
time.Sleep(retryDelay)
|
||||
}
|
||||
_ = json.Unmarshal([]byte(c.Errors.Last().Error()), &openaiErr)
|
||||
c.JSON(openaiErr.StatusCode, gin.H{
|
||||
"error": openaiErr.OpenAIError,
|
||||
})
|
||||
}
|
||||
return retryHandler
|
||||
}
|
||||
Reference in New Issue
Block a user