From a3f0be63522b2d1827372f2510eb6e53477960c4 Mon Sep 17 00:00:00 2001 From: Sebastien Boving Date: Thu, 2 Feb 2017 15:07:22 -0800 Subject: [PATCH] Add support for tcp diags. --- link_linux.go | 21 ------ netlink.go | 10 ++- netlink_unspecified.go | 13 ++-- nl/syscall.go | 6 ++ order.go | 32 +++++++++ socket.go | 27 +++++++ socket_linux.go | 159 +++++++++++++++++++++++++++++++++++++++++ socket_test.go | 54 ++++++++++++++ 8 files changed, 292 insertions(+), 30 deletions(-) create mode 100644 order.go create mode 100644 socket.go create mode 100644 socket_linux.go create mode 100644 socket_test.go diff --git a/link_linux.go b/link_linux.go index bfde63f..56409eb 100644 --- a/link_linux.go +++ b/link_linux.go @@ -29,7 +29,6 @@ const ( TUNTAP_ONE_QUEUE TuntapFlag = syscall.IFF_ONE_QUEUE ) -var native = nl.NativeEndian() var lookupByDump = false var macvlanModes = [...]uint32{ @@ -1443,26 +1442,6 @@ func linkFlags(rawFlags uint32) net.Flags { return f } -func htonl(val uint32) []byte { - bytes := make([]byte, 4) - binary.BigEndian.PutUint32(bytes, val) - return bytes -} - -func htons(val uint16) []byte { - bytes := make([]byte, 2) - binary.BigEndian.PutUint16(bytes, val) - return bytes -} - -func ntohl(buf []byte) uint32 { - return binary.BigEndian.Uint32(buf) -} - -func ntohs(buf []byte) uint16 { - return binary.BigEndian.Uint16(buf) -} - func addGretapAttrs(gretap *Gretap, linkInfo *nl.RtAttr) { data := nl.NewRtAttrChild(linkInfo, nl.IFLA_INFO_DATA, nil) diff --git a/netlink.go b/netlink.go index d8e02f4..fb15952 100644 --- a/netlink.go +++ b/netlink.go @@ -8,7 +8,15 @@ // interface that is loosly modeled on the iproute2 cli. package netlink -import "net" +import ( + "errors" + "net" +) + +var ( + // ErrNotImplemented is returned when a requested feature is not implemented. + ErrNotImplemented = errors.New("not implemented") +) // ParseIPNet parses a string in ip/net format and returns a net.IPNet. // This is valuable because addresses in netlink are often IPNets and diff --git a/netlink_unspecified.go b/netlink_unspecified.go index 8a3c23e..fa421e4 100644 --- a/netlink_unspecified.go +++ b/netlink_unspecified.go @@ -2,14 +2,7 @@ package netlink -import ( - "errors" - "net" -) - -var ( - ErrNotImplemented = errors.New("not implemented") -) +import "net" func LinkSetUp(link Link) error { return ErrNotImplemented @@ -214,3 +207,7 @@ func NeighList(linkIndex, family int) ([]Neigh, error) { func NeighDeserialize(m []byte) (*Neigh, error) { return nil, ErrNotImplemented } + +func SocketGet(local, remote net.Addr) (*Socket, error) { + return nil, ErrNotImplemented +} diff --git a/nl/syscall.go b/nl/syscall.go index 47aa632..18d54a4 100644 --- a/nl/syscall.go +++ b/nl/syscall.go @@ -35,3 +35,9 @@ const ( FR_ACT_UNREACHABLE /* Drop with ENETUNREACH */ FR_ACT_PROHIBIT /* Drop with EACCES */ ) + +// socket diags related +const ( + SOCK_DIAG_BY_FAMILY = 20 /* linux.sock_diag.h */ + TCPDIAG_NOCOOKIE = 0xFFFFFFFF /* TCPDIAG_NOCOOKIE in net/ipv4/tcp_diag.h*/ +) diff --git a/order.go b/order.go new file mode 100644 index 0000000..e28e153 --- /dev/null +++ b/order.go @@ -0,0 +1,32 @@ +package netlink + +import ( + "encoding/binary" + + "github.com/vishvananda/netlink/nl" +) + +var ( + native = nl.NativeEndian() + networkOrder = binary.BigEndian +) + +func htonl(val uint32) []byte { + bytes := make([]byte, 4) + binary.BigEndian.PutUint32(bytes, val) + return bytes +} + +func htons(val uint16) []byte { + bytes := make([]byte, 2) + binary.BigEndian.PutUint16(bytes, val) + return bytes +} + +func ntohl(buf []byte) uint32 { + return binary.BigEndian.Uint32(buf) +} + +func ntohs(buf []byte) uint16 { + return binary.BigEndian.Uint16(buf) +} diff --git a/socket.go b/socket.go new file mode 100644 index 0000000..41aa726 --- /dev/null +++ b/socket.go @@ -0,0 +1,27 @@ +package netlink + +import "net" + +// SocketID identifies a single socket. +type SocketID struct { + SourcePort uint16 + DestinationPort uint16 + Source net.IP + Destination net.IP + Interface uint32 + Cookie [2]uint32 +} + +// Socket represents a netlink socket. +type Socket struct { + Family uint8 + State uint8 + Timer uint8 + Retrans uint8 + ID SocketID + Expires uint32 + RQueue uint32 + WQueue uint32 + UID uint32 + INode uint32 +} diff --git a/socket_linux.go b/socket_linux.go new file mode 100644 index 0000000..b42b84f --- /dev/null +++ b/socket_linux.go @@ -0,0 +1,159 @@ +package netlink + +import ( + "errors" + "fmt" + "net" + "syscall" + + "github.com/vishvananda/netlink/nl" +) + +const ( + sizeofSocketID = 0x30 + sizeofSocketRequest = sizeofSocketID + 0x8 + sizeofSocket = sizeofSocketID + 0x18 +) + +type socketRequest struct { + Family uint8 + Protocol uint8 + Ext uint8 + pad uint8 + States uint32 + ID SocketID +} + +type writeBuffer struct { + Bytes []byte + pos int +} + +func (b *writeBuffer) Write(c byte) { + b.Bytes[b.pos] = c + b.pos++ +} + +func (b *writeBuffer) Next(n int) []byte { + s := b.Bytes[b.pos : b.pos+n] + b.pos += n + return s +} + +func (r *socketRequest) Serialize() []byte { + b := writeBuffer{Bytes: make([]byte, sizeofSocketRequest)} + b.Write(r.Family) + b.Write(r.Protocol) + b.Write(r.Ext) + b.Write(r.pad) + native.PutUint32(b.Next(4), r.States) + networkOrder.PutUint16(b.Next(2), r.ID.SourcePort) + networkOrder.PutUint16(b.Next(2), r.ID.DestinationPort) + copy(b.Next(4), r.ID.Source.To4()) + b.Next(12) + copy(b.Next(4), r.ID.Destination.To4()) + b.Next(12) + native.PutUint32(b.Next(4), r.ID.Interface) + native.PutUint32(b.Next(4), r.ID.Cookie[0]) + native.PutUint32(b.Next(4), r.ID.Cookie[1]) + return b.Bytes +} + +func (r *socketRequest) Len() int { return sizeofSocketRequest } + +type readBuffer struct { + Bytes []byte + pos int +} + +func (b *readBuffer) Read() byte { + c := b.Bytes[b.pos] + b.pos++ + return c +} + +func (b *readBuffer) Next(n int) []byte { + s := b.Bytes[b.pos : b.pos+n] + b.pos += n + return s +} + +func (s *Socket) deserialize(b []byte) error { + if len(b) < sizeofSocket { + return fmt.Errorf("socket data short read (%d); want %d", len(b), sizeofSocket) + } + rb := readBuffer{Bytes: b} + s.Family = rb.Read() + s.State = rb.Read() + s.Timer = rb.Read() + s.Retrans = rb.Read() + s.ID.SourcePort = networkOrder.Uint16(rb.Next(2)) + s.ID.DestinationPort = networkOrder.Uint16(rb.Next(2)) + s.ID.Source = net.IPv4(rb.Read(), rb.Read(), rb.Read(), rb.Read()) + rb.Next(12) + s.ID.Destination = net.IPv4(rb.Read(), rb.Read(), rb.Read(), rb.Read()) + rb.Next(12) + s.ID.Interface = native.Uint32(rb.Next(4)) + s.ID.Cookie[0] = native.Uint32(rb.Next(4)) + s.ID.Cookie[1] = native.Uint32(rb.Next(4)) + s.Expires = native.Uint32(rb.Next(4)) + s.RQueue = native.Uint32(rb.Next(4)) + s.WQueue = native.Uint32(rb.Next(4)) + s.UID = native.Uint32(rb.Next(4)) + s.INode = native.Uint32(rb.Next(4)) + return nil +} + +// SocketGet returns the Socket identified by its local and remote addresses. +func SocketGet(local, remote net.Addr) (*Socket, error) { + localTCP, ok := local.(*net.TCPAddr) + if !ok { + return nil, ErrNotImplemented + } + remoteTCP, ok := remote.(*net.TCPAddr) + if !ok { + return nil, ErrNotImplemented + } + localIP := localTCP.IP.To4() + if localIP == nil { + return nil, ErrNotImplemented + } + remoteIP := remoteTCP.IP.To4() + if remoteIP == nil { + return nil, ErrNotImplemented + } + + s, err := nl.Subscribe(syscall.NETLINK_INET_DIAG) + if err != nil { + return nil, err + } + defer s.Close() + req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, 0) + req.AddData(&socketRequest{ + Family: syscall.AF_INET, + Protocol: syscall.IPPROTO_TCP, + ID: SocketID{ + SourcePort: uint16(localTCP.Port), + DestinationPort: uint16(remoteTCP.Port), + Source: localIP, + Destination: remoteIP, + Cookie: [2]uint32{nl.TCPDIAG_NOCOOKIE, nl.TCPDIAG_NOCOOKIE}, + }, + }) + s.Send(req) + msgs, err := s.Receive() + if err != nil { + return nil, err + } + if len(msgs) == 0 { + return nil, errors.New("no message nor error from netlink") + } + if len(msgs) > 2 { + return nil, fmt.Errorf("multiple (%d) matching sockets", len(msgs)) + } + sock := &Socket{} + if err := sock.deserialize(msgs[0].Data); err != nil { + return nil, err + } + return sock, nil +} diff --git a/socket_test.go b/socket_test.go new file mode 100644 index 0000000..70fd440 --- /dev/null +++ b/socket_test.go @@ -0,0 +1,54 @@ +package netlink + +import ( + "log" + "net" + "os/user" + "strconv" + "testing" +) + +func TestSocketGet(t *testing.T) { + addr, err := net.ResolveTCPAddr("tcp", "localhost:0") + if err != nil { + log.Fatal(err) + } + l, err := net.ListenTCP("tcp", addr) + if err != nil { + log.Fatal(err) + } + defer l.Close() + + conn, err := net.Dial(l.Addr().Network(), l.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + localAddr := conn.LocalAddr().(*net.TCPAddr) + remoteAddr := conn.RemoteAddr().(*net.TCPAddr) + socket, err := SocketGet(localAddr, remoteAddr) + if err != nil { + t.Fatal(err) + } + + if got, want := socket.ID.Source, localAddr.IP; !got.Equal(want) { + t.Fatalf("local ip = %v, want %v", got, want) + } + if got, want := socket.ID.Destination, remoteAddr.IP; !got.Equal(want) { + t.Fatalf("remote ip = %v, want %v", got, want) + } + if got, want := int(socket.ID.SourcePort), localAddr.Port; got != want { + t.Fatalf("local port = %d, want %d", got, want) + } + if got, want := int(socket.ID.DestinationPort), remoteAddr.Port; got != want { + t.Fatalf("remote port = %d, want %d", got, want) + } + u, err := user.Current() + if err != nil { + t.Fatal(err) + } + if got, want := strconv.Itoa(int(socket.UID)), u.Uid; got != want { + t.Fatalf("UID = %s, want %s", got, want) + } +}