fix: 限定下只允许注入 x- 开头的 headers

This commit is contained in:
CaiCandong 2025-03-20 23:40:13 +08:00
parent 27cb506900
commit 819b5807a8
2 changed files with 8 additions and 8 deletions

View File

@ -3,16 +3,17 @@ package adaptor
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/client"
"github.com/songquanpeng/one-api/relay/meta"
"io" "io"
"net/http" "net/http"
"strings" "strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/client"
"github.com/songquanpeng/one-api/relay/meta"
) )
const ( const (
extraRequestHeaderPrefix = "X-Oneapi-" extraRequestHeaderPrefix = "X-"
) )
func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) { func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) {
@ -20,9 +21,8 @@ func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta
req.Header.Set("Accept", c.Request.Header.Get("Accept")) req.Header.Set("Accept", c.Request.Header.Get("Accept"))
for key, values := range c.Request.Header { for key, values := range c.Request.Header {
if strings.HasPrefix(key, extraRequestHeaderPrefix) { if strings.HasPrefix(key, extraRequestHeaderPrefix) {
headerKey := strings.TrimPrefix(key, extraRequestHeaderPrefix)
for _, value := range values { for _, value := range values {
req.Header.Add(headerKey, value) req.Header.Add(key, value)
} }
} }
} }

View File

@ -17,7 +17,7 @@ func TestSetupCommonRequestHeader(t *testing.T) {
} }
c.Request.Header.Set("Content-Type", "application/json") c.Request.Header.Set("Content-Type", "application/json")
c.Request.Header.Set("Accept", "application/json") c.Request.Header.Set("Accept", "application/json")
c.Request.Header.Set("x-oneapi-test-header", "test-value") c.Request.Header.Set("x-test-header", "test-value")
// 创建测试用的http请求 // 创建测试用的http请求
req, _ := http.NewRequest("GET", "http://example.com", nil) req, _ := http.NewRequest("GET", "http://example.com", nil)
@ -33,5 +33,5 @@ func TestSetupCommonRequestHeader(t *testing.T) {
// 验证结果 // 验证结果
assert.Equal(t, "application/json", req.Header.Get("Content-Type")) assert.Equal(t, "application/json", req.Header.Get("Content-Type"))
assert.Equal(t, "application/json", req.Header.Get("Accept")) assert.Equal(t, "application/json", req.Header.Get("Accept"))
assert.Equal(t, "test-value", req.Header.Get("test-header")) assert.Equal(t, "test-value", req.Header.Get("x-test-header"))
} }