diff --git a/relay/adaptor/common.go b/relay/adaptor/common.go index 8953d7a3..584759a8 100644 --- a/relay/adaptor/common.go +++ b/relay/adaptor/common.go @@ -1,13 +1,17 @@ package adaptor import ( + "bytes" + "context" "errors" "fmt" + "io" + "net/http" + "strings" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/client" "github.com/songquanpeng/one-api/relay/meta" - "io" - "net/http" ) func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) { @@ -18,12 +22,32 @@ func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta } } +func NewReusableRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) { + switch v := body.(type) { + case *bytes.Buffer, *bytes.Reader, *strings.Reader: + return http.NewRequestWithContext(ctx, method, url, v) // 标准库会自动填 GetBody + default: + data, err := io.ReadAll(body) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, method, url, bytes.NewReader(data)) + if err != nil { + return nil, err + } + req.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(data)), nil + } + return req, nil + } +} + func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { fullRequestURL, err := a.GetRequestURL(meta) if err != nil { return nil, fmt.Errorf("get request url failed: %w", err) } - req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) + req, err := NewReusableRequest(c.Request.Context(), c.Request.Method, fullRequestURL, requestBody) if err != nil { return nil, fmt.Errorf("new request failed: %w", err) }