mirror of https://github.com/vishvananda/netlink
Fix SetSendTimeout/SetReceiveTimeout
They were implemented using SO_SNDTIMEO/SO_RCVTIMEO on the socket descriptor - but that doesn't work now the socket is non-blocking. Instead, set deadlines on the file read/write. Signed-off-by: Rob Murray <rob.murray@docker.com>
This commit is contained in:
parent
0cd1f7961c
commit
e194da52b1
|
@ -12,11 +12,9 @@ import (
|
|||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/vishvananda/netlink/nl"
|
||||
"github.com/vishvananda/netns"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func TestHandleCreateClose(t *testing.T) {
|
||||
|
@ -122,13 +120,22 @@ func TestHandleTimeout(t *testing.T) {
|
|||
defer h.Close()
|
||||
|
||||
for _, sh := range h.sockets {
|
||||
verifySockTimeVal(t, sh.Socket.GetFd(), unix.Timeval{Sec: 0, Usec: 0})
|
||||
verifySockTimeVal(t, sh.Socket, time.Duration(0))
|
||||
}
|
||||
|
||||
h.SetSocketTimeout(2*time.Second + 8*time.Millisecond)
|
||||
const timeout = 2*time.Second + 8*time.Millisecond
|
||||
h.SetSocketTimeout(timeout)
|
||||
|
||||
for _, sh := range h.sockets {
|
||||
verifySockTimeVal(t, sh.Socket.GetFd(), unix.Timeval{Sec: 2, Usec: 8000})
|
||||
verifySockTimeVal(t, sh.Socket, timeout)
|
||||
}
|
||||
}
|
||||
|
||||
func verifySockTimeVal(t *testing.T, socket *nl.NetlinkSocket, expTimeout time.Duration) {
|
||||
t.Helper()
|
||||
send, receive := socket.GetTimeouts()
|
||||
if send != expTimeout || receive != expTimeout {
|
||||
t.Fatalf("Expected timeout: %v, got Send: %v, Receive: %v", expTimeout, send, receive)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -157,30 +164,6 @@ func TestHandleReceiveBuffer(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func verifySockTimeVal(t *testing.T, fd int, tv unix.Timeval) {
|
||||
var (
|
||||
tr unix.Timeval
|
||||
v = uint32(0x10)
|
||||
)
|
||||
_, _, errno := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(fd), unix.SOL_SOCKET, unix.SO_SNDTIMEO, uintptr(unsafe.Pointer(&tr)), uintptr(unsafe.Pointer(&v)), 0)
|
||||
if errno != 0 {
|
||||
t.Fatal(errno)
|
||||
}
|
||||
|
||||
if tr.Sec != tv.Sec || tr.Usec != tv.Usec {
|
||||
t.Fatalf("Unexpected timeout value read: %v. Expected: %v", tr, tv)
|
||||
}
|
||||
|
||||
_, _, errno = unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(fd), unix.SOL_SOCKET, unix.SO_RCVTIMEO, uintptr(unsafe.Pointer(&tr)), uintptr(unsafe.Pointer(&v)), 0)
|
||||
if errno != 0 {
|
||||
t.Fatal(errno)
|
||||
}
|
||||
|
||||
if tr.Sec != tv.Sec || tr.Usec != tv.Usec {
|
||||
t.Fatalf("Unexpected timeout value read: %v. Expected: %v", tr, tv)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
iter = 10
|
||||
numThread = uint32(4)
|
||||
|
|
|
@ -4,6 +4,7 @@ package nl
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
|
@ -11,6 +12,7 @@ import (
|
|||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/vishvananda/netns"
|
||||
|
@ -656,9 +658,11 @@ func NewNetlinkRequest(proto, flags int) *NetlinkRequest {
|
|||
}
|
||||
|
||||
type NetlinkSocket struct {
|
||||
fd int32
|
||||
file *os.File
|
||||
lsa unix.SockaddrNetlink
|
||||
fd int32
|
||||
file *os.File
|
||||
lsa unix.SockaddrNetlink
|
||||
sendTimeout int64 // Access using atomic.Load/StoreInt64
|
||||
receiveTimeout int64 // Access using atomic.Load/StoreInt64
|
||||
sync.Mutex
|
||||
}
|
||||
|
||||
|
@ -802,8 +806,44 @@ func (s *NetlinkSocket) GetFd() int {
|
|||
return int(s.fd)
|
||||
}
|
||||
|
||||
func (s *NetlinkSocket) GetTimeouts() (send, receive time.Duration) {
|
||||
return time.Duration(atomic.LoadInt64(&s.sendTimeout)),
|
||||
time.Duration(atomic.LoadInt64(&s.receiveTimeout))
|
||||
}
|
||||
|
||||
func (s *NetlinkSocket) Send(request *NetlinkRequest) error {
|
||||
return unix.Sendto(int(s.fd), request.Serialize(), 0, &s.lsa)
|
||||
rawConn, err := s.file.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var (
|
||||
deadline time.Time
|
||||
innerErr error
|
||||
)
|
||||
sendTimeout := atomic.LoadInt64(&s.sendTimeout)
|
||||
if sendTimeout != 0 {
|
||||
deadline = time.Now().Add(time.Duration(sendTimeout))
|
||||
}
|
||||
if err := s.file.SetWriteDeadline(deadline); err != nil {
|
||||
return err
|
||||
}
|
||||
serializedReq := request.Serialize()
|
||||
err = rawConn.Write(func(fd uintptr) (done bool) {
|
||||
innerErr = unix.Sendto(int(s.fd), serializedReq, 0, &s.lsa)
|
||||
return innerErr != unix.EWOULDBLOCK
|
||||
})
|
||||
if innerErr != nil {
|
||||
return innerErr
|
||||
}
|
||||
if err != nil {
|
||||
// The timeout was previously implemented using SO_SNDTIMEO on a blocking
|
||||
// socket. So, continue to return EAGAIN when the timeout is reached.
|
||||
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
return unix.EAGAIN
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, *unix.SockaddrNetlink, error) {
|
||||
|
@ -812,20 +852,33 @@ func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, *unix.SockaddrNetli
|
|||
return nil, nil, err
|
||||
}
|
||||
var (
|
||||
deadline time.Time
|
||||
fromAddr *unix.SockaddrNetlink
|
||||
rb [RECEIVE_BUFFER_SIZE]byte
|
||||
nr int
|
||||
from unix.Sockaddr
|
||||
innerErr error
|
||||
)
|
||||
receiveTimeout := atomic.LoadInt64(&s.receiveTimeout)
|
||||
if receiveTimeout != 0 {
|
||||
deadline = time.Now().Add(time.Duration(receiveTimeout))
|
||||
}
|
||||
if err := s.file.SetReadDeadline(deadline); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
err = rawConn.Read(func(fd uintptr) (done bool) {
|
||||
nr, from, innerErr = unix.Recvfrom(int(fd), rb[:], 0)
|
||||
return innerErr != unix.EWOULDBLOCK
|
||||
})
|
||||
if innerErr != nil {
|
||||
err = innerErr
|
||||
return nil, nil, innerErr
|
||||
}
|
||||
if err != nil {
|
||||
// The timeout was previously implemented using SO_RCVTIMEO on a blocking
|
||||
// socket. So, continue to return EAGAIN when the timeout is reached.
|
||||
if errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
return nil, nil, unix.EAGAIN
|
||||
}
|
||||
return nil, nil, err
|
||||
}
|
||||
fromAddr, ok := from.(*unix.SockaddrNetlink)
|
||||
|
@ -847,16 +900,14 @@ func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, *unix.SockaddrNetli
|
|||
|
||||
// SetSendTimeout allows to set a send timeout on the socket
|
||||
func (s *NetlinkSocket) SetSendTimeout(timeout *unix.Timeval) error {
|
||||
// Set a send timeout of SOCKET_SEND_TIMEOUT, this will allow the Send to periodically unblock and avoid that a routine
|
||||
// remains stuck on a send on a closed fd
|
||||
return unix.SetsockoptTimeval(int(s.fd), unix.SOL_SOCKET, unix.SO_SNDTIMEO, timeout)
|
||||
atomic.StoreInt64(&s.sendTimeout, timeout.Nano())
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReceiveTimeout allows to set a receive timeout on the socket
|
||||
func (s *NetlinkSocket) SetReceiveTimeout(timeout *unix.Timeval) error {
|
||||
// Set a read timeout of SOCKET_READ_TIMEOUT, this will allow the Read to periodically unblock and avoid that a routine
|
||||
// remains stuck on a recvmsg on a closed fd
|
||||
return unix.SetsockoptTimeval(int(s.fd), unix.SOL_SOCKET, unix.SO_RCVTIMEO, timeout)
|
||||
atomic.StoreInt64(&s.receiveTimeout, timeout.Nano())
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetReceiveBufferSize allows to set a receive buffer size on the socket
|
||||
|
|
|
@ -97,6 +97,69 @@ func TestIfSocketCloses(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestReceiveTimeout(t *testing.T) {
|
||||
nlSock, err := getNetlinkSocket(unix.NETLINK_ROUTE)
|
||||
if err != nil {
|
||||
t.Fatalf("Error creating the socket: %v", err)
|
||||
}
|
||||
// Even if the test fails because the timeout doesn't work, closing the
|
||||
// socket at the end of the test should result in an EAGAIN (as long as
|
||||
// TestIfSocketCloses completed, otherwise this test will leak the
|
||||
// goroutines running the Receive).
|
||||
defer nlSock.Close()
|
||||
const failAfter = time.Second
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
timeout time.Duration
|
||||
}{
|
||||
{
|
||||
name: "1us timeout", // The smallest value accepted by Handle.SetSocketTimeout
|
||||
timeout: time.Microsecond,
|
||||
},
|
||||
{
|
||||
name: "100ms timeout",
|
||||
timeout: 100 * time.Millisecond,
|
||||
},
|
||||
{
|
||||
name: "500ms timeout",
|
||||
timeout: 500 * time.Millisecond,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
timeout := unix.NsecToTimeval(int64(tc.timeout))
|
||||
nlSock.SetReceiveTimeout(&timeout)
|
||||
|
||||
doneC := make(chan time.Duration)
|
||||
errC := make(chan error)
|
||||
go func() {
|
||||
start := time.Now()
|
||||
_, _, err := nlSock.Receive()
|
||||
dur := time.Since(start)
|
||||
if err != unix.EAGAIN {
|
||||
errC <- err
|
||||
return
|
||||
}
|
||||
doneC <- dur
|
||||
}()
|
||||
|
||||
failTimerC := time.After(failAfter)
|
||||
select {
|
||||
case dur := <-doneC:
|
||||
if dur < tc.timeout || dur > (tc.timeout+(100*time.Millisecond)) {
|
||||
t.Fatalf("Expected timeout %v got %v", tc.timeout, dur)
|
||||
}
|
||||
case err := <-errC:
|
||||
t.Fatalf("Expected EAGAIN, but got: %v", err)
|
||||
case <-failTimerC:
|
||||
t.Fatalf("No timeout received")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (msg *CnMsgOp) write(b []byte) {
|
||||
native := NativeEndian()
|
||||
native.PutUint32(b[0:4], msg.ID.Idx)
|
||||
|
|
Loading…
Reference in New Issue