mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	feat: able to set multiple subnets
This commit is contained in:
		@@ -5,9 +5,18 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/songquanpeng/one-api/common/logger"
 | 
			
		||||
	"net"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func IsValidSubnet(subnet string) error {
 | 
			
		||||
func splitSubnets(subnets string) []string {
 | 
			
		||||
	res := strings.Split(subnets, ",")
 | 
			
		||||
	for i := 0; i < len(res); i++ {
 | 
			
		||||
		res[i] = strings.TrimSpace(res[i])
 | 
			
		||||
	}
 | 
			
		||||
	return res
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func isValidSubnet(subnet string) error {
 | 
			
		||||
	_, _, err := net.ParseCIDR(subnet)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("failed to parse subnet: %w", err)
 | 
			
		||||
@@ -15,7 +24,7 @@ func IsValidSubnet(subnet string) error {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func IsIpInSubnet(ctx context.Context, ip string, subnet string) bool {
 | 
			
		||||
func isIpInSubnet(ctx context.Context, ip string, subnet string) bool {
 | 
			
		||||
	_, ipNet, err := net.ParseCIDR(subnet)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Errorf(ctx, "failed to parse subnet: %s", err.Error())
 | 
			
		||||
@@ -23,3 +32,21 @@ func IsIpInSubnet(ctx context.Context, ip string, subnet string) bool {
 | 
			
		||||
	}
 | 
			
		||||
	return ipNet.Contains(net.ParseIP(ip))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func IsValidSubnets(subnets string) error {
 | 
			
		||||
	for _, subnet := range splitSubnets(subnets) {
 | 
			
		||||
		if err := isValidSubnet(subnet); err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func IsIpInSubnets(ctx context.Context, ip string, subnets string) bool {
 | 
			
		||||
	for _, subnet := range splitSubnets(subnets) {
 | 
			
		||||
		if isIpInSubnet(ctx, ip, subnet) {
 | 
			
		||||
			return true
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return false
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -13,7 +13,7 @@ func TestIsIpInSubnet(t *testing.T) {
 | 
			
		||||
	ip2 := "125.216.250.89"
 | 
			
		||||
	subnet := "192.168.0.0/24"
 | 
			
		||||
	Convey("TestIsIpInSubnet", t, func() {
 | 
			
		||||
		So(IsIpInSubnet(ctx, ip1, subnet), ShouldBeTrue)
 | 
			
		||||
		So(IsIpInSubnet(ctx, ip2, subnet), ShouldBeFalse)
 | 
			
		||||
		So(isIpInSubnet(ctx, ip1, subnet), ShouldBeTrue)
 | 
			
		||||
		So(isIpInSubnet(ctx, ip2, subnet), ShouldBeFalse)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -111,7 +111,7 @@ func validateToken(c *gin.Context, token model.Token) error {
 | 
			
		||||
		return fmt.Errorf("令牌名称过长")
 | 
			
		||||
	}
 | 
			
		||||
	if token.Subnet != nil && *token.Subnet != "" {
 | 
			
		||||
		err := network.IsValidSubnet(*token.Subnet)
 | 
			
		||||
		err := network.IsValidSubnets(*token.Subnet)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("无效的网段:%s", err.Error())
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -102,7 +102,7 @@ func TokenAuth() func(c *gin.Context) {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		if token.Subnet != nil && *token.Subnet != "" {
 | 
			
		||||
			if !network.IsIpInSubnet(ctx, c.ClientIP(), *token.Subnet) {
 | 
			
		||||
			if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) {
 | 
			
		||||
				abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌只能在指定网段使用:%s,当前 ip:%s", *token.Subnet, c.ClientIP()))
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 
 | 
			
		||||
@@ -158,7 +158,7 @@ const EditToken = () => {
 | 
			
		||||
            <Form.Input
 | 
			
		||||
              label='IP 限制'
 | 
			
		||||
              name='subnet'
 | 
			
		||||
              placeholder={'请输入允许访问的网段,例如:192.168.0.0/24'}
 | 
			
		||||
              placeholder={'请输入允许访问的网段,例如:192.168.0.0/24,请使用英文逗号分隔多个网段'}
 | 
			
		||||
              onChange={handleInputChange}
 | 
			
		||||
              value={inputs.subnet}
 | 
			
		||||
              autoComplete='new-password'
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user