SocketGet support udp and ipv6

Signed-off-by: Asutorufa <16442314+Asutorufa@users.noreply.github.com>
This commit is contained in:
Asutorufa 2023-09-19 21:05:05 +08:00 committed by Alessandro Boch
parent 3e28e6db88
commit d237ee16c3
2 changed files with 126 additions and 56 deletions

View File

@ -56,10 +56,8 @@ func (r *socketRequest) Serialize() []byte {
copy(b.Next(16), r.ID.Source)
copy(b.Next(16), r.ID.Destination)
} else {
copy(b.Next(4), r.ID.Source.To4())
b.Next(12)
copy(b.Next(4), r.ID.Destination.To4())
b.Next(12)
copy(b.Next(16), r.ID.Source.To4())
copy(b.Next(16), r.ID.Destination.To4())
}
native.PutUint32(b.Next(4), r.ID.Interface)
native.PutUint32(b.Next(4), r.ID.Cookie[0])
@ -160,20 +158,44 @@ func (u *UnixSocket) deserialize(b []byte) error {
// SocketGet returns the Socket identified by its local and remote addresses.
func SocketGet(local, remote net.Addr) (*Socket, error) {
localTCP, ok := local.(*net.TCPAddr)
var protocol uint8
var localIP, remoteIP net.IP
var localPort, remotePort uint16
switch l := local.(type) {
case *net.TCPAddr:
r, ok := remote.(*net.TCPAddr)
if !ok {
return nil, ErrNotImplemented
}
remoteTCP, ok := remote.(*net.TCPAddr)
localIP = l.IP
localPort = uint16(l.Port)
remoteIP = r.IP
remotePort = uint16(r.Port)
protocol = unix.IPPROTO_TCP
case *net.UDPAddr:
r, ok := remote.(*net.UDPAddr)
if !ok {
return nil, ErrNotImplemented
}
localIP := localTCP.IP.To4()
if localIP == nil {
localIP = l.IP
localPort = uint16(l.Port)
remoteIP = r.IP
remotePort = uint16(r.Port)
protocol = unix.IPPROTO_UDP
default:
return nil, ErrNotImplemented
}
remoteIP := remoteTCP.IP.To4()
if remoteIP == nil {
var family uint8
if localIP.To4() != nil && remoteIP.To4() != nil {
family = unix.AF_INET
}
if family == 0 && localIP.To16() != nil && remoteIP.To16() != nil {
family = unix.AF_INET6
}
if family == 0 {
return nil, ErrNotImplemented
}
@ -182,19 +204,24 @@ func SocketGet(local, remote net.Addr) (*Socket, error) {
return nil, err
}
defer s.Close()
req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, 0)
req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP)
req.AddData(&socketRequest{
Family: unix.AF_INET,
Protocol: unix.IPPROTO_TCP,
Family: family,
Protocol: protocol,
States: 0xffffffff,
ID: SocketID{
SourcePort: uint16(localTCP.Port),
DestinationPort: uint16(remoteTCP.Port),
SourcePort: localPort,
DestinationPort: remotePort,
Source: localIP,
Destination: remoteIP,
Cookie: [2]uint32{nl.TCPDIAG_NOCOOKIE, nl.TCPDIAG_NOCOOKIE},
},
})
s.Send(req)
if err := s.Send(req); err != nil {
return nil, err
}
msgs, from, err := s.Receive()
if err != nil {
return nil, err

View File

@ -16,28 +16,31 @@ import (
func TestSocketGet(t *testing.T) {
defer setUpNetlinkTestWithLoopback(t)()
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
if err != nil {
log.Fatal(err)
type Addr struct {
IP net.IP
Port int
}
l, err := net.ListenTCP("tcp", addr)
if err != nil {
log.Fatal(err)
}
defer l.Close()
conn, err := net.Dial(l.Addr().Network(), l.Addr().String())
getAddr := func(a net.Addr) Addr {
var addr Addr
switch v := a.(type) {
case *net.UDPAddr:
addr.IP = v.IP
addr.Port = v.Port
case *net.TCPAddr:
addr.IP = v.IP
addr.Port = v.Port
}
return addr
}
checkSocket := func(t *testing.T, local, remote net.Addr) {
socket, err := SocketGet(local, remote)
if err != nil {
t.Fatal(err)
}
defer conn.Close()
localAddr := conn.LocalAddr().(*net.TCPAddr)
remoteAddr := conn.RemoteAddr().(*net.TCPAddr)
socket, err := SocketGet(localAddr, remoteAddr)
if err != nil {
t.Fatal(err)
}
localAddr, remoteAddr := getAddr(local), getAddr(remote)
if got, want := socket.ID.Source, localAddr.IP; !got.Equal(want) {
t.Fatalf("local ip = %v, want %v", got, want)
@ -60,6 +63,46 @@ func TestSocketGet(t *testing.T) {
}
}
for _, v := range [...]string{"tcp4", "tcp6"} {
addr, err := net.ResolveTCPAddr(v, "localhost:0")
if err != nil {
log.Fatal(err)
}
l, err := net.ListenTCP(v, addr)
if err != nil {
log.Fatal(err)
}
defer l.Close()
conn, err := net.Dial(l.Addr().Network(), l.Addr().String())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
checkSocket(t, conn.LocalAddr(), conn.RemoteAddr())
}
for _, v := range [...]string{"udp4", "udp6"} {
addr, err := net.ResolveUDPAddr(v, "localhost:0")
if err != nil {
log.Fatal(err)
}
l, err := net.ListenUDP(v, addr)
if err != nil {
log.Fatal(err)
}
defer l.Close()
conn, err := net.Dial(l.LocalAddr().Network(), l.LocalAddr().String())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
checkSocket(t, conn.LocalAddr(), conn.RemoteAddr())
}
}
func TestSocketDestroy(t *testing.T) {
defer setUpNetlinkTestWithLoopback(t)()