diff --git a/Makefile b/Makefile index 75f3429..8dc5a92 100644 --- a/Makefile +++ b/Makefile @@ -18,7 +18,7 @@ $(call goroot,$(DEPS)): .PHONY: $(call testdirs,$(DIRS)) $(call testdirs,$(DIRS)): - sudo -E go test -v github.com/vishvananda/netlink/$@ + sudo -E go test -test.parallel 4 -timeout 60s -v github.com/vishvananda/netlink/$@ $(call fmt,$(call testdirs,$(DIRS))): ! gofmt -l $(subst fmt-,,$@)/*.go | grep '' diff --git a/handle_test.go b/handle_test.go index f241b3d..b4e6965 100644 --- a/handle_test.go +++ b/handle_test.go @@ -3,7 +3,11 @@ package netlink import ( "crypto/rand" "encoding/hex" + "fmt" "io" + "net" + "sync" + "sync/atomic" "testing" "github.com/vishvananda/netns" @@ -95,3 +99,188 @@ func TestHandleCreateNetns(t *testing.T) { t.Fatalf("Unexpected link found on netns %s: %v", curNs, ll) } } + +var ( + iter = 10 + numThread = uint32(4) + prefix = "iface" + handle1 *Handle + handle2 *Handle + ns1 netns.NsHandle + ns2 netns.NsHandle + done uint32 + initError error + once sync.Once +) + +func getXfrmState(thread int) *XfrmState { + return &XfrmState{ + Src: net.IPv4(byte(192), byte(168), 1, byte(1+thread)), + Dst: net.IPv4(byte(192), byte(168), 2, byte(1+thread)), + Proto: XFRM_PROTO_AH, + Mode: XFRM_MODE_TUNNEL, + Spi: thread, + Auth: &XfrmStateAlgo{ + Name: "hmac(sha256)", + Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"), + }, + } +} + +func getXfrmPolicy(thread int) *XfrmPolicy { + return &XfrmPolicy{ + Src: &net.IPNet{IP: net.IPv4(byte(10), byte(10), byte(thread), 0), Mask: []byte{255, 255, 255, 0}}, + Dst: &net.IPNet{IP: net.IPv4(byte(10), byte(10), byte(thread), 0), Mask: []byte{255, 255, 255, 0}}, + Proto: 17, + DstPort: 1234, + SrcPort: 5678, + Dir: XFRM_DIR_OUT, + Tmpls: []XfrmPolicyTmpl{ + { + Src: net.IPv4(byte(192), byte(168), 1, byte(thread)), + Dst: net.IPv4(byte(192), byte(168), 2, byte(thread)), + Proto: XFRM_PROTO_ESP, + Mode: XFRM_MODE_TUNNEL, + }, + }, + } +} +func initParallel() { + ns1, initError = netns.New() + if initError != nil { + return + } + handle1, initError = NewHandleAt(ns1) + if initError != nil { + return + } + ns2, initError = netns.New() + if initError != nil { + return + } + handle2, initError = NewHandleAt(ns2) + if initError != nil { + return + } +} + +func parallelDone() { + atomic.AddUint32(&done, 1) + if done == numThread { + if ns1.IsOpen() { + ns1.Close() + } + if ns2.IsOpen() { + ns2.Close() + } + if handle1 != nil { + handle1.Delete() + } + if handle2 != nil { + handle2.Delete() + } + } +} + +// Do few route and xfrm operation on the two handles in parallel +func runParallelTests(t *testing.T, thread int) { + defer parallelDone() + + t.Parallel() + + once.Do(initParallel) + if initError != nil { + t.Fatal(initError) + } + + state := getXfrmState(thread) + policy := getXfrmPolicy(thread) + for i := 0; i < iter; i++ { + ifName := fmt.Sprintf("%s_%d_%d", prefix, thread, i) + link := &Dummy{LinkAttrs{Name: ifName}} + err := handle1.LinkAdd(link) + if err != nil { + t.Fatal(err) + } + l, err := handle1.LinkByName(ifName) + if err != nil { + t.Fatal(err) + } + err = handle1.LinkSetUp(l) + if err != nil { + t.Fatal(err) + } + handle1.LinkSetNsFd(l, int(ns2)) + if err != nil { + t.Fatal(err) + } + err = handle1.XfrmStateAdd(state) + if err != nil { + t.Fatal(err) + } + err = handle1.XfrmPolicyAdd(policy) + if err != nil { + t.Fatal(err) + } + err = handle2.LinkSetDown(l) + if err != nil { + t.Fatal(err) + } + err = handle2.XfrmStateAdd(state) + if err != nil { + t.Fatal(err) + } + err = handle2.XfrmPolicyAdd(policy) + if err != nil { + t.Fatal(err) + } + _, err = handle2.LinkByName(ifName) + if err != nil { + t.Fatal(err) + } + handle2.LinkSetNsFd(l, int(ns1)) + if err != nil { + t.Fatal(err) + } + err = handle1.LinkSetUp(l) + if err != nil { + t.Fatal(err) + } + l, err = handle1.LinkByName(ifName) + if err != nil { + t.Fatal(err) + } + err = handle1.XfrmPolicyDel(policy) + if err != nil { + t.Fatal(err) + } + err = handle2.XfrmPolicyDel(policy) + if err != nil { + t.Fatal(err) + } + err = handle1.XfrmStateDel(state) + if err != nil { + t.Fatal(err) + } + err = handle2.XfrmStateDel(state) + if err != nil { + t.Fatal(err) + } + } +} + +func TestHandleParallel1(t *testing.T) { + runParallelTests(t, 1) +} + +func TestHandleParallel2(t *testing.T) { + runParallelTests(t, 2) +} + +func TestHandleParallel3(t *testing.T) { + runParallelTests(t, 3) +} + +func TestHandleParallel4(t *testing.T) { + runParallelTests(t, 4) +} diff --git a/nl/nl_linux.go b/nl/nl_linux.go index 1e5233b..8306890 100644 --- a/nl/nl_linux.go +++ b/nl/nl_linux.go @@ -7,6 +7,7 @@ import ( "fmt" "net" "runtime" + "sync" "sync/atomic" "syscall" "unsafe" @@ -233,6 +234,9 @@ func (req *NetlinkRequest) Execute(sockType int, resType uint16) ([][]byte, erro return nil, err } defer s.Close() + } else { + s.Lock() + defer s.Unlock() } if err := s.Send(req); err != nil { @@ -302,6 +306,7 @@ func NewNetlinkRequest(proto, flags int) *NetlinkRequest { type NetlinkSocket struct { fd int lsa syscall.SockaddrNetlink + sync.Mutex } func getNetlinkSocket(protocol int) (*NetlinkSocket, error) {