diff --git a/handle_test.go b/handle_test.go index ac627ba..c7feb22 100644 --- a/handle_test.go +++ b/handle_test.go @@ -12,11 +12,9 @@ import ( "sync/atomic" "testing" "time" - "unsafe" "github.com/vishvananda/netlink/nl" "github.com/vishvananda/netns" - "golang.org/x/sys/unix" ) func TestHandleCreateClose(t *testing.T) { @@ -122,13 +120,22 @@ func TestHandleTimeout(t *testing.T) { defer h.Close() for _, sh := range h.sockets { - verifySockTimeVal(t, sh.Socket.GetFd(), unix.Timeval{Sec: 0, Usec: 0}) + verifySockTimeVal(t, sh.Socket, time.Duration(0)) } - h.SetSocketTimeout(2*time.Second + 8*time.Millisecond) + const timeout = 2*time.Second + 8*time.Millisecond + h.SetSocketTimeout(timeout) for _, sh := range h.sockets { - verifySockTimeVal(t, sh.Socket.GetFd(), unix.Timeval{Sec: 2, Usec: 8000}) + verifySockTimeVal(t, sh.Socket, timeout) + } +} + +func verifySockTimeVal(t *testing.T, socket *nl.NetlinkSocket, expTimeout time.Duration) { + t.Helper() + send, receive := socket.GetTimeouts() + if send != expTimeout || receive != expTimeout { + t.Fatalf("Expected timeout: %v, got Send: %v, Receive: %v", expTimeout, send, receive) } } @@ -157,30 +164,6 @@ func TestHandleReceiveBuffer(t *testing.T) { } } -func verifySockTimeVal(t *testing.T, fd int, tv unix.Timeval) { - var ( - tr unix.Timeval - v = uint32(0x10) - ) - _, _, errno := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(fd), unix.SOL_SOCKET, unix.SO_SNDTIMEO, uintptr(unsafe.Pointer(&tr)), uintptr(unsafe.Pointer(&v)), 0) - if errno != 0 { - t.Fatal(errno) - } - - if tr.Sec != tv.Sec || tr.Usec != tv.Usec { - t.Fatalf("Unexpected timeout value read: %v. Expected: %v", tr, tv) - } - - _, _, errno = unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(fd), unix.SOL_SOCKET, unix.SO_RCVTIMEO, uintptr(unsafe.Pointer(&tr)), uintptr(unsafe.Pointer(&v)), 0) - if errno != 0 { - t.Fatal(errno) - } - - if tr.Sec != tv.Sec || tr.Usec != tv.Usec { - t.Fatalf("Unexpected timeout value read: %v. Expected: %v", tr, tv) - } -} - var ( iter = 10 numThread = uint32(4) diff --git a/nl/nl_linux.go b/nl/nl_linux.go index 6cecc45..a11e6a9 100644 --- a/nl/nl_linux.go +++ b/nl/nl_linux.go @@ -4,6 +4,7 @@ package nl import ( "bytes" "encoding/binary" + "errors" "fmt" "net" "os" @@ -11,6 +12,7 @@ import ( "sync" "sync/atomic" "syscall" + "time" "unsafe" "github.com/vishvananda/netns" @@ -656,9 +658,11 @@ func NewNetlinkRequest(proto, flags int) *NetlinkRequest { } type NetlinkSocket struct { - fd int32 - file *os.File - lsa unix.SockaddrNetlink + fd int32 + file *os.File + lsa unix.SockaddrNetlink + sendTimeout int64 // Access using atomic.Load/StoreInt64 + receiveTimeout int64 // Access using atomic.Load/StoreInt64 sync.Mutex } @@ -802,8 +806,44 @@ func (s *NetlinkSocket) GetFd() int { return int(s.fd) } +func (s *NetlinkSocket) GetTimeouts() (send, receive time.Duration) { + return time.Duration(atomic.LoadInt64(&s.sendTimeout)), + time.Duration(atomic.LoadInt64(&s.receiveTimeout)) +} + func (s *NetlinkSocket) Send(request *NetlinkRequest) error { - return unix.Sendto(int(s.fd), request.Serialize(), 0, &s.lsa) + rawConn, err := s.file.SyscallConn() + if err != nil { + return err + } + var ( + deadline time.Time + innerErr error + ) + sendTimeout := atomic.LoadInt64(&s.sendTimeout) + if sendTimeout != 0 { + deadline = time.Now().Add(time.Duration(sendTimeout)) + } + if err := s.file.SetWriteDeadline(deadline); err != nil { + return err + } + serializedReq := request.Serialize() + err = rawConn.Write(func(fd uintptr) (done bool) { + innerErr = unix.Sendto(int(s.fd), serializedReq, 0, &s.lsa) + return innerErr != unix.EWOULDBLOCK + }) + if innerErr != nil { + return innerErr + } + if err != nil { + // The timeout was previously implemented using SO_SNDTIMEO on a blocking + // socket. So, continue to return EAGAIN when the timeout is reached. + if errors.Is(err, os.ErrDeadlineExceeded) { + return unix.EAGAIN + } + return err + } + return nil } func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, *unix.SockaddrNetlink, error) { @@ -812,20 +852,33 @@ func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, *unix.SockaddrNetli return nil, nil, err } var ( + deadline time.Time fromAddr *unix.SockaddrNetlink rb [RECEIVE_BUFFER_SIZE]byte nr int from unix.Sockaddr innerErr error ) + receiveTimeout := atomic.LoadInt64(&s.receiveTimeout) + if receiveTimeout != 0 { + deadline = time.Now().Add(time.Duration(receiveTimeout)) + } + if err := s.file.SetReadDeadline(deadline); err != nil { + return nil, nil, err + } err = rawConn.Read(func(fd uintptr) (done bool) { nr, from, innerErr = unix.Recvfrom(int(fd), rb[:], 0) return innerErr != unix.EWOULDBLOCK }) if innerErr != nil { - err = innerErr + return nil, nil, innerErr } if err != nil { + // The timeout was previously implemented using SO_RCVTIMEO on a blocking + // socket. So, continue to return EAGAIN when the timeout is reached. + if errors.Is(err, os.ErrDeadlineExceeded) { + return nil, nil, unix.EAGAIN + } return nil, nil, err } fromAddr, ok := from.(*unix.SockaddrNetlink) @@ -847,16 +900,14 @@ func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, *unix.SockaddrNetli // SetSendTimeout allows to set a send timeout on the socket func (s *NetlinkSocket) SetSendTimeout(timeout *unix.Timeval) error { - // Set a send timeout of SOCKET_SEND_TIMEOUT, this will allow the Send to periodically unblock and avoid that a routine - // remains stuck on a send on a closed fd - return unix.SetsockoptTimeval(int(s.fd), unix.SOL_SOCKET, unix.SO_SNDTIMEO, timeout) + atomic.StoreInt64(&s.sendTimeout, timeout.Nano()) + return nil } // SetReceiveTimeout allows to set a receive timeout on the socket func (s *NetlinkSocket) SetReceiveTimeout(timeout *unix.Timeval) error { - // Set a read timeout of SOCKET_READ_TIMEOUT, this will allow the Read to periodically unblock and avoid that a routine - // remains stuck on a recvmsg on a closed fd - return unix.SetsockoptTimeval(int(s.fd), unix.SOL_SOCKET, unix.SO_RCVTIMEO, timeout) + atomic.StoreInt64(&s.receiveTimeout, timeout.Nano()) + return nil } // SetReceiveBufferSize allows to set a receive buffer size on the socket diff --git a/nl/nl_linux_test.go b/nl/nl_linux_test.go index 96de8d5..117a48a 100644 --- a/nl/nl_linux_test.go +++ b/nl/nl_linux_test.go @@ -97,6 +97,69 @@ func TestIfSocketCloses(t *testing.T) { } } +func TestReceiveTimeout(t *testing.T) { + nlSock, err := getNetlinkSocket(unix.NETLINK_ROUTE) + if err != nil { + t.Fatalf("Error creating the socket: %v", err) + } + // Even if the test fails because the timeout doesn't work, closing the + // socket at the end of the test should result in an EAGAIN (as long as + // TestIfSocketCloses completed, otherwise this test will leak the + // goroutines running the Receive). + defer nlSock.Close() + const failAfter = time.Second + + tests := []struct { + name string + timeout time.Duration + }{ + { + name: "1us timeout", // The smallest value accepted by Handle.SetSocketTimeout + timeout: time.Microsecond, + }, + { + name: "100ms timeout", + timeout: 100 * time.Millisecond, + }, + { + name: "500ms timeout", + timeout: 500 * time.Millisecond, + }, + } + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + timeout := unix.NsecToTimeval(int64(tc.timeout)) + nlSock.SetReceiveTimeout(&timeout) + + doneC := make(chan time.Duration) + errC := make(chan error) + go func() { + start := time.Now() + _, _, err := nlSock.Receive() + dur := time.Since(start) + if err != unix.EAGAIN { + errC <- err + return + } + doneC <- dur + }() + + failTimerC := time.After(failAfter) + select { + case dur := <-doneC: + if dur < tc.timeout || dur > (tc.timeout+(100*time.Millisecond)) { + t.Fatalf("Expected timeout %v got %v", tc.timeout, dur) + } + case err := <-errC: + t.Fatalf("Expected EAGAIN, but got: %v", err) + case <-failTimerC: + t.Fatalf("No timeout received") + } + }) + } +} + func (msg *CnMsgOp) write(b []byte) { native := NativeEndian() native.PutUint32(b[0:4], msg.ID.Idx)