diff --git a/handle_linux.go b/handle_linux.go index 26887b7..6535667 100644 --- a/handle_linux.go +++ b/handle_linux.go @@ -21,6 +21,22 @@ type Handle struct { lookupByDump bool } +// SetSocketTimeout configures timeout for default netlink sockets +func SetSocketTimeout(to time.Duration) error { + if to < time.Microsecond { + return fmt.Errorf("invalid timeout, minimul value is %s", time.Microsecond) + } + + nl.SocketTimeoutTv = unix.NsecToTimeval(to.Nanoseconds()) + return nil +} + +// GetSocketTimeout returns the timeout value used by default netlink sockets +func GetSocketTimeout() time.Duration { + nsec := unix.TimevalToNsec(nl.SocketTimeoutTv) + return time.Duration(nsec) * time.Nanosecond +} + // SupportsNetlinkFamily reports whether the passed netlink family is supported by this Handle func (h *Handle) SupportsNetlinkFamily(nlFamily int) bool { _, ok := h.sockets[nlFamily] diff --git a/handle_linux_test.go b/handle_linux_test.go new file mode 100644 index 0000000..a4514b6 --- /dev/null +++ b/handle_linux_test.go @@ -0,0 +1,17 @@ +package netlink + +import ( + "testing" + "time" +) + +func TestSetGetSocketTimeout(t *testing.T) { + timeout := 10 * time.Second + if err := SetSocketTimeout(10 * time.Second); err != nil { + t.Fatalf("Set socket timeout for default handle failed: %v", err) + } + + if val := GetSocketTimeout(); val != timeout { + t.Fatalf("Unexpcted socket timeout value: got=%v, expected=%v", val, timeout) + } +} diff --git a/nl/nl_linux.go b/nl/nl_linux.go index cef64b8..dcd4b94 100644 --- a/nl/nl_linux.go +++ b/nl/nl_linux.go @@ -35,6 +35,9 @@ var SupportedNlFamilies = []int{unix.NETLINK_ROUTE, unix.NETLINK_XFRM, unix.NETL var nextSeqNr uint32 +// Default netlink socket timeout, 60s +var SocketTimeoutTv = unix.Timeval{Sec: 60, Usec: 0} + // GetIPFamily returns the family type of a net.IP. func GetIPFamily(ip net.IP) int { if len(ip) <= net.IPv4len { @@ -426,6 +429,14 @@ func (req *NetlinkRequest) Execute(sockType int, resType uint16) ([][]byte, erro if err != nil { return nil, err } + + if err := s.SetSendTimeout(&SocketTimeoutTv); err != nil { + return nil, err + } + if err := s.SetReceiveTimeout(&SocketTimeoutTv); err != nil { + return nil, err + } + defer s.Close() } else { s.Lock()