diff --git a/netlink_unspecified.go b/netlink_unspecified.go index 0f6fde2..da12c42 100644 --- a/netlink_unspecified.go +++ b/netlink_unspecified.go @@ -283,3 +283,7 @@ func NeighDeserialize(m []byte) (*Neigh, error) { func SocketGet(local, remote net.Addr) (*Socket, error) { return nil, ErrNotImplemented } + +func SocketDestroy(local, remote net.Addr) (*Socket, error) { + return nil, ErrNotImplemented +} diff --git a/nl/syscall.go b/nl/syscall.go index bdf6ba6..b5ba039 100644 --- a/nl/syscall.go +++ b/nl/syscall.go @@ -46,6 +46,7 @@ const ( // socket diags related const ( SOCK_DIAG_BY_FAMILY = 20 /* linux.sock_diag.h */ + SOCK_DESTROY = 21 TCPDIAG_NOCOOKIE = 0xFFFFFFFF /* TCPDIAG_NOCOOKIE in net/ipv4/tcp_diag.h*/ ) diff --git a/socket_linux.go b/socket_linux.go index 1634609..b6182a5 100644 --- a/socket_linux.go +++ b/socket_linux.go @@ -215,6 +215,45 @@ func SocketGet(local, remote net.Addr) (*Socket, error) { return sock, nil } +// SocketDestroy kills the Socket identified by its local and remote addresses. +func SocketDestroy(local, remote net.Addr) error { + localTCP, ok := local.(*net.TCPAddr) + if !ok { + return ErrNotImplemented + } + remoteTCP, ok := remote.(*net.TCPAddr) + if !ok { + return ErrNotImplemented + } + localIP := localTCP.IP.To4() + if localIP == nil { + return ErrNotImplemented + } + remoteIP := remoteTCP.IP.To4() + if remoteIP == nil { + return ErrNotImplemented + } + + s, err := nl.Subscribe(unix.NETLINK_INET_DIAG) + if err != nil { + return err + } + defer s.Close() + req := nl.NewNetlinkRequest(nl.SOCK_DESTROY, unix.NLM_F_ACK) + req.AddData(&socketRequest{ + Family: unix.AF_INET, + Protocol: unix.IPPROTO_TCP, + ID: SocketID{ + SourcePort: uint16(localTCP.Port), + DestinationPort: uint16(remoteTCP.Port), + Source: localIP, + Destination: remoteIP, + Cookie: [2]uint32{nl.TCPDIAG_NOCOOKIE, nl.TCPDIAG_NOCOOKIE}, + }, + }) + return s.Send(req) +} + // SocketDiagTCPInfo requests INET_DIAG_INFO for TCP protocol for specified family type and return with extension TCP info. func SocketDiagTCPInfo(family uint8) ([]*InetDiagTCPInfoResp, error) { // Construct the request diff --git a/socket_test.go b/socket_test.go index 27a816e..d4dc84a 100644 --- a/socket_test.go +++ b/socket_test.go @@ -60,6 +60,33 @@ func TestSocketGet(t *testing.T) { } } +func TestSocketDestroy(t *testing.T) { + defer setUpNetlinkTestWithLoopback(t)() + + addr, err := net.ResolveTCPAddr("tcp", "localhost:0") + if err != nil { + log.Fatal(err) + } + l, err := net.ListenTCP("tcp", addr) + if err != nil { + log.Fatal(err) + } + defer l.Close() + + conn, err := net.Dial(l.Addr().Network(), l.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + localAddr := conn.LocalAddr().(*net.TCPAddr) + remoteAddr := conn.RemoteAddr().(*net.TCPAddr) + err = SocketDestroy(localAddr, remoteAddr) + if err != nil { + t.Fatal(err) + } +} + func TestSocketDiagTCPInfo(t *testing.T) { Family4 := uint8(syscall.AF_INET) Family6 := uint8(syscall.AF_INET6)