diff --git a/socket_linux.go b/socket_linux.go index e4b6fa7..4eb4aea 100644 --- a/socket_linux.go +++ b/socket_linux.go @@ -157,7 +157,7 @@ func (u *UnixSocket) deserialize(b []byte) error { } // SocketGet returns the Socket identified by its local and remote addresses. -func SocketGet(local, remote net.Addr) (*Socket, error) { +func (h *Handle) SocketGet(local, remote net.Addr) (*Socket, error) { var protocol uint8 var localIP, remoteIP net.IP var localPort, remotePort uint16 @@ -199,12 +199,7 @@ func SocketGet(local, remote net.Addr) (*Socket, error) { return nil, ErrNotImplemented } - s, err := nl.Subscribe(unix.NETLINK_INET_DIAG) - if err != nil { - return nil, err - } - defer s.Close() - req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) + req := h.newNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) req.AddData(&socketRequest{ Family: family, Protocol: protocol, @@ -218,32 +213,31 @@ func SocketGet(local, remote net.Addr) (*Socket, error) { }, }) - if err := s.Send(req); err != nil { - return nil, err - } - - msgs, from, err := s.Receive() + msgs, err := req.Execute(unix.NETLINK_INET_DIAG, nl.SOCK_DIAG_BY_FAMILY) if err != nil { return nil, err } - if from.Pid != nl.PidKernel { - return nil, fmt.Errorf("wrong sender portid %d, expected %d", from.Pid, nl.PidKernel) - } if len(msgs) == 0 { return nil, errors.New("no message nor error from netlink") } if len(msgs) > 2 { return nil, fmt.Errorf("multiple (%d) matching sockets", len(msgs)) } + sock := &Socket{} - if err := sock.deserialize(msgs[0].Data); err != nil { + if err := sock.deserialize(msgs[0]); err != nil { return nil, err } return sock, nil } +// SocketGet returns the Socket identified by its local and remote addresses. +func SocketGet(local, remote net.Addr) (*Socket, error) { + return pkgHandle.SocketGet(local, remote) +} + // SocketDestroy kills the Socket identified by its local and remote addresses. -func SocketDestroy(local, remote net.Addr) error { +func (h *Handle) SocketDestroy(local, remote net.Addr) error { localTCP, ok := local.(*net.TCPAddr) if !ok { return ErrNotImplemented @@ -266,7 +260,7 @@ func SocketDestroy(local, remote net.Addr) error { return err } defer s.Close() - req := nl.NewNetlinkRequest(nl.SOCK_DESTROY, unix.NLM_F_ACK) + req := h.newNetlinkRequest(nl.SOCK_DESTROY, unix.NLM_F_ACK) req.AddData(&socketRequest{ Family: unix.AF_INET, Protocol: unix.IPPROTO_TCP, @@ -278,13 +272,20 @@ func SocketDestroy(local, remote net.Addr) error { Cookie: [2]uint32{nl.TCPDIAG_NOCOOKIE, nl.TCPDIAG_NOCOOKIE}, }, }) - return s.Send(req) + + _, err = req.Execute(unix.NETLINK_INET_DIAG, 0) + return err +} + +// SocketDestroy kills the Socket identified by its local and remote addresses. +func SocketDestroy(local, remote net.Addr) error { + return pkgHandle.SocketDestroy(local, remote) } // SocketDiagTCPInfo requests INET_DIAG_INFO for TCP protocol for specified family type and return with extension TCP info. -func SocketDiagTCPInfo(family uint8) ([]*InetDiagTCPInfoResp, error) { +func (h *Handle) SocketDiagTCPInfo(family uint8) ([]*InetDiagTCPInfoResp, error) { // Construct the request - req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) + req := h.newNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) req.AddData(&socketRequest{ Family: family, Protocol: unix.IPPROTO_TCP, @@ -294,34 +295,41 @@ func SocketDiagTCPInfo(family uint8) ([]*InetDiagTCPInfoResp, error) { // Do the query and parse the result var result []*InetDiagTCPInfoResp - err := socketDiagExecutor(req, func(m syscall.NetlinkMessage) error { + var err error + err = req.ExecuteIter(unix.NETLINK_INET_DIAG, nl.SOCK_DIAG_BY_FAMILY, func(msg []byte) bool { sockInfo := &Socket{} - if err := sockInfo.deserialize(m.Data); err != nil { - return err + if err = sockInfo.deserialize(msg); err != nil { + return false } - attrs, err := nl.ParseRouteAttr(m.Data[sizeofSocket:]) - if err != nil { - return err + var attrs []syscall.NetlinkRouteAttr + if attrs, err = nl.ParseRouteAttr(msg[sizeofSocket:]); err != nil { + return false } - res, err := attrsToInetDiagTCPInfoResp(attrs, sockInfo) - if err != nil { - return err + var res *InetDiagTCPInfoResp + if res, err = attrsToInetDiagTCPInfoResp(attrs, sockInfo); err != nil { + return false } result = append(result, res) - return nil + return true }) + if err != nil { return nil, err } return result, nil } +// SocketDiagTCPInfo requests INET_DIAG_INFO for TCP protocol for specified family type and return with extension TCP info. +func SocketDiagTCPInfo(family uint8) ([]*InetDiagTCPInfoResp, error) { + return pkgHandle.SocketDiagTCPInfo(family) +} + // SocketDiagTCP requests INET_DIAG_INFO for TCP protocol for specified family type and return related socket. -func SocketDiagTCP(family uint8) ([]*Socket, error) { +func (h *Handle) SocketDiagTCP(family uint8) ([]*Socket, error) { // Construct the request - req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) + req := h.newNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) req.AddData(&socketRequest{ Family: family, Protocol: unix.IPPROTO_TCP, @@ -331,13 +339,14 @@ func SocketDiagTCP(family uint8) ([]*Socket, error) { // Do the query and parse the result var result []*Socket - err := socketDiagExecutor(req, func(m syscall.NetlinkMessage) error { + var err error + err = req.ExecuteIter(unix.NETLINK_INET_DIAG, nl.SOCK_DIAG_BY_FAMILY, func(msg []byte) bool { sockInfo := &Socket{} - if err := sockInfo.deserialize(m.Data); err != nil { - return err + if err = sockInfo.deserialize(msg); err != nil { + return false } result = append(result, sockInfo) - return nil + return true }) if err != nil { return nil, err @@ -345,15 +354,20 @@ func SocketDiagTCP(family uint8) ([]*Socket, error) { return result, nil } +// SocketDiagTCP requests INET_DIAG_INFO for TCP protocol for specified family type and return related socket. +func SocketDiagTCP(family uint8) ([]*Socket, error) { + return pkgHandle.SocketDiagTCP(family) +} + // SocketDiagUDPInfo requests INET_DIAG_INFO for UDP protocol for specified family type and return with extension info. -func SocketDiagUDPInfo(family uint8) ([]*InetDiagUDPInfoResp, error) { +func (h *Handle) SocketDiagUDPInfo(family uint8) ([]*InetDiagUDPInfoResp, error) { // Construct the request var extensions uint8 extensions = 1 << (INET_DIAG_VEGASINFO - 1) extensions |= 1 << (INET_DIAG_INFO - 1) extensions |= 1 << (INET_DIAG_MEMINFO - 1) - req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) + req := h.newNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) req.AddData(&socketRequest{ Family: family, Protocol: unix.IPPROTO_UDP, @@ -363,23 +377,25 @@ func SocketDiagUDPInfo(family uint8) ([]*InetDiagUDPInfoResp, error) { // Do the query and parse the result var result []*InetDiagUDPInfoResp - err := socketDiagExecutor(req, func(m syscall.NetlinkMessage) error { + var err error + err = req.ExecuteIter(unix.NETLINK_INET_DIAG, nl.SOCK_DIAG_BY_FAMILY, func(msg []byte) bool { sockInfo := &Socket{} - if err := sockInfo.deserialize(m.Data); err != nil { - return err - } - attrs, err := nl.ParseRouteAttr(m.Data[sizeofSocket:]) - if err != nil { - return err + if err = sockInfo.deserialize(msg); err != nil { + return false } - res, err := attrsToInetDiagUDPInfoResp(attrs, sockInfo) - if err != nil { - return err + var attrs []syscall.NetlinkRouteAttr + if attrs, err = nl.ParseRouteAttr(msg[sizeofSocket:]); err != nil { + return false + } + + var res *InetDiagUDPInfoResp + if res, err = attrsToInetDiagUDPInfoResp(attrs, sockInfo); err != nil { + return false } result = append(result, res) - return nil + return true }) if err != nil { return nil, err @@ -387,10 +403,15 @@ func SocketDiagUDPInfo(family uint8) ([]*InetDiagUDPInfoResp, error) { return result, nil } +// SocketDiagUDPInfo requests INET_DIAG_INFO for UDP protocol for specified family type and return with extension info. +func SocketDiagUDPInfo(family uint8) ([]*InetDiagUDPInfoResp, error) { + return pkgHandle.SocketDiagUDPInfo(family) +} + // SocketDiagUDP requests INET_DIAG_INFO for UDP protocol for specified family type and return related socket. -func SocketDiagUDP(family uint8) ([]*Socket, error) { +func (h *Handle) SocketDiagUDP(family uint8) ([]*Socket, error) { // Construct the request - req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) + req := h.newNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) req.AddData(&socketRequest{ Family: family, Protocol: unix.IPPROTO_UDP, @@ -400,13 +421,64 @@ func SocketDiagUDP(family uint8) ([]*Socket, error) { // Do the query and parse the result var result []*Socket - err := socketDiagExecutor(req, func(m syscall.NetlinkMessage) error { + var err error + err = req.ExecuteIter(unix.NETLINK_INET_DIAG, nl.SOCK_DIAG_BY_FAMILY, func(msg []byte) bool { sockInfo := &Socket{} - if err := sockInfo.deserialize(m.Data); err != nil { - return err + if err = sockInfo.deserialize(msg); err != nil { + return false } result = append(result, sockInfo) - return nil + return true + }) + if err != nil { + return nil, err + } + return result, nil +} + +// SocketDiagUDP requests INET_DIAG_INFO for UDP protocol for specified family type and return related socket. +func SocketDiagUDP(family uint8) ([]*Socket, error) { + return pkgHandle.SocketDiagUDP(family) +} + +// UnixSocketDiagInfo requests UNIX_DIAG_INFO for unix sockets and return with extension info. +func (h *Handle) UnixSocketDiagInfo() ([]*UnixDiagInfoResp, error) { + // Construct the request + var extensions uint8 + extensions = 1 << UNIX_DIAG_NAME + extensions |= 1 << UNIX_DIAG_PEER + extensions |= 1 << UNIX_DIAG_RQLEN + req := h.newNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) + req.AddData(&unixSocketRequest{ + Family: unix.AF_UNIX, + States: ^uint32(0), // all states + Show: uint32(extensions), + }) + + var result []*UnixDiagInfoResp + var err error + err = req.ExecuteIter(unix.NETLINK_INET_DIAG, nl.SOCK_DIAG_BY_FAMILY, func(msg []byte) bool { + sockInfo := &UnixSocket{} + if err = sockInfo.deserialize(msg); err != nil { + return false + } + + // Diagnosis also delivers sockets with AF_INET family, filter those + if sockInfo.Family != unix.AF_UNIX { + return false + } + + var attrs []syscall.NetlinkRouteAttr + if attrs, err = nl.ParseRouteAttr(msg[sizeofSocket:]); err != nil { + return false + } + + var res *UnixDiagInfoResp + if res, err = attrsToUnixDiagInfoResp(attrs, sockInfo); err != nil { + return false + } + result = append(result, res) + return true }) if err != nil { return nil, err @@ -416,41 +488,31 @@ func SocketDiagUDP(family uint8) ([]*Socket, error) { // UnixSocketDiagInfo requests UNIX_DIAG_INFO for unix sockets and return with extension info. func UnixSocketDiagInfo() ([]*UnixDiagInfoResp, error) { + return pkgHandle.UnixSocketDiagInfo() +} + +// UnixSocketDiag requests UNIX_DIAG_INFO for unix sockets. +func (h *Handle) UnixSocketDiag() ([]*UnixSocket, error) { // Construct the request - var extensions uint8 - extensions = 1 << UNIX_DIAG_NAME - extensions |= 1 << UNIX_DIAG_PEER - extensions |= 1 << UNIX_DIAG_RQLEN - req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) + req := h.newNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) req.AddData(&unixSocketRequest{ Family: unix.AF_UNIX, States: ^uint32(0), // all states - Show: uint32(extensions), }) - var result []*UnixDiagInfoResp - err := socketDiagExecutor(req, func(m syscall.NetlinkMessage) error { + var result []*UnixSocket + var err error + err = req.ExecuteIter(unix.NETLINK_INET_DIAG, nl.SOCK_DIAG_BY_FAMILY, func(msg []byte) bool { sockInfo := &UnixSocket{} - if err := sockInfo.deserialize(m.Data); err != nil { - return err + if err = sockInfo.deserialize(msg); err != nil { + return false } // Diagnosis also delivers sockets with AF_INET family, filter those - if sockInfo.Family != unix.AF_UNIX { - return nil + if sockInfo.Family == unix.AF_UNIX { + result = append(result, sockInfo) } - - attrs, err := nl.ParseRouteAttr(m.Data[sizeofUnixSocket:]) - if err != nil { - return err - } - - res, err := attrsToUnixDiagInfoResp(attrs, sockInfo) - if err != nil { - return err - } - result = append(result, res) - return nil + return true }) if err != nil { return nil, err @@ -460,68 +522,7 @@ func UnixSocketDiagInfo() ([]*UnixDiagInfoResp, error) { // UnixSocketDiag requests UNIX_DIAG_INFO for unix sockets. func UnixSocketDiag() ([]*UnixSocket, error) { - // Construct the request - req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) - req.AddData(&unixSocketRequest{ - Family: unix.AF_UNIX, - States: ^uint32(0), // all states - }) - - var result []*UnixSocket - err := socketDiagExecutor(req, func(m syscall.NetlinkMessage) error { - sockInfo := &UnixSocket{} - if err := sockInfo.deserialize(m.Data); err != nil { - return err - } - - // Diagnosis also delivers sockets with AF_INET family, filter those - if sockInfo.Family == unix.AF_UNIX { - result = append(result, sockInfo) - } - return nil - }) - if err != nil { - return nil, err - } - return result, nil -} - -// socketDiagExecutor requests diagnoses info from the NETLINK_INET_DIAG socket for the specified request. -func socketDiagExecutor(req *nl.NetlinkRequest, receiver func(syscall.NetlinkMessage) error) error { - s, err := nl.Subscribe(unix.NETLINK_INET_DIAG) - if err != nil { - return err - } - defer s.Close() - s.Send(req) - -loop: - for { - msgs, from, err := s.Receive() - if err != nil { - return err - } - if from.Pid != nl.PidKernel { - return fmt.Errorf("wrong sender portid %d, expected %d", from.Pid, nl.PidKernel) - } - if len(msgs) == 0 { - return errors.New("no message nor error from netlink") - } - - for _, m := range msgs { - switch m.Header.Type { - case unix.NLMSG_DONE: - break loop - case unix.NLMSG_ERROR: - error := int32(native.Uint32(m.Data[0:4])) - return syscall.Errno(-error) - } - if err := receiver(m); err != nil { - return err - } - } - } - return nil + return pkgHandle.UnixSocketDiag() } func attrsToInetDiagTCPInfoResp(attrs []syscall.NetlinkRouteAttr, sockInfo *Socket) (*InetDiagTCPInfoResp, error) {