diff --git a/inet_diag.go b/inet_diag.go index bee391a..a483ee1 100644 --- a/inet_diag.go +++ b/inet_diag.go @@ -29,3 +29,8 @@ type InetDiagTCPInfoResp struct { TCPInfo *TCPInfo TCPBBRInfo *TCPBBRInfo } + +type InetDiagUDPInfoResp struct { + InetDiagMsg *Socket + Memory *MemInfo +} diff --git a/socket_linux.go b/socket_linux.go index b881fe4..f39c2a6 100644 --- a/socket_linux.go +++ b/socket_linux.go @@ -174,8 +174,18 @@ func SocketGet(local, remote net.Addr) (*Socket, error) { // SocketDiagTCPInfo requests INET_DIAG_INFO for TCP protocol for specified family type and return with extension TCP info. func SocketDiagTCPInfo(family uint8) ([]*InetDiagTCPInfoResp, error) { + // Construct the request + req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) + req.AddData(&socketRequest{ + Family: family, + Protocol: unix.IPPROTO_TCP, + Ext: (1 << (INET_DIAG_VEGASINFO - 1)) | (1 << (INET_DIAG_INFO - 1)), + States: uint32(0xfff), // all states + }) + + // Do the query and parse the result var result []*InetDiagTCPInfoResp - err := socketDiagTCPExecutor(family, func(m syscall.NetlinkMessage) error { + err := socketDiagExecutor(req, func(m syscall.NetlinkMessage) error { sockInfo := &Socket{} if err := sockInfo.deserialize(m.Data); err != nil { return err @@ -201,8 +211,18 @@ func SocketDiagTCPInfo(family uint8) ([]*InetDiagTCPInfoResp, error) { // SocketDiagTCP requests INET_DIAG_INFO for TCP protocol for specified family type and return related socket. func SocketDiagTCP(family uint8) ([]*Socket, error) { + // Construct the request + req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) + req.AddData(&socketRequest{ + Family: family, + Protocol: unix.IPPROTO_TCP, + Ext: (1 << (INET_DIAG_VEGASINFO - 1)) | (1 << (INET_DIAG_INFO - 1)), + States: uint32(0xfff), // all states + }) + + // Do the query and parse the result var result []*Socket - err := socketDiagTCPExecutor(family, func(m syscall.NetlinkMessage) error { + err := socketDiagExecutor(req, func(m syscall.NetlinkMessage) error { sockInfo := &Socket{} if err := sockInfo.deserialize(m.Data); err != nil { return err @@ -216,21 +236,82 @@ func SocketDiagTCP(family uint8) ([]*Socket, error) { return result, nil } -// socketDiagTCPExecutor requests INET_DIAG_INFO for TCP protocol for specified family type. -func socketDiagTCPExecutor(family uint8, receiver func(syscall.NetlinkMessage) error) error { +// SocketDiagUDPInfo requests INET_DIAG_INFO for UDP protocol for specified family type and return with extension info. +func SocketDiagUDPInfo(family uint8) ([]*InetDiagUDPInfoResp, error) { + // Construct the request + var extensions uint8 + extensions = 1 << (INET_DIAG_VEGASINFO - 1) + extensions |= 1 << (INET_DIAG_INFO - 1) + extensions |= 1 << (INET_DIAG_MEMINFO - 1) + + req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) + req.AddData(&socketRequest{ + Family: family, + Protocol: unix.IPPROTO_UDP, + Ext: extensions, + States: uint32(0xfff), // all states + }) + + // Do the query and parse the result + var result []*InetDiagUDPInfoResp + err := socketDiagExecutor(req, func(m syscall.NetlinkMessage) error { + sockInfo := &Socket{} + if err := sockInfo.deserialize(m.Data); err != nil { + return err + } + attrs, err := nl.ParseRouteAttr(m.Data[sizeofSocket:]) + if err != nil { + return err + } + + res, err := attrsToInetDiagUDPInfoResp(attrs, sockInfo) + if err != nil { + return err + } + + result = append(result, res) + return nil + }) + if err != nil { + return nil, err + } + return result, nil +} + +// SocketDiagUDP requests INET_DIAG_INFO for UDP protocol for specified family type and return related socket. +func SocketDiagUDP(family uint8) ([]*Socket, error) { + // Construct the request + req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) + req.AddData(&socketRequest{ + Family: family, + Protocol: unix.IPPROTO_UDP, + Ext: (1 << (INET_DIAG_VEGASINFO - 1)) | (1 << (INET_DIAG_INFO - 1)), + States: uint32(0xfff), // all states + }) + + // Do the query and parse the result + var result []*Socket + err := socketDiagExecutor(req, func(m syscall.NetlinkMessage) error { + sockInfo := &Socket{} + if err := sockInfo.deserialize(m.Data); err != nil { + return err + } + result = append(result, sockInfo) + return nil + }) + if err != nil { + return nil, err + } + return result, nil +} + +// socketDiagExecutor requests diagnoses info from the NETLINK_INET_DIAG socket for the specified request. +func socketDiagExecutor(req *nl.NetlinkRequest, receiver func(syscall.NetlinkMessage) error) error { s, err := nl.Subscribe(unix.NETLINK_INET_DIAG) if err != nil { return err } defer s.Close() - - req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) - req.AddData(&socketRequest{ - Family: family, - Protocol: unix.IPPROTO_TCP, - Ext: (1 << (INET_DIAG_VEGASINFO - 1)) | (1 << (INET_DIAG_INFO - 1)), - States: uint32(0xfff), // All TCP states - }) s.Send(req) loop: @@ -240,7 +321,7 @@ loop: return err } if from.Pid != nl.PidKernel { - return fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, nl.PidKernel) + return fmt.Errorf("wrong sender portid %d, expected %d", from.Pid, nl.PidKernel) } if len(msgs) == 0 { return errors.New("no message nor error from netlink") @@ -263,29 +344,40 @@ loop: } func attrsToInetDiagTCPInfoResp(attrs []syscall.NetlinkRouteAttr, sockInfo *Socket) (*InetDiagTCPInfoResp, error) { - var tcpInfo *TCPInfo - var tcpBBRInfo *TCPBBRInfo + info := &InetDiagTCPInfoResp{ + InetDiagMsg: sockInfo, + } for _, a := range attrs { - if a.Attr.Type == INET_DIAG_INFO { - tcpInfo = &TCPInfo{} - if err := tcpInfo.deserialize(a.Value); err != nil { + switch a.Attr.Type { + case INET_DIAG_INFO: + info.TCPInfo = &TCPInfo{} + if err := info.TCPInfo.deserialize(a.Value); err != nil { return nil, err } - continue - } - - if a.Attr.Type == INET_DIAG_BBRINFO { - tcpBBRInfo = &TCPBBRInfo{} - if err := tcpBBRInfo.deserialize(a.Value); err != nil { + case INET_DIAG_BBRINFO: + info.TCPBBRInfo = &TCPBBRInfo{} + if err := info.TCPBBRInfo.deserialize(a.Value); err != nil { return nil, err } - continue } } - return &InetDiagTCPInfoResp{ - InetDiagMsg: sockInfo, - TCPInfo: tcpInfo, - TCPBBRInfo: tcpBBRInfo, - }, nil + return info, nil +} + +func attrsToInetDiagUDPInfoResp(attrs []syscall.NetlinkRouteAttr, sockInfo *Socket) (*InetDiagUDPInfoResp, error) { + info := &InetDiagUDPInfoResp{ + InetDiagMsg: sockInfo, + } + for _, a := range attrs { + switch a.Attr.Type { + case INET_DIAG_MEMINFO: + info.Memory = &MemInfo{} + if err := info.Memory.deserialize(a.Value); err != nil { + return nil, err + } + } + } + + return info, nil } diff --git a/socket_test.go b/socket_test.go index f21520f..84605c2 100644 --- a/socket_test.go +++ b/socket_test.go @@ -1,3 +1,4 @@ +//go:build linux // +build linux package netlink @@ -75,3 +76,18 @@ func TestSocketDiagTCPInfo(t *testing.T) { } } } + +func TestSocketDiagUDPnfo(t *testing.T) { + for _, want := range []uint8{syscall.AF_INET, syscall.AF_INET6} { + result, err := SocketDiagUDPInfo(want) + if err != nil { + t.Fatal(err) + } + + for _, r := range result { + if got := r.InetDiagMsg.Family; got != want { + t.Fatalf("protocol family = %v, want %v", got, want) + } + } + } +} diff --git a/tcp.go b/tcp.go index 23ca014..43f80a0 100644 --- a/tcp.go +++ b/tcp.go @@ -82,3 +82,11 @@ type TCPBBRInfo struct { BBRPacingGain uint32 BBRCwndGain uint32 } + +// According to https://man7.org/linux/man-pages/man7/sock_diag.7.html +type MemInfo struct { + RMem uint32 + WMem uint32 + FMem uint32 + TMem uint32 +} diff --git a/tcp_linux.go b/tcp_linux.go index 2938587..e98036d 100644 --- a/tcp_linux.go +++ b/tcp_linux.go @@ -8,6 +8,7 @@ import ( const ( tcpBBRInfoLen = 20 + memInfoLen = 16 ) func checkDeserErr(err error) error { @@ -351,3 +352,17 @@ func (t *TCPBBRInfo) deserialize(b []byte) error { return nil } + +func (m *MemInfo) deserialize(b []byte) error { + if len(b) != memInfoLen { + return errors.New("Invalid length") + } + + rb := bytes.NewBuffer(b) + m.RMem = native.Uint32(rb.Next(4)) + m.WMem = native.Uint32(rb.Next(4)) + m.FMem = native.Uint32(rb.Next(4)) + m.TMem = native.Uint32(rb.Next(4)) + + return nil +}