netlink/netlink_test.go
Alex O'Regan aed23dbf5e Adds ConntrackCreate & ConntrackUpdate
- Also refactored setUpNetlinkTestWithKModule function to reduce redundant NS's created and checks made.

 - Add conntrack protoinfo TCP support + groundwork for other protocols.

 - Tests to cover the above.
2024-07-04 08:47:44 -07:00

279 lines
6.2 KiB
Go

//go:build linux
// +build linux
package netlink
import (
"bytes"
"crypto/rand"
"encoding/hex"
"fmt"
"io/ioutil"
"log"
"os"
"os/exec"
"runtime"
"strings"
"testing"
"github.com/vishvananda/netlink/nl"
"github.com/vishvananda/netns"
"golang.org/x/sys/unix"
)
type tearDownNetlinkTest func()
func skipUnlessRoot(t testing.TB) {
t.Helper()
if os.Getuid() != 0 {
t.Skip("Test requires root privileges.")
}
}
func skipUnlessKModuleLoaded(t *testing.T, moduleNames ...string) {
t.Helper()
file, err := ioutil.ReadFile("/proc/modules")
if err != nil {
t.Fatal("Failed to open /proc/modules", err)
}
foundRequiredMods := make(map[string]bool)
lines := strings.Split(string(file), "\n")
for _, name := range moduleNames {
foundRequiredMods[name] = false
for _, line := range lines {
n := strings.Split(line, " ")[0]
if n == name {
foundRequiredMods[name] = true
break
}
}
}
failed := false
for _, name := range moduleNames {
if found, _ := foundRequiredMods[name]; !found {
t.Logf("Test requires missing kmodule %q.", name)
failed = true
}
}
if failed {
t.SkipNow()
}
}
func setUpNetlinkTest(t testing.TB) tearDownNetlinkTest {
skipUnlessRoot(t)
// new temporary namespace so we don't pollute the host
// lock thread since the namespace is thread local
runtime.LockOSThread()
var err error
ns, err := netns.New()
if err != nil {
t.Fatal("Failed to create newns", ns)
}
return func() {
ns.Close()
runtime.UnlockOSThread()
}
}
// setUpNamedNetlinkTest create a temporary named names space with a random name
func setUpNamedNetlinkTest(t *testing.T) (string, tearDownNetlinkTest) {
skipUnlessRoot(t)
origNS, err := netns.Get()
if err != nil {
t.Fatal("Failed saving orig namespace")
}
// create a random name
rnd := make([]byte, 4)
if _, err := rand.Read(rnd); err != nil {
t.Fatal("failed creating random ns name")
}
name := "netlinktest-" + hex.EncodeToString(rnd)
ns, err := netns.NewNamed(name)
if err != nil {
t.Fatal("Failed to create new ns", err)
}
runtime.LockOSThread()
cleanup := func() {
ns.Close()
netns.DeleteNamed(name)
netns.Set(origNS)
runtime.UnlockOSThread()
}
if err := netns.Set(ns); err != nil {
cleanup()
t.Fatal("Failed entering new namespace", err)
}
return name, cleanup
}
func setUpNetlinkTestWithLoopback(t *testing.T) tearDownNetlinkTest {
skipUnlessRoot(t)
runtime.LockOSThread()
ns, err := netns.New()
if err != nil {
t.Fatal("Failed to create new netns", ns)
}
link, err := LinkByName("lo")
if err != nil {
t.Fatalf("Failed to find \"lo\" in new netns: %v", err)
}
if err := LinkSetUp(link); err != nil {
t.Fatalf("Failed to bring up \"lo\" in new netns: %v", err)
}
return func() {
ns.Close()
runtime.UnlockOSThread()
}
}
func setUpF(t *testing.T, path, value string) {
file, err := os.Create(path)
if err != nil {
t.Fatalf("Failed to open %s: %s", path, err)
}
defer file.Close()
file.WriteString(value)
}
func setUpMPLSNetlinkTest(t *testing.T) tearDownNetlinkTest {
if _, err := os.Stat("/proc/sys/net/mpls/platform_labels"); err != nil {
t.Skip("Test requires MPLS support.")
}
f := setUpNetlinkTest(t)
setUpF(t, "/proc/sys/net/mpls/platform_labels", "1024")
setUpF(t, "/proc/sys/net/mpls/conf/lo/input", "1")
return f
}
func setUpSEG6NetlinkTest(t *testing.T) tearDownNetlinkTest {
// check if SEG6 options are enabled in Kernel Config
cmd := exec.Command("uname", "-r")
var out bytes.Buffer
cmd.Stdout = &out
if err := cmd.Run(); err != nil {
t.Fatal("Failed to run: uname -r")
}
s := []string{"/boot/config-", strings.TrimRight(out.String(), "\n")}
filename := strings.Join(s, "")
grepKey := func(key, fname string) (string, error) {
cmd := exec.Command("grep", key, filename)
var out bytes.Buffer
cmd.Stdout = &out
err := cmd.Run() // "err != nil" if no line matched with grep
return strings.TrimRight(out.String(), "\n"), err
}
key := string("CONFIG_IPV6_SEG6_LWTUNNEL=y")
if _, err := grepKey(key, filename); err != nil {
msg := "Skipped test because it requires SEG6_LWTUNNEL support."
log.Println(msg)
t.Skip(msg)
}
// Add CONFIG_IPV6_SEG6_HMAC to support seg6_hamc
// key := string("CONFIG_IPV6_SEG6_HMAC=y")
return setUpNetlinkTest(t)
}
func setUpNetlinkTestWithKModule(t *testing.T, moduleNames ...string) tearDownNetlinkTest {
skipUnlessKModuleLoaded(t, moduleNames...)
return setUpNetlinkTest(t)
}
func setUpNamedNetlinkTestWithKModule(t *testing.T, moduleNames ...string) (string, tearDownNetlinkTest) {
file, err := ioutil.ReadFile("/proc/modules")
if err != nil {
t.Fatal("Failed to open /proc/modules", err)
}
foundRequiredMods := make(map[string]bool)
lines := strings.Split(string(file), "\n")
for _, name := range moduleNames {
foundRequiredMods[name] = false
for _, line := range lines {
n := strings.Split(line, " ")[0]
if n == name {
foundRequiredMods[name] = true
break
}
}
}
failed := false
for _, name := range moduleNames {
if found, _ := foundRequiredMods[name]; !found {
t.Logf("Test requires missing kmodule %q.", name)
failed = true
}
}
if failed {
t.SkipNow()
}
return setUpNamedNetlinkTest(t)
}
func remountSysfs() error {
if err := unix.Mount("", "/", "none", unix.MS_SLAVE|unix.MS_REC, ""); err != nil {
return err
}
if err := unix.Unmount("/sys", unix.MNT_DETACH); err != nil {
return err
}
return unix.Mount("", "/sys", "sysfs", 0, "")
}
func minKernelRequired(t *testing.T, kernel, major int) {
t.Helper()
k, m, err := KernelVersion()
if err != nil {
t.Fatal(err)
}
if k < kernel || k == kernel && m < major {
t.Skipf("Host Kernel (%d.%d) does not meet test's minimum required version: (%d.%d)",
k, m, kernel, major)
}
}
func KernelVersion() (kernel, major int, err error) {
uts := unix.Utsname{}
if err = unix.Uname(&uts); err != nil {
return
}
ba := make([]byte, 0, len(uts.Release))
for _, b := range uts.Release {
if b == 0 {
break
}
ba = append(ba, byte(b))
}
var rest string
if n, _ := fmt.Sscanf(string(ba), "%d.%d%s", &kernel, &major, &rest); n < 2 {
err = fmt.Errorf("can't parse kernel version in %q", string(ba))
}
return
}
func TestMain(m *testing.M) {
nl.EnableErrorMessageReporting = true
os.Exit(m.Run())
}