diff --git a/handle_test.go b/handle_test.go index 9e62b89..f241b3d 100644 --- a/handle_test.go +++ b/handle_test.go @@ -1,6 +1,9 @@ package netlink import ( + "crypto/rand" + "encoding/hex" + "io" "testing" "github.com/vishvananda/netns" @@ -22,7 +25,11 @@ func TestHandleCreateDelete(t *testing.T) { } func TestHandleCreateNetns(t *testing.T) { - ifName := "dummy0" + id := make([]byte, 4) + if _, err := io.ReadFull(rand.Reader, id); err != nil { + t.Fatal(err) + } + ifName := "dummy-" + hex.EncodeToString(id) // Create an handle on the current netns curNs, err := netns.Get() diff --git a/xfrm_policy_linux.go b/xfrm_policy_linux.go index 7f9dda7..7d9a3db 100644 --- a/xfrm_policy_linux.go +++ b/xfrm_policy_linux.go @@ -155,50 +155,96 @@ func (h *Handle) XfrmPolicyList(family int) ([]XfrmPolicy, error) { var res []XfrmPolicy for _, m := range msgs { - msg := nl.DeserializeXfrmUserpolicyInfo(m) - - if family != FAMILY_ALL && family != int(msg.Sel.Family) { + if policy, err := parseXfrmPolicy(m, family); err == nil { + res = append(res, *policy) + } else if err == familyError { continue - } - - var policy XfrmPolicy - - policy.Dst = msg.Sel.Daddr.ToIPNet(msg.Sel.PrefixlenD) - policy.Src = msg.Sel.Saddr.ToIPNet(msg.Sel.PrefixlenS) - policy.Proto = Proto(msg.Sel.Proto) - policy.DstPort = int(nl.Swap16(msg.Sel.Dport)) - policy.SrcPort = int(nl.Swap16(msg.Sel.Sport)) - policy.Priority = int(msg.Priority) - policy.Index = int(msg.Index) - policy.Dir = Dir(msg.Dir) - - attrs, err := nl.ParseRouteAttr(m[msg.Len():]) - if err != nil { + } else { return nil, err } - - for _, attr := range attrs { - switch attr.Attr.Type { - case nl.XFRMA_TMPL: - max := len(attr.Value) - for i := 0; i < max; i += nl.SizeofXfrmUserTmpl { - var resTmpl XfrmPolicyTmpl - tmpl := nl.DeserializeXfrmUserTmpl(attr.Value[i : i+nl.SizeofXfrmUserTmpl]) - resTmpl.Dst = tmpl.XfrmId.Daddr.ToIP() - resTmpl.Src = tmpl.Saddr.ToIP() - resTmpl.Proto = Proto(tmpl.XfrmId.Proto) - resTmpl.Mode = Mode(tmpl.Mode) - resTmpl.Reqid = int(tmpl.Reqid) - policy.Tmpls = append(policy.Tmpls, resTmpl) - } - case nl.XFRMA_MARK: - mark := nl.DeserializeXfrmMark(attr.Value[:]) - policy.Mark = new(XfrmMark) - policy.Mark.Value = mark.Value - policy.Mark.Mask = mark.Mask - } - } - res = append(res, policy) } return res, nil } + +// XfrmPolicyGet gets a the policy described by the index or selector, if found. +// Equivalent to: `ip xfrm policy get { SELECTOR | index INDEX } dir DIR [ctx CTX ] [ mark MARK [ mask MASK ] ] [ ptype PTYPE ]`. +func XfrmPolicyGet(policy *XfrmPolicy) (*XfrmPolicy, error) { + h, err := NewHandle() + if err != nil { + return nil, err + } + defer h.Delete() + return h.XfrmPolicyGet(policy) +} + +// XfrmPolicyGet gets a the policy described by the index or selector, if found. +// Equivalent to: `ip xfrm policy get { SELECTOR | index INDEX } dir DIR [ctx CTX ] [ mark MARK [ mask MASK ] ] [ ptype PTYPE ]`. +func (h *Handle) XfrmPolicyGet(policy *XfrmPolicy) (*XfrmPolicy, error) { + req := h.newNetlinkRequest(nl.XFRM_MSG_GETPOLICY, syscall.NLM_F_DUMP) + + msg := &nl.XfrmUserpolicyInfo{} + selFromPolicy(&msg.Sel, policy) + msg.Index = uint32(policy.Index) + msg.Dir = uint8(policy.Dir) + req.AddData(msg) + + msgs, err := req.Execute(syscall.NETLINK_XFRM, nl.XFRM_MSG_NEWPOLICY) + if err != nil { + return nil, err + } + + if policy, err := parseXfrmPolicy(msgs[0], FAMILY_ALL); err == nil { + return policy, nil + } else { + return nil, err + } +} + +func parseXfrmPolicy(m []byte, family int) (*XfrmPolicy, error) { + msg := nl.DeserializeXfrmUserpolicyInfo(m) + + // This is mainly for the policy dump + if family != FAMILY_ALL && family != int(msg.Sel.Family) { + return nil, familyError + } + + var policy XfrmPolicy + + policy.Dst = msg.Sel.Daddr.ToIPNet(msg.Sel.PrefixlenD) + policy.Src = msg.Sel.Saddr.ToIPNet(msg.Sel.PrefixlenS) + policy.Proto = Proto(msg.Sel.Proto) + policy.DstPort = int(nl.Swap16(msg.Sel.Dport)) + policy.SrcPort = int(nl.Swap16(msg.Sel.Sport)) + policy.Priority = int(msg.Priority) + policy.Index = int(msg.Index) + policy.Dir = Dir(msg.Dir) + + attrs, err := nl.ParseRouteAttr(m[msg.Len():]) + if err != nil { + return nil, err + } + + for _, attr := range attrs { + switch attr.Attr.Type { + case nl.XFRMA_TMPL: + max := len(attr.Value) + for i := 0; i < max; i += nl.SizeofXfrmUserTmpl { + var resTmpl XfrmPolicyTmpl + tmpl := nl.DeserializeXfrmUserTmpl(attr.Value[i : i+nl.SizeofXfrmUserTmpl]) + resTmpl.Dst = tmpl.XfrmId.Daddr.ToIP() + resTmpl.Src = tmpl.Saddr.ToIP() + resTmpl.Proto = Proto(tmpl.XfrmId.Proto) + resTmpl.Mode = Mode(tmpl.Mode) + resTmpl.Reqid = int(tmpl.Reqid) + policy.Tmpls = append(policy.Tmpls, resTmpl) + } + case nl.XFRMA_MARK: + mark := nl.DeserializeXfrmMark(attr.Value[:]) + policy.Mark = new(XfrmMark) + policy.Mark.Value = mark.Value + policy.Mark.Mask = mark.Mask + } + } + + return &policy, nil +} diff --git a/xfrm_policy_test.go b/xfrm_policy_test.go index 526e3c3..0c90458 100644 --- a/xfrm_policy_test.go +++ b/xfrm_policy_test.go @@ -12,7 +12,7 @@ func TestXfrmPolicyAddUpdateDel(t *testing.T) { src, _ := ParseIPNet("127.1.1.1/32") dst, _ := ParseIPNet("127.1.1.2/32") - policy := XfrmPolicy{ + policy := &XfrmPolicy{ Src: src, Dst: dst, Proto: 17, @@ -32,7 +32,7 @@ func TestXfrmPolicyAddUpdateDel(t *testing.T) { Mode: XFRM_MODE_TUNNEL, } policy.Tmpls = append(policy.Tmpls, tmpl) - if err := XfrmPolicyAdd(&policy); err != nil { + if err := XfrmPolicyAdd(policy); err != nil { t.Fatal(err) } policies, err := XfrmPolicyList(FAMILY_ALL) @@ -44,30 +44,34 @@ func TestXfrmPolicyAddUpdateDel(t *testing.T) { t.Fatal("Policy not added properly") } - // Verify Selector fields - if !compareIPNet(policies[0].Dst, policy.Dst) || - !compareIPNet(policies[0].Src, policy.Src) || - policies[0].Proto != policy.Proto || - policies[0].DstPort != policy.DstPort || - policies[0].SrcPort != policy.SrcPort { - t.Fatalf("Incorrect policy data retrieved. Expected %v. Got %v.", - policy, policies[0]) + if !comparePolicies(policy, &policies[0]) { + t.Fatalf("unexpected policy returned.\nExpected: %v.\nGot %v", policy, policies[0]) + } + + // Look for a specific policy + sp, err := XfrmPolicyGet(policy) + if err != nil { + t.Fatal(err) + } + + if !comparePolicies(policy, sp) { + t.Fatalf("unexpected policy returned") } // Modify the policy policy.Priority = 100 - if err := XfrmPolicyUpdate(&policy); err != nil { + if err := XfrmPolicyUpdate(policy); err != nil { t.Fatal(err) } - policies, err = XfrmPolicyList(FAMILY_ALL) + sp, err = XfrmPolicyGet(policy) if err != nil { t.Fatal(err) } - if policies[0].Priority != 100 { + if sp.Priority != 100 { t.Fatalf("failed to modify the policy") } - if err = XfrmPolicyDel(&policy); err != nil { + if err = XfrmPolicyDel(policy); err != nil { t.Fatal(err) } @@ -80,6 +84,34 @@ func TestXfrmPolicyAddUpdateDel(t *testing.T) { } } +func comparePolicies(a, b *XfrmPolicy) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } + // Do not check Index which is assigned by kernel + return a.Dir == b.Dir && a.Priority == b.Priority && + compareIPNet(a.Src, b.Src) && compareIPNet(a.Dst, b.Dst) && + a.Mark.Value == b.Mark.Value && a.Mark.Mask == b.Mark.Mask && + compareTemplates(a.Tmpls, b.Tmpls) +} + +func compareTemplates(a, b []XfrmPolicyTmpl) bool { + if len(a) != len(b) { + return false + } + for i, ta := range a { + tb := b[i] + if !ta.Dst.Equal(tb.Dst) || !ta.Src.Equal(tb.Src) || + ta.Mode != tb.Mode || ta.Reqid != tb.Reqid || ta.Proto != tb.Proto { + return false + } + } + return true +} + func compareIPNet(a, b *net.IPNet) bool { if a == b { return true diff --git a/xfrm_state_linux.go b/xfrm_state_linux.go index 3c6226f..a9e2251 100644 --- a/xfrm_state_linux.go +++ b/xfrm_state_linux.go @@ -208,7 +208,21 @@ func (h *Handle) XfrmStateList(family int) ([]XfrmState, error) { // ID := [ src ADDR ] [ dst ADDR ] [ proto XFRM-PROTO ] [ spi SPI ] // mark is optional func XfrmStateGet(state *XfrmState) (*XfrmState, error) { - req := nl.NewNetlinkRequest(nl.XFRM_MSG_GETSA, syscall.NLM_F_DUMP) + h, err := NewHandle() + if err != nil { + return nil, err + } + defer h.Delete() + return h.XfrmStateGet(state) +} + +// XfrmStateGet gets the xfrm state described by the ID, if found. +// Equivalent to: `ip xfrm state get ID [ mark MARK [ mask MASK ] ]`. +// Only the fields which constitue the SA ID must be filled in: +// ID := [ src ADDR ] [ dst ADDR ] [ proto XFRM-PROTO ] [ spi SPI ] +// mark is optional +func (h *Handle) XfrmStateGet(state *XfrmState) (*XfrmState, error) { + req := h.newNetlinkRequest(nl.XFRM_MSG_GETSA, syscall.NLM_F_DUMP) msg := &nl.XfrmUsersaInfo{} msg.Family = uint16(nl.GetIPFamily(state.Dst)) diff --git a/xfrm_state_test.go b/xfrm_state_test.go index e5aa6ff..d8e5c80 100644 --- a/xfrm_state_test.go +++ b/xfrm_state_test.go @@ -69,6 +69,12 @@ func TestXfrmStateAddDel(t *testing.T) { } func compareStates(a, b *XfrmState) bool { + if a == b { + return true + } + if a == nil || b == nil { + return false + } return a.Src.Equal(b.Src) && a.Dst.Equal(b.Dst) && a.Mode == b.Mode && a.Spi == b.Spi && a.Proto == b.Proto && a.Auth.Name == b.Auth.Name && bytes.Equal(a.Auth.Key, b.Auth.Key) &&