diff --git a/route_linux.go b/route_linux.go index 6b915e5..0cd4f83 100644 --- a/route_linux.go +++ b/route_linux.go @@ -1521,7 +1521,7 @@ type RouteGetOptions struct { VrfName string SrcAddr net.IP UID *uint32 - Mark int + Mark uint32 FIBMatch bool } @@ -1630,7 +1630,7 @@ func (h *Handle) RouteGetWithOptions(destination net.IP, options *RouteGetOption if options.Mark > 0 { b := make([]byte, 4) - native.PutUint32(b, uint32(options.Mark)) + native.PutUint32(b, options.Mark) req.AddData(nl.NewRtAttr(unix.RTA_MARK, b)) } diff --git a/route_test.go b/route_test.go index c844157..7e5e6b0 100644 --- a/route_test.go +++ b/route_test.go @@ -2449,27 +2449,41 @@ func TestRouteFWMarkOption(t *testing.T) { } // a table different than unix.RT_TABLE_MAIN - testtable := 1000 + testTable0 := 254 + testTable1 := 1000 + testTable2 := 1001 - gw1 := net.IPv4(192, 168, 1, 254) - gw2 := net.IPv4(192, 168, 2, 254) + gw0 := net.IPv4(192, 168, 1, 254) + gw1 := net.IPv4(192, 168, 2, 254) + gw2 := net.IPv4(192, 168, 3, 254) - // add default route via gw1 (in main route table by default) + // add default route via gw0 (in main route table by default) defaultRouteMain := Route{ - Dst: nil, - Gw: gw1, + Dst: nil, + Gw: gw0, + Table: testTable0, } if err := RouteAdd(&defaultRouteMain); err != nil { t.Fatal(err) } + // add default route via gw1 in test route table + defaultRouteTest1 := Route{ + Dst: nil, + Gw: gw1, + Table: testTable1, + } + if err := RouteAdd(&defaultRouteTest1); err != nil { + t.Fatal(err) + } + // add default route via gw2 in test route table - defaultRouteTest := Route{ + defaultRouteTest2 := Route{ Dst: nil, Gw: gw2, - Table: testtable, + Table: testTable2, } - if err := RouteAdd(&defaultRouteTest); err != nil { + if err := RouteAdd(&defaultRouteTest2); err != nil { t.Fatal(err) } @@ -2481,21 +2495,48 @@ func TestRouteFWMarkOption(t *testing.T) { if err != nil { t.Fatal(err) } - if len(routes) != 2 || routes[0].Table == routes[1].Table { + if len(routes) != 3 || routes[0].Table == routes[1].Table || routes[1].Table == routes[2].Table || + routes[0].Table == routes[2].Table { t.Fatal("Routes not added properly") } // add a rule that fwmark match should result in route lookup of test table - fwmark := 1000 + fwmark1 := uint32(0xAFFFFFFF) + fwmark2 := uint32(0xBFFFFFFF) rule := NewRule() - rule.Mark = fwmark - rule.Mask = 0xFFFFFFFF - rule.Table = testtable + rule.Mark = fwmark1 + rule.Mask = &[]uint32{0xFFFFFFFF}[0] + + rule.Table = testTable1 if err := RuleAdd(rule); err != nil { t.Fatal(err) } + rule = NewRule() + rule.Mark = fwmark2 + rule.Mask = &[]uint32{0xFFFFFFFF}[0] + rule.Table = testTable2 + if err := RuleAdd(rule); err != nil { + t.Fatal(err) + } + + rules, err := RuleListFiltered(FAMILY_V4, &Rule{Mark: fwmark1}, RT_FILTER_MARK) + if err != nil { + t.Fatal(err) + } + if len(rules) != 1 || rules[0].Table != testTable1 || rules[0].Mark != fwmark1 { + t.Fatal("Rules not added properly") + } + + rules, err = RuleListFiltered(FAMILY_V4, &Rule{Mark: fwmark2}, RT_FILTER_MARK) + if err != nil { + t.Fatal(err) + } + if len(rules) != 1 || rules[0].Table != testTable2 || rules[0].Mark != fwmark2 { + t.Fatal("Rules not added properly") + } + dstIP := net.IPv4(10, 1, 1, 1) // check getting route without FWMark option @@ -2503,12 +2544,21 @@ func TestRouteFWMarkOption(t *testing.T) { if err != nil { t.Fatal(err) } + if len(routes) != 1 || !routes[0].Gw.Equal(gw0) { + t.Fatal(routes) + } + + // check getting route with FWMark option + routes, err = RouteGetWithOptions(dstIP, &RouteGetOptions{Mark: fwmark1}) + if err != nil { + t.Fatal(err) + } if len(routes) != 1 || !routes[0].Gw.Equal(gw1) { t.Fatal(routes) } // check getting route with FWMark option - routes, err = RouteGetWithOptions(dstIP, &RouteGetOptions{Mark: fwmark}) + routes, err = RouteGetWithOptions(dstIP, &RouteGetOptions{Mark: fwmark2}) if err != nil { t.Fatal(err) } diff --git a/rule.go b/rule.go index 9a6b57d..cc73945 100644 --- a/rule.go +++ b/rule.go @@ -10,8 +10,8 @@ type Rule struct { Priority int Family int Table int - Mark int - Mask int + Mark uint32 + Mask *uint32 Tos uint TunID uint Goto int @@ -51,8 +51,8 @@ func NewRule() *Rule { SuppressIfgroup: -1, SuppressPrefixlen: -1, Priority: -1, - Mark: -1, - Mask: -1, + Mark: 0, + Mask: nil, Goto: -1, Flow: -1, } diff --git a/rule_linux.go b/rule_linux.go index e919892..18c03a3 100644 --- a/rule_linux.go +++ b/rule_linux.go @@ -102,14 +102,14 @@ func ruleHandle(rule *Rule, req *nl.NetlinkRequest) error { native.PutUint32(b, uint32(rule.Priority)) req.AddData(nl.NewRtAttr(nl.FRA_PRIORITY, b)) } - if rule.Mark >= 0 { + if rule.Mark != 0 || rule.Mask != nil { b := make([]byte, 4) - native.PutUint32(b, uint32(rule.Mark)) + native.PutUint32(b, rule.Mark) req.AddData(nl.NewRtAttr(nl.FRA_FWMARK, b)) } - if rule.Mask >= 0 { + if rule.Mask != nil { b := make([]byte, 4) - native.PutUint32(b, uint32(rule.Mask)) + native.PutUint32(b, *rule.Mask) req.AddData(nl.NewRtAttr(nl.FRA_FWMASK, b)) } if rule.Flow >= 0 { @@ -242,9 +242,10 @@ func (h *Handle) RuleListFiltered(family int, filter *Rule, filterMask uint64) ( Mask: net.CIDRMask(int(msg.Dst_len), 8*len(attrs[j].Value)), } case nl.FRA_FWMARK: - rule.Mark = int(native.Uint32(attrs[j].Value[0:4])) + rule.Mark = native.Uint32(attrs[j].Value[0:4]) case nl.FRA_FWMASK: - rule.Mask = int(native.Uint32(attrs[j].Value[0:4])) + mask := native.Uint32(attrs[j].Value[0:4]) + rule.Mask = &mask case nl.FRA_TUN_ID: rule.TunID = uint(native.Uint64(attrs[j].Value[0:8])) case nl.FRA_IIFNAME: @@ -297,7 +298,7 @@ func (h *Handle) RuleListFiltered(family int, filter *Rule, filterMask uint64) ( continue case filterMask&RT_FILTER_MARK != 0 && rule.Mark != filter.Mark: continue - case filterMask&RT_FILTER_MASK != 0 && rule.Mask != filter.Mask: + case filterMask&RT_FILTER_MASK != 0 && !ptrEqual(rule.Mask, filter.Mask): continue } } @@ -321,3 +322,13 @@ func (pr *RuleUIDRange) toRtAttrData() []byte { native.PutUint32(b[1], pr.End) return bytes.Join(b, []byte{}) } + +func ptrEqual(a, b *uint32) bool { + if a == b { + return true + } + if (a == nil) || (b == nil) { + return false + } + return *a == *b +} diff --git a/rule_test.go b/rule_test.go index ab9c577..c452fe9 100644 --- a/rule_test.go +++ b/rule_test.go @@ -252,7 +252,7 @@ func runRuleListFiltered(t *testing.T, family int, srcNet, dstNet *net.IPNet) { r.Family = family r.Table = 1 RuleAdd(r) - + r.Priority = 32765 // Set priority for assertion return r }, @@ -325,7 +325,7 @@ func runRuleListFiltered(t *testing.T, family int, srcNet, dstNet *net.IPNet) { }, { name: "returns rules filtered by Mask", - ruleFilter: &Rule{Mask: 0x5}, + ruleFilter: &Rule{Mask: &[]uint32{0x5}[0]}, filterMask: RT_FILTER_MASK, preRun: func() *Rule { r := NewRule() @@ -333,7 +333,7 @@ func runRuleListFiltered(t *testing.T, family int, srcNet, dstNet *net.IPNet) { r.Priority = 1 // Must add priority and table otherwise it's auto-assigned r.Family = family r.Table = 1 - r.Mask = 0x5 + r.Mask = &[]uint32{0x5}[0] RuleAdd(r) return r }, @@ -352,7 +352,7 @@ func runRuleListFiltered(t *testing.T, family int, srcNet, dstNet *net.IPNet) { r.Priority = 1 // Must add priority, table, mask otherwise it's auto-assigned r.Family = family r.Table = 1 - r.Mask = 0xff + r.Mask = &[]uint32{0xff}[0] r.Mark = 0xbb RuleAdd(r) return r @@ -362,6 +362,204 @@ func runRuleListFiltered(t *testing.T, family int, srcNet, dstNet *net.IPNet) { return []Rule{*r}, false }, }, + { + name: "returns rules filtered by fwmark 0", + ruleFilter: &Rule{Mark: 0, Mask: nil, Table: 100}, + filterMask: RT_FILTER_MARK | RT_FILTER_MASK | RT_FILTER_TABLE, + preRun: func() *Rule { + r := NewRule() + r.Src = srcNet + r.Priority = 1 + r.Family = family + r.Table = 100 + r.Mark = 0 + r.Mask = nil + if err := RuleAdd(r); err != nil { + t.Fatal(err) + } + return r + }, + postRun: func(r *Rule) { RuleDel(r) }, + setupWant: func(r *Rule) ([]Rule, bool) { + return []Rule{*r}, false + }, + }, + { + name: "returns rules filtered by fwmark 0/0xFFFFFFFF", + ruleFilter: &Rule{Mark: 0, Mask: &[]uint32{0xFFFFFFFF}[0], Table: 100}, + filterMask: RT_FILTER_MARK | RT_FILTER_MASK | RT_FILTER_TABLE, + preRun: func() *Rule { + r := NewRule() + r.Src = srcNet + r.Priority = 1 + r.Family = family + r.Table = 100 + r.Mark = 0 + r.Mask = &[]uint32{0xFFFFFFFF}[0] + if err := RuleAdd(r); err != nil { + t.Fatal(err) + } + return r + }, + postRun: func(r *Rule) { RuleDel(r) }, + setupWant: func(r *Rule) ([]Rule, bool) { + return []Rule{*r}, false + }, + }, + { + name: "returns rules filtered by fwmark 0x1234/0", + ruleFilter: &Rule{Mark: 0x1234, Mask: &[]uint32{0}[0], Table: 100}, + filterMask: RT_FILTER_MARK | RT_FILTER_MASK | RT_FILTER_TABLE, + preRun: func() *Rule { + r := NewRule() + r.Src = srcNet + r.Priority = 1 + r.Family = family + r.Table = 100 + r.Mark = 0x1234 + r.Mask = &[]uint32{0}[0] + if err := RuleAdd(r); err != nil { + t.Fatal(err) + } + return r + }, + postRun: func(r *Rule) { RuleDel(r) }, + setupWant: func(r *Rule) ([]Rule, bool) { + return []Rule{*r}, false + }, + }, + { + name: "returns rules filtered by fwmark 0/0xFFFFFFFF", + ruleFilter: &Rule{Mark: 0, Mask: &[]uint32{0xFFFFFFFF}[0], Table: 100}, + filterMask: RT_FILTER_MARK | RT_FILTER_MASK | RT_FILTER_TABLE, + preRun: func() *Rule { + r := NewRule() + r.Src = srcNet + r.Priority = 1 + r.Family = family + r.Table = 100 + r.Mark = 0 + r.Mask = &[]uint32{0xFFFFFFFF}[0] + if err := RuleAdd(r); err != nil { + t.Fatal(err) + } + return r + }, + postRun: func(r *Rule) { RuleDel(r) }, + setupWant: func(r *Rule) ([]Rule, bool) { + return []Rule{*r}, false + }, + }, + { + name: "returns rules filtered by fwmark 0xFFFFFFFF", + ruleFilter: &Rule{Mark: 0xFFFFFFFF, Mask: &[]uint32{0xFFFFFFFF}[0], Table: 100}, + filterMask: RT_FILTER_MARK | RT_FILTER_MASK | RT_FILTER_TABLE, + preRun: func() *Rule { + r := NewRule() + r.Src = srcNet + r.Priority = 1 + r.Family = family + r.Table = 100 + r.Mark = 0xFFFFFFFF + r.Mask = nil + if err := RuleAdd(r); err != nil { + t.Fatal(err) + } + return r + }, + postRun: func(r *Rule) { RuleDel(r) }, + setupWant: func(r *Rule) ([]Rule, bool) { + return []Rule{*r}, false + }, + }, + { + name: "returns rules filtered by fwmark 0x1234", + ruleFilter: &Rule{Mark: 0x1234, Mask: &[]uint32{0xFFFFFFFF}[0], Table: 100}, + filterMask: RT_FILTER_MARK | RT_FILTER_MASK | RT_FILTER_TABLE, + preRun: func() *Rule { + r := NewRule() + r.Src = srcNet + r.Priority = 1 + r.Family = family + r.Table = 100 + r.Mark = 0x1234 + r.Mask = nil + if err := RuleAdd(r); err != nil { + t.Fatal(err) + } + return r + }, + postRun: func(r *Rule) { RuleDel(r) }, + setupWant: func(r *Rule) ([]Rule, bool) { + return []Rule{*r}, false + }, + }, + { + name: "returns rules filtered by fwmark 0x12345678", + ruleFilter: &Rule{Mark: 0x12345678, Mask: &[]uint32{0xFFFFFFFF}[0], Table: 100}, + filterMask: RT_FILTER_MARK | RT_FILTER_MASK | RT_FILTER_TABLE, + preRun: func() *Rule { + r := NewRule() + r.Src = srcNet + r.Priority = 1 + r.Family = family + r.Table = 100 + r.Mark = 0x12345678 + r.Mask = nil + if err := RuleAdd(r); err != nil { + t.Fatal(err) + } + return r + }, + postRun: func(r *Rule) { RuleDel(r) }, + setupWant: func(r *Rule) ([]Rule, bool) { + return []Rule{*r}, false + }, + }, + { + name: "returns rules filtered by fwmark 0xFFFFFFFF/0", + ruleFilter: &Rule{Mark: 0xFFFFFFFF, Mask: &[]uint32{0}[0], Table: 100}, + filterMask: RT_FILTER_MARK | RT_FILTER_MASK | RT_FILTER_TABLE, + preRun: func() *Rule { + r := NewRule() + r.Src = srcNet + r.Priority = 1 + r.Family = family + r.Table = 100 + r.Mark = 0xFFFFFFFF + r.Mask = &[]uint32{0}[0] + if err := RuleAdd(r); err != nil { + t.Fatal(err) + } + return r + }, + postRun: func(r *Rule) { RuleDel(r) }, + setupWant: func(r *Rule) ([]Rule, bool) { + return []Rule{*r}, false + }, + }, + { + name: "returns rules filtered by fwmark 0xFFFFFFFF/0xFFFFFFFF", + ruleFilter: &Rule{Mark: 0xFFFFFFFF, Mask: &[]uint32{0xFFFFFFFF}[0], Table: 100}, + filterMask: RT_FILTER_MARK | RT_FILTER_MASK | RT_FILTER_TABLE, + preRun: func() *Rule { + r := NewRule() + r.Src = srcNet + r.Priority = 1 + r.Family = family + r.Table = 100 + r.Mark = 0xFFFFFFFF + r.Mask = &[]uint32{0xFFFFFFFF}[0] + if err := RuleAdd(r); err != nil { + t.Fatal(err) + } + return r + }, + postRun: func(r *Rule) { RuleDel(r) }, + setupWant: func(r *Rule) ([]Rule, bool) { + return []Rule{*r}, false + }, + }, { name: "returns rules filtered by Tos", ruleFilter: &Rule{Tos: 12}, @@ -474,5 +672,8 @@ func ruleEquals(a, b Rule) bool { a.Invert == b.Invert && a.Tos == b.Tos && a.IPProto == b.IPProto && - a.Protocol == b.Protocol + a.Protocol == b.Protocol && + a.Mark == b.Mark && + (ptrEqual(a.Mask, b.Mask) || (a.Mark != 0 && + (a.Mask == nil && *b.Mask == 0xFFFFFFFF || b.Mask == nil && *a.Mask == 0xFFFFFFFF))) }