diff --git a/fou.go b/fou.go index 71e73c3..ea9f6cf 100644 --- a/fou.go +++ b/fou.go @@ -1,16 +1,7 @@ package netlink import ( - "errors" -) - -var ( - // ErrAttrHeaderTruncated is returned when a netlink attribute's header is - // truncated. - ErrAttrHeaderTruncated = errors.New("attribute header truncated") - // ErrAttrBodyTruncated is returned when a netlink attribute's body is - // truncated. - ErrAttrBodyTruncated = errors.New("attribute body truncated") + "net" ) type Fou struct { @@ -18,4 +9,8 @@ type Fou struct { Port int Protocol int EncapType int + Local net.IP + Peer net.IP + PeerPort int + IfIndex int } diff --git a/fou_linux.go b/fou_linux.go index ed55b2b..72bfb61 100644 --- a/fou_linux.go +++ b/fou_linux.go @@ -1,3 +1,4 @@ +//go:build linux // +build linux package netlink @@ -5,6 +6,8 @@ package netlink import ( "encoding/binary" "errors" + "log" + "net" "github.com/vishvananda/netlink/nl" "golang.org/x/sys/unix" @@ -29,6 +32,12 @@ const ( FOU_ATTR_IPPROTO FOU_ATTR_TYPE FOU_ATTR_REMCSUM_NOPARTIAL + FOU_ATTR_LOCAL_V4 + FOU_ATTR_LOCAL_V6 + FOU_ATTR_PEER_V4 + FOU_ATTR_PEER_V6 + FOU_ATTR_PEER_PORT + FOU_ATTR_IFINDEX FOU_ATTR_MAX = FOU_ATTR_REMCSUM_NOPARTIAL ) @@ -169,41 +178,28 @@ func (h *Handle) FouList(fam int) ([]Fou, error) { } func deserializeFouMsg(msg []byte) (Fou, error) { - // we'll skip to byte 4 to first attribute - msg = msg[3:] - var shift int fou := Fou{} - for { - // attribute header is at least 16 bits - if len(msg) < 4 { - return fou, ErrAttrHeaderTruncated - } - - lgt := int(binary.BigEndian.Uint16(msg[0:2])) - if len(msg) < lgt+4 { - return fou, ErrAttrBodyTruncated - } - attr := binary.BigEndian.Uint16(msg[2:4]) - - shift = lgt + 3 - switch attr { + for attr := range nl.ParseAttributes(msg[4:]) { + switch attr.Type { case FOU_ATTR_AF: - fou.Family = int(msg[5]) + fou.Family = int(attr.Value[0]) case FOU_ATTR_PORT: - fou.Port = int(binary.BigEndian.Uint16(msg[5:7])) - // port is 2 bytes - shift = lgt + 2 + fou.Port = int(networkOrder.Uint16(attr.Value)) case FOU_ATTR_IPPROTO: - fou.Protocol = int(msg[5]) + fou.Protocol = int(attr.Value[0]) case FOU_ATTR_TYPE: - fou.EncapType = int(msg[5]) - } - - msg = msg[shift:] - - if len(msg) < 4 { - break + fou.EncapType = int(attr.Value[0]) + case FOU_ATTR_LOCAL_V4, FOU_ATTR_LOCAL_V6: + fou.Local = net.IP(attr.Value) + case FOU_ATTR_PEER_V4, FOU_ATTR_PEER_V6: + fou.Peer = net.IP(attr.Value) + case FOU_ATTR_PEER_PORT: + fou.PeerPort = int(networkOrder.Uint16(attr.Value)) + case FOU_ATTR_IFINDEX: + fou.IfIndex = int(native.Uint16(attr.Value)) + default: + log.Printf("unknown fou attribute from kernel: %+v %v", attr, attr.Type&nl.NLA_TYPE_MASK) } } diff --git a/fou_test.go b/fou_test.go index b252bba..9667fcd 100644 --- a/fou_test.go +++ b/fou_test.go @@ -1,8 +1,10 @@ +//go:build linux // +build linux package netlink import ( + "net" "testing" ) @@ -33,20 +35,50 @@ func TestFouDeserializeMsg(t *testing.T) { } } - // deserialize truncated attribute header - msg = []byte{3, 1, 0, 0, 5, 0} - if _, err := deserializeFouMsg(msg); err == nil { - t.Error("expected attribute header truncated error") - } else if err != ErrAttrHeaderTruncated { - t.Errorf("unexpected error: %s", err.Error()) + // deserialize a valid message(kernel >= 5.2) + msg = []byte{3, 1, 0, 0, 5, 0, 2, 0, 2, 0, 0, 0, 6, 0, 1, 0, 43, 103, 0, 0, 6, 0, 10, 0, 86, 206, 0, 0, 5, 0, 3, 0, 0, 0, 0, 0, 5, 0, 4, 0, 2, 0, 0, 0, 8, 0, 11, 0, 0, 0, 0, 0, 8, 0, 6, 0, 1, 2, 3, 4, 8, 0, 8, 0, 5, 6, 7, 8} + if fou, err := deserializeFouMsg(msg); err != nil { + t.Error(err.Error()) + } else { + if fou.Family != FAMILY_V4 { + t.Errorf("expected family %d, got %d", FAMILY_V4, fou.Family) + } + + if fou.Port != 11111 { + t.Errorf("expected port 5555, got %d", fou.Port) + } + + if fou.Protocol != 0 { // gue + t.Errorf("expected protocol 0, got %d", fou.Protocol) + } + + if fou.IfIndex != 0 { + t.Errorf("expected ifindex 0, got %d", fou.Protocol) + } + + if fou.EncapType != FOU_ENCAP_GUE { + t.Errorf("expected encap type %d, got %d", FOU_ENCAP_GUE, fou.EncapType) + } + + if expected := net.IPv4(1, 2, 3, 4); !fou.Local.Equal(expected) { + t.Errorf("expected local %v, got %v", expected, fou.Local) + } + + if expected := net.IPv4(5, 6, 7, 8); !fou.Peer.Equal(expected) { + t.Errorf("expected peer %v, got %v", expected, fou.Peer) + } + + if fou.PeerPort != 22222 { + t.Errorf("expected peer port 0, got %d", fou.PeerPort) + } } - // deserialize truncated attribute header - msg = []byte{3, 1, 0, 0, 5, 0, 2, 0, 2, 0, 0} - if _, err := deserializeFouMsg(msg); err == nil { - t.Error("expected attribute body truncated error") - } else if err != ErrAttrBodyTruncated { + // unknown attribute should be skipped + msg = []byte{3, 1, 0, 0, 5, 0, 112, 0, 2, 0, 0, 0, 5, 0, 2, 0, 2, 0, 0} + if fou, err := deserializeFouMsg(msg); err != nil { t.Errorf("unexpected error: %s", err.Error()) + } else if fou.Family != 2 { + t.Errorf("expected family 2, got %d", fou.Family) } } diff --git a/fou_unspecified.go b/fou_unspecified.go index 3a8365b..7e55015 100644 --- a/fou_unspecified.go +++ b/fou_unspecified.go @@ -1,3 +1,4 @@ +//go:build !linux // +build !linux package netlink