From fe3b5664d23a11b52ba59bece4ff29c52772a56b Mon Sep 17 00:00:00 2001 From: ISHIDA Wataru Date: Sun, 19 Feb 2017 03:22:09 +0000 Subject: [PATCH] support MPLS $ ip -M route add 100 dev eth0 $ ip -M route add 100 as to 200/300 dev eth0 $ ip -M route add 100 nexthop dev eth0 as to 200 \ nexthop dev eth1 as to 300 $ ip route add 10.10.0.0/24 encap mpls 200/300 dev eth0 $ ip route add 10.0.0.0/24 nexthop encap mpls 200 dev eth0 \ nexthop encap mpls 300 dev eth1 Signed-off-by: ISHIDA Wataru --- netlink_linux.go | 7 +- netlink_test.go | 20 ++++ nl/mpls_linux.go | 36 ++++++++ nl/nl_linux.go | 7 +- nl/route_linux.go | 30 +++++- nl/syscall.go | 25 +++++ route.go | 61 +++++++++++-- route_linux.go | 228 +++++++++++++++++++++++++++++++++++++++++----- route_test.go | 47 ++++++++++ 9 files changed, 423 insertions(+), 38 deletions(-) create mode 100644 nl/mpls_linux.go diff --git a/netlink_linux.go b/netlink_linux.go index 32d8537..a20d293 100644 --- a/netlink_linux.go +++ b/netlink_linux.go @@ -4,7 +4,8 @@ import "github.com/vishvananda/netlink/nl" // Family type definitions const ( - FAMILY_ALL = nl.FAMILY_ALL - FAMILY_V4 = nl.FAMILY_V4 - FAMILY_V6 = nl.FAMILY_V6 + FAMILY_ALL = nl.FAMILY_ALL + FAMILY_V4 = nl.FAMILY_V4 + FAMILY_V6 = nl.FAMILY_V6 + FAMILY_MPLS = nl.FAMILY_MPLS ) diff --git a/netlink_test.go b/netlink_test.go index 90aa74d..5037b7f 100644 --- a/netlink_test.go +++ b/netlink_test.go @@ -36,3 +36,23 @@ func setUpNetlinkTest(t *testing.T) tearDownNetlinkTest { runtime.UnlockOSThread() } } + +func setUpMPLSNetlinkTest(t *testing.T) tearDownNetlinkTest { + if _, err := os.Stat("/proc/sys/net/mpls/platform_labels"); err != nil { + msg := "Skipped test because it requires MPLS support." + log.Printf(msg) + t.Skip(msg) + } + f := setUpNetlinkTest(t) + setUpF := func(path, value string) { + file, err := os.Create(path) + defer file.Close() + if err != nil { + t.Fatalf("Failed to open %s: %s", path, err) + } + file.WriteString(value) + } + setUpF("/proc/sys/net/mpls/platform_labels", "1024") + setUpF("/proc/sys/net/mpls/conf/lo/input", "1") + return f +} diff --git a/nl/mpls_linux.go b/nl/mpls_linux.go new file mode 100644 index 0000000..3915b7e --- /dev/null +++ b/nl/mpls_linux.go @@ -0,0 +1,36 @@ +package nl + +import "encoding/binary" + +const ( + MPLS_LS_LABEL_SHIFT = 12 + MPLS_LS_S_SHIFT = 8 +) + +func EncodeMPLSStack(labels ...int) []byte { + b := make([]byte, 4*len(labels)) + for idx, label := range labels { + l := label << MPLS_LS_LABEL_SHIFT + if idx == len(labels)-1 { + l |= 1 << MPLS_LS_S_SHIFT + } + binary.BigEndian.PutUint32(b[idx*4:], uint32(l)) + } + return b +} + +func DecodeMPLSStack(buf []byte) []int { + if len(buf)%4 != 0 { + return nil + } + stack := make([]int, 0, len(buf)/4) + for len(buf) > 0 { + l := binary.BigEndian.Uint32(buf[:4]) + buf = buf[4:] + stack = append(stack, int(l)>>MPLS_LS_LABEL_SHIFT) + if (l>>MPLS_LS_S_SHIFT)&1 > 0 { + break + } + } + return stack +} diff --git a/nl/nl_linux.go b/nl/nl_linux.go index ab41fa1..fb9031e 100644 --- a/nl/nl_linux.go +++ b/nl/nl_linux.go @@ -17,9 +17,10 @@ import ( const ( // Family type definitions - FAMILY_ALL = syscall.AF_UNSPEC - FAMILY_V4 = syscall.AF_INET - FAMILY_V6 = syscall.AF_INET6 + FAMILY_ALL = syscall.AF_UNSPEC + FAMILY_V4 = syscall.AF_INET + FAMILY_V6 = syscall.AF_INET6 + FAMILY_MPLS = AF_MPLS ) // SupportedNlFamilies contains the list of netlink families this netlink package supports diff --git a/nl/route_linux.go b/nl/route_linux.go index f7db88d..1a064d6 100644 --- a/nl/route_linux.go +++ b/nl/route_linux.go @@ -43,12 +43,38 @@ func (msg *RtMsg) Serialize() []byte { type RtNexthop struct { syscall.RtNexthop + Children []NetlinkRequestData } func DeserializeRtNexthop(b []byte) *RtNexthop { return (*RtNexthop)(unsafe.Pointer(&b[0:syscall.SizeofRtNexthop][0])) } -func (msg *RtNexthop) Serialize() []byte { - return (*(*[syscall.SizeofRtNexthop]byte)(unsafe.Pointer(msg)))[:] +func (msg *RtNexthop) Len() int { + if len(msg.Children) == 0 { + return syscall.SizeofRtNexthop + } + + l := 0 + for _, child := range msg.Children { + l += rtaAlignOf(child.Len()) + } + l += syscall.SizeofRtNexthop + return rtaAlignOf(l) +} + +func (msg *RtNexthop) Serialize() []byte { + length := msg.Len() + msg.RtNexthop.Len = uint16(length) + buf := make([]byte, length) + copy(buf, (*(*[syscall.SizeofRtNexthop]byte)(unsafe.Pointer(msg)))[:]) + next := rtaAlignOf(syscall.SizeofRtNexthop) + if len(msg.Children) > 0 { + for _, child := range msg.Children { + childBuf := child.Serialize() + copy(buf[next:], childBuf) + next += rtaAlignOf(len(childBuf)) + } + } + return buf } diff --git a/nl/syscall.go b/nl/syscall.go index 18d54a4..3473e53 100644 --- a/nl/syscall.go +++ b/nl/syscall.go @@ -41,3 +41,28 @@ const ( SOCK_DIAG_BY_FAMILY = 20 /* linux.sock_diag.h */ TCPDIAG_NOCOOKIE = 0xFFFFFFFF /* TCPDIAG_NOCOOKIE in net/ipv4/tcp_diag.h*/ ) + +const ( + AF_MPLS = 28 +) + +const ( + RTA_NEWDST = 0x13 + RTA_ENCAP_TYPE = 0x15 + RTA_ENCAP = 0x16 +) + +// RTA_ENCAP subtype +const ( + MPLS_IPTUNNEL_UNSPEC = iota + MPLS_IPTUNNEL_DST +) + +// light weight tunnel encap types +const ( + LWTUNNEL_ENCAP_NONE = iota + LWTUNNEL_ENCAP_MPLS + LWTUNNEL_ENCAP_IP + LWTUNNEL_ENCAP_ILA + LWTUNNEL_ENCAP_IP6 +) diff --git a/route.go b/route.go index 335d703..03ac4b2 100644 --- a/route.go +++ b/route.go @@ -3,6 +3,7 @@ package netlink import ( "fmt" "net" + "strings" ) // Scope is an enum representing a route scope. @@ -10,6 +11,20 @@ type Scope uint8 type NextHopFlag int +type Destination interface { + Family() int + Decode([]byte) error + Encode() ([]byte, error) + String() string +} + +type Encap interface { + Type() int + Decode([]byte) error + Encode() ([]byte, error) + String() string +} + // Route represents a netlink route. type Route struct { LinkIndex int @@ -25,15 +40,36 @@ type Route struct { Type int Tos int Flags int + MPLSDst *int + NewDst Destination + Encap Encap } func (r Route) String() string { - if len(r.MultiPath) > 0 { - return fmt.Sprintf("{Dst: %s Src: %s Gw: %s Flags: %s Table: %d}", r.Dst, - r.Src, r.MultiPath, r.ListFlags(), r.Table) + elems := []string{} + if len(r.MultiPath) == 0 { + elems = append(elems, fmt.Sprintf("Ifindex: %d", r.LinkIndex)) } - return fmt.Sprintf("{Ifindex: %d Dst: %s Src: %s Gw: %s Flags: %s Table: %d}", r.LinkIndex, r.Dst, - r.Src, r.Gw, r.ListFlags(), r.Table) + if r.MPLSDst != nil { + elems = append(elems, fmt.Sprintf("Dst: %d", r.MPLSDst)) + } else { + elems = append(elems, fmt.Sprintf("Dst: %s", r.Dst)) + } + if r.NewDst != nil { + elems = append(elems, fmt.Sprintf("NewDst: %s", r.NewDst)) + } + if r.Encap != nil { + elems = append(elems, fmt.Sprintf("Encap: %s", r.Encap)) + } + elems = append(elems, fmt.Sprintf("Src: %s", r.Src)) + if len(r.MultiPath) > 0 { + elems = append(elems, fmt.Sprintf("Gw: %s", r.MultiPath)) + } else { + elems = append(elems, fmt.Sprintf("Gw: %s", r.Gw)) + } + elems = append(elems, fmt.Sprintf("Flags: %s", r.ListFlags())) + elems = append(elems, fmt.Sprintf("Table: %d", r.Table)) + return fmt.Sprintf("{%s}", strings.Join(elems, " ")) } func (r *Route) SetFlag(flag NextHopFlag) { @@ -60,8 +96,21 @@ type NexthopInfo struct { Hops int Gw net.IP Flags int + NewDst Destination + Encap Encap } func (n *NexthopInfo) String() string { - return fmt.Sprintf("{Ifindex: %d Weight: %d Gw: %s Flags: %s}", n.LinkIndex, n.Hops+1, n.Gw, n.ListFlags()) + elems := []string{} + elems = append(elems, fmt.Sprintf("Ifindex: %d", n.LinkIndex)) + if n.NewDst != nil { + elems = append(elems, fmt.Sprintf("NewDst: %s", n.NewDst)) + } + if n.Encap != nil { + elems = append(elems, fmt.Sprintf("Encap: %s", n.Encap)) + } + elems = append(elems, fmt.Sprintf("Weight: %d", n.Hops+1)) + elems = append(elems, fmt.Sprintf("Gw: %d", n.Gw)) + elems = append(elems, fmt.Sprintf("Flags: %s", n.ListFlags())) + return fmt.Sprintf("{%s}", strings.Join(elems, " ")) } diff --git a/route_linux.go b/route_linux.go index c9b1fec..9e0f1f9 100644 --- a/route_linux.go +++ b/route_linux.go @@ -3,6 +3,7 @@ package netlink import ( "fmt" "net" + "strings" "syscall" "github.com/vishvananda/netlink/nl" @@ -60,6 +61,74 @@ func (n *NexthopInfo) ListFlags() []string { return listFlags(n.Flags) } +type MPLSDestination struct { + Labels []int +} + +func (d *MPLSDestination) Family() int { + return nl.FAMILY_MPLS +} + +func (d *MPLSDestination) Decode(buf []byte) error { + d.Labels = nl.DecodeMPLSStack(buf) + return nil +} + +func (d *MPLSDestination) Encode() ([]byte, error) { + return nl.EncodeMPLSStack(d.Labels...), nil +} + +func (d *MPLSDestination) String() string { + s := make([]string, 0, len(d.Labels)) + for _, l := range d.Labels { + s = append(s, fmt.Sprintf("%d", l)) + } + return strings.Join(s, "/") +} + +type MPLSEncap struct { + Labels []int +} + +func (e *MPLSEncap) Type() int { + return nl.LWTUNNEL_ENCAP_MPLS +} + +func (e *MPLSEncap) Decode(buf []byte) error { + if len(buf) < 4 { + return fmt.Errorf("Lack of bytes") + } + native := nl.NativeEndian() + l := native.Uint16(buf) + if len(buf) < int(l) { + return fmt.Errorf("Lack of bytes") + } + buf = buf[:l] + typ := native.Uint16(buf[2:]) + if typ != nl.MPLS_IPTUNNEL_DST { + return fmt.Errorf("Unknown MPLS Encap Type: %d", typ) + } + e.Labels = nl.DecodeMPLSStack(buf[4:]) + return nil +} + +func (e *MPLSEncap) Encode() ([]byte, error) { + s := nl.EncodeMPLSStack(e.Labels...) + native := nl.NativeEndian() + hdr := make([]byte, 4) + native.PutUint16(hdr, uint16(len(s)+4)) + native.PutUint16(hdr[2:], nl.MPLS_IPTUNNEL_DST) + return append(hdr, s...), nil +} + +func (e *MPLSEncap) String() string { + s := make([]string, 0, len(e.Labels)) + for _, l := range e.Labels { + s = append(s, fmt.Sprintf("%d", l)) + } + return strings.Join(s, "/") +} + // RouteAdd will add a route to the system. // Equivalent to: `ip route add $route` func RouteAdd(route *Route) error { @@ -102,7 +171,7 @@ func (h *Handle) RouteDel(route *Route) error { } func (h *Handle) routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg) error { - if (route.Dst == nil || route.Dst.IP == nil) && route.Src == nil && route.Gw == nil { + if (route.Dst == nil || route.Dst.IP == nil) && route.Src == nil && route.Gw == nil && route.MPLSDst == nil { return fmt.Errorf("one of Dst.IP, Src, or Gw must not be nil") } @@ -121,6 +190,33 @@ func (h *Handle) routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg dstData = route.Dst.IP.To16() } rtAttrs = append(rtAttrs, nl.NewRtAttr(syscall.RTA_DST, dstData)) + } else if route.MPLSDst != nil { + family = nl.FAMILY_MPLS + msg.Dst_len = uint8(20) + msg.Type = syscall.RTN_UNICAST + rtAttrs = append(rtAttrs, nl.NewRtAttr(syscall.RTA_DST, nl.EncodeMPLSStack(*route.MPLSDst))) + } + + if route.NewDst != nil { + if family != -1 && family != route.NewDst.Family() { + return fmt.Errorf("new destination and destination are not the same address family") + } + buf, err := route.NewDst.Encode() + if err != nil { + return err + } + rtAttrs = append(rtAttrs, nl.NewRtAttr(nl.RTA_NEWDST, buf)) + } + + if route.Encap != nil { + buf := make([]byte, 2) + native.PutUint16(buf, uint16(route.Encap.Type())) + rtAttrs = append(rtAttrs, nl.NewRtAttr(nl.RTA_ENCAP_TYPE, buf)) + buf, err := route.Encap.Encode() + if err != nil { + return err + } + rtAttrs = append(rtAttrs, nl.NewRtAttr(nl.RTA_ENCAP, buf)) } if route.Src != nil { @@ -161,27 +257,43 @@ func (h *Handle) routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg RtNexthop: syscall.RtNexthop{ Hops: uint8(nh.Hops), Ifindex: int32(nh.LinkIndex), - Len: uint16(syscall.SizeofRtNexthop), Flags: uint8(nh.Flags), }, } - var gwData []byte + children := []nl.NetlinkRequestData{} if nh.Gw != nil { gwFamily := nl.GetIPFamily(nh.Gw) if family != -1 && family != gwFamily { return fmt.Errorf("gateway, source, and destination ip are not the same IP family") } - var gw *nl.RtAttr if gwFamily == FAMILY_V4 { - gw = nl.NewRtAttr(syscall.RTA_GATEWAY, []byte(nh.Gw.To4())) + children = append(children, nl.NewRtAttr(syscall.RTA_GATEWAY, []byte(nh.Gw.To4()))) } else { - gw = nl.NewRtAttr(syscall.RTA_GATEWAY, []byte(nh.Gw.To16())) + children = append(children, nl.NewRtAttr(syscall.RTA_GATEWAY, []byte(nh.Gw.To16()))) } - gwData = gw.Serialize() - rtnh.Len += uint16(len(gwData)) } + if nh.NewDst != nil { + if family != -1 && family != nh.NewDst.Family() { + return fmt.Errorf("new destination and destination are not the same address family") + } + buf, err := nh.NewDst.Encode() + if err != nil { + return err + } + children = append(children, nl.NewRtAttr(nl.RTA_NEWDST, buf)) + } + if nh.Encap != nil { + buf := make([]byte, 2) + native.PutUint16(buf, uint16(nh.Encap.Type())) + rtAttrs = append(rtAttrs, nl.NewRtAttr(nl.RTA_ENCAP_TYPE, buf)) + buf, err := nh.Encap.Encode() + if err != nil { + return err + } + children = append(children, nl.NewRtAttr(nl.RTA_ENCAP, buf)) + } + rtnh.Children = children buf = append(buf, rtnh.Serialize()...) - buf = append(buf, gwData...) } rtAttrs = append(rtAttrs, nl.NewRtAttr(syscall.RTA_MULTIPATH, buf)) } @@ -308,18 +420,20 @@ func (h *Handle) RouteListFiltered(family int, filter *Route, filterMask uint64) case filterMask&RT_FILTER_SRC != 0 && !route.Src.Equal(filter.Src): continue case filterMask&RT_FILTER_DST != 0: - if filter.Dst == nil { - if route.Dst != nil { - continue - } - } else { - if route.Dst == nil { - continue - } - aMaskLen, aMaskBits := route.Dst.Mask.Size() - bMaskLen, bMaskBits := filter.Dst.Mask.Size() - if !(route.Dst.IP.Equal(filter.Dst.IP) && aMaskLen == bMaskLen && aMaskBits == bMaskBits) { - continue + if filter.MPLSDst == nil || route.MPLSDst == nil || (*filter.MPLSDst) != (*route.MPLSDst) { + if filter.Dst == nil { + if route.Dst != nil { + continue + } + } else { + if route.Dst == nil { + continue + } + aMaskLen, aMaskBits := route.Dst.Mask.Size() + bMaskLen, bMaskBits := filter.Dst.Mask.Size() + if !(route.Dst.IP.Equal(filter.Dst.IP) && aMaskLen == bMaskLen && aMaskBits == bMaskBits) { + continue + } } } } @@ -346,6 +460,7 @@ func deserializeRoute(m []byte) (Route, error) { } native := nl.NativeEndian() + var encap, encapType syscall.NetlinkRouteAttr for _, attr := range attrs { switch attr.Attr.Type { case syscall.RTA_GATEWAY: @@ -353,9 +468,17 @@ func deserializeRoute(m []byte) (Route, error) { case syscall.RTA_PREFSRC: route.Src = net.IP(attr.Value) case syscall.RTA_DST: - route.Dst = &net.IPNet{ - IP: attr.Value, - Mask: net.CIDRMask(int(msg.Dst_len), 8*len(attr.Value)), + if msg.Family == nl.FAMILY_MPLS { + stack := nl.DecodeMPLSStack(attr.Value) + if len(stack) == 0 || len(stack) > 1 { + return route, fmt.Errorf("invalid MPLS RTA_DST") + } + route.MPLSDst = &stack[0] + } else { + route.Dst = &net.IPNet{ + IP: attr.Value, + Mask: net.CIDRMask(int(msg.Dst_len), 8*len(attr.Value)), + } } case syscall.RTA_OIF: route.LinkIndex = int(native.Uint32(attr.Value[0:4])) @@ -383,12 +506,41 @@ func deserializeRoute(m []byte) (Route, error) { if err != nil { return nil, nil, err } + var encap, encapType syscall.NetlinkRouteAttr for _, attr := range attrs { switch attr.Attr.Type { case syscall.RTA_GATEWAY: info.Gw = net.IP(attr.Value) + case nl.RTA_NEWDST: + var d Destination + switch msg.Family { + case nl.FAMILY_MPLS: + d = &MPLSDestination{} + } + if err := d.Decode(attr.Value); err != nil { + return nil, nil, err + } + info.NewDst = d + case nl.RTA_ENCAP_TYPE: + encapType = attr + case nl.RTA_ENCAP: + encap = attr } } + + if len(encap.Value) != 0 && len(encapType.Value) != 0 { + typ := int(native.Uint16(encapType.Value[0:2])) + var e Encap + switch typ { + case nl.LWTUNNEL_ENCAP_MPLS: + e = &MPLSEncap{} + if err := e.Decode(encap.Value); err != nil { + return nil, nil, err + } + } + info.Encap = e + } + return info, value[int(nh.RtNexthop.Len):], nil } rest := attr.Value @@ -400,8 +552,36 @@ func deserializeRoute(m []byte) (Route, error) { route.MultiPath = append(route.MultiPath, info) rest = buf } + case nl.RTA_NEWDST: + var d Destination + switch msg.Family { + case nl.FAMILY_MPLS: + d = &MPLSDestination{} + } + if err := d.Decode(attr.Value); err != nil { + return route, err + } + route.NewDst = d + case nl.RTA_ENCAP_TYPE: + encapType = attr + case nl.RTA_ENCAP: + encap = attr } } + + if len(encap.Value) != 0 && len(encapType.Value) != 0 { + typ := int(native.Uint16(encapType.Value[0:2])) + var e Encap + switch typ { + case nl.LWTUNNEL_ENCAP_MPLS: + e = &MPLSEncap{} + if err := e.Decode(encap.Value); err != nil { + return route, err + } + } + route.Encap = e + } + return route, nil } diff --git a/route_test.go b/route_test.go index a6d10fc..2435001 100644 --- a/route_test.go +++ b/route_test.go @@ -451,3 +451,50 @@ func TestFilterDefaultRoute(t *testing.T) { } } + +func TestMPLSRouteAddDel(t *testing.T) { + tearDown := setUpMPLSNetlinkTest(t) + defer tearDown() + + // get loopback interface + link, err := LinkByName("lo") + if err != nil { + t.Fatal(err) + } + + // bring the interface up + if err := LinkSetUp(link); err != nil { + t.Fatal(err) + } + + mplsDst := 100 + route := Route{ + LinkIndex: link.Attrs().Index, + MPLSDst: &mplsDst, + NewDst: &MPLSDestination{ + Labels: []int{200, 300}, + }, + } + if err := RouteAdd(&route); err != nil { + t.Fatal(err) + } + routes, err := RouteList(link, FAMILY_MPLS) + if err != nil { + t.Fatal(err) + } + if len(routes) != 1 { + t.Fatal("Route not added properly") + } + + if err := RouteDel(&route); err != nil { + t.Fatal(err) + } + routes, err = RouteList(link, FAMILY_MPLS) + if err != nil { + t.Fatal(err) + } + if len(routes) != 0 { + t.Fatal("Route not removed properly") + } + +}