Add handle support for socket

This commit is contained in:
Hasan Mahmood 2024-07-29 19:23:29 -05:00 committed by Alessandro Boch
parent aaf4f9866c
commit a57a7bd6b2
1 changed files with 143 additions and 142 deletions

View File

@ -157,7 +157,7 @@ func (u *UnixSocket) deserialize(b []byte) error {
} }
// SocketGet returns the Socket identified by its local and remote addresses. // SocketGet returns the Socket identified by its local and remote addresses.
func SocketGet(local, remote net.Addr) (*Socket, error) { func (h *Handle) SocketGet(local, remote net.Addr) (*Socket, error) {
var protocol uint8 var protocol uint8
var localIP, remoteIP net.IP var localIP, remoteIP net.IP
var localPort, remotePort uint16 var localPort, remotePort uint16
@ -199,12 +199,7 @@ func SocketGet(local, remote net.Addr) (*Socket, error) {
return nil, ErrNotImplemented return nil, ErrNotImplemented
} }
s, err := nl.Subscribe(unix.NETLINK_INET_DIAG) req := h.newNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP)
if err != nil {
return nil, err
}
defer s.Close()
req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP)
req.AddData(&socketRequest{ req.AddData(&socketRequest{
Family: family, Family: family,
Protocol: protocol, Protocol: protocol,
@ -218,32 +213,31 @@ func SocketGet(local, remote net.Addr) (*Socket, error) {
}, },
}) })
if err := s.Send(req); err != nil { msgs, err := req.Execute(unix.NETLINK_INET_DIAG, nl.SOCK_DIAG_BY_FAMILY)
return nil, err
}
msgs, from, err := s.Receive()
if err != nil { if err != nil {
return nil, err return nil, err
} }
if from.Pid != nl.PidKernel {
return nil, fmt.Errorf("wrong sender portid %d, expected %d", from.Pid, nl.PidKernel)
}
if len(msgs) == 0 { if len(msgs) == 0 {
return nil, errors.New("no message nor error from netlink") return nil, errors.New("no message nor error from netlink")
} }
if len(msgs) > 2 { if len(msgs) > 2 {
return nil, fmt.Errorf("multiple (%d) matching sockets", len(msgs)) return nil, fmt.Errorf("multiple (%d) matching sockets", len(msgs))
} }
sock := &Socket{} sock := &Socket{}
if err := sock.deserialize(msgs[0].Data); err != nil { if err := sock.deserialize(msgs[0]); err != nil {
return nil, err return nil, err
} }
return sock, nil return sock, nil
} }
// SocketGet returns the Socket identified by its local and remote addresses.
func SocketGet(local, remote net.Addr) (*Socket, error) {
return pkgHandle.SocketGet(local, remote)
}
// SocketDestroy kills the Socket identified by its local and remote addresses. // SocketDestroy kills the Socket identified by its local and remote addresses.
func SocketDestroy(local, remote net.Addr) error { func (h *Handle) SocketDestroy(local, remote net.Addr) error {
localTCP, ok := local.(*net.TCPAddr) localTCP, ok := local.(*net.TCPAddr)
if !ok { if !ok {
return ErrNotImplemented return ErrNotImplemented
@ -266,7 +260,7 @@ func SocketDestroy(local, remote net.Addr) error {
return err return err
} }
defer s.Close() defer s.Close()
req := nl.NewNetlinkRequest(nl.SOCK_DESTROY, unix.NLM_F_ACK) req := h.newNetlinkRequest(nl.SOCK_DESTROY, unix.NLM_F_ACK)
req.AddData(&socketRequest{ req.AddData(&socketRequest{
Family: unix.AF_INET, Family: unix.AF_INET,
Protocol: unix.IPPROTO_TCP, Protocol: unix.IPPROTO_TCP,
@ -278,13 +272,20 @@ func SocketDestroy(local, remote net.Addr) error {
Cookie: [2]uint32{nl.TCPDIAG_NOCOOKIE, nl.TCPDIAG_NOCOOKIE}, Cookie: [2]uint32{nl.TCPDIAG_NOCOOKIE, nl.TCPDIAG_NOCOOKIE},
}, },
}) })
return s.Send(req)
_, err = req.Execute(unix.NETLINK_INET_DIAG, 0)
return err
}
// SocketDestroy kills the Socket identified by its local and remote addresses.
func SocketDestroy(local, remote net.Addr) error {
return pkgHandle.SocketDestroy(local, remote)
} }
// SocketDiagTCPInfo requests INET_DIAG_INFO for TCP protocol for specified family type and return with extension TCP info. // SocketDiagTCPInfo requests INET_DIAG_INFO for TCP protocol for specified family type and return with extension TCP info.
func SocketDiagTCPInfo(family uint8) ([]*InetDiagTCPInfoResp, error) { func (h *Handle) SocketDiagTCPInfo(family uint8) ([]*InetDiagTCPInfoResp, error) {
// Construct the request // Construct the request
req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) req := h.newNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP)
req.AddData(&socketRequest{ req.AddData(&socketRequest{
Family: family, Family: family,
Protocol: unix.IPPROTO_TCP, Protocol: unix.IPPROTO_TCP,
@ -294,34 +295,41 @@ func SocketDiagTCPInfo(family uint8) ([]*InetDiagTCPInfoResp, error) {
// Do the query and parse the result // Do the query and parse the result
var result []*InetDiagTCPInfoResp var result []*InetDiagTCPInfoResp
err := socketDiagExecutor(req, func(m syscall.NetlinkMessage) error { var err error
err = req.ExecuteIter(unix.NETLINK_INET_DIAG, nl.SOCK_DIAG_BY_FAMILY, func(msg []byte) bool {
sockInfo := &Socket{} sockInfo := &Socket{}
if err := sockInfo.deserialize(m.Data); err != nil { if err = sockInfo.deserialize(msg); err != nil {
return err return false
} }
attrs, err := nl.ParseRouteAttr(m.Data[sizeofSocket:]) var attrs []syscall.NetlinkRouteAttr
if err != nil { if attrs, err = nl.ParseRouteAttr(msg[sizeofSocket:]); err != nil {
return err return false
} }
res, err := attrsToInetDiagTCPInfoResp(attrs, sockInfo) var res *InetDiagTCPInfoResp
if err != nil { if res, err = attrsToInetDiagTCPInfoResp(attrs, sockInfo); err != nil {
return err return false
} }
result = append(result, res) result = append(result, res)
return nil return true
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
return result, nil return result, nil
} }
// SocketDiagTCPInfo requests INET_DIAG_INFO for TCP protocol for specified family type and return with extension TCP info.
func SocketDiagTCPInfo(family uint8) ([]*InetDiagTCPInfoResp, error) {
return pkgHandle.SocketDiagTCPInfo(family)
}
// SocketDiagTCP requests INET_DIAG_INFO for TCP protocol for specified family type and return related socket. // SocketDiagTCP requests INET_DIAG_INFO for TCP protocol for specified family type and return related socket.
func SocketDiagTCP(family uint8) ([]*Socket, error) { func (h *Handle) SocketDiagTCP(family uint8) ([]*Socket, error) {
// Construct the request // Construct the request
req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) req := h.newNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP)
req.AddData(&socketRequest{ req.AddData(&socketRequest{
Family: family, Family: family,
Protocol: unix.IPPROTO_TCP, Protocol: unix.IPPROTO_TCP,
@ -331,13 +339,14 @@ func SocketDiagTCP(family uint8) ([]*Socket, error) {
// Do the query and parse the result // Do the query and parse the result
var result []*Socket var result []*Socket
err := socketDiagExecutor(req, func(m syscall.NetlinkMessage) error { var err error
err = req.ExecuteIter(unix.NETLINK_INET_DIAG, nl.SOCK_DIAG_BY_FAMILY, func(msg []byte) bool {
sockInfo := &Socket{} sockInfo := &Socket{}
if err := sockInfo.deserialize(m.Data); err != nil { if err = sockInfo.deserialize(msg); err != nil {
return err return false
} }
result = append(result, sockInfo) result = append(result, sockInfo)
return nil return true
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -345,15 +354,20 @@ func SocketDiagTCP(family uint8) ([]*Socket, error) {
return result, nil return result, nil
} }
// SocketDiagTCP requests INET_DIAG_INFO for TCP protocol for specified family type and return related socket.
func SocketDiagTCP(family uint8) ([]*Socket, error) {
return pkgHandle.SocketDiagTCP(family)
}
// SocketDiagUDPInfo requests INET_DIAG_INFO for UDP protocol for specified family type and return with extension info. // SocketDiagUDPInfo requests INET_DIAG_INFO for UDP protocol for specified family type and return with extension info.
func SocketDiagUDPInfo(family uint8) ([]*InetDiagUDPInfoResp, error) { func (h *Handle) SocketDiagUDPInfo(family uint8) ([]*InetDiagUDPInfoResp, error) {
// Construct the request // Construct the request
var extensions uint8 var extensions uint8
extensions = 1 << (INET_DIAG_VEGASINFO - 1) extensions = 1 << (INET_DIAG_VEGASINFO - 1)
extensions |= 1 << (INET_DIAG_INFO - 1) extensions |= 1 << (INET_DIAG_INFO - 1)
extensions |= 1 << (INET_DIAG_MEMINFO - 1) extensions |= 1 << (INET_DIAG_MEMINFO - 1)
req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) req := h.newNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP)
req.AddData(&socketRequest{ req.AddData(&socketRequest{
Family: family, Family: family,
Protocol: unix.IPPROTO_UDP, Protocol: unix.IPPROTO_UDP,
@ -363,23 +377,25 @@ func SocketDiagUDPInfo(family uint8) ([]*InetDiagUDPInfoResp, error) {
// Do the query and parse the result // Do the query and parse the result
var result []*InetDiagUDPInfoResp var result []*InetDiagUDPInfoResp
err := socketDiagExecutor(req, func(m syscall.NetlinkMessage) error { var err error
err = req.ExecuteIter(unix.NETLINK_INET_DIAG, nl.SOCK_DIAG_BY_FAMILY, func(msg []byte) bool {
sockInfo := &Socket{} sockInfo := &Socket{}
if err := sockInfo.deserialize(m.Data); err != nil { if err = sockInfo.deserialize(msg); err != nil {
return err return false
}
attrs, err := nl.ParseRouteAttr(m.Data[sizeofSocket:])
if err != nil {
return err
} }
res, err := attrsToInetDiagUDPInfoResp(attrs, sockInfo) var attrs []syscall.NetlinkRouteAttr
if err != nil { if attrs, err = nl.ParseRouteAttr(msg[sizeofSocket:]); err != nil {
return err return false
}
var res *InetDiagUDPInfoResp
if res, err = attrsToInetDiagUDPInfoResp(attrs, sockInfo); err != nil {
return false
} }
result = append(result, res) result = append(result, res)
return nil return true
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -387,10 +403,15 @@ func SocketDiagUDPInfo(family uint8) ([]*InetDiagUDPInfoResp, error) {
return result, nil return result, nil
} }
// SocketDiagUDPInfo requests INET_DIAG_INFO for UDP protocol for specified family type and return with extension info.
func SocketDiagUDPInfo(family uint8) ([]*InetDiagUDPInfoResp, error) {
return pkgHandle.SocketDiagUDPInfo(family)
}
// SocketDiagUDP requests INET_DIAG_INFO for UDP protocol for specified family type and return related socket. // SocketDiagUDP requests INET_DIAG_INFO for UDP protocol for specified family type and return related socket.
func SocketDiagUDP(family uint8) ([]*Socket, error) { func (h *Handle) SocketDiagUDP(family uint8) ([]*Socket, error) {
// Construct the request // Construct the request
req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) req := h.newNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP)
req.AddData(&socketRequest{ req.AddData(&socketRequest{
Family: family, Family: family,
Protocol: unix.IPPROTO_UDP, Protocol: unix.IPPROTO_UDP,
@ -400,13 +421,64 @@ func SocketDiagUDP(family uint8) ([]*Socket, error) {
// Do the query and parse the result // Do the query and parse the result
var result []*Socket var result []*Socket
err := socketDiagExecutor(req, func(m syscall.NetlinkMessage) error { var err error
err = req.ExecuteIter(unix.NETLINK_INET_DIAG, nl.SOCK_DIAG_BY_FAMILY, func(msg []byte) bool {
sockInfo := &Socket{} sockInfo := &Socket{}
if err := sockInfo.deserialize(m.Data); err != nil { if err = sockInfo.deserialize(msg); err != nil {
return err return false
} }
result = append(result, sockInfo) result = append(result, sockInfo)
return nil return true
})
if err != nil {
return nil, err
}
return result, nil
}
// SocketDiagUDP requests INET_DIAG_INFO for UDP protocol for specified family type and return related socket.
func SocketDiagUDP(family uint8) ([]*Socket, error) {
return pkgHandle.SocketDiagUDP(family)
}
// UnixSocketDiagInfo requests UNIX_DIAG_INFO for unix sockets and return with extension info.
func (h *Handle) UnixSocketDiagInfo() ([]*UnixDiagInfoResp, error) {
// Construct the request
var extensions uint8
extensions = 1 << UNIX_DIAG_NAME
extensions |= 1 << UNIX_DIAG_PEER
extensions |= 1 << UNIX_DIAG_RQLEN
req := h.newNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP)
req.AddData(&unixSocketRequest{
Family: unix.AF_UNIX,
States: ^uint32(0), // all states
Show: uint32(extensions),
})
var result []*UnixDiagInfoResp
var err error
err = req.ExecuteIter(unix.NETLINK_INET_DIAG, nl.SOCK_DIAG_BY_FAMILY, func(msg []byte) bool {
sockInfo := &UnixSocket{}
if err = sockInfo.deserialize(msg); err != nil {
return false
}
// Diagnosis also delivers sockets with AF_INET family, filter those
if sockInfo.Family != unix.AF_UNIX {
return false
}
var attrs []syscall.NetlinkRouteAttr
if attrs, err = nl.ParseRouteAttr(msg[sizeofSocket:]); err != nil {
return false
}
var res *UnixDiagInfoResp
if res, err = attrsToUnixDiagInfoResp(attrs, sockInfo); err != nil {
return false
}
result = append(result, res)
return true
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -416,41 +488,31 @@ func SocketDiagUDP(family uint8) ([]*Socket, error) {
// UnixSocketDiagInfo requests UNIX_DIAG_INFO for unix sockets and return with extension info. // UnixSocketDiagInfo requests UNIX_DIAG_INFO for unix sockets and return with extension info.
func UnixSocketDiagInfo() ([]*UnixDiagInfoResp, error) { func UnixSocketDiagInfo() ([]*UnixDiagInfoResp, error) {
return pkgHandle.UnixSocketDiagInfo()
}
// UnixSocketDiag requests UNIX_DIAG_INFO for unix sockets.
func (h *Handle) UnixSocketDiag() ([]*UnixSocket, error) {
// Construct the request // Construct the request
var extensions uint8 req := h.newNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP)
extensions = 1 << UNIX_DIAG_NAME
extensions |= 1 << UNIX_DIAG_PEER
extensions |= 1 << UNIX_DIAG_RQLEN
req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP)
req.AddData(&unixSocketRequest{ req.AddData(&unixSocketRequest{
Family: unix.AF_UNIX, Family: unix.AF_UNIX,
States: ^uint32(0), // all states States: ^uint32(0), // all states
Show: uint32(extensions),
}) })
var result []*UnixDiagInfoResp var result []*UnixSocket
err := socketDiagExecutor(req, func(m syscall.NetlinkMessage) error { var err error
err = req.ExecuteIter(unix.NETLINK_INET_DIAG, nl.SOCK_DIAG_BY_FAMILY, func(msg []byte) bool {
sockInfo := &UnixSocket{} sockInfo := &UnixSocket{}
if err := sockInfo.deserialize(m.Data); err != nil { if err = sockInfo.deserialize(msg); err != nil {
return err return false
} }
// Diagnosis also delivers sockets with AF_INET family, filter those // Diagnosis also delivers sockets with AF_INET family, filter those
if sockInfo.Family != unix.AF_UNIX { if sockInfo.Family == unix.AF_UNIX {
return nil result = append(result, sockInfo)
} }
return true
attrs, err := nl.ParseRouteAttr(m.Data[sizeofUnixSocket:])
if err != nil {
return err
}
res, err := attrsToUnixDiagInfoResp(attrs, sockInfo)
if err != nil {
return err
}
result = append(result, res)
return nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@ -460,68 +522,7 @@ func UnixSocketDiagInfo() ([]*UnixDiagInfoResp, error) {
// UnixSocketDiag requests UNIX_DIAG_INFO for unix sockets. // UnixSocketDiag requests UNIX_DIAG_INFO for unix sockets.
func UnixSocketDiag() ([]*UnixSocket, error) { func UnixSocketDiag() ([]*UnixSocket, error) {
// Construct the request return pkgHandle.UnixSocketDiag()
req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP)
req.AddData(&unixSocketRequest{
Family: unix.AF_UNIX,
States: ^uint32(0), // all states
})
var result []*UnixSocket
err := socketDiagExecutor(req, func(m syscall.NetlinkMessage) error {
sockInfo := &UnixSocket{}
if err := sockInfo.deserialize(m.Data); err != nil {
return err
}
// Diagnosis also delivers sockets with AF_INET family, filter those
if sockInfo.Family == unix.AF_UNIX {
result = append(result, sockInfo)
}
return nil
})
if err != nil {
return nil, err
}
return result, nil
}
// socketDiagExecutor requests diagnoses info from the NETLINK_INET_DIAG socket for the specified request.
func socketDiagExecutor(req *nl.NetlinkRequest, receiver func(syscall.NetlinkMessage) error) error {
s, err := nl.Subscribe(unix.NETLINK_INET_DIAG)
if err != nil {
return err
}
defer s.Close()
s.Send(req)
loop:
for {
msgs, from, err := s.Receive()
if err != nil {
return err
}
if from.Pid != nl.PidKernel {
return fmt.Errorf("wrong sender portid %d, expected %d", from.Pid, nl.PidKernel)
}
if len(msgs) == 0 {
return errors.New("no message nor error from netlink")
}
for _, m := range msgs {
switch m.Header.Type {
case unix.NLMSG_DONE:
break loop
case unix.NLMSG_ERROR:
error := int32(native.Uint32(m.Data[0:4]))
return syscall.Errno(-error)
}
if err := receiver(m); err != nil {
return err
}
}
}
return nil
} }
func attrsToInetDiagTCPInfoResp(attrs []syscall.NetlinkRouteAttr, sockInfo *Socket) (*InetDiagTCPInfoResp, error) { func attrsToInetDiagTCPInfoResp(attrs []syscall.NetlinkRouteAttr, sockInfo *Socket) (*InetDiagTCPInfoResp, error) {