diff --git a/netns_linux.go b/netns_linux.go new file mode 100644 index 0000000..9a0d662 --- /dev/null +++ b/netns_linux.go @@ -0,0 +1,141 @@ +package netlink + +// Network namespace ID functions +// +// The kernel has a weird concept called the network namespace ID. +// This is different from the file reference in proc (and any bind-mounted +// namespaces, etc.) +// +// Instead, namespaces can be assigned a numeric ID at any time. Once set, +// the ID is fixed. The ID can either be set manually by the user, or +// automatically, triggered by certain kernel actions. The most common kernel +// action that triggers namespace ID creation is moving one end of a veth pair +// in to that namespace. + +import ( + "fmt" + + "github.com/vishvananda/netlink/nl" + "golang.org/x/sys/unix" +) + +// These can be replaced by the values from sys/unix when it is next released. +const ( + _ = iota + NETNSA_NSID + NETNSA_PID + NETNSA_FD +) + +// GetNetNsIdByPid looks up the network namespace ID for a given pid (really thread id). +// Returns -1 if the namespace does not have an ID set. +func (h *Handle) GetNetNsIdByPid(pid int) (int, error) { + return h.getNetNsId(NETNSA_PID, uint32(pid)) +} + +// GetNetNsIdByPid looks up the network namespace ID for a given pid (really thread id). +// Returns -1 if the namespace does not have an ID set. +func GetNetNsIdByPid(pid int) (int, error) { + return pkgHandle.GetNetNsIdByPid(pid) +} + +// SetNetNSIdByPid sets the ID of the network namespace for a given pid (really thread id). +// The ID can only be set for namespaces without an ID already set. +func (h *Handle) SetNetNsIdByPid(pid, nsid int) error { + return h.setNetNsId(NETNSA_PID, uint32(pid), uint32(nsid)) +} + +// SetNetNSIdByPid sets the ID of the network namespace for a given pid (really thread id). +// The ID can only be set for namespaces without an ID already set. +func SetNetNsIdByPid(pid, nsid int) error { + return pkgHandle.SetNetNsIdByPid(pid, nsid) +} + +// GetNetNsIdByPid looks up the network namespace ID for a given fd. +// fd must be an open file descriptor to a namespace file. +// Returns -1 if the namespace does not have an ID set. +func (h *Handle) GetNetNsIdByFd(fd int) (int, error) { + return h.getNetNsId(NETNSA_FD, uint32(fd)) +} + +// GetNetNsIdByPid looks up the network namespace ID for a given fd. +// fd must be an open file descriptor to a namespace file. +// Returns -1 if the namespace does not have an ID set. +func GetNetNsIdByFd(fd int) (int, error) { + return pkgHandle.GetNetNsIdByFd(fd) +} + +// SetNetNSIdByFd sets the ID of the network namespace for a given fd. +// fd must be an open file descriptor to a namespace file. +// The ID can only be set for namespaces without an ID already set. +func (h *Handle) SetNetNsIdByFd(fd, nsid int) error { + return h.setNetNsId(NETNSA_FD, uint32(fd), uint32(nsid)) +} + +// SetNetNSIdByFd sets the ID of the network namespace for a given fd. +// fd must be an open file descriptor to a namespace file. +// The ID can only be set for namespaces without an ID already set. +func SetNetNsIdByFd(fd, nsid int) error { + return pkgHandle.SetNetNsIdByFd(fd, nsid) +} + +// getNetNsId requests the netnsid for a given type-val pair +// type should be either NETNSA_PID or NETNSA_FD +func (h *Handle) getNetNsId(attrType int, val uint32) (int, error) { + req := h.newNetlinkRequest(unix.RTM_GETNSID, unix.NLM_F_REQUEST) + + rtgen := nl.NewRtGenMsg() + req.AddData(rtgen) + + b := make([]byte, 4, 4) + native.PutUint32(b, val) + attr := nl.NewRtAttr(attrType, b) + req.AddData(attr) + + msgs, err := req.Execute(unix.NETLINK_ROUTE, unix.RTM_NEWNSID) + + if err != nil { + return 0, err + } + + for _, m := range msgs { + msg := nl.DeserializeRtGenMsg(m) + + attrs, err := nl.ParseRouteAttr(m[msg.Len():]) + if err != nil { + return 0, err + } + + for _, attr := range attrs { + switch attr.Attr.Type { + case NETNSA_NSID: + return int(int32(native.Uint32(attr.Value))), nil + } + } + } + + return 0, fmt.Errorf("unexpected empty result") +} + +// setNetNsId sets the netnsid for a given type-val pair +// type should be either NETNSA_PID or NETNSA_FD +// The ID can only be set for namespaces without an ID already set +func (h *Handle) setNetNsId(attrType int, val uint32, newnsid uint32) error { + req := h.newNetlinkRequest(unix.RTM_NEWNSID, unix.NLM_F_REQUEST|unix.NLM_F_ACK) + + rtgen := nl.NewRtGenMsg() + req.AddData(rtgen) + + b := make([]byte, 4, 4) + native.PutUint32(b, val) + attr := nl.NewRtAttr(attrType, b) + req.AddData(attr) + + b1 := make([]byte, 4, 4) + native.PutUint32(b1, newnsid) + attr1 := nl.NewRtAttr(NETNSA_NSID, b1) + req.AddData(attr1) + + _, err := req.Execute(unix.NETLINK_ROUTE, unix.RTM_NEWNSID) + return err +} diff --git a/netns_test.go b/netns_test.go new file mode 100644 index 0000000..3415b9b --- /dev/null +++ b/netns_test.go @@ -0,0 +1,77 @@ +// +build linux + +package netlink + +import ( + "os" + "runtime" + "syscall" + "testing" + + "github.com/vishvananda/netns" +) + +// TestNetNsIdByFd tests setting and getting the network namespace ID +// by file descriptor. It opens a namespace fd, sets it to a random id, +// then retrieves the ID. +// This does not do any namespace switching. +func TestNetNsIdByFd(t *testing.T) { + // create a network namespace + ns, err := netns.New() + CheckErrorFail(t, err) + + // set its ID + // In an attempt to avoid namespace id collisions, set this to something + // insanely high. When the kernel assigns IDs, it does so starting from 0 + // So, just use our pid shifted up 16 bits + wantID := os.Getpid() << 16 + + h, err := NewHandle() + CheckErrorFail(t, err) + err = h.SetNetNsIdByFd(int(ns), wantID) + CheckErrorFail(t, err) + + // Get the ID back, make sure it matches + haveID, err := h.GetNetNsIdByFd(int(ns)) + if haveID != wantID { + t.Errorf("GetNetNsIdByFd returned %d, want %d", haveID, wantID) + } + + ns.Close() +} + +// TestNetNsIdByPid tests manipulating namespace IDs by pid (really, task / thread id) +// Does the same as TestNetNsIdByFd, but we need to change namespaces so we +// actually have a pid in that namespace +func TestNetNsIdByPid(t *testing.T) { + runtime.LockOSThread() // we need a constant OS thread + origNs, _ := netns.Get() + + // create and enter a new netns + ns, err := netns.New() + CheckErrorFail(t, err) + err = netns.Set(ns) + CheckErrorFail(t, err) + // make sure we go back to the original namespace when done + defer func() { + err := netns.Set(origNs) + if err != nil { + panic("failed to restore network ns, bailing!") + } + runtime.UnlockOSThread() + }() + + // As above, we'll pick a crazy large netnsid to avoid collisions + wantID := syscall.Gettid() << 16 + + h, err := NewHandle() + CheckErrorFail(t, err) + err = h.SetNetNsIdByPid(syscall.Gettid(), wantID) + CheckErrorFail(t, err) + + //Get the ID and see if it worked + haveID, err := h.GetNetNsIdByPid(syscall.Gettid()) + if haveID != wantID { + t.Errorf("GetNetNsIdByPid returned %d, want %d", haveID, wantID) + } +} diff --git a/netns_unspecified.go b/netns_unspecified.go new file mode 100644 index 0000000..5c5899e --- /dev/null +++ b/netns_unspecified.go @@ -0,0 +1,19 @@ +// +build !linux + +package netlink + +func GetNetNsIdByPid(pid int) (int, error) { + return 0, ErrNotImplemented +} + +func SetNetNsIdByPid(pid, nsid int) error { + return ErrNotImplemented +} + +func GetNetNsIdByFd(fd int) (int, error) { + return 0, ErrNotImplemented +} + +func SetNetNsIdByFd(fd, nsid int) error { + return ErrNotImplemented +} diff --git a/nl/route_linux.go b/nl/route_linux.go index f6906fc..03c1900 100644 --- a/nl/route_linux.go +++ b/nl/route_linux.go @@ -79,3 +79,29 @@ func (msg *RtNexthop) Serialize() []byte { } return buf } + +type RtGenMsg struct { + unix.RtGenmsg +} + +func NewRtGenMsg() *RtGenMsg { + return &RtGenMsg{ + RtGenmsg: unix.RtGenmsg{ + Family: unix.AF_UNSPEC, + }, + } +} + +func (msg *RtGenMsg) Len() int { + return rtaAlignOf(unix.SizeofRtGenmsg) +} + +func DeserializeRtGenMsg(b []byte) *RtGenMsg { + return &RtGenMsg{RtGenmsg: unix.RtGenmsg{Family: b[0]}} +} + +func (msg *RtGenMsg) Serialize() []byte { + out := make([]byte, msg.Len()) + out[0] = msg.Family + return out +}