mirror of https://github.com/vishvananda/netlink
SocketGet support udp and ipv6
Signed-off-by: Asutorufa <16442314+Asutorufa@users.noreply.github.com>
This commit is contained in:
parent
3e28e6db88
commit
d237ee16c3
|
@ -56,10 +56,8 @@ func (r *socketRequest) Serialize() []byte {
|
||||||
copy(b.Next(16), r.ID.Source)
|
copy(b.Next(16), r.ID.Source)
|
||||||
copy(b.Next(16), r.ID.Destination)
|
copy(b.Next(16), r.ID.Destination)
|
||||||
} else {
|
} else {
|
||||||
copy(b.Next(4), r.ID.Source.To4())
|
copy(b.Next(16), r.ID.Source.To4())
|
||||||
b.Next(12)
|
copy(b.Next(16), r.ID.Destination.To4())
|
||||||
copy(b.Next(4), r.ID.Destination.To4())
|
|
||||||
b.Next(12)
|
|
||||||
}
|
}
|
||||||
native.PutUint32(b.Next(4), r.ID.Interface)
|
native.PutUint32(b.Next(4), r.ID.Interface)
|
||||||
native.PutUint32(b.Next(4), r.ID.Cookie[0])
|
native.PutUint32(b.Next(4), r.ID.Cookie[0])
|
||||||
|
@ -160,20 +158,44 @@ func (u *UnixSocket) deserialize(b []byte) error {
|
||||||
|
|
||||||
// SocketGet returns the Socket identified by its local and remote addresses.
|
// SocketGet returns the Socket identified by its local and remote addresses.
|
||||||
func SocketGet(local, remote net.Addr) (*Socket, error) {
|
func SocketGet(local, remote net.Addr) (*Socket, error) {
|
||||||
localTCP, ok := local.(*net.TCPAddr)
|
var protocol uint8
|
||||||
|
var localIP, remoteIP net.IP
|
||||||
|
var localPort, remotePort uint16
|
||||||
|
switch l := local.(type) {
|
||||||
|
case *net.TCPAddr:
|
||||||
|
r, ok := remote.(*net.TCPAddr)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, ErrNotImplemented
|
return nil, ErrNotImplemented
|
||||||
}
|
}
|
||||||
remoteTCP, ok := remote.(*net.TCPAddr)
|
localIP = l.IP
|
||||||
|
localPort = uint16(l.Port)
|
||||||
|
remoteIP = r.IP
|
||||||
|
remotePort = uint16(r.Port)
|
||||||
|
protocol = unix.IPPROTO_TCP
|
||||||
|
case *net.UDPAddr:
|
||||||
|
r, ok := remote.(*net.UDPAddr)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, ErrNotImplemented
|
return nil, ErrNotImplemented
|
||||||
}
|
}
|
||||||
localIP := localTCP.IP.To4()
|
localIP = l.IP
|
||||||
if localIP == nil {
|
localPort = uint16(l.Port)
|
||||||
|
remoteIP = r.IP
|
||||||
|
remotePort = uint16(r.Port)
|
||||||
|
protocol = unix.IPPROTO_UDP
|
||||||
|
default:
|
||||||
return nil, ErrNotImplemented
|
return nil, ErrNotImplemented
|
||||||
}
|
}
|
||||||
remoteIP := remoteTCP.IP.To4()
|
|
||||||
if remoteIP == nil {
|
var family uint8
|
||||||
|
if localIP.To4() != nil && remoteIP.To4() != nil {
|
||||||
|
family = unix.AF_INET
|
||||||
|
}
|
||||||
|
|
||||||
|
if family == 0 && localIP.To16() != nil && remoteIP.To16() != nil {
|
||||||
|
family = unix.AF_INET6
|
||||||
|
}
|
||||||
|
|
||||||
|
if family == 0 {
|
||||||
return nil, ErrNotImplemented
|
return nil, ErrNotImplemented
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -182,19 +204,24 @@ func SocketGet(local, remote net.Addr) (*Socket, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, 0)
|
req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP)
|
||||||
req.AddData(&socketRequest{
|
req.AddData(&socketRequest{
|
||||||
Family: unix.AF_INET,
|
Family: family,
|
||||||
Protocol: unix.IPPROTO_TCP,
|
Protocol: protocol,
|
||||||
|
States: 0xffffffff,
|
||||||
ID: SocketID{
|
ID: SocketID{
|
||||||
SourcePort: uint16(localTCP.Port),
|
SourcePort: localPort,
|
||||||
DestinationPort: uint16(remoteTCP.Port),
|
DestinationPort: remotePort,
|
||||||
Source: localIP,
|
Source: localIP,
|
||||||
Destination: remoteIP,
|
Destination: remoteIP,
|
||||||
Cookie: [2]uint32{nl.TCPDIAG_NOCOOKIE, nl.TCPDIAG_NOCOOKIE},
|
Cookie: [2]uint32{nl.TCPDIAG_NOCOOKIE, nl.TCPDIAG_NOCOOKIE},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
s.Send(req)
|
|
||||||
|
if err := s.Send(req); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
msgs, from, err := s.Receive()
|
msgs, from, err := s.Receive()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -16,28 +16,31 @@ import (
|
||||||
func TestSocketGet(t *testing.T) {
|
func TestSocketGet(t *testing.T) {
|
||||||
defer setUpNetlinkTestWithLoopback(t)()
|
defer setUpNetlinkTestWithLoopback(t)()
|
||||||
|
|
||||||
addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
|
type Addr struct {
|
||||||
if err != nil {
|
IP net.IP
|
||||||
log.Fatal(err)
|
Port int
|
||||||
}
|
}
|
||||||
l, err := net.ListenTCP("tcp", addr)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
defer l.Close()
|
|
||||||
|
|
||||||
conn, err := net.Dial(l.Addr().Network(), l.Addr().String())
|
getAddr := func(a net.Addr) Addr {
|
||||||
|
var addr Addr
|
||||||
|
switch v := a.(type) {
|
||||||
|
case *net.UDPAddr:
|
||||||
|
addr.IP = v.IP
|
||||||
|
addr.Port = v.Port
|
||||||
|
case *net.TCPAddr:
|
||||||
|
addr.IP = v.IP
|
||||||
|
addr.Port = v.Port
|
||||||
|
}
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
checkSocket := func(t *testing.T, local, remote net.Addr) {
|
||||||
|
socket, err := SocketGet(local, remote)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
localAddr := conn.LocalAddr().(*net.TCPAddr)
|
localAddr, remoteAddr := getAddr(local), getAddr(remote)
|
||||||
remoteAddr := conn.RemoteAddr().(*net.TCPAddr)
|
|
||||||
socket, err := SocketGet(localAddr, remoteAddr)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if got, want := socket.ID.Source, localAddr.IP; !got.Equal(want) {
|
if got, want := socket.ID.Source, localAddr.IP; !got.Equal(want) {
|
||||||
t.Fatalf("local ip = %v, want %v", got, want)
|
t.Fatalf("local ip = %v, want %v", got, want)
|
||||||
|
@ -58,6 +61,46 @@ func TestSocketGet(t *testing.T) {
|
||||||
if got, want := strconv.Itoa(int(socket.UID)), u.Uid; got != want {
|
if got, want := strconv.Itoa(int(socket.UID)), u.Uid; got != want {
|
||||||
t.Fatalf("UID = %s, want %s", got, want)
|
t.Fatalf("UID = %s, want %s", got, want)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range [...]string{"tcp4", "tcp6"} {
|
||||||
|
addr, err := net.ResolveTCPAddr(v, "localhost:0")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
l, err := net.ListenTCP(v, addr)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
defer l.Close()
|
||||||
|
|
||||||
|
conn, err := net.Dial(l.Addr().Network(), l.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
checkSocket(t, conn.LocalAddr(), conn.RemoteAddr())
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range [...]string{"udp4", "udp6"} {
|
||||||
|
addr, err := net.ResolveUDPAddr(v, "localhost:0")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
l, err := net.ListenUDP(v, addr)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
defer l.Close()
|
||||||
|
conn, err := net.Dial(l.LocalAddr().Network(), l.LocalAddr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
checkSocket(t, conn.LocalAddr(), conn.RemoteAddr())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSocketDestroy(t *testing.T) {
|
func TestSocketDestroy(t *testing.T) {
|
||||||
|
|
Loading…
Reference in New Issue