netlink/rule_test.go
2023-07-28 08:26:29 -07:00

438 lines
10 KiB
Go

//go:build linux
// +build linux
package netlink
import (
"net"
"testing"
"golang.org/x/sys/unix"
)
func TestRuleAddDel(t *testing.T) {
skipUnlessRoot(t)
defer setUpNetlinkTest(t)()
srcNet := &net.IPNet{IP: net.IPv4(172, 16, 0, 1), Mask: net.CIDRMask(16, 32)}
dstNet := &net.IPNet{IP: net.IPv4(172, 16, 1, 1), Mask: net.CIDRMask(24, 32)}
rulesBegin, err := RuleList(FAMILY_V4)
if err != nil {
t.Fatal(err)
}
rule := NewRule()
rule.Family = FAMILY_V4
rule.Table = unix.RT_TABLE_MAIN
rule.Src = srcNet
rule.Dst = dstNet
rule.Priority = 5
rule.OifName = "lo"
rule.IifName = "lo"
rule.Invert = true
rule.Tos = 0x10
rule.Dport = NewRulePortRange(80, 80)
rule.Sport = NewRulePortRange(1000, 1024)
rule.IPProto = unix.IPPROTO_UDP
rule.UIDRange = NewRuleUIDRange(100, 100)
rule.Protocol = unix.RTPROT_KERNEL
if err := RuleAdd(rule); err != nil {
t.Fatal(err)
}
rules, err := RuleList(FAMILY_V4)
if err != nil {
t.Fatal(err)
}
if len(rules) != len(rulesBegin)+1 {
t.Fatal("Rule not added properly")
}
// find this rule
found := ruleExists(rules, *rule)
if !found {
t.Fatal("Rule has diffrent options than one added")
}
if err := RuleDel(rule); err != nil {
t.Fatal(err)
}
rulesEnd, err := RuleList(FAMILY_V4)
if err != nil {
t.Fatal(err)
}
if len(rulesEnd) != len(rulesBegin) {
t.Fatal("Rule not removed properly")
}
}
func TestRuleListFiltered(t *testing.T) {
skipUnlessRoot(t)
defer setUpNetlinkTest(t)()
t.Run("IPv4", testRuleListFilteredIPv4)
t.Run("IPv6", testRuleListFilteredIPv6)
}
func testRuleListFilteredIPv4(t *testing.T) {
srcNet := &net.IPNet{IP: net.IPv4(172, 16, 0, 1), Mask: net.CIDRMask(16, 32)}
dstNet := &net.IPNet{IP: net.IPv4(172, 16, 1, 1), Mask: net.CIDRMask(24, 32)}
runRuleListFiltered(t, FAMILY_V4, srcNet, dstNet)
}
func testRuleListFilteredIPv6(t *testing.T) {
ip1 := net.ParseIP("fd56:6b58:db28:2913::")
ip2 := net.ParseIP("fde9:379f:3b35:6635::")
srcNet := &net.IPNet{IP: ip1, Mask: net.CIDRMask(64, 128)}
dstNet := &net.IPNet{IP: ip2, Mask: net.CIDRMask(96, 128)}
runRuleListFiltered(t, FAMILY_V6, srcNet, dstNet)
}
func runRuleListFiltered(t *testing.T, family int, srcNet, dstNet *net.IPNet) {
defaultRules, _ := RuleList(family)
tests := []struct {
name string
ruleFilter *Rule
filterMask uint64
preRun func() *Rule // Creates sample rule harness
postRun func(*Rule) // Deletes sample rule harness
setupWant func(*Rule) ([]Rule, bool)
}{
{
name: "returns all rules",
ruleFilter: nil,
filterMask: 0,
preRun: func() *Rule { return nil },
postRun: func(r *Rule) {},
setupWant: func(_ *Rule) ([]Rule, bool) {
return defaultRules, false
},
},
{
name: "returns one rule filtered by Src",
ruleFilter: &Rule{Src: srcNet},
filterMask: RT_FILTER_SRC,
preRun: func() *Rule {
r := NewRule()
r.Src = srcNet
r.Priority = 1 // Must add priority and table otherwise it's auto-assigned
r.Family = family
r.Table = 1
RuleAdd(r)
return r
},
postRun: func(r *Rule) { RuleDel(r) },
setupWant: func(r *Rule) ([]Rule, bool) {
return []Rule{*r}, false
},
},
{
name: "returns one rule filtered by Dst",
ruleFilter: &Rule{Dst: dstNet},
filterMask: RT_FILTER_DST,
preRun: func() *Rule {
r := NewRule()
r.Dst = dstNet
r.Priority = 1 // Must add priority and table otherwise it's auto-assigned
r.Family = family
r.Table = 1
RuleAdd(r)
return r
},
postRun: func(r *Rule) { RuleDel(r) },
setupWant: func(r *Rule) ([]Rule, bool) {
return []Rule{*r}, false
},
},
{
name: "returns two rules filtered by Dst",
ruleFilter: &Rule{Dst: dstNet},
filterMask: RT_FILTER_DST,
preRun: func() *Rule {
r := NewRule()
r.Dst = dstNet
r.Priority = 1 // Must add priority and table otherwise it's auto-assigned
r.Family = family
r.Table = 1
RuleAdd(r)
rc := *r // Create almost identical copy
rc.Src = srcNet
RuleAdd(&rc)
return r
},
postRun: func(r *Rule) {
RuleDel(r)
rc := *r // Delete the almost identical copy
rc.Src = srcNet
RuleDel(&rc)
},
setupWant: func(r *Rule) ([]Rule, bool) {
rs := []Rule{}
rs = append(rs, *r)
rc := *r // Append the almost identical copy
rc.Src = srcNet
rs = append(rs, rc)
return rs, false
},
},
{
name: "returns one rule filtered by Src when two rules exist",
ruleFilter: &Rule{Src: srcNet},
filterMask: RT_FILTER_SRC,
preRun: func() *Rule {
r := NewRule()
r.Dst = dstNet
r.Priority = 1 // Must add priority and table otherwise it's auto-assigned
r.Family = family
r.Table = 1
RuleAdd(r)
rc := *r // Create almost identical copy
rc.Src = srcNet
RuleAdd(&rc)
return r
},
postRun: func(r *Rule) {
RuleDel(r)
rc := *r // Delete the almost identical copy
rc.Src = srcNet
RuleDel(&rc)
},
setupWant: func(r *Rule) ([]Rule, bool) {
rs := []Rule{}
// Do not append `r`
rc := *r // Append the almost identical copy
rc.Src = srcNet
rs = append(rs, rc)
return rs, false
},
},
{
name: "returns rules with specific priority",
ruleFilter: &Rule{Priority: 5},
filterMask: RT_FILTER_PRIORITY,
preRun: func() *Rule {
r := NewRule()
r.Src = srcNet
r.Priority = 5
r.Family = family
r.Table = 1
RuleAdd(r)
for i := 2; i < 5; i++ {
rc := *r // Create almost identical copy
rc.Table = i
RuleAdd(&rc)
}
return r
},
postRun: func(r *Rule) {
RuleDel(r)
for i := 2; i < 5; i++ {
rc := *r // Delete the almost identical copy
rc.Table = -1
RuleDel(&rc)
}
},
setupWant: func(r *Rule) ([]Rule, bool) {
rs := []Rule{}
rs = append(rs, *r)
for i := 2; i < 5; i++ {
rc := *r // Append the almost identical copy
rc.Table = i
rs = append(rs, rc)
}
return rs, false
},
},
{
name: "returns rules filtered by Table",
ruleFilter: &Rule{Table: 199},
filterMask: RT_FILTER_TABLE,
preRun: func() *Rule {
r := NewRule()
r.Src = srcNet
r.Priority = 1 // Must add priority otherwise it's auto-assigned
r.Family = family
r.Table = 199
RuleAdd(r)
return r
},
postRun: func(r *Rule) { RuleDel(r) },
setupWant: func(r *Rule) ([]Rule, bool) {
return []Rule{*r}, false
},
},
{
name: "returns rules filtered by Mask",
ruleFilter: &Rule{Mask: 0x5},
filterMask: RT_FILTER_MASK,
preRun: func() *Rule {
r := NewRule()
r.Src = srcNet
r.Priority = 1 // Must add priority and table otherwise it's auto-assigned
r.Family = family
r.Table = 1
r.Mask = 0x5
RuleAdd(r)
return r
},
postRun: func(r *Rule) { RuleDel(r) },
setupWant: func(r *Rule) ([]Rule, bool) {
return []Rule{*r}, false
},
},
{
name: "returns rules filtered by Mark",
ruleFilter: &Rule{Mark: 0xbb},
filterMask: RT_FILTER_MARK,
preRun: func() *Rule {
r := NewRule()
r.Src = srcNet
r.Priority = 1 // Must add priority, table, mask otherwise it's auto-assigned
r.Family = family
r.Table = 1
r.Mask = 0xff
r.Mark = 0xbb
RuleAdd(r)
return r
},
postRun: func(r *Rule) { RuleDel(r) },
setupWant: func(r *Rule) ([]Rule, bool) {
return []Rule{*r}, false
},
},
{
name: "returns rules filtered by Tos",
ruleFilter: &Rule{Tos: 12},
filterMask: RT_FILTER_TOS,
preRun: func() *Rule {
r := NewRule()
r.Src = srcNet
r.Priority = 1 // Must add priority, table, mask otherwise it's auto-assigned
r.Family = family
r.Table = 12
r.Tos = 12 // Tos must equal table
RuleAdd(r)
return r
},
postRun: func(r *Rule) { RuleDel(r) },
setupWant: func(r *Rule) ([]Rule, bool) {
return []Rule{*r}, false
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rule := tt.preRun()
rules, err := RuleListFiltered(family, tt.ruleFilter, tt.filterMask)
tt.postRun(rule)
wantRules, wantErr := tt.setupWant(rule)
if len(wantRules) != len(rules) {
t.Errorf("Expected len: %d, got: %d", len(wantRules), len(rules))
} else {
for i := range wantRules {
if !ruleEquals(wantRules[i], rules[i]) {
t.Errorf("Rules mismatch, want %v, got %v", wantRules[i], rules[i])
}
}
}
if (err != nil) != wantErr {
t.Errorf("Error expectation not met, want %v, got %v", (err != nil), wantErr)
}
})
}
}
func TestRuleString(t *testing.T) {
t.Parallel()
testCases := map[string]struct {
r Rule
s string
}{
"empty rule": {
s: "ip rule 0: from all to all table 0",
},
"rule with src and dst equivalent to <nil>": {
r: Rule{
Priority: 100,
Src: &net.IPNet{IP: net.IPv4(10, 0, 0, 0)},
Dst: &net.IPNet{IP: net.IPv4(20, 0, 0, 0)},
Table: 99,
},
s: "ip rule 100: from all to all table 99",
},
"rule with src and dst": {
r: Rule{
Priority: 100,
Src: &net.IPNet{IP: net.IPv4(10, 0, 0, 0), Mask: net.IPv4Mask(255, 255, 255, 0)},
Dst: &net.IPNet{IP: net.IPv4(20, 0, 0, 0), Mask: net.IPv4Mask(255, 255, 255, 0)},
Table: 99,
},
s: "ip rule 100: from 10.0.0.0/24 to 20.0.0.0/24 table 99",
},
}
for name, testCase := range testCases {
testCase := testCase
t.Run(name, func(t *testing.T) {
t.Parallel()
s := testCase.r.String()
if s != testCase.s {
t.Errorf("expected %q but got %q", testCase.s, s)
}
})
}
}
func ruleExists(rules []Rule, rule Rule) bool {
for i := range rules {
if ruleEquals(rules[i], rule) {
return true
}
}
return false
}
func ruleEquals(a, b Rule) bool {
return a.Table == b.Table &&
((a.Src == nil && b.Src == nil) ||
(a.Src != nil && b.Src != nil && a.Src.String() == b.Src.String())) &&
((a.Dst == nil && b.Dst == nil) ||
(a.Dst != nil && b.Dst != nil && a.Dst.String() == b.Dst.String())) &&
a.OifName == b.OifName &&
a.Priority == b.Priority &&
a.Family == b.Family &&
a.IifName == b.IifName &&
a.Invert == b.Invert &&
a.Tos == b.Tos &&
a.IPProto == b.IPProto &&
a.Protocol == b.Protocol
}