nl: Use atomic load/store for fd field

This allows Close to be called concurrently with Receive without
triggering a data race.

Signed-off-by: Aaron Lehmann <aaron.lehmann@docker.com>
This commit is contained in:
Aaron Lehmann 2017-06-01 08:46:57 +02:00 committed by Vish Ishaya
parent 7bd45e5974
commit bd6d5de5cc
1 changed files with 14 additions and 11 deletions

View File

@ -459,7 +459,7 @@ func NewNetlinkRequest(proto, flags int) *NetlinkRequest {
} }
type NetlinkSocket struct { type NetlinkSocket struct {
fd int fd int32
lsa syscall.SockaddrNetlink lsa syscall.SockaddrNetlink
sync.Mutex sync.Mutex
} }
@ -470,7 +470,7 @@ func getNetlinkSocket(protocol int) (*NetlinkSocket, error) {
return nil, err return nil, err
} }
s := &NetlinkSocket{ s := &NetlinkSocket{
fd: fd, fd: int32(fd),
} }
s.lsa.Family = syscall.AF_NETLINK s.lsa.Family = syscall.AF_NETLINK
if err := syscall.Bind(fd, &s.lsa); err != nil { if err := syscall.Bind(fd, &s.lsa); err != nil {
@ -556,7 +556,7 @@ func Subscribe(protocol int, groups ...uint) (*NetlinkSocket, error) {
return nil, err return nil, err
} }
s := &NetlinkSocket{ s := &NetlinkSocket{
fd: fd, fd: int32(fd),
} }
s.lsa.Family = syscall.AF_NETLINK s.lsa.Family = syscall.AF_NETLINK
@ -585,30 +585,32 @@ func SubscribeAt(newNs, curNs netns.NsHandle, protocol int, groups ...uint) (*Ne
} }
func (s *NetlinkSocket) Close() { func (s *NetlinkSocket) Close() {
syscall.Close(s.fd) fd := int(atomic.SwapInt32(&s.fd, -1))
s.fd = -1 syscall.Close(fd)
} }
func (s *NetlinkSocket) GetFd() int { func (s *NetlinkSocket) GetFd() int {
return s.fd return int(atomic.LoadInt32(&s.fd))
} }
func (s *NetlinkSocket) Send(request *NetlinkRequest) error { func (s *NetlinkSocket) Send(request *NetlinkRequest) error {
if s.fd < 0 { fd := int(atomic.LoadInt32(&s.fd))
if fd < 0 {
return fmt.Errorf("Send called on a closed socket") return fmt.Errorf("Send called on a closed socket")
} }
if err := syscall.Sendto(s.fd, request.Serialize(), 0, &s.lsa); err != nil { if err := syscall.Sendto(fd, request.Serialize(), 0, &s.lsa); err != nil {
return err return err
} }
return nil return nil
} }
func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, error) { func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, error) {
if s.fd < 0 { fd := int(atomic.LoadInt32(&s.fd))
if fd < 0 {
return nil, fmt.Errorf("Receive called on a closed socket") return nil, fmt.Errorf("Receive called on a closed socket")
} }
rb := make([]byte, syscall.Getpagesize()) rb := make([]byte, syscall.Getpagesize())
nr, _, err := syscall.Recvfrom(s.fd, rb, 0) nr, _, err := syscall.Recvfrom(fd, rb, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -620,7 +622,8 @@ func (s *NetlinkSocket) Receive() ([]syscall.NetlinkMessage, error) {
} }
func (s *NetlinkSocket) GetPid() (uint32, error) { func (s *NetlinkSocket) GetPid() (uint32, error) {
lsa, err := syscall.Getsockname(s.fd) fd := int(atomic.LoadInt32(&s.fd))
lsa, err := syscall.Getsockname(fd)
if err != nil { if err != nil {
return 0, err return 0, err
} }