Add RouteListFilteredIter API.

Allows for listing large numbers of routes without
buffering the whole list in memory at once.

Add benchmarks for RouteListFiltered variants.
This commit is contained in:
Shaun Crampton 2024-04-09 10:04:24 +01:00 committed by Alessandro Boch
parent b7b7ca8632
commit b54f85093f
4 changed files with 263 additions and 58 deletions

View File

@ -23,7 +23,7 @@ import (
type tearDownNetlinkTest func()
func skipUnlessRoot(t *testing.T) {
func skipUnlessRoot(t testing.TB) {
t.Helper()
if os.Getuid() != 0 {
@ -53,7 +53,7 @@ func skipUnlessKModuleLoaded(t *testing.T, module ...string) {
}
}
func setUpNetlinkTest(t *testing.T) tearDownNetlinkTest {
func setUpNetlinkTest(t testing.TB) tearDownNetlinkTest {
skipUnlessRoot(t)
// new temporary namespace so we don't pollute the host

View File

@ -488,10 +488,30 @@ func (req *NetlinkRequest) AddRawData(data []byte) {
req.RawData = append(req.RawData, data...)
}
// Execute the request against a the given sockType.
// Execute the request against the given sockType.
// Returns a list of netlink messages in serialized format, optionally filtered
// by resType.
func (req *NetlinkRequest) Execute(sockType int, resType uint16) ([][]byte, error) {
var res [][]byte
err := req.ExecuteIter(sockType, resType, func(msg []byte) bool {
res = append(res, msg)
return true
})
if err != nil {
return nil, err
}
return res, nil
}
// ExecuteIter executes the request against the given sockType.
// Calls the provided callback func once for each netlink message.
// If the callback returns false, it is not called again, but
// the remaining messages are consumed/discarded.
//
// Thread safety: ExecuteIter holds a lock on the socket until
// it finishes iteration so the callback must not call back into
// the netlink API.
func (req *NetlinkRequest) ExecuteIter(sockType int, resType uint16, f func(msg []byte) bool) error {
var (
s *NetlinkSocket
err error
@ -508,18 +528,18 @@ func (req *NetlinkRequest) Execute(sockType int, resType uint16) ([][]byte, erro
if s == nil {
s, err = getNetlinkSocket(sockType)
if err != nil {
return nil, err
return err
}
if err := s.SetSendTimeout(&SocketTimeoutTv); err != nil {
return nil, err
return err
}
if err := s.SetReceiveTimeout(&SocketTimeoutTv); err != nil {
return nil, err
return err
}
if EnableErrorMessageReporting {
if err := s.SetExtAck(true); err != nil {
return nil, err
return err
}
}
@ -530,38 +550,36 @@ func (req *NetlinkRequest) Execute(sockType int, resType uint16) ([][]byte, erro
}
if err := s.Send(req); err != nil {
return nil, err
return err
}
pid, err := s.GetPid()
if err != nil {
return nil, err
return err
}
var res [][]byte
done:
for {
msgs, from, err := s.Receive()
if err != nil {
return nil, err
return err
}
if from.Pid != PidKernel {
return nil, fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, PidKernel)
return fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, PidKernel)
}
for _, m := range msgs {
if m.Header.Seq != req.Seq {
if sharedSocket {
continue
}
return nil, fmt.Errorf("Wrong Seq nr %d, expected %d", m.Header.Seq, req.Seq)
return fmt.Errorf("Wrong Seq nr %d, expected %d", m.Header.Seq, req.Seq)
}
if m.Header.Pid != pid {
continue
}
if m.Header.Flags&unix.NLM_F_DUMP_INTR != 0 {
return nil, syscall.Errno(unix.EINTR)
return syscall.Errno(unix.EINTR)
}
if m.Header.Type == unix.NLMSG_DONE || m.Header.Type == unix.NLMSG_ERROR {
@ -600,18 +618,26 @@ done:
}
}
return nil, err
return err
}
if resType != 0 && m.Header.Type != resType {
continue
}
res = append(res, m.Data)
if cont := f(m.Data); !cont {
// Drain the rest of the messages from the kernel but don't
// pass them to the iterator func.
f = dummyMsgIterFunc
}
if m.Header.Flags&unix.NLM_F_MULTI == 0 {
break done
}
}
}
return res, nil
return nil
}
func dummyMsgIterFunc(msg []byte) bool {
return true
}
// Create a new netlink request from proto and flags

View File

@ -436,7 +436,7 @@ func (e *SEG6LocalEncap) String() string {
}
if e.Flags[nl.SEG6_LOCAL_SRH] {
segs := make([]string, 0, len(e.Segments))
//append segment backwards (from n to 0) since seg#0 is the last segment.
// append segment backwards (from n to 0) since seg#0 is the last segment.
for i := len(e.Segments); i > 0; i-- {
segs = append(segs, e.Segments[i-1].String())
}
@ -874,8 +874,22 @@ func (h *Handle) RouteDel(route *Route) error {
}
func (h *Handle) routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg) ([][]byte, error) {
if err := h.prepareRouteReq(route, req, msg); err != nil {
return nil, err
}
return req.Execute(unix.NETLINK_ROUTE, 0)
}
func (h *Handle) routeHandleIter(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg, f func(msg []byte) bool) error {
if err := h.prepareRouteReq(route, req, msg); err != nil {
return err
}
return req.ExecuteIter(unix.NETLINK_ROUTE, 0, f)
}
func (h *Handle) prepareRouteReq(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg) error {
if req.NlMsghdr.Type != unix.RTM_GETROUTE && (route.Dst == nil || route.Dst.IP == nil) && route.Src == nil && route.Gw == nil && route.MPLSDst == nil {
return nil, fmt.Errorf("Either Dst.IP, Src.IP or Gw must be set")
return fmt.Errorf("either Dst.IP, Src.IP or Gw must be set")
}
family := -1
@ -902,11 +916,11 @@ func (h *Handle) routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg
if route.NewDst != nil {
if family != -1 && family != route.NewDst.Family() {
return nil, fmt.Errorf("new destination and destination are not the same address family")
return fmt.Errorf("new destination and destination are not the same address family")
}
buf, err := route.NewDst.Encode()
if err != nil {
return nil, err
return err
}
rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_NEWDST, buf))
}
@ -917,7 +931,7 @@ func (h *Handle) routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg
rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_ENCAP_TYPE, buf))
buf, err := route.Encap.Encode()
if err != nil {
return nil, err
return err
}
switch route.Encap.Type() {
case nl.LWTUNNEL_ENCAP_BPF:
@ -931,7 +945,7 @@ func (h *Handle) routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg
if route.Src != nil {
srcFamily := nl.GetIPFamily(route.Src)
if family != -1 && family != srcFamily {
return nil, fmt.Errorf("source and destination ip are not the same IP family")
return fmt.Errorf("source and destination ip are not the same IP family")
}
family = srcFamily
var srcData []byte
@ -947,7 +961,7 @@ func (h *Handle) routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg
if route.Gw != nil {
gwFamily := nl.GetIPFamily(route.Gw)
if family != -1 && family != gwFamily {
return nil, fmt.Errorf("gateway, source, and destination ip are not the same IP family")
return fmt.Errorf("gateway, source, and destination ip are not the same IP family")
}
family = gwFamily
var gwData []byte
@ -962,7 +976,7 @@ func (h *Handle) routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg
if route.Via != nil {
buf, err := route.Via.Encode()
if err != nil {
return nil, fmt.Errorf("failed to encode RTA_VIA: %v", err)
return fmt.Errorf("failed to encode RTA_VIA: %v", err)
}
rtAttrs = append(rtAttrs, nl.NewRtAttr(unix.RTA_VIA, buf))
}
@ -981,7 +995,7 @@ func (h *Handle) routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg
if nh.Gw != nil {
gwFamily := nl.GetIPFamily(nh.Gw)
if family != -1 && family != gwFamily {
return nil, fmt.Errorf("gateway, source, and destination ip are not the same IP family")
return fmt.Errorf("gateway, source, and destination ip are not the same IP family")
}
if gwFamily == FAMILY_V4 {
children = append(children, nl.NewRtAttr(unix.RTA_GATEWAY, []byte(nh.Gw.To4())))
@ -991,11 +1005,11 @@ func (h *Handle) routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg
}
if nh.NewDst != nil {
if family != -1 && family != nh.NewDst.Family() {
return nil, fmt.Errorf("new destination and destination are not the same address family")
return fmt.Errorf("new destination and destination are not the same address family")
}
buf, err := nh.NewDst.Encode()
if err != nil {
return nil, err
return err
}
children = append(children, nl.NewRtAttr(unix.RTA_NEWDST, buf))
}
@ -1005,14 +1019,14 @@ func (h *Handle) routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg
children = append(children, nl.NewRtAttr(unix.RTA_ENCAP_TYPE, buf))
buf, err := nh.Encap.Encode()
if err != nil {
return nil, err
return err
}
children = append(children, nl.NewRtAttr(unix.RTA_ENCAP, buf))
}
if nh.Via != nil {
buf, err := nh.Via.Encode()
if err != nil {
return nil, err
return err
}
children = append(children, nl.NewRtAttr(unix.RTA_VIA, buf))
}
@ -1143,8 +1157,7 @@ func (h *Handle) routeHandle(route *Route, req *nl.NetlinkRequest, msg *nl.RtMsg
native.PutUint32(b, uint32(route.LinkIndex))
req.AddData(nl.NewRtAttr(unix.RTA_OIF, b))
}
return req.Execute(unix.NETLINK_ROUTE, 0)
return nil
}
// RouteList gets a list of routes in the system.
@ -1176,73 +1189,94 @@ func RouteListFiltered(family int, filter *Route, filterMask uint64) ([]Route, e
// RouteListFiltered gets a list of routes in the system filtered with specified rules.
// All rules must be defined in RouteFilter struct
func (h *Handle) RouteListFiltered(family int, filter *Route, filterMask uint64) ([]Route, error) {
req := h.newNetlinkRequest(unix.RTM_GETROUTE, unix.NLM_F_DUMP)
rtmsg := &nl.RtMsg{}
rtmsg.Family = uint8(family)
msgs, err := h.routeHandle(filter, req, rtmsg)
var res []Route
err := h.RouteListFilteredIter(family, filter, filterMask, func(route Route) (cont bool) {
res = append(res, route)
return true
})
if err != nil {
return nil, err
}
return res, nil
}
var res []Route
for _, m := range msgs {
// RouteListFilteredIter passes each route that matches the filter to the given iterator func. Iteration continues
// until all routes are loaded or the func returns false.
func RouteListFilteredIter(family int, filter *Route, filterMask uint64, f func(Route) (cont bool)) error {
return pkgHandle.RouteListFilteredIter(family, filter, filterMask, f)
}
func (h *Handle) RouteListFilteredIter(family int, filter *Route, filterMask uint64, f func(Route) (cont bool)) error {
req := h.newNetlinkRequest(unix.RTM_GETROUTE, unix.NLM_F_DUMP)
rtmsg := &nl.RtMsg{}
rtmsg.Family = uint8(family)
var parseErr error
err := h.routeHandleIter(filter, req, rtmsg, func(m []byte) bool {
msg := nl.DeserializeRtMsg(m)
if family != FAMILY_ALL && msg.Family != uint8(family) {
// Ignore routes not matching requested family
continue
return true
}
if msg.Flags&unix.RTM_F_CLONED != 0 {
// Ignore cloned routes
continue
return true
}
if msg.Table != unix.RT_TABLE_MAIN {
if filter == nil || filterMask&RT_FILTER_TABLE == 0 {
// Ignore non-main tables
continue
return true
}
}
route, err := deserializeRoute(m)
if err != nil {
return nil, err
parseErr = err
return false
}
if filter != nil {
switch {
case filterMask&RT_FILTER_TABLE != 0 && filter.Table != unix.RT_TABLE_UNSPEC && route.Table != filter.Table:
continue
return true
case filterMask&RT_FILTER_PROTOCOL != 0 && route.Protocol != filter.Protocol:
continue
return true
case filterMask&RT_FILTER_SCOPE != 0 && route.Scope != filter.Scope:
continue
return true
case filterMask&RT_FILTER_TYPE != 0 && route.Type != filter.Type:
continue
return true
case filterMask&RT_FILTER_TOS != 0 && route.Tos != filter.Tos:
continue
return true
case filterMask&RT_FILTER_REALM != 0 && route.Realm != filter.Realm:
continue
return true
case filterMask&RT_FILTER_OIF != 0 && route.LinkIndex != filter.LinkIndex:
continue
return true
case filterMask&RT_FILTER_IIF != 0 && route.ILinkIndex != filter.ILinkIndex:
continue
return true
case filterMask&RT_FILTER_GW != 0 && !route.Gw.Equal(filter.Gw):
continue
return true
case filterMask&RT_FILTER_SRC != 0 && !route.Src.Equal(filter.Src):
continue
return true
case filterMask&RT_FILTER_DST != 0:
if filter.MPLSDst == nil || route.MPLSDst == nil || (*filter.MPLSDst) != (*route.MPLSDst) {
if filter.Dst == nil {
filter.Dst = genZeroIPNet(family)
}
if !ipNetEqual(route.Dst, filter.Dst) {
continue
return true
}
}
case filterMask&RT_FILTER_HOPLIMIT != 0 && route.Hoplimit != filter.Hoplimit:
continue
return true
}
}
res = append(res, route)
return f(route)
})
if err != nil {
return err
}
return res, nil
if parseErr != nil {
return parseErr
}
return nil
}
// deserializeRoute decodes a binary netlink message into a Route struct
@ -1762,7 +1796,7 @@ func (p RouteProtocol) String() string {
return "gated"
case unix.RTPROT_ISIS:
return "isis"
//case unix.RTPROT_KEEPALIVED:
// case unix.RTPROT_KEEPALIVED:
// return "keepalived"
case unix.RTPROT_KERNEL:
return "kernel"

View File

@ -949,6 +949,151 @@ func TestRouteFilterByFamily(t *testing.T) {
}
}
func TestRouteFilterIterCanStop(t *testing.T) {
tearDown := setUpNetlinkTest(t)
defer tearDown()
// get loopback interface
link, err := LinkByName("lo")
if err != nil {
t.Fatal(err)
}
// bring the interface up
if err = LinkSetUp(link); err != nil {
t.Fatal(err)
}
// add a gateway route
dst := &net.IPNet{
IP: net.IPv4(1, 1, 1, 1),
Mask: net.CIDRMask(32, 32),
}
for i := 0; i < 3; i++ {
route := Route{
LinkIndex: link.Attrs().Index,
Dst: dst,
Scope: unix.RT_SCOPE_LINK,
Priority: 1 + i,
Table: 1000,
Type: unix.RTN_UNICAST,
}
if err := RouteAdd(&route); err != nil {
t.Fatal(err)
}
}
var routes []Route
err = RouteListFilteredIter(FAMILY_V4, &Route{
Dst: dst,
Scope: unix.RT_SCOPE_LINK,
Table: 1000,
Type: unix.RTN_UNICAST,
}, RT_FILTER_TABLE, func(route Route) (cont bool) {
routes = append(routes, route)
return len(routes) < 2
})
if err != nil {
t.Fatal(err)
}
if len(routes) != 2 {
t.Fatal("Unexpected number of iterations")
}
for _, route := range routes {
if route.Scope != unix.RT_SCOPE_LINK {
t.Fatal("Invalid Scope. Route not added properly")
}
if route.Priority < 1 || route.Priority > 3 {
t.Fatal("Priority outside expected range. Route not added properly")
}
if route.Table != 1000 {
t.Fatalf("Invalid Table %d. Route not added properly", route.Table)
}
if route.Type != unix.RTN_UNICAST {
t.Fatal("Invalid Type. Route not added properly")
}
}
}
func BenchmarkRouteListFilteredNew(b *testing.B) {
tearDown := setUpNetlinkTest(b)
defer tearDown()
link, err := setUpRoutesBench(b)
b.ResetTimer()
b.ReportAllocs()
var routes []Route
for i := 0; i < b.N; i++ {
routes, err = pkgHandle.RouteListFiltered(FAMILY_V4, &Route{
LinkIndex: link.Attrs().Index,
}, RT_FILTER_OIF)
if err != nil {
b.Fatal(err)
}
if len(routes) != 65535 {
b.Fatal("Incorrect number of routes.", len(routes))
}
}
runtime.KeepAlive(routes)
}
func BenchmarkRouteListIter(b *testing.B) {
tearDown := setUpNetlinkTest(b)
defer tearDown()
link, err := setUpRoutesBench(b)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
var routes int
err = RouteListFilteredIter(FAMILY_V4, &Route{
LinkIndex: link.Attrs().Index,
}, RT_FILTER_OIF, func(route Route) (cont bool) {
routes++
return true
})
if err != nil {
b.Fatal(err)
}
if routes != 65535 {
b.Fatal("Incorrect number of routes.", routes)
}
}
}
func setUpRoutesBench(b *testing.B) (Link, error) {
// get loopback interface
link, err := LinkByName("lo")
if err != nil {
b.Fatal(err)
}
// bring the interface up
if err = LinkSetUp(link); err != nil {
b.Fatal(err)
}
// add a gateway route
for i := 0; i < 65535; i++ {
dst := &net.IPNet{
IP: net.IPv4(1, 1, byte(i>>8), byte(i&0xff)),
Mask: net.CIDRMask(32, 32),
}
route := Route{
LinkIndex: link.Attrs().Index,
Dst: dst,
Scope: unix.RT_SCOPE_LINK,
Priority: 10,
Type: unix.RTN_UNICAST,
}
if err := RouteAdd(&route); err != nil {
b.Fatal(err)
}
}
return link, err
}
func tableIDIn(ids []int, id int) bool {
for _, v := range ids {
if v == id {