diff --git a/socket.go b/socket.go index 41aa726..ebcf842 100644 --- a/socket.go +++ b/socket.go @@ -25,3 +25,13 @@ type Socket struct { UID uint32 INode uint32 } + +// UnixSocket represents a netlink unix socket. +type UnixSocket struct { + Type uint8 + Family uint8 + State uint8 + pad uint8 + INode uint32 + Cookie [2]uint32 +} diff --git a/socket_linux.go b/socket_linux.go index f39c2a6..c60a75c 100644 --- a/socket_linux.go +++ b/socket_linux.go @@ -11,9 +11,11 @@ import ( ) const ( - sizeofSocketID = 0x30 - sizeofSocketRequest = sizeofSocketID + 0x8 - sizeofSocket = sizeofSocketID + 0x18 + sizeofSocketID = 0x30 + sizeofSocketRequest = sizeofSocketID + 0x8 + sizeofSocket = sizeofSocketID + 0x18 + sizeofUnixSocketRequest = 0x18 // 24 byte + sizeofUnixSocket = 0x10 // 16 byte ) type socketRequest struct { @@ -67,6 +69,32 @@ func (r *socketRequest) Serialize() []byte { func (r *socketRequest) Len() int { return sizeofSocketRequest } +// According to linux/include/uapi/linux/unix_diag.h +type unixSocketRequest struct { + Family uint8 + Protocol uint8 + pad uint16 + States uint32 + INode uint32 + Show uint32 + Cookie [2]uint32 +} + +func (r *unixSocketRequest) Serialize() []byte { + b := writeBuffer{Bytes: make([]byte, sizeofUnixSocketRequest)} + b.Write(r.Family) + b.Write(r.Protocol) + native.PutUint16(b.Next(2), r.pad) + native.PutUint32(b.Next(4), r.States) + native.PutUint32(b.Next(4), r.INode) + native.PutUint32(b.Next(4), r.Show) + native.PutUint32(b.Next(4), r.Cookie[0]) + native.PutUint32(b.Next(4), r.Cookie[1]) + return b.Bytes +} + +func (r *unixSocketRequest) Len() int { return sizeofUnixSocketRequest } + type readBuffer struct { Bytes []byte pos int @@ -115,6 +143,21 @@ func (s *Socket) deserialize(b []byte) error { return nil } +func (u *UnixSocket) deserialize(b []byte) error { + if len(b) < sizeofUnixSocket { + return fmt.Errorf("unix diag data short read (%d); want %d", len(b), sizeofUnixSocket) + } + rb := readBuffer{Bytes: b} + u.Type = rb.Read() + u.Family = rb.Read() + u.State = rb.Read() + u.pad = rb.Read() + u.INode = native.Uint32(rb.Next(4)) + u.Cookie[0] = native.Uint32(rb.Next(4)) + u.Cookie[1] = native.Uint32(rb.Next(4)) + return nil +} + // SocketGet returns the Socket identified by its local and remote addresses. func SocketGet(local, remote net.Addr) (*Socket, error) { localTCP, ok := local.(*net.TCPAddr) @@ -157,7 +200,7 @@ func SocketGet(local, remote net.Addr) (*Socket, error) { return nil, err } if from.Pid != nl.PidKernel { - return nil, fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, nl.PidKernel) + return nil, fmt.Errorf("wrong sender portid %d, expected %d", from.Pid, nl.PidKernel) } if len(msgs) == 0 { return nil, errors.New("no message nor error from netlink") @@ -305,6 +348,78 @@ func SocketDiagUDP(family uint8) ([]*Socket, error) { return result, nil } +// UnixSocketDiagInfo requests UNIX_DIAG_INFO for unix sockets and return with extension info. +func UnixSocketDiagInfo() ([]*UnixDiagInfoResp, error) { + // Construct the request + var extensions uint8 + extensions = 1 << UNIX_DIAG_NAME + extensions |= 1 << UNIX_DIAG_PEER + extensions |= 1 << UNIX_DIAG_RQLEN + req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) + req.AddData(&unixSocketRequest{ + Family: unix.AF_UNIX, + States: ^uint32(0), // all states + Show: uint32(extensions), + }) + + var result []*UnixDiagInfoResp + err := socketDiagExecutor(req, func(m syscall.NetlinkMessage) error { + sockInfo := &UnixSocket{} + if err := sockInfo.deserialize(m.Data); err != nil { + return err + } + + // Diagnosis also delivers sockets with AF_INET family, filter those + if sockInfo.Family != unix.AF_UNIX { + return nil + } + + attrs, err := nl.ParseRouteAttr(m.Data[sizeofUnixSocket:]) + if err != nil { + return err + } + + res, err := attrsToUnixDiagInfoResp(attrs, sockInfo) + if err != nil { + return err + } + result = append(result, res) + return nil + }) + if err != nil { + return nil, err + } + return result, nil +} + +// UnixSocketDiag requests UNIX_DIAG_INFO for unix sockets. +func UnixSocketDiag() ([]*UnixSocket, error) { + // Construct the request + req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) + req.AddData(&unixSocketRequest{ + Family: unix.AF_UNIX, + States: ^uint32(0), // all states + }) + + var result []*UnixSocket + err := socketDiagExecutor(req, func(m syscall.NetlinkMessage) error { + sockInfo := &UnixSocket{} + if err := sockInfo.deserialize(m.Data); err != nil { + return err + } + + // Diagnosis also delivers sockets with AF_INET family, filter those + if sockInfo.Family == unix.AF_UNIX { + result = append(result, sockInfo) + } + return nil + }) + if err != nil { + return nil, err + } + return result, nil +} + // socketDiagExecutor requests diagnoses info from the NETLINK_INET_DIAG socket for the specified request. func socketDiagExecutor(req *nl.NetlinkRequest, receiver func(syscall.NetlinkMessage) error) error { s, err := nl.Subscribe(unix.NETLINK_INET_DIAG) @@ -381,3 +496,28 @@ func attrsToInetDiagUDPInfoResp(attrs []syscall.NetlinkRouteAttr, sockInfo *Sock return info, nil } + +func attrsToUnixDiagInfoResp(attrs []syscall.NetlinkRouteAttr, sockInfo *UnixSocket) (*UnixDiagInfoResp, error) { + info := &UnixDiagInfoResp{ + DiagMsg: sockInfo, + } + for _, a := range attrs { + switch a.Attr.Type { + case UNIX_DIAG_NAME: + name := string(a.Value[:a.Attr.Len]) + info.Name = &name + case UNIX_DIAG_PEER: + peer := native.Uint32(a.Value) + info.Peer = &peer + case UNIX_DIAG_RQLEN: + info.Queue = &QueueInfo{ + RQueue: native.Uint32(a.Value[:4]), + WQueue: native.Uint32(a.Value[4:]), + } + // default: + // fmt.Println("unknown unix attribute type", a.Attr.Type, "with data", a.Value) + } + } + + return info, nil +} diff --git a/socket_test.go b/socket_test.go index 84605c2..27a816e 100644 --- a/socket_test.go +++ b/socket_test.go @@ -4,6 +4,7 @@ package netlink import ( + "fmt" "log" "net" "os/user" @@ -91,3 +92,18 @@ func TestSocketDiagUDPnfo(t *testing.T) { } } } + +func TestUnixSocketDiagInfo(t *testing.T) { + want := syscall.AF_UNIX + result, err := UnixSocketDiagInfo() + if err != nil { + t.Fatal(err) + } + + for i, r := range result { + fmt.Println(r.DiagMsg) + if got := r.DiagMsg.Family; got != uint8(want) { + t.Fatalf("%d: protocol family = %v, want %v", i, got, want) + } + } +} diff --git a/unix_diag.go b/unix_diag.go new file mode 100644 index 0000000..d81776f --- /dev/null +++ b/unix_diag.go @@ -0,0 +1,27 @@ +package netlink + +// According to linux/include/uapi/linux/unix_diag.h +const ( + UNIX_DIAG_NAME = iota + UNIX_DIAG_VFS + UNIX_DIAG_PEER + UNIX_DIAG_ICONS + UNIX_DIAG_RQLEN + UNIX_DIAG_MEMINFO + UNIX_DIAG_SHUTDOWN + UNIX_DIAG_UID + UNIX_DIAG_MAX +) + +type UnixDiagInfoResp struct { + DiagMsg *UnixSocket + Name *string + Peer *uint32 + Queue *QueueInfo + Shutdown *uint8 +} + +type QueueInfo struct { + RQueue uint32 + WQueue uint32 +}