conntrack: allow to filter by subnet

Add a new conntrack filter to be able to filter by subnet, in
addition to current IP address filter.

Signed-off-by: Antonio Ojea <aojea@redhat.com>
This commit is contained in:
Antonio Ojea 2021-03-27 00:42:27 +01:00 committed by Alessandro Boch
parent dbf5d9e510
commit a3836f0b5d
2 changed files with 252 additions and 43 deletions

View File

@ -346,23 +346,34 @@ type CustomConntrackFilter interface {
}
type ConntrackFilter struct {
ipFilter map[ConntrackFilterType]net.IP
ipNetFilter map[ConntrackFilterType]*net.IPNet
portFilter map[ConntrackFilterType]uint16
protoFilter uint8
}
// AddIP adds an IP to the conntrack filter
func (f *ConntrackFilter) AddIP(tp ConntrackFilterType, ip net.IP) error {
if f.ipFilter == nil {
f.ipFilter = make(map[ConntrackFilterType]net.IP)
// AddIPNet adds a IP subnet to the conntrack filter
func (f *ConntrackFilter) AddIPNet(tp ConntrackFilterType, ipNet *net.IPNet) error {
if ipNet == nil {
return fmt.Errorf("Filter attribute empty")
}
if _, ok := f.ipFilter[tp]; ok {
if f.ipNetFilter == nil {
f.ipNetFilter = make(map[ConntrackFilterType]*net.IPNet)
}
if _, ok := f.ipNetFilter[tp]; ok {
return errors.New("Filter attribute already present")
}
f.ipFilter[tp] = ip
f.ipNetFilter[tp] = ipNet
return nil
}
// AddIP adds an IP to the conntrack filter
func (f *ConntrackFilter) AddIP(tp ConntrackFilterType, ip net.IP) error {
if ip == nil {
return fmt.Errorf("Filter attribute empty")
}
return f.AddIPNet(tp, NewIPNet(ip))
}
// AddPort adds a Port to the conntrack filter if the Layer 4 protocol allows it
func (f *ConntrackFilter) AddPort(tp ConntrackFilterType, port uint16) error {
switch f.protoFilter {
@ -394,7 +405,7 @@ func (f *ConntrackFilter) AddProtocol(proto uint8) error {
// MatchConntrackFlow applies the filter to the flow and returns true if the flow matches the filter
// false otherwise
func (f *ConntrackFilter) MatchConntrackFlow(flow *ConntrackFlow) bool {
if len(f.ipFilter) == 0 && len(f.portFilter) == 0 && f.protoFilter == 0 {
if len(f.ipNetFilter) == 0 && len(f.portFilter) == 0 && f.protoFilter == 0 {
// empty filter always not match
return false
}
@ -408,30 +419,30 @@ func (f *ConntrackFilter) MatchConntrackFlow(flow *ConntrackFlow) bool {
match := true
// IP conntrack filter
if len(f.ipFilter) > 0 {
if len(f.ipNetFilter) > 0 {
// -orig-src ip Source address from original direction
if elem, found := f.ipFilter[ConntrackOrigSrcIP]; found {
match = match && elem.Equal(flow.Forward.SrcIP)
if elem, found := f.ipNetFilter[ConntrackOrigSrcIP]; found {
match = match && elem.Contains(flow.Forward.SrcIP)
}
// -orig-dst ip Destination address from original direction
if elem, found := f.ipFilter[ConntrackOrigDstIP]; match && found {
match = match && elem.Equal(flow.Forward.DstIP)
if elem, found := f.ipNetFilter[ConntrackOrigDstIP]; match && found {
match = match && elem.Contains(flow.Forward.DstIP)
}
// -src-nat ip Source NAT ip
if elem, found := f.ipFilter[ConntrackReplySrcIP]; match && found {
match = match && elem.Equal(flow.Reverse.SrcIP)
if elem, found := f.ipNetFilter[ConntrackReplySrcIP]; match && found {
match = match && elem.Contains(flow.Reverse.SrcIP)
}
// -dst-nat ip Destination NAT ip
if elem, found := f.ipFilter[ConntrackReplyDstIP]; match && found {
match = match && elem.Equal(flow.Reverse.DstIP)
if elem, found := f.ipNetFilter[ConntrackReplyDstIP]; match && found {
match = match && elem.Contains(flow.Reverse.DstIP)
}
// Match source or destination reply IP
if elem, found := f.ipFilter[ConntrackReplyAnyIP]; match && found {
match = match && (elem.Equal(flow.Reverse.SrcIP) || elem.Equal(flow.Reverse.DstIP))
if elem, found := f.ipNetFilter[ConntrackReplyAnyIP]; match && found {
match = match && (elem.Contains(flow.Reverse.SrcIP) || elem.Contains(flow.Reverse.DstIP))
}
}

View File

@ -381,11 +381,17 @@ func TestConntrackFilter(t *testing.T) {
// Adding same attribute should fail
filter := &ConntrackFilter{}
filter.AddIP(ConntrackOrigSrcIP, net.ParseIP("10.0.0.1"))
err := filter.AddIP(ConntrackOrigSrcIP, net.ParseIP("10.0.0.1"))
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if err := filter.AddIP(ConntrackOrigSrcIP, net.ParseIP("10.0.0.1")); err == nil {
t.Fatalf("Error, it should fail adding same attribute to the filter")
}
filter.AddProtocol(6)
err = filter.AddProtocol(6)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if err := filter.AddProtocol(17); err == nil {
t.Fatalf("Error, it should fail adding same attribute to the filter")
}
@ -402,17 +408,26 @@ func TestConntrackFilter(t *testing.T) {
// Can not add a Port filter if the Layer 4 protocol does not support it
filter = &ConntrackFilter{}
filter.AddProtocol(47)
err = filter.AddProtocol(47)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if err := filter.AddPort(ConntrackOrigSrcPort, 80); err == nil {
t.Fatalf("Error, it should fail adding a port filter with a wrong protocol")
}
// Proto filter
filterV4 := &ConntrackFilter{}
filterV4.AddProtocol(6)
err = filterV4.AddProtocol(6)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
filterV6 := &ConntrackFilter{}
filterV6.AddProtocol(132)
err = filterV6.AddProtocol(132)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
v4Match, v6Match = applyFilter(flowList, filterV4, filterV6)
if v4Match != 1 || v6Match != 1 {
@ -421,10 +436,16 @@ func TestConntrackFilter(t *testing.T) {
// SrcIP filter
filterV4 = &ConntrackFilter{}
filterV4.AddIP(ConntrackOrigSrcIP, net.ParseIP("10.0.0.1"))
err = filterV4.AddIP(ConntrackOrigSrcIP, net.ParseIP("10.0.0.1"))
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
filterV6 = &ConntrackFilter{}
filterV6.AddIP(ConntrackOrigSrcIP, net.ParseIP("eeee:eeee:eeee:eeee:eeee:eeee:eeee:eeee"))
err = filterV6.AddIP(ConntrackOrigSrcIP, net.ParseIP("eeee:eeee:eeee:eeee:eeee:eeee:eeee:eeee"))
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
v4Match, v6Match = applyFilter(flowList, filterV4, filterV6)
if v4Match != 1 || v6Match != 1 {
@ -433,10 +454,16 @@ func TestConntrackFilter(t *testing.T) {
// DstIp filter
filterV4 = &ConntrackFilter{}
filterV4.AddIP(ConntrackOrigDstIP, net.ParseIP("20.0.0.1"))
err = filterV4.AddIP(ConntrackOrigDstIP, net.ParseIP("20.0.0.1"))
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
filterV6 = &ConntrackFilter{}
filterV6.AddIP(ConntrackOrigDstIP, net.ParseIP("dddd:dddd:dddd:dddd:dddd:dddd:dddd:dddd"))
err = filterV6.AddIP(ConntrackOrigDstIP, net.ParseIP("dddd:dddd:dddd:dddd:dddd:dddd:dddd:dddd"))
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
v4Match, v6Match = applyFilter(flowList, filterV4, filterV6)
if v4Match != 1 || v6Match != 1 {
@ -445,10 +472,16 @@ func TestConntrackFilter(t *testing.T) {
// SrcIP for NAT
filterV4 = &ConntrackFilter{}
filterV4.AddIP(ConntrackReplySrcIP, net.ParseIP("20.0.0.1"))
err = filterV4.AddIP(ConntrackReplySrcIP, net.ParseIP("20.0.0.1"))
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
filterV6 = &ConntrackFilter{}
filterV6.AddIP(ConntrackReplySrcIP, net.ParseIP("dddd:dddd:dddd:dddd:dddd:dddd:dddd:dddd"))
err = filterV6.AddIP(ConntrackReplySrcIP, net.ParseIP("dddd:dddd:dddd:dddd:dddd:dddd:dddd:dddd"))
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
v4Match, v6Match = applyFilter(flowList, filterV4, filterV6)
if v4Match != 1 || v6Match != 1 {
@ -457,10 +490,16 @@ func TestConntrackFilter(t *testing.T) {
// DstIP for NAT
filterV4 = &ConntrackFilter{}
filterV4.AddIP(ConntrackReplyDstIP, net.ParseIP("192.168.1.1"))
err = filterV4.AddIP(ConntrackReplyDstIP, net.ParseIP("192.168.1.1"))
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
filterV6 = &ConntrackFilter{}
filterV6.AddIP(ConntrackReplyDstIP, net.ParseIP("dddd:dddd:dddd:dddd:dddd:dddd:dddd:dddd"))
err = filterV6.AddIP(ConntrackReplyDstIP, net.ParseIP("dddd:dddd:dddd:dddd:dddd:dddd:dddd:dddd"))
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
v4Match, v6Match = applyFilter(flowList, filterV4, filterV6)
if v4Match != 2 || v6Match != 0 {
@ -469,24 +508,171 @@ func TestConntrackFilter(t *testing.T) {
// AnyIp for Nat
filterV4 = &ConntrackFilter{}
filterV4.AddIP(ConntrackReplyAnyIP, net.ParseIP("192.168.1.1"))
err = filterV4.AddIP(ConntrackReplyAnyIP, net.ParseIP("192.168.1.1"))
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
filterV6 = &ConntrackFilter{}
filterV6.AddIP(ConntrackReplyAnyIP, net.ParseIP("eeee:eeee:eeee:eeee:eeee:eeee:eeee:eeee"))
err = filterV6.AddIP(ConntrackReplyAnyIP, net.ParseIP("eeee:eeee:eeee:eeee:eeee:eeee:eeee:eeee"))
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
v4Match, v6Match = applyFilter(flowList, filterV4, filterV6)
if v4Match != 2 || v6Match != 1 {
t.Fatalf("Error, there should be an exact match, v4:%d, v6:%d", v4Match, v6Match)
}
// SrcPort filter
// SrcIPNet filter
filterV4 = &ConntrackFilter{}
filterV4.AddProtocol(6)
filterV4.AddPort(ConntrackOrigSrcPort, 5000)
ipNet, err := ParseIPNet("10.0.0.0/12")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
err = filterV4.AddIPNet(ConntrackOrigSrcIP, ipNet)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
filterV6 = &ConntrackFilter{}
filterV6.AddProtocol(132)
filterV6.AddPort(ConntrackOrigSrcPort, 1000)
ipNet, err = ParseIPNet("eeee:eeee:eeee:eeee::/64")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
err = filterV6.AddIPNet(ConntrackOrigSrcIP, ipNet)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
v4Match, v6Match = applyFilter(flowList, filterV4, filterV6)
if v4Match != 2 || v6Match != 1 {
t.Fatalf("Error, there should be only 1 match, v4:%d, v6:%d", v4Match, v6Match)
}
// DstIpNet filter
filterV4 = &ConntrackFilter{}
ipNet, err = ParseIPNet("20.0.0.0/12")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
err = filterV4.AddIPNet(ConntrackOrigDstIP, ipNet)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
filterV6 = &ConntrackFilter{}
ipNet, err = ParseIPNet("dddd:dddd:dddd:dddd::/64")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
err = filterV6.AddIPNet(ConntrackOrigDstIP, ipNet)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
v4Match, v6Match = applyFilter(flowList, filterV4, filterV6)
if v4Match != 2 || v6Match != 1 {
t.Fatalf("Error, there should be only 1 match, v4:%d, v6:%d", v4Match, v6Match)
}
// SrcIPNet for NAT
filterV4 = &ConntrackFilter{}
ipNet, err = ParseIPNet("20.0.0.0/12")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
err = filterV4.AddIPNet(ConntrackReplySrcIP, ipNet)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
filterV6 = &ConntrackFilter{}
ipNet, err = ParseIPNet("dddd:dddd:dddd:dddd::/64")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
err = filterV6.AddIPNet(ConntrackReplySrcIP, ipNet)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
v4Match, v6Match = applyFilter(flowList, filterV4, filterV6)
if v4Match != 2 || v6Match != 1 {
t.Fatalf("Error, there should be only 1 match, v4:%d, v6:%d", v4Match, v6Match)
}
// DstIPNet for NAT
filterV4 = &ConntrackFilter{}
ipNet, err = ParseIPNet("192.168.0.0/12")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
err = filterV4.AddIPNet(ConntrackReplyDstIP, ipNet)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
filterV6 = &ConntrackFilter{}
ipNet, err = ParseIPNet("dddd:dddd:dddd:dddd::/64")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
err = filterV6.AddIPNet(ConntrackReplyDstIP, ipNet)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
v4Match, v6Match = applyFilter(flowList, filterV4, filterV6)
if v4Match != 2 || v6Match != 0 {
t.Fatalf("Error, there should be an exact match, v4:%d, v6:%d", v4Match, v6Match)
}
// AnyIpNet for Nat
filterV4 = &ConntrackFilter{}
ipNet, err = ParseIPNet("192.168.0.0/12")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
err = filterV4.AddIPNet(ConntrackReplyAnyIP, ipNet)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
filterV6 = &ConntrackFilter{}
ipNet, err = ParseIPNet("eeee:eeee:eeee:eeee::/64")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
err = filterV6.AddIPNet(ConntrackReplyAnyIP, ipNet)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
v4Match, v6Match = applyFilter(flowList, filterV4, filterV6)
if v4Match != 2 || v6Match != 1 {
t.Fatalf("Error, there should be an exact match, v4:%d, v6:%d", v4Match, v6Match)
}
// SrcPort filter
filterV4 = &ConntrackFilter{}
err = filterV4.AddProtocol(6)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
err = filterV4.AddPort(ConntrackOrigSrcPort, 5000)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
filterV6 = &ConntrackFilter{}
err = filterV6.AddProtocol(132)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
err = filterV6.AddPort(ConntrackOrigSrcPort, 1000)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
v4Match, v6Match = applyFilter(flowList, filterV4, filterV6)
if v4Match != 1 || v6Match != 1 {
@ -495,12 +681,24 @@ func TestConntrackFilter(t *testing.T) {
// DstPort filter
filterV4 = &ConntrackFilter{}
filterV4.AddProtocol(6)
filterV4.AddPort(ConntrackOrigDstPort, 6000)
err = filterV4.AddProtocol(6)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
err = filterV4.AddPort(ConntrackOrigDstPort, 6000)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
filterV6 = &ConntrackFilter{}
filterV6.AddProtocol(132)
filterV6.AddPort(ConntrackOrigDstPort, 2000)
err = filterV6.AddProtocol(132)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
err = filterV6.AddPort(ConntrackOrigDstPort, 2000)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
v4Match, v6Match = applyFilter(flowList, filterV4, filterV6)
if v4Match != 1 || v6Match != 1 {