Provide method to query for specific policy (#115)

Signed-off-by: Alessandro Boch <aboch@docker.com>
This commit is contained in:
Alessandro Boch 2016-05-09 16:52:35 -07:00 committed by Vish Ishaya
parent a123807666
commit cb0b035c41
5 changed files with 162 additions and 57 deletions

View File

@ -1,6 +1,9 @@
package netlink
import (
"crypto/rand"
"encoding/hex"
"io"
"testing"
"github.com/vishvananda/netns"
@ -22,7 +25,11 @@ func TestHandleCreateDelete(t *testing.T) {
}
func TestHandleCreateNetns(t *testing.T) {
ifName := "dummy0"
id := make([]byte, 4)
if _, err := io.ReadFull(rand.Reader, id); err != nil {
t.Fatal(err)
}
ifName := "dummy-" + hex.EncodeToString(id)
// Create an handle on the current netns
curNs, err := netns.Get()

View File

@ -155,50 +155,96 @@ func (h *Handle) XfrmPolicyList(family int) ([]XfrmPolicy, error) {
var res []XfrmPolicy
for _, m := range msgs {
msg := nl.DeserializeXfrmUserpolicyInfo(m)
if family != FAMILY_ALL && family != int(msg.Sel.Family) {
if policy, err := parseXfrmPolicy(m, family); err == nil {
res = append(res, *policy)
} else if err == familyError {
continue
}
var policy XfrmPolicy
policy.Dst = msg.Sel.Daddr.ToIPNet(msg.Sel.PrefixlenD)
policy.Src = msg.Sel.Saddr.ToIPNet(msg.Sel.PrefixlenS)
policy.Proto = Proto(msg.Sel.Proto)
policy.DstPort = int(nl.Swap16(msg.Sel.Dport))
policy.SrcPort = int(nl.Swap16(msg.Sel.Sport))
policy.Priority = int(msg.Priority)
policy.Index = int(msg.Index)
policy.Dir = Dir(msg.Dir)
attrs, err := nl.ParseRouteAttr(m[msg.Len():])
if err != nil {
} else {
return nil, err
}
for _, attr := range attrs {
switch attr.Attr.Type {
case nl.XFRMA_TMPL:
max := len(attr.Value)
for i := 0; i < max; i += nl.SizeofXfrmUserTmpl {
var resTmpl XfrmPolicyTmpl
tmpl := nl.DeserializeXfrmUserTmpl(attr.Value[i : i+nl.SizeofXfrmUserTmpl])
resTmpl.Dst = tmpl.XfrmId.Daddr.ToIP()
resTmpl.Src = tmpl.Saddr.ToIP()
resTmpl.Proto = Proto(tmpl.XfrmId.Proto)
resTmpl.Mode = Mode(tmpl.Mode)
resTmpl.Reqid = int(tmpl.Reqid)
policy.Tmpls = append(policy.Tmpls, resTmpl)
}
case nl.XFRMA_MARK:
mark := nl.DeserializeXfrmMark(attr.Value[:])
policy.Mark = new(XfrmMark)
policy.Mark.Value = mark.Value
policy.Mark.Mask = mark.Mask
}
}
res = append(res, policy)
}
return res, nil
}
// XfrmPolicyGet gets a the policy described by the index or selector, if found.
// Equivalent to: `ip xfrm policy get { SELECTOR | index INDEX } dir DIR [ctx CTX ] [ mark MARK [ mask MASK ] ] [ ptype PTYPE ]`.
func XfrmPolicyGet(policy *XfrmPolicy) (*XfrmPolicy, error) {
h, err := NewHandle()
if err != nil {
return nil, err
}
defer h.Delete()
return h.XfrmPolicyGet(policy)
}
// XfrmPolicyGet gets a the policy described by the index or selector, if found.
// Equivalent to: `ip xfrm policy get { SELECTOR | index INDEX } dir DIR [ctx CTX ] [ mark MARK [ mask MASK ] ] [ ptype PTYPE ]`.
func (h *Handle) XfrmPolicyGet(policy *XfrmPolicy) (*XfrmPolicy, error) {
req := h.newNetlinkRequest(nl.XFRM_MSG_GETPOLICY, syscall.NLM_F_DUMP)
msg := &nl.XfrmUserpolicyInfo{}
selFromPolicy(&msg.Sel, policy)
msg.Index = uint32(policy.Index)
msg.Dir = uint8(policy.Dir)
req.AddData(msg)
msgs, err := req.Execute(syscall.NETLINK_XFRM, nl.XFRM_MSG_NEWPOLICY)
if err != nil {
return nil, err
}
if policy, err := parseXfrmPolicy(msgs[0], FAMILY_ALL); err == nil {
return policy, nil
} else {
return nil, err
}
}
func parseXfrmPolicy(m []byte, family int) (*XfrmPolicy, error) {
msg := nl.DeserializeXfrmUserpolicyInfo(m)
// This is mainly for the policy dump
if family != FAMILY_ALL && family != int(msg.Sel.Family) {
return nil, familyError
}
var policy XfrmPolicy
policy.Dst = msg.Sel.Daddr.ToIPNet(msg.Sel.PrefixlenD)
policy.Src = msg.Sel.Saddr.ToIPNet(msg.Sel.PrefixlenS)
policy.Proto = Proto(msg.Sel.Proto)
policy.DstPort = int(nl.Swap16(msg.Sel.Dport))
policy.SrcPort = int(nl.Swap16(msg.Sel.Sport))
policy.Priority = int(msg.Priority)
policy.Index = int(msg.Index)
policy.Dir = Dir(msg.Dir)
attrs, err := nl.ParseRouteAttr(m[msg.Len():])
if err != nil {
return nil, err
}
for _, attr := range attrs {
switch attr.Attr.Type {
case nl.XFRMA_TMPL:
max := len(attr.Value)
for i := 0; i < max; i += nl.SizeofXfrmUserTmpl {
var resTmpl XfrmPolicyTmpl
tmpl := nl.DeserializeXfrmUserTmpl(attr.Value[i : i+nl.SizeofXfrmUserTmpl])
resTmpl.Dst = tmpl.XfrmId.Daddr.ToIP()
resTmpl.Src = tmpl.Saddr.ToIP()
resTmpl.Proto = Proto(tmpl.XfrmId.Proto)
resTmpl.Mode = Mode(tmpl.Mode)
resTmpl.Reqid = int(tmpl.Reqid)
policy.Tmpls = append(policy.Tmpls, resTmpl)
}
case nl.XFRMA_MARK:
mark := nl.DeserializeXfrmMark(attr.Value[:])
policy.Mark = new(XfrmMark)
policy.Mark.Value = mark.Value
policy.Mark.Mask = mark.Mask
}
}
return &policy, nil
}

View File

@ -12,7 +12,7 @@ func TestXfrmPolicyAddUpdateDel(t *testing.T) {
src, _ := ParseIPNet("127.1.1.1/32")
dst, _ := ParseIPNet("127.1.1.2/32")
policy := XfrmPolicy{
policy := &XfrmPolicy{
Src: src,
Dst: dst,
Proto: 17,
@ -32,7 +32,7 @@ func TestXfrmPolicyAddUpdateDel(t *testing.T) {
Mode: XFRM_MODE_TUNNEL,
}
policy.Tmpls = append(policy.Tmpls, tmpl)
if err := XfrmPolicyAdd(&policy); err != nil {
if err := XfrmPolicyAdd(policy); err != nil {
t.Fatal(err)
}
policies, err := XfrmPolicyList(FAMILY_ALL)
@ -44,30 +44,34 @@ func TestXfrmPolicyAddUpdateDel(t *testing.T) {
t.Fatal("Policy not added properly")
}
// Verify Selector fields
if !compareIPNet(policies[0].Dst, policy.Dst) ||
!compareIPNet(policies[0].Src, policy.Src) ||
policies[0].Proto != policy.Proto ||
policies[0].DstPort != policy.DstPort ||
policies[0].SrcPort != policy.SrcPort {
t.Fatalf("Incorrect policy data retrieved. Expected %v. Got %v.",
policy, policies[0])
if !comparePolicies(policy, &policies[0]) {
t.Fatalf("unexpected policy returned.\nExpected: %v.\nGot %v", policy, policies[0])
}
// Look for a specific policy
sp, err := XfrmPolicyGet(policy)
if err != nil {
t.Fatal(err)
}
if !comparePolicies(policy, sp) {
t.Fatalf("unexpected policy returned")
}
// Modify the policy
policy.Priority = 100
if err := XfrmPolicyUpdate(&policy); err != nil {
if err := XfrmPolicyUpdate(policy); err != nil {
t.Fatal(err)
}
policies, err = XfrmPolicyList(FAMILY_ALL)
sp, err = XfrmPolicyGet(policy)
if err != nil {
t.Fatal(err)
}
if policies[0].Priority != 100 {
if sp.Priority != 100 {
t.Fatalf("failed to modify the policy")
}
if err = XfrmPolicyDel(&policy); err != nil {
if err = XfrmPolicyDel(policy); err != nil {
t.Fatal(err)
}
@ -80,6 +84,34 @@ func TestXfrmPolicyAddUpdateDel(t *testing.T) {
}
}
func comparePolicies(a, b *XfrmPolicy) bool {
if a == b {
return true
}
if a == nil || b == nil {
return false
}
// Do not check Index which is assigned by kernel
return a.Dir == b.Dir && a.Priority == b.Priority &&
compareIPNet(a.Src, b.Src) && compareIPNet(a.Dst, b.Dst) &&
a.Mark.Value == b.Mark.Value && a.Mark.Mask == b.Mark.Mask &&
compareTemplates(a.Tmpls, b.Tmpls)
}
func compareTemplates(a, b []XfrmPolicyTmpl) bool {
if len(a) != len(b) {
return false
}
for i, ta := range a {
tb := b[i]
if !ta.Dst.Equal(tb.Dst) || !ta.Src.Equal(tb.Src) ||
ta.Mode != tb.Mode || ta.Reqid != tb.Reqid || ta.Proto != tb.Proto {
return false
}
}
return true
}
func compareIPNet(a, b *net.IPNet) bool {
if a == b {
return true

View File

@ -208,7 +208,21 @@ func (h *Handle) XfrmStateList(family int) ([]XfrmState, error) {
// ID := [ src ADDR ] [ dst ADDR ] [ proto XFRM-PROTO ] [ spi SPI ]
// mark is optional
func XfrmStateGet(state *XfrmState) (*XfrmState, error) {
req := nl.NewNetlinkRequest(nl.XFRM_MSG_GETSA, syscall.NLM_F_DUMP)
h, err := NewHandle()
if err != nil {
return nil, err
}
defer h.Delete()
return h.XfrmStateGet(state)
}
// XfrmStateGet gets the xfrm state described by the ID, if found.
// Equivalent to: `ip xfrm state get ID [ mark MARK [ mask MASK ] ]`.
// Only the fields which constitue the SA ID must be filled in:
// ID := [ src ADDR ] [ dst ADDR ] [ proto XFRM-PROTO ] [ spi SPI ]
// mark is optional
func (h *Handle) XfrmStateGet(state *XfrmState) (*XfrmState, error) {
req := h.newNetlinkRequest(nl.XFRM_MSG_GETSA, syscall.NLM_F_DUMP)
msg := &nl.XfrmUsersaInfo{}
msg.Family = uint16(nl.GetIPFamily(state.Dst))

View File

@ -69,6 +69,12 @@ func TestXfrmStateAddDel(t *testing.T) {
}
func compareStates(a, b *XfrmState) bool {
if a == b {
return true
}
if a == nil || b == nil {
return false
}
return a.Src.Equal(b.Src) && a.Dst.Equal(b.Dst) &&
a.Mode == b.Mode && a.Spi == b.Spi && a.Proto == b.Proto &&
a.Auth.Name == b.Auth.Name && bytes.Equal(a.Auth.Key, b.Auth.Key) &&