mirror of https://github.com/vishvananda/netlink
SocketGet support udp and ipv6
Signed-off-by: Asutorufa <16442314+Asutorufa@users.noreply.github.com>
This commit is contained in:
parent
3e28e6db88
commit
d237ee16c3
|
@ -56,10 +56,8 @@ func (r *socketRequest) Serialize() []byte {
|
|||
copy(b.Next(16), r.ID.Source)
|
||||
copy(b.Next(16), r.ID.Destination)
|
||||
} else {
|
||||
copy(b.Next(4), r.ID.Source.To4())
|
||||
b.Next(12)
|
||||
copy(b.Next(4), r.ID.Destination.To4())
|
||||
b.Next(12)
|
||||
copy(b.Next(16), r.ID.Source.To4())
|
||||
copy(b.Next(16), r.ID.Destination.To4())
|
||||
}
|
||||
native.PutUint32(b.Next(4), r.ID.Interface)
|
||||
native.PutUint32(b.Next(4), r.ID.Cookie[0])
|
||||
|
@ -160,20 +158,44 @@ func (u *UnixSocket) deserialize(b []byte) error {
|
|||
|
||||
// SocketGet returns the Socket identified by its local and remote addresses.
|
||||
func SocketGet(local, remote net.Addr) (*Socket, error) {
|
||||
localTCP, ok := local.(*net.TCPAddr)
|
||||
if !ok {
|
||||
var protocol uint8
|
||||
var localIP, remoteIP net.IP
|
||||
var localPort, remotePort uint16
|
||||
switch l := local.(type) {
|
||||
case *net.TCPAddr:
|
||||
r, ok := remote.(*net.TCPAddr)
|
||||
if !ok {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
localIP = l.IP
|
||||
localPort = uint16(l.Port)
|
||||
remoteIP = r.IP
|
||||
remotePort = uint16(r.Port)
|
||||
protocol = unix.IPPROTO_TCP
|
||||
case *net.UDPAddr:
|
||||
r, ok := remote.(*net.UDPAddr)
|
||||
if !ok {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
localIP = l.IP
|
||||
localPort = uint16(l.Port)
|
||||
remoteIP = r.IP
|
||||
remotePort = uint16(r.Port)
|
||||
protocol = unix.IPPROTO_UDP
|
||||
default:
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
remoteTCP, ok := remote.(*net.TCPAddr)
|
||||
if !ok {
|
||||
return nil, ErrNotImplemented
|
||||
|
||||
var family uint8
|
||||
if localIP.To4() != nil && remoteIP.To4() != nil {
|
||||
family = unix.AF_INET
|
||||
}
|
||||
localIP := localTCP.IP.To4()
|
||||
if localIP == nil {
|
||||
return nil, ErrNotImplemented
|
||||
|
||||
if family == 0 && localIP.To16() != nil && remoteIP.To16() != nil {
|
||||
family = unix.AF_INET6
|
||||
}
|
||||
remoteIP := remoteTCP.IP.To4()
|
||||
if remoteIP == nil {
|
||||
|
||||
if family == 0 {
|
||||
return nil, ErrNotImplemented
|
||||
}
|
||||
|
||||
|
@ -182,19 +204,24 @@ func SocketGet(local, remote net.Addr) (*Socket, error) {
|
|||
return nil, err
|
||||
}
|
||||
defer s.Close()
|
||||
req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, 0)
|
||||
req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP)
|
||||
req.AddData(&socketRequest{
|
||||
Family: unix.AF_INET,
|
||||
Protocol: unix.IPPROTO_TCP,
|
||||
Family: family,
|
||||
Protocol: protocol,
|
||||
States: 0xffffffff,
|
||||
ID: SocketID{
|
||||
SourcePort: uint16(localTCP.Port),
|
||||
DestinationPort: uint16(remoteTCP.Port),
|
||||
SourcePort: localPort,
|
||||
DestinationPort: remotePort,
|
||||
Source: localIP,
|
||||
Destination: remoteIP,
|
||||
Cookie: [2]uint32{nl.TCPDIAG_NOCOOKIE, nl.TCPDIAG_NOCOOKIE},
|
||||
},
|
||||
})
|
||||
s.Send(req)
|
||||
|
||||
if err := s.Send(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msgs, from, err := s.Receive()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
115
socket_test.go
115
socket_test.go
|
@ -16,47 +16,90 @@ import (
|
|||
func TestSocketGet(t *testing.T) {
|
||||
defer setUpNetlinkTestWithLoopback(t)()
|
||||
|
||||
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
l, err := net.ListenTCP("tcp", addr)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer l.Close()
|
||||
|
||||
conn, err := net.Dial(l.Addr().Network(), l.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
localAddr := conn.LocalAddr().(*net.TCPAddr)
|
||||
remoteAddr := conn.RemoteAddr().(*net.TCPAddr)
|
||||
socket, err := SocketGet(localAddr, remoteAddr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
type Addr struct {
|
||||
IP net.IP
|
||||
Port int
|
||||
}
|
||||
|
||||
if got, want := socket.ID.Source, localAddr.IP; !got.Equal(want) {
|
||||
t.Fatalf("local ip = %v, want %v", got, want)
|
||||
getAddr := func(a net.Addr) Addr {
|
||||
var addr Addr
|
||||
switch v := a.(type) {
|
||||
case *net.UDPAddr:
|
||||
addr.IP = v.IP
|
||||
addr.Port = v.Port
|
||||
case *net.TCPAddr:
|
||||
addr.IP = v.IP
|
||||
addr.Port = v.Port
|
||||
}
|
||||
return addr
|
||||
}
|
||||
if got, want := socket.ID.Destination, remoteAddr.IP; !got.Equal(want) {
|
||||
t.Fatalf("remote ip = %v, want %v", got, want)
|
||||
|
||||
checkSocket := func(t *testing.T, local, remote net.Addr) {
|
||||
socket, err := SocketGet(local, remote)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
localAddr, remoteAddr := getAddr(local), getAddr(remote)
|
||||
|
||||
if got, want := socket.ID.Source, localAddr.IP; !got.Equal(want) {
|
||||
t.Fatalf("local ip = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := socket.ID.Destination, remoteAddr.IP; !got.Equal(want) {
|
||||
t.Fatalf("remote ip = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := int(socket.ID.SourcePort), localAddr.Port; got != want {
|
||||
t.Fatalf("local port = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := int(socket.ID.DestinationPort), remoteAddr.Port; got != want {
|
||||
t.Fatalf("remote port = %d, want %d", got, want)
|
||||
}
|
||||
u, err := user.Current()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got, want := strconv.Itoa(int(socket.UID)), u.Uid; got != want {
|
||||
t.Fatalf("UID = %s, want %s", got, want)
|
||||
}
|
||||
}
|
||||
if got, want := int(socket.ID.SourcePort), localAddr.Port; got != want {
|
||||
t.Fatalf("local port = %d, want %d", got, want)
|
||||
|
||||
for _, v := range [...]string{"tcp4", "tcp6"} {
|
||||
addr, err := net.ResolveTCPAddr(v, "localhost:0")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
l, err := net.ListenTCP(v, addr)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer l.Close()
|
||||
|
||||
conn, err := net.Dial(l.Addr().Network(), l.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
checkSocket(t, conn.LocalAddr(), conn.RemoteAddr())
|
||||
}
|
||||
if got, want := int(socket.ID.DestinationPort), remoteAddr.Port; got != want {
|
||||
t.Fatalf("remote port = %d, want %d", got, want)
|
||||
}
|
||||
u, err := user.Current()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got, want := strconv.Itoa(int(socket.UID)), u.Uid; got != want {
|
||||
t.Fatalf("UID = %s, want %s", got, want)
|
||||
|
||||
for _, v := range [...]string{"udp4", "udp6"} {
|
||||
addr, err := net.ResolveUDPAddr(v, "localhost:0")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
l, err := net.ListenUDP(v, addr)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer l.Close()
|
||||
conn, err := net.Dial(l.LocalAddr().Network(), l.LocalAddr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
checkSocket(t, conn.LocalAddr(), conn.RemoteAddr())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue