From b5b45b160b3781bdff3ca5c9b01a2303aeafcbe6 Mon Sep 17 00:00:00 2001 From: marek-polewski Date: Thu, 26 Nov 2015 16:31:53 +0100 Subject: [PATCH] add route filtration structure --- route_linux.go | 92 +++++++++++++++++++++++++++++++++++++++++++------- route_test.go | 14 +++++++- 2 files changed, 93 insertions(+), 13 deletions(-) diff --git a/route_linux.go b/route_linux.go index 0c3e355..cbd525f 100644 --- a/route_linux.go +++ b/route_linux.go @@ -10,6 +10,37 @@ import ( // RtAttr is shared so it is in netlink_linux.go +// RouteFilter represents filter that can be apply to RouteList function. +type RouteFilter struct { + Table int + Protocol int + Scope Scope + Type int + Tos int + Iif int + Oif int + Dst *net.IPNet + Src net.IP + Gw net.IP + + FlagMask uint64 +} + +// Flag mask for router filters. RouterFilter.FlagMask must be set to on +// for filter to work. +const ( + RT_FILTER_PROTOCOL uint64 = 1 << (1 + iota) + RT_FILTER_SCOPE + RT_FILTER_TYPE + RT_FILTER_TOS + RT_FILTER_IIF + RT_FILTER_OIF + RT_FILTER_DST + RT_FILTER_SRC + RT_FILTER_GW + RT_FILTER_TABLE +) + // RouteAdd will add a route to the system. // Equivalent to: `ip route add $route` func RouteAdd(route *Route) error { @@ -126,6 +157,19 @@ func routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg) error { // Equivalent to: `ip route show`. // The list can be filtered by link and ip family. func RouteList(link Link, family int) ([]Route, error) { + var rf *RouteFilter + if link != nil { + rf = &RouteFilter{ + Oif: link.Attrs().Index, + FlagMask: RT_FILTER_OIF, + } + } + return RouteListFiltered(family, rf) +} + +// RouteListFiltered gets a list of routes in the system filtered with specified rules. +// All rules must be defined in RouteFilter struct +func RouteListFiltered(family int, filter *RouteFilter) ([]Route, error) { req := nl.NewNetlinkRequest(syscall.RTM_GETROUTE, syscall.NLM_F_DUMP) infmsg := nl.NewIfInfomsg(family) req.AddData(infmsg) @@ -135,13 +179,6 @@ func RouteList(link Link, family int) ([]Route, error) { return nil, err } - index := 0 - if link != nil { - base := link.Attrs() - ensureIndex(base) - index = base.Index - } - var res []Route for _, m := range msgs { msg := nl.DeserializeRtMsg(m) @@ -152,8 +189,10 @@ func RouteList(link Link, family int) ([]Route, error) { } if msg.Table != syscall.RT_TABLE_MAIN { - // Ignore non-main tables - continue + if filter == nil || filter != nil && filter.FlagMask&RT_FILTER_TABLE == 0 { + // Ignore non-main tables + continue + } } route, err := deserializeRoute(m) @@ -161,10 +200,39 @@ func RouteList(link Link, family int) ([]Route, error) { return nil, err } - if link != nil && route.LinkIndex != index { - // Ignore routes from other interfaces - continue + if filter != nil { + f := filter.FlagMask + switch { + case f&RT_FILTER_TABLE != 0 && filter.Table != route.Table: + continue + case f&RT_FILTER_PROTOCOL != 0 && route.Protocol != filter.Protocol: + continue + case f&RT_FILTER_SCOPE != 0 && route.Scope != filter.Scope: + continue + case f&RT_FILTER_TYPE != 0 && route.Type != filter.Type: + continue + case f&RT_FILTER_TOS != 0 && route.Tos != filter.Tos: + continue + case f&RT_FILTER_OIF != 0 && filter.Oif != route.LinkIndex: + continue + case f&RT_FILTER_IIF != 0 && filter.Iif != route.Iif: + continue + case f&RT_FILTER_GW != 0 && !route.Gw.Equal(filter.Gw): + continue + case f&RT_FILTER_SRC != 0 && !route.Src.Equal(filter.Src): + continue + case f&RT_FILTER_DST != 0 && filter.Dst != nil: + if route.Dst == nil { + continue + } + aMaskLen, aMaskBits := route.Dst.Mask.Size() + bMaskLen, bMaskBits := filter.Dst.Mask.Size() + if !(route.Dst.IP.Equal(filter.Dst.IP) && aMaskLen == bMaskLen && aMaskBits == bMaskBits) { + continue + } + } } + res = append(res, route) } diff --git a/route_test.go b/route_test.go index 6f8cdab..345c09e 100644 --- a/route_test.go +++ b/route_test.go @@ -172,11 +172,20 @@ func TestRouteExtraFields(t *testing.T) { Priority: 13, Table: syscall.RT_TABLE_MAIN, Type: syscall.RTN_UNICAST, + Tos: 14, } if err := RouteAdd(&route); err != nil { t.Fatal(err) } - routes, err := RouteList(link, FAMILY_V4) + routes, err := RouteListFiltered(FAMILY_V4, &RouteFilter{ + Dst: dst, + Src: src, + Scope: syscall.RT_SCOPE_LINK, + Table: syscall.RT_TABLE_MAIN, + Type: syscall.RTN_UNICAST, + Tos: 14, + FlagMask: RT_FILTER_DST | RT_FILTER_SRC | RT_FILTER_SCOPE | RT_FILTER_TABLE | RT_FILTER_TYPE | RT_FILTER_TOS, + }) if err != nil { t.Fatal(err) } @@ -196,4 +205,7 @@ func TestRouteExtraFields(t *testing.T) { if routes[0].Type != syscall.RTN_UNICAST { t.Fatal("Invalid Type. Route not added properly") } + if routes[0].Tos != 14 { + t.Fatal("Invalid Tos. Route not added properly") + } }