Adds ConntrackCreate & ConntrackUpdate

- Also refactored setUpNetlinkTestWithKModule function to reduce redundant NS's created and checks made.

 - Add conntrack protoinfo TCP support + groundwork for other protocols.

 - Tests to cover the above.
This commit is contained in:
Alex O'Regan 2023-11-24 16:56:37 +00:00 committed by Alessandro Boch
parent a1c5e0237d
commit aed23dbf5e
6 changed files with 1008 additions and 42 deletions

1
.gitignore vendored
View File

@ -1 +1,2 @@
.idea/
.vscode/

View File

@ -55,6 +55,18 @@ func ConntrackTableFlush(table ConntrackTableType) error {
return pkgHandle.ConntrackTableFlush(table)
}
// ConntrackCreate creates a new conntrack flow in the desired table
// conntrack -I [table] Create a conntrack or expectation
func ConntrackCreate(table ConntrackTableType, family InetFamily, flow *ConntrackFlow) error {
return pkgHandle.ConntrackCreate(table, family, flow)
}
// ConntrackUpdate updates an existing conntrack flow in the desired table using the handle
// conntrack -U [table] Update a conntrack
func ConntrackUpdate(table ConntrackTableType, family InetFamily, flow *ConntrackFlow) error {
return pkgHandle.ConntrackUpdate(table, family, flow)
}
// ConntrackDeleteFilter deletes entries on the specified table on the base of the filter
// conntrack -D [table] parameters Delete conntrack or expectation
func ConntrackDeleteFilter(table ConntrackTableType, family InetFamily, filter CustomConntrackFilter) (uint, error) {
@ -87,6 +99,40 @@ func (h *Handle) ConntrackTableFlush(table ConntrackTableType) error {
return err
}
// ConntrackCreate creates a new conntrack flow in the desired table using the handle
// conntrack -I [table] Create a conntrack or expectation
func (h *Handle) ConntrackCreate(table ConntrackTableType, family InetFamily, flow *ConntrackFlow) error {
req := h.newConntrackRequest(table, family, nl.IPCTNL_MSG_CT_NEW, unix.NLM_F_ACK|unix.NLM_F_CREATE)
attr, err := flow.toNlData()
if err != nil {
return err
}
for _, a := range attr {
req.AddData(a)
}
_, err = req.Execute(unix.NETLINK_NETFILTER, 0)
return err
}
// ConntrackUpdate updates an existing conntrack flow in the desired table using the handle
// conntrack -U [table] Update a conntrack
func (h *Handle) ConntrackUpdate(table ConntrackTableType, family InetFamily, flow *ConntrackFlow) error {
req := h.newConntrackRequest(table, family, nl.IPCTNL_MSG_CT_NEW, unix.NLM_F_ACK|unix.NLM_F_REPLACE)
attr, err := flow.toNlData()
if err != nil {
return err
}
for _, a := range attr {
req.AddData(a)
}
_, err = req.Execute(unix.NETLINK_NETFILTER, 0)
return err
}
// ConntrackDeleteFilter deletes entries on the specified table on the base of the filter using the netlink handle passed
// conntrack -D [table] parameters Delete conntrack or expectation
func (h *Handle) ConntrackDeleteFilter(table ConntrackTableType, family InetFamily, filter CustomConntrackFilter) (uint, error) {
@ -128,10 +174,44 @@ func (h *Handle) dumpConntrackTable(table ConntrackTableType, family InetFamily)
return req.Execute(unix.NETLINK_NETFILTER, 0)
}
// ProtoInfo wraps an L4-protocol structure - roughly corresponds to the
// __nfct_protoinfo union found in libnetfilter_conntrack/include/internal/object.h.
// Currently, only protocol names, and TCP state is supported.
type ProtoInfo interface {
Protocol() string
}
// ProtoInfoTCP corresponds to the `tcp` struct of the __nfct_protoinfo union.
// Only TCP state is currently supported.
type ProtoInfoTCP struct {
State uint8
}
// Protocol returns "tcp".
func (*ProtoInfoTCP) Protocol() string {return "tcp"}
func (p *ProtoInfoTCP) toNlData() ([]*nl.RtAttr, error) {
ctProtoInfo := nl.NewRtAttr(unix.NLA_F_NESTED | nl.CTA_PROTOINFO, []byte{})
ctProtoInfoTCP := nl.NewRtAttr(unix.NLA_F_NESTED|nl.CTA_PROTOINFO_TCP, []byte{})
ctProtoInfoTCPState := nl.NewRtAttr(nl.CTA_PROTOINFO_TCP_STATE, nl.Uint8Attr(p.State))
ctProtoInfoTCP.AddChild(ctProtoInfoTCPState)
ctProtoInfo.AddChild(ctProtoInfoTCP)
return []*nl.RtAttr{ctProtoInfo}, nil
}
// ProtoInfoSCTP only supports the protocol name.
type ProtoInfoSCTP struct {}
// Protocol returns "sctp".
func (*ProtoInfoSCTP) Protocol() string {return "sctp"}
// ProtoInfoDCCP only supports the protocol name.
type ProtoInfoDCCP struct {}
// Protocol returns "dccp".
func (*ProtoInfoDCCP) Protocol() string {return "dccp"}
// The full conntrack flow structure is very complicated and can be found in the file:
// http://git.netfilter.org/libnetfilter_conntrack/tree/include/internal/object.h
// For the time being, the structure below allows to parse and extract the base information of a flow
type ipTuple struct {
type IPTuple struct {
Bytes uint64
DstIP net.IP
DstPort uint16
@ -141,16 +221,49 @@ type ipTuple struct {
SrcPort uint16
}
// toNlData generates the inner fields of a nested tuple netlink datastructure
// does not generate the "nested"-flagged outer message.
func (t *IPTuple) toNlData(family uint8) ([]*nl.RtAttr, error) {
var srcIPsFlag, dstIPsFlag int
if family == nl.FAMILY_V4 {
srcIPsFlag = nl.CTA_IP_V4_SRC
dstIPsFlag = nl.CTA_IP_V4_DST
} else if family == nl.FAMILY_V6 {
srcIPsFlag = nl.CTA_IP_V6_SRC
dstIPsFlag = nl.CTA_IP_V6_DST
} else {
return []*nl.RtAttr{}, fmt.Errorf("couldn't generate netlink message for tuple due to unrecognized FamilyType '%d'", family)
}
ctTupleIP := nl.NewRtAttr(unix.NLA_F_NESTED|nl.CTA_TUPLE_IP, nil)
ctTupleIPSrc := nl.NewRtAttr(srcIPsFlag, t.SrcIP)
ctTupleIP.AddChild(ctTupleIPSrc)
ctTupleIPDst := nl.NewRtAttr(dstIPsFlag, t.DstIP)
ctTupleIP.AddChild(ctTupleIPDst)
ctTupleProto := nl.NewRtAttr(unix.NLA_F_NESTED|nl.CTA_TUPLE_PROTO, nil)
ctTupleProtoNum := nl.NewRtAttr(nl.CTA_PROTO_NUM, []byte{t.Protocol})
ctTupleProto.AddChild(ctTupleProtoNum)
ctTupleProtoSrcPort := nl.NewRtAttr(nl.CTA_PROTO_SRC_PORT, nl.BEUint16Attr(t.SrcPort))
ctTupleProto.AddChild(ctTupleProtoSrcPort)
ctTupleProtoDstPort := nl.NewRtAttr(nl.CTA_PROTO_DST_PORT, nl.BEUint16Attr(t.DstPort))
ctTupleProto.AddChild(ctTupleProtoDstPort, )
return []*nl.RtAttr{ctTupleIP, ctTupleProto}, nil
}
type ConntrackFlow struct {
FamilyType uint8
Forward ipTuple
Reverse ipTuple
Forward IPTuple
Reverse IPTuple
Mark uint32
Zone uint16
TimeStart uint64
TimeStop uint64
TimeOut uint32
Labels []byte
ProtoInfo ProtoInfo
}
func (s *ConntrackFlow) String() string {
@ -175,6 +288,85 @@ func (s *ConntrackFlow) String() string {
return res
}
// toNlData generates netlink messages representing the flow.
func (s *ConntrackFlow) toNlData() ([]*nl.RtAttr, error) {
var payload []*nl.RtAttr
// The message structure is built as follows:
// <len, NLA_F_NESTED|CTA_TUPLE_ORIG>
// <len, NLA_F_NESTED|CTA_TUPLE_IP>
// <len, [CTA_IP_V4_SRC|CTA_IP_V6_SRC]>
// <IP>
// <len, [CTA_IP_V4_DST|CTA_IP_V6_DST]>
// <IP>
// <len, NLA_F_NESTED|nl.CTA_TUPLE_PROTO>
// <len, CTA_PROTO_NUM>
// <uint8>
// <len, CTA_PROTO_SRC_PORT>
// <BEuint16>
// <len, CTA_PROTO_DST_PORT>
// <BEuint16>
// <len, NLA_F_NESTED|CTA_TUPLE_REPLY>
// <len, NLA_F_NESTED|CTA_TUPLE_IP>
// <len, [CTA_IP_V4_SRC|CTA_IP_V6_SRC]>
// <IP>
// <len, [CTA_IP_V4_DST|CTA_IP_V6_DST]>
// <IP>
// <len, NLA_F_NESTED|nl.CTA_TUPLE_PROTO>
// <len, CTA_PROTO_NUM>
// <uint8>
// <len, CTA_PROTO_SRC_PORT>
// <BEuint16>
// <len, CTA_PROTO_DST_PORT>
// <BEuint16>
// <len, CTA_STATUS>
// <uint64>
// <len, CTA_MARK>
// <BEuint64>
// <len, CTA_TIMEOUT>
// <BEuint64>
// <len, NLA_F_NESTED|CTA_PROTOINFO>
// CTA_TUPLE_ORIG
ctTupleOrig := nl.NewRtAttr(unix.NLA_F_NESTED|nl.CTA_TUPLE_ORIG, nil)
forwardFlowAttrs, err := s.Forward.toNlData(s.FamilyType)
if err != nil {
return nil, fmt.Errorf("couldn't generate netlink data for conntrack forward flow: %w", err)
}
for _, a := range forwardFlowAttrs {
ctTupleOrig.AddChild(a)
}
// CTA_TUPLE_REPLY
ctTupleReply := nl.NewRtAttr(unix.NLA_F_NESTED|nl.CTA_TUPLE_REPLY, nil)
reverseFlowAttrs, err := s.Reverse.toNlData(s.FamilyType)
if err != nil {
return nil, fmt.Errorf("couldn't generate netlink data for conntrack reverse flow: %w", err)
}
for _, a := range reverseFlowAttrs {
ctTupleReply.AddChild(a)
}
ctMark := nl.NewRtAttr(nl.CTA_MARK, nl.BEUint32Attr(s.Mark))
ctTimeout := nl.NewRtAttr(nl.CTA_TIMEOUT, nl.BEUint32Attr(s.TimeOut))
payload = append(payload, ctTupleOrig, ctTupleReply, ctMark, ctTimeout)
if s.ProtoInfo != nil {
switch p := s.ProtoInfo.(type) {
case *ProtoInfoTCP:
attrs, err := p.toNlData()
if err != nil {
return nil, fmt.Errorf("couldn't generate netlink data for conntrack flow's TCP protoinfo: %w", err)
}
payload = append(payload, attrs...)
default:
return nil, errors.New("couldn't generate netlink data for conntrack: field 'ProtoInfo' only supports TCP or nil")
}
}
return payload, nil
}
// This method parse the ip tuple structure
// The message structure is the following:
// <len, [CTA_IP_V4_SRC|CTA_IP_V6_SRC], 16 bytes for the IP>
@ -182,7 +374,7 @@ func (s *ConntrackFlow) String() string {
// <len, NLA_F_NESTED|nl.CTA_TUPLE_PROTO, 1 byte for the protocol, 3 bytes of padding>
// <len, CTA_PROTO_SRC_PORT, 2 bytes for the source port, 2 bytes of padding>
// <len, CTA_PROTO_DST_PORT, 2 bytes for the source port, 2 bytes of padding>
func parseIpTuple(reader *bytes.Reader, tpl *ipTuple) uint8 {
func parseIpTuple(reader *bytes.Reader, tpl *IPTuple) uint8 {
for i := 0; i < 2; i++ {
_, t, _, v := parseNfAttrTLV(reader)
switch t {
@ -201,7 +393,7 @@ func parseIpTuple(reader *bytes.Reader, tpl *ipTuple) uint8 {
tpl.Protocol = uint8(v[0])
}
// We only parse TCP & UDP headers. Skip the others.
if tpl.Protocol != 6 && tpl.Protocol != 17 {
if tpl.Protocol != unix.IPPROTO_TCP && tpl.Protocol != unix.IPPROTO_UDP {
// skip the rest
bytesRemaining := protoInfoTotalLen - protoInfoBytesRead
reader.Seek(int64(bytesRemaining), seekCurrent)
@ -250,9 +442,13 @@ func parseNfAttrTL(r *bytes.Reader) (isNested bool, attrType, len uint16) {
return isNested, attrType, len
}
func skipNfAttrValue(r *bytes.Reader, len uint16) {
// skipNfAttrValue seeks `r` past attr of length `len`.
// Maintains buffer alignment.
// Returns length of the seek performed.
func skipNfAttrValue(r *bytes.Reader, len uint16) uint16 {
len = (len + nl.NLA_ALIGNTO - 1) & ^(nl.NLA_ALIGNTO - 1)
r.Seek(int64(len), seekCurrent)
return len
}
func parseBERaw16(r *bytes.Reader, v *uint16) {
@ -267,6 +463,10 @@ func parseBERaw64(r *bytes.Reader, v *uint64) {
binary.Read(r, binary.BigEndian, v)
}
func parseRaw32(r *bytes.Reader, v *uint32) {
binary.Read(r, nl.NativeEndian(), v)
}
func parseByteAndPacketCounters(r *bytes.Reader) (bytes, packets uint64) {
for i := 0; i < 2; i++ {
switch _, t, _ := parseNfAttrTL(r); t {
@ -306,6 +506,60 @@ func parseTimeStamp(r *bytes.Reader, readSize uint16) (tstart, tstop uint64) {
}
func parseProtoInfoTCPState(r *bytes.Reader) (s uint8) {
binary.Read(r, binary.BigEndian, &s)
r.Seek(nl.SizeofNfattr - 1, seekCurrent)
return s
}
// parseProtoInfoTCP reads the entire nested protoinfo structure, but only parses the state attr.
func parseProtoInfoTCP(r *bytes.Reader, attrLen uint16) (*ProtoInfoTCP) {
p := new(ProtoInfoTCP)
bytesRead := 0
for bytesRead < int(attrLen) {
_, t, l := parseNfAttrTL(r)
bytesRead += nl.SizeofNfattr
switch t {
case nl.CTA_PROTOINFO_TCP_STATE:
p.State = parseProtoInfoTCPState(r)
bytesRead += nl.SizeofNfattr
default:
bytesRead += int(skipNfAttrValue(r, l))
}
}
return p
}
func parseProtoInfo(r *bytes.Reader, attrLen uint16) (p ProtoInfo) {
bytesRead := 0
for bytesRead < int(attrLen) {
_, t, l := parseNfAttrTL(r)
bytesRead += nl.SizeofNfattr
switch t {
case nl.CTA_PROTOINFO_TCP:
p = parseProtoInfoTCP(r, l)
bytesRead += int(l)
// No inner fields of DCCP / SCTP currently supported.
case nl.CTA_PROTOINFO_DCCP:
p = new(ProtoInfoDCCP)
skipped := skipNfAttrValue(r, l)
bytesRead += int(skipped)
case nl.CTA_PROTOINFO_SCTP:
p = new(ProtoInfoSCTP)
skipped := skipNfAttrValue(r, l)
bytesRead += int(skipped)
default:
skipped := skipNfAttrValue(r, l)
bytesRead += int(skipped)
}
}
return p
}
func parseTimeOut(r *bytes.Reader) (ttimeout uint32) {
parseBERaw32(r, &ttimeout)
return
@ -365,7 +619,7 @@ func parseRawData(data []byte) *ConntrackFlow {
case nl.CTA_TIMESTAMP:
s.TimeStart, s.TimeStop = parseTimeStamp(reader, l)
case nl.CTA_PROTOINFO:
skipNfAttrValue(reader, l)
s.ProtoInfo = parseProtoInfo(reader, l)
default:
skipNfAttrValue(reader, l)
}
@ -373,11 +627,11 @@ func parseRawData(data []byte) *ConntrackFlow {
switch t {
case nl.CTA_MARK:
s.Mark = parseConnectionMark(reader)
case nl.CTA_LABELS:
case nl.CTA_LABELS:
s.Labels = parseConnectionLabels(reader)
case nl.CTA_TIMEOUT:
s.TimeOut = parseTimeOut(reader)
case nl.CTA_STATUS, nl.CTA_USE, nl.CTA_ID:
case nl.CTA_ID, nl.CTA_STATUS, nl.CTA_USE:
skipNfAttrValue(reader, l)
case nl.CTA_ZONE:
s.Zone = parseConnectionZone(reader)

View File

@ -253,8 +253,8 @@ func TestConntrackTableDelete(t *testing.T) {
t.Skipf("Fails in CI: Flow creation fails")
}
skipUnlessRoot(t)
setUpNetlinkTestWithKModule(t, "nf_conntrack")
setUpNetlinkTestWithKModule(t, "nf_conntrack_netlink")
requiredModules := []string{"nf_conntrack", "nf_conntrack_netlink"}
k, m, err := KernelVersion()
if err != nil {
t.Fatal(err)
@ -262,9 +262,11 @@ func TestConntrackTableDelete(t *testing.T) {
// conntrack l3proto was unified since 4.19
// https://github.com/torvalds/linux/commit/a0ae2562c6c4b2721d9fddba63b7286c13517d9f
if k < 4 || k == 4 && m < 19 {
setUpNetlinkTestWithKModule(t, "nf_conntrack_ipv4")
requiredModules = append(requiredModules, "nf_conntrack_ipv4")
}
setUpNetlinkTestWithKModule(t, requiredModules...)
// Creates a new namespace and bring up the loopback interface
origns, ns, h := nsCreateAndEnter(t)
defer netns.Set(*origns)
@ -348,32 +350,32 @@ func TestConntrackTableDelete(t *testing.T) {
func TestConntrackFilter(t *testing.T) {
var flowList []ConntrackFlow
flowList = append(flowList, ConntrackFlow{
FamilyType: unix.AF_INET,
Forward: ipTuple{
SrcIP: net.ParseIP("10.0.0.1"),
DstIP: net.ParseIP("20.0.0.1"),
SrcPort: 1000,
DstPort: 2000,
Protocol: 17,
FamilyType: unix.AF_INET,
Forward: IPTuple{
SrcIP: net.ParseIP("10.0.0.1"),
DstIP: net.ParseIP("20.0.0.1"),
SrcPort: 1000,
DstPort: 2000,
Protocol: 17,
},
Reverse: IPTuple{
SrcIP: net.ParseIP("20.0.0.1"),
DstIP: net.ParseIP("192.168.1.1"),
SrcPort: 2000,
DstPort: 1000,
Protocol: 17,
},
},
Reverse: ipTuple{
SrcIP: net.ParseIP("20.0.0.1"),
DstIP: net.ParseIP("192.168.1.1"),
SrcPort: 2000,
DstPort: 1000,
Protocol: 17,
},
},
ConntrackFlow{
FamilyType: unix.AF_INET,
Forward: ipTuple{
Forward: IPTuple{
SrcIP: net.ParseIP("10.0.0.2"),
DstIP: net.ParseIP("20.0.0.2"),
SrcPort: 5000,
DstPort: 6000,
Protocol: 6,
},
Reverse: ipTuple{
Reverse: IPTuple{
SrcIP: net.ParseIP("20.0.0.2"),
DstIP: net.ParseIP("192.168.1.1"),
SrcPort: 6000,
@ -385,14 +387,14 @@ func TestConntrackFilter(t *testing.T) {
},
ConntrackFlow{
FamilyType: unix.AF_INET6,
Forward: ipTuple{
Forward: IPTuple{
SrcIP: net.ParseIP("eeee:eeee:eeee:eeee:eeee:eeee:eeee:eeee"),
DstIP: net.ParseIP("dddd:dddd:dddd:dddd:dddd:dddd:dddd:dddd"),
SrcPort: 1000,
DstPort: 2000,
Protocol: 132,
},
Reverse: ipTuple{
Reverse: IPTuple{
SrcIP: net.ParseIP("dddd:dddd:dddd:dddd:dddd:dddd:dddd:dddd"),
DstIP: net.ParseIP("eeee:eeee:eeee:eeee:eeee:eeee:eeee:eeee"),
SrcPort: 2000,
@ -979,3 +981,613 @@ func TestParseRawData(t *testing.T) {
})
}
}
// TestConntrackUpdateV4 first tries to update a non-existant IPv4 conntrack and asserts that an error occurs.
// It then creates a conntrack entry using and adjacent API method (ConntrackCreate), and attempts to update the value of the created conntrack.
func TestConntrackUpdateV4(t *testing.T) {
// Print timestamps in UTC
os.Setenv("TZ", "")
requiredModules := []string{"nf_conntrack", "nf_conntrack_netlink"}
k, m, err := KernelVersion()
if err != nil {
t.Fatal(err)
}
// Conntrack l3proto was unified since 4.19
// https://github.com/torvalds/linux/commit/a0ae2562c6c4b2721d9fddba63b7286c13517d9f
if k < 4 || k == 4 && m < 19 {
requiredModules = append(requiredModules, "nf_conntrack_ipv4")
}
// Implicitly skips test if not root:
nsStr, teardown := setUpNamedNetlinkTestWithKModule(t, requiredModules...)
defer teardown()
ns, err := netns.GetFromName(nsStr)
if err != nil {
t.Fatalf("couldn't get handle to generated namespace: %s", err)
}
h, err := NewHandleAt(ns, nl.FAMILY_V4)
if err != nil {
t.Fatalf("failed to create netlink handle: %s", err)
}
flow := ConntrackFlow{
FamilyType: FAMILY_V4,
Forward: IPTuple{
SrcIP: net.IP{234,234,234,234},
DstIP: net.IP{123,123,123,123},
SrcPort: 48385,
DstPort: 53,
Protocol: unix.IPPROTO_TCP,
},
Reverse: IPTuple{
SrcIP: net.IP{123,123,123,123},
DstIP: net.IP{234,234,234,234},
SrcPort: 53,
DstPort: 48385,
Protocol: unix.IPPROTO_TCP,
},
// No point checking equivalence of timeout, but value must
// be reasonable to allow for a potentially slow subsequent read.
TimeOut: 100,
Mark: 12,
ProtoInfo: &ProtoInfoTCP{
State: nl.TCP_CONNTRACK_SYN_SENT2,
},
}
err = h.ConntrackUpdate(ConntrackTable, nl.FAMILY_V4, &flow)
if err == nil {
t.Fatalf("expected an error to occur when trying to update a non-existant conntrack: %+v", flow)
}
err = h.ConntrackCreate(ConntrackTable, nl.FAMILY_V4, &flow)
if err != nil {
t.Fatalf("failed to insert conntrack: %s", err)
}
flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4)
if err != nil {
t.Fatalf("failed to list conntracks following successful insert: %s", err)
}
filter := ConntrackFilter{
ipNetFilter: map[ConntrackFilterType]*net.IPNet{
ConntrackOrigSrcIP: NewIPNet(flow.Forward.SrcIP),
ConntrackOrigDstIP: NewIPNet(flow.Forward.DstIP),
ConntrackReplySrcIP: NewIPNet(flow.Reverse.SrcIP),
ConntrackReplyDstIP: NewIPNet(flow.Reverse.DstIP),
},
portFilter: map[ConntrackFilterType]uint16{
ConntrackOrigSrcPort: flow.Forward.SrcPort,
ConntrackOrigDstPort: flow.Forward.DstPort,
},
protoFilter:unix.IPPROTO_TCP,
}
var match *ConntrackFlow
for _, f := range flows {
if filter.MatchConntrackFlow(f) {
match = f
break
}
}
if match == nil {
t.Fatalf("Didn't find any matching conntrack entries for original flow: %+v\n Filter used: %+v", flow, filter)
} else {
t.Logf("Found entry in conntrack table matching original flow: %+v labels=%+v", match, match.Labels)
}
checkFlowsEqual(t, &flow, match)
checkProtoInfosEqual(t, flow.ProtoInfo, match.ProtoInfo)
// Change the conntrack and update the kernel entry.
flow.Mark = 10
flow.ProtoInfo = &ProtoInfoTCP{
State: nl.TCP_CONNTRACK_ESTABLISHED,
}
err = h.ConntrackUpdate(ConntrackTable, nl.FAMILY_V4, &flow)
if err != nil {
t.Fatalf("failed to update conntrack with new mark: %s", err)
}
// Look for updated conntrack.
flows, err = h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4)
if err != nil {
t.Fatalf("failed to list conntracks following successful update: %s", err)
}
var updatedMatch *ConntrackFlow
for _, f := range flows {
if filter.MatchConntrackFlow(f) {
updatedMatch = f
break
}
}
if updatedMatch == nil {
t.Fatalf("Didn't find any matching conntrack entries for updated flow: %+v\n Filter used: %+v", flow, filter)
} else {
t.Logf("Found entry in conntrack table matching updated flow: %+v labels=%+v", updatedMatch, updatedMatch.Labels)
}
checkFlowsEqual(t, &flow, updatedMatch)
checkProtoInfosEqual(t, flow.ProtoInfo, updatedMatch.ProtoInfo)
}
// TestConntrackUpdateV6 first tries to update a non-existant IPv6 conntrack and asserts that an error occurs.
// It then creates a conntrack entry using and adjacent API method (ConntrackCreate), and attempts to update the value of the created conntrack.
func TestConntrackUpdateV6(t *testing.T) {
// Print timestamps in UTC
os.Setenv("TZ", "")
requiredModules := []string{"nf_conntrack", "nf_conntrack_netlink"}
k, m, err := KernelVersion()
if err != nil {
t.Fatal(err)
}
// Conntrack l3proto was unified since 4.19
// https://github.com/torvalds/linux/commit/a0ae2562c6c4b2721d9fddba63b7286c13517d9f
if k < 4 || k == 4 && m < 19 {
requiredModules = append(requiredModules, "nf_conntrack_ipv4")
}
// Implicitly skips test if not root:
nsStr, teardown := setUpNamedNetlinkTestWithKModule(t, requiredModules...)
defer teardown()
ns, err := netns.GetFromName(nsStr)
if err != nil {
t.Fatalf("couldn't get handle to generated namespace: %s", err)
}
h, err := NewHandleAt(ns, nl.FAMILY_V6)
if err != nil {
t.Fatalf("failed to create netlink handle: %s", err)
}
flow := ConntrackFlow{
FamilyType: FAMILY_V6,
Forward: IPTuple{
SrcIP: net.ParseIP("2001:db8::68"),
DstIP: net.ParseIP("2001:db9::32"),
SrcPort: 48385,
DstPort: 53,
Protocol: unix.IPPROTO_TCP,
},
Reverse: IPTuple{
SrcIP: net.ParseIP("2001:db9::32"),
DstIP: net.ParseIP("2001:db8::68"),
SrcPort: 53,
DstPort: 48385,
Protocol: unix.IPPROTO_TCP,
},
// No point checking equivalence of timeout, but value must
// be reasonable to allow for a potentially slow subsequent read.
TimeOut: 100,
Mark: 12,
ProtoInfo: &ProtoInfoTCP{
State: nl.TCP_CONNTRACK_SYN_SENT2,
},
}
err = h.ConntrackUpdate(ConntrackTable, nl.FAMILY_V6, &flow)
if err == nil {
t.Fatalf("expected an error to occur when trying to update a non-existant conntrack: %+v", flow)
}
err = h.ConntrackCreate(ConntrackTable, nl.FAMILY_V6, &flow)
if err != nil {
t.Fatalf("failed to insert conntrack: %s", err)
}
flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V6)
if err != nil {
t.Fatalf("failed to list conntracks following successful insert: %s", err)
}
filter := ConntrackFilter{
ipNetFilter: map[ConntrackFilterType]*net.IPNet{
ConntrackOrigSrcIP: NewIPNet(flow.Forward.SrcIP),
ConntrackOrigDstIP: NewIPNet(flow.Forward.DstIP),
ConntrackReplySrcIP: NewIPNet(flow.Reverse.SrcIP),
ConntrackReplyDstIP: NewIPNet(flow.Reverse.DstIP),
},
portFilter: map[ConntrackFilterType]uint16{
ConntrackOrigSrcPort: flow.Forward.SrcPort,
ConntrackOrigDstPort: flow.Forward.DstPort,
},
protoFilter:unix.IPPROTO_TCP,
}
var match *ConntrackFlow
for _, f := range flows {
if filter.MatchConntrackFlow(f) {
match = f
break
}
}
if match == nil {
t.Fatalf("didn't find any matching conntrack entries for original flow: %+v\n Filter used: %+v", flow, filter)
} else {
t.Logf("found entry in conntrack table matching original flow: %+v labels=%+v", match, match.Labels)
}
checkFlowsEqual(t, &flow, match)
checkProtoInfosEqual(t, flow.ProtoInfo, match.ProtoInfo)
// Change the conntrack and update the kernel entry.
flow.Mark = 10
flow.ProtoInfo = &ProtoInfoTCP{
State: nl.TCP_CONNTRACK_ESTABLISHED,
}
err = h.ConntrackUpdate(ConntrackTable, nl.FAMILY_V6, &flow)
if err != nil {
t.Fatalf("failed to update conntrack with new mark: %s", err)
}
// Look for updated conntrack.
flows, err = h.ConntrackTableList(ConntrackTable, nl.FAMILY_V6)
if err != nil {
t.Fatalf("failed to list conntracks following successful update: %s", err)
}
var updatedMatch *ConntrackFlow
for _, f := range flows {
if filter.MatchConntrackFlow(f) {
updatedMatch = f
break
}
}
if updatedMatch == nil {
t.Fatalf("didn't find any matching conntrack entries for updated flow: %+v\n Filter used: %+v", flow, filter)
} else {
t.Logf("found entry in conntrack table matching updated flow: %+v labels=%+v", updatedMatch, updatedMatch.Labels)
}
checkFlowsEqual(t, &flow, updatedMatch)
checkProtoInfosEqual(t, flow.ProtoInfo, updatedMatch.ProtoInfo)
}
func TestConntrackCreateV4(t *testing.T) {
// Print timestamps in UTC
os.Setenv("TZ", "")
requiredModules := []string{"nf_conntrack", "nf_conntrack_netlink"}
k, m, err := KernelVersion()
if err != nil {
t.Fatal(err)
}
// Conntrack l3proto was unified since 4.19
// https://github.com/torvalds/linux/commit/a0ae2562c6c4b2721d9fddba63b7286c13517d9f
if k < 4 || k == 4 && m < 19 {
requiredModules = append(requiredModules, "nf_conntrack_ipv4")
}
// Implicitly skips test if not root:
nsStr, teardown := setUpNamedNetlinkTestWithKModule(t, requiredModules...)
defer teardown()
ns, err := netns.GetFromName(nsStr)
if err != nil {
t.Fatalf("couldn't get handle to generated namespace: %s", err)
}
h, err := NewHandleAt(ns, nl.FAMILY_V4)
if err != nil {
t.Fatalf("failed to create netlink handle: %s", err)
}
flow := ConntrackFlow{
FamilyType: FAMILY_V4,
Forward: IPTuple{
SrcIP: net.IP{234,234,234,234},
DstIP: net.IP{123,123,123,123},
SrcPort: 48385,
DstPort: 53,
Protocol: unix.IPPROTO_TCP,
},
Reverse: IPTuple{
SrcIP: net.IP{123,123,123,123},
DstIP: net.IP{234,234,234,234},
SrcPort: 53,
DstPort: 48385,
Protocol: unix.IPPROTO_TCP,
},
// No point checking equivalence of timeout, but value must
// be reasonable to allow for a potentially slow subsequent read.
TimeOut: 100,
Mark: 12,
ProtoInfo: &ProtoInfoTCP{
State: nl.TCP_CONNTRACK_ESTABLISHED,
},
}
err = h.ConntrackCreate(ConntrackTable, nl.FAMILY_V4, &flow)
if err != nil {
t.Fatalf("failed to insert conntrack: %s", err)
}
flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V4)
if err != nil {
t.Fatalf("failed to list conntracks following successful insert: %s", err)
}
filter := ConntrackFilter{
ipNetFilter: map[ConntrackFilterType]*net.IPNet{
ConntrackOrigSrcIP: NewIPNet(flow.Forward.SrcIP),
ConntrackOrigDstIP: NewIPNet(flow.Forward.DstIP),
ConntrackReplySrcIP: NewIPNet(flow.Reverse.SrcIP),
ConntrackReplyDstIP: NewIPNet(flow.Reverse.DstIP),
},
portFilter: map[ConntrackFilterType]uint16{
ConntrackOrigSrcPort: flow.Forward.SrcPort,
ConntrackOrigDstPort: flow.Forward.DstPort,
},
protoFilter:unix.IPPROTO_TCP,
}
var match *ConntrackFlow
for _, f := range flows {
if filter.MatchConntrackFlow(f) {
match = f
break
}
}
if match == nil {
t.Fatalf("didn't find any matching conntrack entries for original flow: %+v\n Filter used: %+v", flow, filter)
} else {
t.Logf("Found entry in conntrack table matching original flow: %+v labels=%+v", match, match.Labels)
}
checkFlowsEqual(t, &flow, match)
checkProtoInfosEqual(t, flow.ProtoInfo, match.ProtoInfo)
}
func TestConntrackCreateV6(t *testing.T) {
// Print timestamps in UTC
os.Setenv("TZ", "")
requiredModules := []string{"nf_conntrack", "nf_conntrack_netlink"}
k, m, err := KernelVersion()
if err != nil {
t.Fatal(err)
}
// Conntrack l3proto was unified since 4.19
// https://github.com/torvalds/linux/commit/a0ae2562c6c4b2721d9fddba63b7286c13517d9f
if k < 4 || k == 4 && m < 19 {
requiredModules = append(requiredModules, "nf_conntrack_ipv4")
}
// Implicitly skips test if not root:
nsStr, teardown := setUpNamedNetlinkTestWithKModule(t, requiredModules...)
defer teardown()
ns, err := netns.GetFromName(nsStr)
if err != nil {
t.Fatalf("couldn't get handle to generated namespace: %s", err)
}
h, err := NewHandleAt(ns, nl.FAMILY_V6)
if err != nil {
t.Fatalf("failed to create netlink handle: %s", err)
}
flow := ConntrackFlow{
FamilyType: FAMILY_V6,
Forward: IPTuple{
SrcIP: net.ParseIP("2001:db8::68"),
DstIP: net.ParseIP("2001:db9::32"),
SrcPort: 48385,
DstPort: 53,
Protocol: unix.IPPROTO_TCP,
},
Reverse: IPTuple{
SrcIP: net.ParseIP("2001:db9::32"),
DstIP: net.ParseIP("2001:db8::68"),
SrcPort: 53,
DstPort: 48385,
Protocol: unix.IPPROTO_TCP,
},
// No point checking equivalence of timeout, but value must
// be reasonable to allow for a potentially slow subsequent read.
TimeOut: 100,
Mark: 12,
ProtoInfo: &ProtoInfoTCP{
State: nl.TCP_CONNTRACK_ESTABLISHED,
},
}
err = h.ConntrackCreate(ConntrackTable, nl.FAMILY_V6, &flow)
if err != nil {
t.Fatalf("failed to insert conntrack: %s", err)
}
flows, err := h.ConntrackTableList(ConntrackTable, nl.FAMILY_V6)
if err != nil {
t.Fatalf("failed to list conntracks following successful insert: %s", err)
}
filter := ConntrackFilter{
ipNetFilter: map[ConntrackFilterType]*net.IPNet{
ConntrackOrigSrcIP: NewIPNet(flow.Forward.SrcIP),
ConntrackOrigDstIP: NewIPNet(flow.Forward.DstIP),
ConntrackReplySrcIP: NewIPNet(flow.Reverse.SrcIP),
ConntrackReplyDstIP: NewIPNet(flow.Reverse.DstIP),
},
portFilter: map[ConntrackFilterType]uint16{
ConntrackOrigSrcPort: flow.Forward.SrcPort,
ConntrackOrigDstPort: flow.Forward.DstPort,
},
protoFilter:unix.IPPROTO_TCP,
}
var match *ConntrackFlow
for _, f := range flows {
if filter.MatchConntrackFlow(f) {
match = f
break
}
}
if match == nil {
t.Fatalf("Didn't find any matching conntrack entries for original flow: %+v\n Filter used: %+v", flow, filter)
} else {
t.Logf("Found entry in conntrack table matching original flow: %+v labels=%+v", match, match.Labels)
}
// Other fields are implicitly correct due to the filter/match logic.
if match.Mark != flow.Mark {
t.Logf("Matched kernel entry did not have correct mark. Kernel: %d, Expected: %d", flow.Mark, match.Mark)
t.Fail()
}
checkProtoInfosEqual(t, flow.ProtoInfo, match.ProtoInfo)
}
// TestConntrackFlowToNlData generates a serialized representation of a
// ConntrackFlow and runs the resulting bytes back through `parseRawData` to validate.
func TestConntrackFlowToNlData(t *testing.T) {
flowV4 := ConntrackFlow{
FamilyType: FAMILY_V4,
Forward: IPTuple{
SrcIP: net.IP{234,234,234,234},
DstIP: net.IP{123,123,123,123},
SrcPort: 48385,
DstPort: 53,
Protocol: unix.IPPROTO_TCP,
},
Reverse: IPTuple{
SrcIP: net.IP{123,123,123,123},
DstIP: net.IP{234,234,234,234},
SrcPort: 53,
DstPort: 48385,
Protocol: unix.IPPROTO_TCP,
},
Mark: 5,
TimeOut: 10,
ProtoInfo: &ProtoInfoTCP{
State: nl.TCP_CONNTRACK_ESTABLISHED,
},
}
flowV6 := ConntrackFlow {
FamilyType: FAMILY_V6,
Forward: IPTuple{
SrcIP: net.ParseIP("2001:db8::68"),
DstIP: net.ParseIP("2001:db9::32"),
SrcPort: 48385,
DstPort: 53,
Protocol: unix.IPPROTO_TCP,
},
Reverse: IPTuple{
SrcIP: net.ParseIP("2001:db9::32"),
DstIP: net.ParseIP("2001:db8::68"),
SrcPort: 53,
DstPort: 48385,
Protocol: unix.IPPROTO_TCP,
},
Mark: 5,
TimeOut: 10,
ProtoInfo: &ProtoInfoTCP{
State: nl.TCP_CONNTRACK_ESTABLISHED,
},
}
var bytesV4, bytesV6 []byte
attrsV4, err := flowV4.toNlData()
if err != nil {
t.Fatalf("Error converting ConntrackFlow to netlink messages: %s", err)
}
// Mock nfgenmsg header
bytesV4 = append(bytesV4, flowV4.FamilyType,0,0,0)
for _, a := range attrsV4 {
bytesV4 = append(bytesV4, a.Serialize()...)
}
attrsV6, err := flowV6.toNlData()
if err != nil {
t.Fatalf("Error converting ConntrackFlow to netlink messages: %s", err)
}
// Mock nfgenmsg header
bytesV6 = append(bytesV6, flowV6.FamilyType,0,0,0)
for _, a := range attrsV6 {
bytesV6 = append(bytesV6, a.Serialize()...)
}
parsedFlowV4 := parseRawData(bytesV4)
checkFlowsEqual(t, &flowV4, parsedFlowV4)
checkProtoInfosEqual(t, flowV4.ProtoInfo, parsedFlowV4.ProtoInfo)
parsedFlowV6 := parseRawData(bytesV6)
checkFlowsEqual(t, &flowV6, parsedFlowV6)
checkProtoInfosEqual(t, flowV6.ProtoInfo, parsedFlowV6.ProtoInfo)
}
func checkFlowsEqual(t *testing.T, f1, f2 *ConntrackFlow) {
// No point checking timeout as it will differ between reads.
// Timestart and timestop may also differ.
if f1.FamilyType != f2.FamilyType {
t.Logf("Conntrack flow FamilyTypes differ. Tuple1: %d, Tuple2: %d.\n", f1.FamilyType, f2.FamilyType)
t.Fail()
}
if f1.Mark != f2.Mark {
t.Logf("Conntrack flow Marks differ. Tuple1: %d, Tuple2: %d.\n", f1.Mark, f2.Mark)
t.Fail()
}
if !tuplesEqual(f1.Forward, f2.Forward) {
t.Logf("Forward tuples mismatch. Tuple1 forward flow: %+v, Tuple2 forward flow: %+v.\n", f1.Forward, f2.Forward)
t.Fail()
}
if !tuplesEqual(f1.Reverse, f2.Reverse) {
t.Logf("Reverse tuples mismatch. Tuple1 reverse flow: %+v, Tuple2 reverse flow: %+v.\n", f1.Reverse, f2.Reverse)
t.Fail()
}
}
func checkProtoInfosEqual(t *testing.T, p1, p2 ProtoInfo) {
t.Logf("Checking protoinfo fields equal:\n\t p1: %+v\n\t p2: %+v", p1, p2)
if !protoInfosEqual(p1, p2) {
t.Logf("Protoinfo structs differ: P1: %+v, P2: %+v", p1, p2)
t.Fail()
}
}
func protoInfosEqual(p1, p2 ProtoInfo) bool {
if p1 == nil {
return p2 == nil
} else if p2 != nil {
return p1.Protocol() == p2.Protocol()
}
return false
}
func tuplesEqual(t1, t2 IPTuple) bool {
if t1.Bytes != t2.Bytes {
return false
}
if !t1.DstIP.Equal(t2.DstIP) {
return false
}
if !t1.SrcIP.Equal(t2.SrcIP) {
return false
}
if t1.DstPort != t2.DstPort {
return false
}
if t1.SrcPort != t2.SrcPort {
return false
}
if t1.Packets != t2.Packets {
return false
}
if t1.Protocol != t2.Protocol {
return false
}
return true
}

View File

@ -31,25 +31,36 @@ func skipUnlessRoot(t testing.TB) {
}
}
func skipUnlessKModuleLoaded(t *testing.T, module ...string) {
func skipUnlessKModuleLoaded(t *testing.T, moduleNames ...string) {
t.Helper()
file, err := ioutil.ReadFile("/proc/modules")
if err != nil {
t.Fatal("Failed to open /proc/modules", err)
}
for _, mod := range module {
found := false
for _, line := range strings.Split(string(file), "\n") {
foundRequiredMods := make(map[string]bool)
lines := strings.Split(string(file), "\n")
for _, name := range moduleNames {
foundRequiredMods[name] = false
for _, line := range lines {
n := strings.Split(line, " ")[0]
if n == mod {
found = true
if n == name {
foundRequiredMods[name] = true
break
}
}
}
failed := false
for _, name := range moduleNames {
if found, _ := foundRequiredMods[name]; !found {
t.Logf("Test requires missing kmodule %q.", name)
failed = true
}
if !found {
t.Skipf("Test requires kmodule %q.", mod)
}
}
if failed {
t.SkipNow()
}
}
@ -180,10 +191,43 @@ func setUpSEG6NetlinkTest(t *testing.T) tearDownNetlinkTest {
return setUpNetlinkTest(t)
}
func setUpNetlinkTestWithKModule(t *testing.T, name string) tearDownNetlinkTest {
skipUnlessKModuleLoaded(t, name)
func setUpNetlinkTestWithKModule(t *testing.T, moduleNames ...string) tearDownNetlinkTest {
skipUnlessKModuleLoaded(t, moduleNames...)
return setUpNetlinkTest(t)
}
func setUpNamedNetlinkTestWithKModule(t *testing.T, moduleNames ...string) (string, tearDownNetlinkTest) {
file, err := ioutil.ReadFile("/proc/modules")
if err != nil {
t.Fatal("Failed to open /proc/modules", err)
}
foundRequiredMods := make(map[string]bool)
lines := strings.Split(string(file), "\n")
for _, name := range moduleNames {
foundRequiredMods[name] = false
for _, line := range lines {
n := strings.Split(line, " ")[0]
if n == name {
foundRequiredMods[name] = true
break
}
}
}
failed := false
for _, name := range moduleNames {
if found, _ := foundRequiredMods[name]; !found {
t.Logf("Test requires missing kmodule %q.", name)
failed = true
}
}
if failed {
t.SkipNow()
}
return setUpNamedNetlinkTest(t)
}
func remountSysfs() error {
if err := unix.Mount("", "/", "none", unix.MS_SLAVE|unix.MS_REC, ""); err != nil {

View File

@ -15,6 +15,38 @@ var L4ProtoMap = map[uint8]string{
17: "udp",
}
// From https://git.netfilter.org/libnetfilter_conntrack/tree/include/libnetfilter_conntrack/libnetfilter_conntrack_tcp.h
// enum tcp_state {
// TCP_CONNTRACK_NONE,
// TCP_CONNTRACK_SYN_SENT,
// TCP_CONNTRACK_SYN_RECV,
// TCP_CONNTRACK_ESTABLISHED,
// TCP_CONNTRACK_FIN_WAIT,
// TCP_CONNTRACK_CLOSE_WAIT,
// TCP_CONNTRACK_LAST_ACK,
// TCP_CONNTRACK_TIME_WAIT,
// TCP_CONNTRACK_CLOSE,
// TCP_CONNTRACK_LISTEN, /* obsolete */
// #define TCP_CONNTRACK_SYN_SENT2 TCP_CONNTRACK_LISTEN
// TCP_CONNTRACK_MAX,
// TCP_CONNTRACK_IGNORE
// };
const (
TCP_CONNTRACK_NONE = 0
TCP_CONNTRACK_SYN_SENT = 1
TCP_CONNTRACK_SYN_RECV = 2
TCP_CONNTRACK_ESTABLISHED = 3
TCP_CONNTRACK_FIN_WAIT = 4
TCP_CONNTRACK_CLOSE_WAIT = 5
TCP_CONNTRACK_LAST_ACK = 6
TCP_CONNTRACK_TIME_WAIT = 7
TCP_CONNTRACK_CLOSE = 8
TCP_CONNTRACK_LISTEN = 9
TCP_CONNTRACK_SYN_SENT2 = 9
TCP_CONNTRACK_MAX = 10
TCP_CONNTRACK_IGNORE = 11
)
// All the following constants are coming from:
// https://github.com/torvalds/linux/blob/master/include/uapi/linux/netfilter/nfnetlink_conntrack.h
@ -31,6 +63,7 @@ var L4ProtoMap = map[uint8]string{
// IPCTNL_MSG_MAX
// };
const (
IPCTNL_MSG_CT_NEW = 0
IPCTNL_MSG_CT_GET = 1
IPCTNL_MSG_CT_DELETE = 2
)
@ -91,6 +124,7 @@ const (
CTA_ZONE = 18
CTA_TIMESTAMP = 20
CTA_LABELS = 22
CTA_LABELS_MASK = 23
)
// enum ctattr_tuple {
@ -151,7 +185,10 @@ const (
// };
// #define CTA_PROTOINFO_MAX (__CTA_PROTOINFO_MAX - 1)
const (
CTA_PROTOINFO_UNSPEC = 0
CTA_PROTOINFO_TCP = 1
CTA_PROTOINFO_DCCP = 2
CTA_PROTOINFO_SCTP = 3
)
// enum ctattr_protoinfo_tcp {

View File

@ -909,6 +909,12 @@ func Uint16Attr(v uint16) []byte {
return bytes
}
func BEUint16Attr(v uint16) []byte {
bytes := make([]byte, 2)
binary.BigEndian.PutUint16(bytes, v)
return bytes
}
func Uint32Attr(v uint32) []byte {
native := NativeEndian()
bytes := make([]byte, 4)
@ -916,6 +922,12 @@ func Uint32Attr(v uint32) []byte {
return bytes
}
func BEUint32Attr(v uint32) []byte {
bytes := make([]byte, 4)
binary.BigEndian.PutUint32(bytes, v)
return bytes
}
func Uint64Attr(v uint64) []byte {
native := NativeEndian()
bytes := make([]byte, 8)
@ -923,6 +935,12 @@ func Uint64Attr(v uint64) []byte {
return bytes
}
func BEUint64Attr(v uint64) []byte {
bytes := make([]byte, 8)
binary.BigEndian.PutUint64(bytes, v)
return bytes
}
func ParseRouteAttr(b []byte) ([]syscall.NetlinkRouteAttr, error) {
var attrs []syscall.NetlinkRouteAttr
for len(b) >= unix.SizeofRtAttr {