diff --git a/rule.go b/rule.go index b151ca2..9a6b57d 100644 --- a/rule.go +++ b/rule.go @@ -27,6 +27,7 @@ type Rule struct { Sport *RulePortRange IPProto int UIDRange *RuleUIDRange + Protocol uint8 } func (r Rule) String() string { diff --git a/rule_linux.go b/rule_linux.go index ecd408e..69f53e4 100644 --- a/rule_linux.go +++ b/rule_linux.go @@ -173,6 +173,10 @@ func ruleHandle(rule *Rule, req *nl.NetlinkRequest) error { req.AddData(nl.NewRtAttr(nl.FRA_UID_RANGE, b)) } + if rule.Protocol > 0 { + req.AddData(nl.NewRtAttr(nl.FRA_PROTOCOL, nl.Uint8Attr(rule.Protocol))) + } + _, err := req.Execute(unix.NETLINK_ROUTE, 0) return err } @@ -269,6 +273,8 @@ func (h *Handle) RuleListFiltered(family int, filter *Rule, filterMask uint64) ( rule.Sport = NewRulePortRange(native.Uint16(attrs[j].Value[0:2]), native.Uint16(attrs[j].Value[2:4])) case nl.FRA_UID_RANGE: rule.UIDRange = NewRuleUIDRange(native.Uint32(attrs[j].Value[0:4]), native.Uint32(attrs[j].Value[4:8])) + case nl.FRA_PROTOCOL: + rule.Protocol = uint8(attrs[j].Value[0]) } } diff --git a/rule_test.go b/rule_test.go index 4cdc356..f43d436 100644 --- a/rule_test.go +++ b/rule_test.go @@ -35,6 +35,7 @@ func TestRuleAddDel(t *testing.T) { rule.Sport = NewRulePortRange(1000, 1024) rule.IPProto = unix.IPPROTO_UDP rule.UIDRange = NewRuleUIDRange(100, 100) + rule.Protocol = unix.RTPROT_KERNEL if err := RuleAdd(rule); err != nil { t.Fatal(err) } @@ -420,5 +421,6 @@ func ruleEquals(a, b Rule) bool { a.IifName == b.IifName && a.Invert == b.Invert && a.Tos == b.Tos && - a.IPProto == b.IPProto + a.IPProto == b.IPProto && + a.Protocol == b.Protocol }