package middleware import ( "bytes" "encoding/json" "fmt" "github.com/gin-gonic/gin" "io" "net/http" "one-api/common" "strconv" "time" ) type OpenAIErrorWithStatusCode struct { OpenAIError StatusCode int `json:"status_code"` } type OpenAIError struct { Message string `json:"message"` Type string `json:"type"` Param string `json:"param"` Code any `json:"code"` } func RetryHandler(group *gin.RouterGroup) gin.HandlerFunc { var retryHandler gin.HandlerFunc // 获取RetryHandler在当前HandlersChain的位置 index := len(group.Handlers) + 1 retryHandler = func(c *gin.Context) { // Backup request hasBody := c.Request.ContentLength > 0 backupHeader := c.Request.Header.Clone() var backupBody []byte var err error if hasBody { backupBody, err = io.ReadAll(c.Request.Body) if err != nil { abortWithMessage(c, http.StatusBadRequest, "Invalid request") return } _ = c.Request.Body.Close() c.Request.Body = io.NopCloser(bytes.NewBuffer(backupBody)) } // 获取 retryHandler 后续的中间件 // Get next handlers nextHandlers := group.Handlers[index:] // 加入Relay处理函数 c.Handler() => c.handlers.Last() => controller.Relay // Add Relay handler nextHandlers = append(nextHandlers, c.Handler()) // Retry maxRetryStr := c.Query("retry") maxRetry, err := strconv.Atoi(maxRetryStr) if err != nil || maxRetryStr == "" || maxRetry < 0 || maxRetry > common.RetryTimes { maxRetry = common.RetryTimes } retryDelay := time.Duration(common.RetryInterval) * time.Millisecond var openaiErr *OpenAIErrorWithStatusCode for i := 0; i < maxRetry; i++ { if i == 0 { // 第一次请求, 直接执行使用c.Next()调用后续中间件, 防止直接使用handler 内部调用c.Next() 导致重复执行 // First request, execute next middleware c.Next() fmt.Println("c.Next()") } else { // Clear errors to avoid confusion in next middleware c.Errors = c.Errors[:0] // 重试, 恢复请求头和请求体, 并执行后续中间件 // Retry, restore request and execute next middleware c.Request.Header = backupHeader.Clone() if hasBody { c.Request.Body = io.NopCloser(bytes.NewBuffer(backupBody)) } for _, handler := range nextHandlers { handler(c) } } // If no errors, return if len(c.Errors) == 0 { return } // c.index 指向 AbortIndex 可以防止出错时重复执行后续中间件 c.Abort() // If errors, retry after delay time.Sleep(retryDelay) } _ = json.Unmarshal([]byte(c.Errors.Last().Error()), &openaiErr) c.JSON(openaiErr.StatusCode, gin.H{ "error": openaiErr.OpenAIError, }) } return retryHandler }