SocketGet support udp and ipv6

Signed-off-by: Asutorufa <16442314+Asutorufa@users.noreply.github.com>
This commit is contained in:
Asutorufa 2023-09-19 21:05:05 +08:00 committed by Alessandro Boch
parent 3e28e6db88
commit d237ee16c3
2 changed files with 126 additions and 56 deletions

View File

@ -56,10 +56,8 @@ func (r *socketRequest) Serialize() []byte {
copy(b.Next(16), r.ID.Source) copy(b.Next(16), r.ID.Source)
copy(b.Next(16), r.ID.Destination) copy(b.Next(16), r.ID.Destination)
} else { } else {
copy(b.Next(4), r.ID.Source.To4()) copy(b.Next(16), r.ID.Source.To4())
b.Next(12) copy(b.Next(16), r.ID.Destination.To4())
copy(b.Next(4), r.ID.Destination.To4())
b.Next(12)
} }
native.PutUint32(b.Next(4), r.ID.Interface) native.PutUint32(b.Next(4), r.ID.Interface)
native.PutUint32(b.Next(4), r.ID.Cookie[0]) native.PutUint32(b.Next(4), r.ID.Cookie[0])
@ -160,20 +158,44 @@ func (u *UnixSocket) deserialize(b []byte) error {
// SocketGet returns the Socket identified by its local and remote addresses. // SocketGet returns the Socket identified by its local and remote addresses.
func SocketGet(local, remote net.Addr) (*Socket, error) { func SocketGet(local, remote net.Addr) (*Socket, error) {
localTCP, ok := local.(*net.TCPAddr) var protocol uint8
if !ok { var localIP, remoteIP net.IP
var localPort, remotePort uint16
switch l := local.(type) {
case *net.TCPAddr:
r, ok := remote.(*net.TCPAddr)
if !ok {
return nil, ErrNotImplemented
}
localIP = l.IP
localPort = uint16(l.Port)
remoteIP = r.IP
remotePort = uint16(r.Port)
protocol = unix.IPPROTO_TCP
case *net.UDPAddr:
r, ok := remote.(*net.UDPAddr)
if !ok {
return nil, ErrNotImplemented
}
localIP = l.IP
localPort = uint16(l.Port)
remoteIP = r.IP
remotePort = uint16(r.Port)
protocol = unix.IPPROTO_UDP
default:
return nil, ErrNotImplemented return nil, ErrNotImplemented
} }
remoteTCP, ok := remote.(*net.TCPAddr)
if !ok { var family uint8
return nil, ErrNotImplemented if localIP.To4() != nil && remoteIP.To4() != nil {
family = unix.AF_INET
} }
localIP := localTCP.IP.To4()
if localIP == nil { if family == 0 && localIP.To16() != nil && remoteIP.To16() != nil {
return nil, ErrNotImplemented family = unix.AF_INET6
} }
remoteIP := remoteTCP.IP.To4()
if remoteIP == nil { if family == 0 {
return nil, ErrNotImplemented return nil, ErrNotImplemented
} }
@ -182,19 +204,24 @@ func SocketGet(local, remote net.Addr) (*Socket, error) {
return nil, err return nil, err
} }
defer s.Close() defer s.Close()
req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, 0) req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP)
req.AddData(&socketRequest{ req.AddData(&socketRequest{
Family: unix.AF_INET, Family: family,
Protocol: unix.IPPROTO_TCP, Protocol: protocol,
States: 0xffffffff,
ID: SocketID{ ID: SocketID{
SourcePort: uint16(localTCP.Port), SourcePort: localPort,
DestinationPort: uint16(remoteTCP.Port), DestinationPort: remotePort,
Source: localIP, Source: localIP,
Destination: remoteIP, Destination: remoteIP,
Cookie: [2]uint32{nl.TCPDIAG_NOCOOKIE, nl.TCPDIAG_NOCOOKIE}, Cookie: [2]uint32{nl.TCPDIAG_NOCOOKIE, nl.TCPDIAG_NOCOOKIE},
}, },
}) })
s.Send(req)
if err := s.Send(req); err != nil {
return nil, err
}
msgs, from, err := s.Receive() msgs, from, err := s.Receive()
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -16,47 +16,90 @@ import (
func TestSocketGet(t *testing.T) { func TestSocketGet(t *testing.T) {
defer setUpNetlinkTestWithLoopback(t)() defer setUpNetlinkTestWithLoopback(t)()
addr, err := net.ResolveTCPAddr("tcp", "localhost:0") type Addr struct {
if err != nil { IP net.IP
log.Fatal(err) Port int
}
l, err := net.ListenTCP("tcp", addr)
if err != nil {
log.Fatal(err)
}
defer l.Close()
conn, err := net.Dial(l.Addr().Network(), l.Addr().String())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
localAddr := conn.LocalAddr().(*net.TCPAddr)
remoteAddr := conn.RemoteAddr().(*net.TCPAddr)
socket, err := SocketGet(localAddr, remoteAddr)
if err != nil {
t.Fatal(err)
} }
if got, want := socket.ID.Source, localAddr.IP; !got.Equal(want) { getAddr := func(a net.Addr) Addr {
t.Fatalf("local ip = %v, want %v", got, want) var addr Addr
switch v := a.(type) {
case *net.UDPAddr:
addr.IP = v.IP
addr.Port = v.Port
case *net.TCPAddr:
addr.IP = v.IP
addr.Port = v.Port
}
return addr
} }
if got, want := socket.ID.Destination, remoteAddr.IP; !got.Equal(want) {
t.Fatalf("remote ip = %v, want %v", got, want) checkSocket := func(t *testing.T, local, remote net.Addr) {
socket, err := SocketGet(local, remote)
if err != nil {
t.Fatal(err)
}
localAddr, remoteAddr := getAddr(local), getAddr(remote)
if got, want := socket.ID.Source, localAddr.IP; !got.Equal(want) {
t.Fatalf("local ip = %v, want %v", got, want)
}
if got, want := socket.ID.Destination, remoteAddr.IP; !got.Equal(want) {
t.Fatalf("remote ip = %v, want %v", got, want)
}
if got, want := int(socket.ID.SourcePort), localAddr.Port; got != want {
t.Fatalf("local port = %d, want %d", got, want)
}
if got, want := int(socket.ID.DestinationPort), remoteAddr.Port; got != want {
t.Fatalf("remote port = %d, want %d", got, want)
}
u, err := user.Current()
if err != nil {
t.Fatal(err)
}
if got, want := strconv.Itoa(int(socket.UID)), u.Uid; got != want {
t.Fatalf("UID = %s, want %s", got, want)
}
} }
if got, want := int(socket.ID.SourcePort), localAddr.Port; got != want {
t.Fatalf("local port = %d, want %d", got, want) for _, v := range [...]string{"tcp4", "tcp6"} {
addr, err := net.ResolveTCPAddr(v, "localhost:0")
if err != nil {
log.Fatal(err)
}
l, err := net.ListenTCP(v, addr)
if err != nil {
log.Fatal(err)
}
defer l.Close()
conn, err := net.Dial(l.Addr().Network(), l.Addr().String())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
checkSocket(t, conn.LocalAddr(), conn.RemoteAddr())
} }
if got, want := int(socket.ID.DestinationPort), remoteAddr.Port; got != want {
t.Fatalf("remote port = %d, want %d", got, want) for _, v := range [...]string{"udp4", "udp6"} {
} addr, err := net.ResolveUDPAddr(v, "localhost:0")
u, err := user.Current() if err != nil {
if err != nil { log.Fatal(err)
t.Fatal(err) }
} l, err := net.ListenUDP(v, addr)
if got, want := strconv.Itoa(int(socket.UID)), u.Uid; got != want { if err != nil {
t.Fatalf("UID = %s, want %s", got, want) log.Fatal(err)
}
defer l.Close()
conn, err := net.Dial(l.LocalAddr().Network(), l.LocalAddr().String())
if err != nil {
t.Fatal(err)
}
defer conn.Close()
checkSocket(t, conn.LocalAddr(), conn.RemoteAddr())
} }
} }