add route filtration structure

This commit is contained in:
marek-polewski 2015-11-26 16:31:53 +01:00
parent dfdad47336
commit b5b45b160b
2 changed files with 93 additions and 13 deletions

View File

@ -10,6 +10,37 @@ import (
// RtAttr is shared so it is in netlink_linux.go // RtAttr is shared so it is in netlink_linux.go
// RouteFilter represents filter that can be apply to RouteList function.
type RouteFilter struct {
Table int
Protocol int
Scope Scope
Type int
Tos int
Iif int
Oif int
Dst *net.IPNet
Src net.IP
Gw net.IP
FlagMask uint64
}
// Flag mask for router filters. RouterFilter.FlagMask must be set to on
// for filter to work.
const (
RT_FILTER_PROTOCOL uint64 = 1 << (1 + iota)
RT_FILTER_SCOPE
RT_FILTER_TYPE
RT_FILTER_TOS
RT_FILTER_IIF
RT_FILTER_OIF
RT_FILTER_DST
RT_FILTER_SRC
RT_FILTER_GW
RT_FILTER_TABLE
)
// RouteAdd will add a route to the system. // RouteAdd will add a route to the system.
// Equivalent to: `ip route add $route` // Equivalent to: `ip route add $route`
func RouteAdd(route *Route) error { func RouteAdd(route *Route) error {
@ -126,6 +157,19 @@ func routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg) error {
// Equivalent to: `ip route show`. // Equivalent to: `ip route show`.
// The list can be filtered by link and ip family. // The list can be filtered by link and ip family.
func RouteList(link Link, family int) ([]Route, error) { func RouteList(link Link, family int) ([]Route, error) {
var rf *RouteFilter
if link != nil {
rf = &RouteFilter{
Oif: link.Attrs().Index,
FlagMask: RT_FILTER_OIF,
}
}
return RouteListFiltered(family, rf)
}
// RouteListFiltered gets a list of routes in the system filtered with specified rules.
// All rules must be defined in RouteFilter struct
func RouteListFiltered(family int, filter *RouteFilter) ([]Route, error) {
req := nl.NewNetlinkRequest(syscall.RTM_GETROUTE, syscall.NLM_F_DUMP) req := nl.NewNetlinkRequest(syscall.RTM_GETROUTE, syscall.NLM_F_DUMP)
infmsg := nl.NewIfInfomsg(family) infmsg := nl.NewIfInfomsg(family)
req.AddData(infmsg) req.AddData(infmsg)
@ -135,13 +179,6 @@ func RouteList(link Link, family int) ([]Route, error) {
return nil, err return nil, err
} }
index := 0
if link != nil {
base := link.Attrs()
ensureIndex(base)
index = base.Index
}
var res []Route var res []Route
for _, m := range msgs { for _, m := range msgs {
msg := nl.DeserializeRtMsg(m) msg := nl.DeserializeRtMsg(m)
@ -152,8 +189,10 @@ func RouteList(link Link, family int) ([]Route, error) {
} }
if msg.Table != syscall.RT_TABLE_MAIN { if msg.Table != syscall.RT_TABLE_MAIN {
// Ignore non-main tables if filter == nil || filter != nil && filter.FlagMask&RT_FILTER_TABLE == 0 {
continue // Ignore non-main tables
continue
}
} }
route, err := deserializeRoute(m) route, err := deserializeRoute(m)
@ -161,10 +200,39 @@ func RouteList(link Link, family int) ([]Route, error) {
return nil, err return nil, err
} }
if link != nil && route.LinkIndex != index { if filter != nil {
// Ignore routes from other interfaces f := filter.FlagMask
continue switch {
case f&RT_FILTER_TABLE != 0 && filter.Table != route.Table:
continue
case f&RT_FILTER_PROTOCOL != 0 && route.Protocol != filter.Protocol:
continue
case f&RT_FILTER_SCOPE != 0 && route.Scope != filter.Scope:
continue
case f&RT_FILTER_TYPE != 0 && route.Type != filter.Type:
continue
case f&RT_FILTER_TOS != 0 && route.Tos != filter.Tos:
continue
case f&RT_FILTER_OIF != 0 && filter.Oif != route.LinkIndex:
continue
case f&RT_FILTER_IIF != 0 && filter.Iif != route.Iif:
continue
case f&RT_FILTER_GW != 0 && !route.Gw.Equal(filter.Gw):
continue
case f&RT_FILTER_SRC != 0 && !route.Src.Equal(filter.Src):
continue
case f&RT_FILTER_DST != 0 && filter.Dst != nil:
if route.Dst == nil {
continue
}
aMaskLen, aMaskBits := route.Dst.Mask.Size()
bMaskLen, bMaskBits := filter.Dst.Mask.Size()
if !(route.Dst.IP.Equal(filter.Dst.IP) && aMaskLen == bMaskLen && aMaskBits == bMaskBits) {
continue
}
}
} }
res = append(res, route) res = append(res, route)
} }

View File

@ -172,11 +172,20 @@ func TestRouteExtraFields(t *testing.T) {
Priority: 13, Priority: 13,
Table: syscall.RT_TABLE_MAIN, Table: syscall.RT_TABLE_MAIN,
Type: syscall.RTN_UNICAST, Type: syscall.RTN_UNICAST,
Tos: 14,
} }
if err := RouteAdd(&route); err != nil { if err := RouteAdd(&route); err != nil {
t.Fatal(err) t.Fatal(err)
} }
routes, err := RouteList(link, FAMILY_V4) routes, err := RouteListFiltered(FAMILY_V4, &RouteFilter{
Dst: dst,
Src: src,
Scope: syscall.RT_SCOPE_LINK,
Table: syscall.RT_TABLE_MAIN,
Type: syscall.RTN_UNICAST,
Tos: 14,
FlagMask: RT_FILTER_DST | RT_FILTER_SRC | RT_FILTER_SCOPE | RT_FILTER_TABLE | RT_FILTER_TYPE | RT_FILTER_TOS,
})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -196,4 +205,7 @@ func TestRouteExtraFields(t *testing.T) {
if routes[0].Type != syscall.RTN_UNICAST { if routes[0].Type != syscall.RTN_UNICAST {
t.Fatal("Invalid Type. Route not added properly") t.Fatal("Invalid Type. Route not added properly")
} }
if routes[0].Tos != 14 {
t.Fatal("Invalid Tos. Route not added properly")
}
} }