nl: Use atomic load/store for fd field

This allows Close to be called concurrently with Receive without
triggering a data race.

Signed-off-by: Aaron Lehmann <aaron.lehmann@docker.com>
This commit is contained in:
Aaron Lehmann 2017-06-01 08:46:57 +02:00 committed by Vish Ishaya
parent 7bd45e5974
commit bd6d5de5cc
1 changed files with 14 additions and 11 deletions

View File

@ -459,7 +459,7 @@ func NewNetlinkRequest(proto, flags int) *NetlinkRequest {
}
type NetlinkSocket struct {
fd int
fd int32
lsa syscall.SockaddrNetlink
sync.Mutex
}
@ -470,7 +470,7 @@ func getNetlinkSocket(protocol int) (*NetlinkSocket, error) {
return nil, err
}
s := &NetlinkSocket{
fd: fd,
fd: int32(fd),
}
s.lsa.Family = syscall.AF_NETLINK
if err := syscall.Bind(fd, &s.lsa); err != nil {
@ -556,7 +556,7 @@ func Subscribe(protocol int, groups ...uint) (*NetlinkSocket, error) {
return nil, err
}
s := &NetlinkSocket{
fd: fd,
fd: int32(fd),
}
s.lsa.Family = syscall.AF_NETLINK
@ -585,30 +585,32 @@ func SubscribeAt(newNs, curNs netns.NsHandle, protocol int, groups ...uint) (*Ne
}
func (s *NetlinkSocket) Close() {
syscall.Close(s.fd)
s.fd = -1
fd := int(atomic.SwapInt32(&s.fd, -1))
syscall.Close(fd)
}
func (s *NetlinkSocket) GetFd() int {
return s.fd
return int(atomic.LoadInt32(&s.fd))
}
func (s *NetlinkSocket) Send(request *NetlinkRequest) error {
if s.fd < 0 {
fd := int(atomic.LoadInt32(&s.fd))
if fd < 0 {
return fmt.Errorf("Send called on a closed socket")
}
if err := syscall.Sendto(s.fd, request.Serialize(), 0, &s.lsa); err != nil {
if err := syscall.Sendto(fd, request.Serialize(), 0, &s.lsa); err != nil {
return err
}
return nil
}
func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, error) {
if s.fd < 0 {
fd := int(atomic.LoadInt32(&s.fd))
if fd < 0 {
return nil, fmt.Errorf("Receive called on a closed socket")
}
rb := make([]byte, syscall.Getpagesize())
nr, _, err := syscall.Recvfrom(s.fd, rb, 0)
nr, _, err := syscall.Recvfrom(fd, rb, 0)
if err != nil {
return nil, err
}
@ -620,7 +622,8 @@ func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, error) {
}
func (s *NetlinkSocket) GetPid() (uint32, error) {
lsa, err := syscall.Getsockname(s.fd)
fd := int(atomic.LoadInt32(&s.fd))
lsa, err := syscall.Getsockname(fd)
if err != nil {
return 0, err
}