Implement ip xfrm state get (#114)

Signed-off-by: Alessandro Boch <aboch@docker.com>
This commit is contained in:
Alessandro Boch 2016-05-08 11:32:17 -07:00 committed by Vish Ishaya
parent 7ec3682687
commit 096107b4d7
2 changed files with 136 additions and 67 deletions

View File

@ -133,7 +133,7 @@ func XfrmStateDel(state *XfrmState) error {
}
// XfrmStateList gets a list of xfrm states in the system.
// Equivalent to: `ip xfrm state show`.
// Equivalent to: `ip [-4|-6] xfrm state show`.
// The list can be filtered by ip family.
func XfrmStateList(family int) ([]XfrmState, error) {
req := nl.NewNetlinkRequest(nl.XFRM_MSG_GETSA, syscall.NLM_F_DUMP)
@ -145,67 +145,113 @@ func XfrmStateList(family int) ([]XfrmState, error) {
var res []XfrmState
for _, m := range msgs {
msg := nl.DeserializeXfrmUsersaInfo(m)
if family != FAMILY_ALL && family != int(msg.Family) {
if state, err := parseXfrmState(m, family); err == nil {
res = append(res, *state)
} else if err == familyError {
continue
}
var state XfrmState
state.Dst = msg.Id.Daddr.ToIP()
state.Src = msg.Saddr.ToIP()
state.Proto = Proto(msg.Id.Proto)
state.Mode = Mode(msg.Mode)
state.Spi = int(nl.Swap32(msg.Id.Spi))
state.Reqid = int(msg.Reqid)
state.ReplayWindow = int(msg.ReplayWindow)
attrs, err := nl.ParseRouteAttr(m[msg.Len():])
if err != nil {
} else {
return nil, err
}
for _, attr := range attrs {
switch attr.Attr.Type {
case nl.XFRMA_ALG_AUTH, nl.XFRMA_ALG_CRYPT:
var resAlgo *XfrmStateAlgo
if attr.Attr.Type == nl.XFRMA_ALG_AUTH {
if state.Auth == nil {
state.Auth = new(XfrmStateAlgo)
}
resAlgo = state.Auth
} else {
state.Crypt = new(XfrmStateAlgo)
resAlgo = state.Crypt
}
algo := nl.DeserializeXfrmAlgo(attr.Value[:])
(*resAlgo).Name = nl.BytesToString(algo.AlgName[:])
(*resAlgo).Key = algo.AlgKey
case nl.XFRMA_ALG_AUTH_TRUNC:
if state.Auth == nil {
state.Auth = new(XfrmStateAlgo)
}
algo := nl.DeserializeXfrmAlgoAuth(attr.Value[:])
state.Auth.Name = nl.BytesToString(algo.AlgName[:])
state.Auth.Key = algo.AlgKey
state.Auth.TruncateLen = int(algo.AlgTruncLen)
case nl.XFRMA_ENCAP:
encap := nl.DeserializeXfrmEncapTmpl(attr.Value[:])
state.Encap = new(XfrmStateEncap)
state.Encap.Type = EncapType(encap.EncapType)
state.Encap.SrcPort = int(nl.Swap16(encap.EncapSport))
state.Encap.DstPort = int(nl.Swap16(encap.EncapDport))
state.Encap.OriginalAddress = encap.EncapOa.ToIP()
case nl.XFRMA_MARK:
mark := nl.DeserializeXfrmMark(attr.Value[:])
state.Mark = new(XfrmMark)
state.Mark.Value = mark.Value
state.Mark.Mask = mark.Mask
}
}
res = append(res, state)
}
return res, nil
}
// XfrmStateGet gets the xfrm state described by the ID, if found.
// Equivalent to: `ip xfrm state get ID [ mark MARK [ mask MASK ] ]`.
// Only the fields which constitue the SA ID must be filled in:
// ID := [ src ADDR ] [ dst ADDR ] [ proto XFRM-PROTO ] [ spi SPI ]
// mark is optional
func XfrmStateGet(state *XfrmState) (*XfrmState, error) {
req := nl.NewNetlinkRequest(nl.XFRM_MSG_GETSA, syscall.NLM_F_DUMP)
msg := &nl.XfrmUsersaInfo{}
msg.Family = uint16(nl.GetIPFamily(state.Dst))
msg.Id.Daddr.FromIP(state.Dst)
msg.Saddr.FromIP(state.Src)
msg.Id.Proto = uint8(state.Proto)
msg.Id.Spi = nl.Swap32(uint32(state.Spi))
req.AddData(msg)
if state.Mark != nil {
out := nl.NewRtAttr(nl.XFRMA_MARK, writeMark(state.Mark))
req.AddData(out)
}
msgs, err := req.Execute(syscall.NETLINK_XFRM, nl.XFRM_MSG_NEWSA)
if err != nil {
return nil, err
}
if state, err := parseXfrmState(msgs[0], FAMILY_ALL); err == nil {
return state, nil
} else {
return nil, err
}
}
var familyError = fmt.Errorf("family error")
func parseXfrmState(m []byte, family int) (*XfrmState, error) {
msg := nl.DeserializeXfrmUsersaInfo(m)
// This is mainly for the state dump
if family != FAMILY_ALL && family != int(msg.Family) {
return nil, familyError
}
var state XfrmState
state.Dst = msg.Id.Daddr.ToIP()
state.Src = msg.Saddr.ToIP()
state.Proto = Proto(msg.Id.Proto)
state.Mode = Mode(msg.Mode)
state.Spi = int(nl.Swap32(msg.Id.Spi))
state.Reqid = int(msg.Reqid)
state.ReplayWindow = int(msg.ReplayWindow)
attrs, err := nl.ParseRouteAttr(m[nl.SizeofXfrmUsersaInfo:])
if err != nil {
return nil, err
}
for _, attr := range attrs {
switch attr.Attr.Type {
case nl.XFRMA_ALG_AUTH, nl.XFRMA_ALG_CRYPT:
var resAlgo *XfrmStateAlgo
if attr.Attr.Type == nl.XFRMA_ALG_AUTH {
if state.Auth == nil {
state.Auth = new(XfrmStateAlgo)
}
resAlgo = state.Auth
} else {
state.Crypt = new(XfrmStateAlgo)
resAlgo = state.Crypt
}
algo := nl.DeserializeXfrmAlgo(attr.Value[:])
(*resAlgo).Name = nl.BytesToString(algo.AlgName[:])
(*resAlgo).Key = algo.AlgKey
case nl.XFRMA_ALG_AUTH_TRUNC:
if state.Auth == nil {
state.Auth = new(XfrmStateAlgo)
}
algo := nl.DeserializeXfrmAlgoAuth(attr.Value[:])
state.Auth.Name = nl.BytesToString(algo.AlgName[:])
state.Auth.Key = algo.AlgKey
state.Auth.TruncateLen = int(algo.AlgTruncLen)
case nl.XFRMA_ENCAP:
encap := nl.DeserializeXfrmEncapTmpl(attr.Value[:])
state.Encap = new(XfrmStateEncap)
state.Encap.Type = EncapType(encap.EncapType)
state.Encap.SrcPort = int(nl.Swap16(encap.EncapSport))
state.Encap.DstPort = int(nl.Swap16(encap.EncapDport))
state.Encap.OriginalAddress = encap.EncapOa.ToIP()
case nl.XFRMA_MARK:
mark := nl.DeserializeXfrmMark(attr.Value[:])
state.Mark = new(XfrmMark)
state.Mark.Value = mark.Value
state.Mark.Mask = mark.Mask
}
}
return &state, nil
}

View File

@ -1,6 +1,7 @@
package netlink
import (
"bytes"
"net"
"testing"
)
@ -9,7 +10,7 @@ func TestXfrmStateAddDel(t *testing.T) {
tearDown := setUpNetlinkTest(t)
defer tearDown()
state := XfrmState{
state := &XfrmState{
Src: net.ParseIP("127.0.0.1"),
Dst: net.ParseIP("127.0.0.2"),
Proto: XFRM_PROTO_ESP,
@ -28,27 +29,49 @@ func TestXfrmStateAddDel(t *testing.T) {
Mask: 0xffff0000,
},
}
if err := XfrmStateAdd(&state); err != nil {
if err := XfrmStateAdd(state); err != nil {
t.Fatal(err)
}
policies, err := XfrmStateList(FAMILY_ALL)
states, err := XfrmStateList(FAMILY_ALL)
if err != nil {
t.Fatal(err)
}
if len(policies) != 1 {
if len(states) != 1 {
t.Fatal("State not added properly")
}
if err = XfrmStateDel(&state); err != nil {
t.Fatal(err)
if !compareStates(state, &states[0]) {
t.Fatalf("unexpected states returned")
}
policies, err = XfrmStateList(FAMILY_ALL)
// Get specific state
sa, err := XfrmStateGet(state)
if err != nil {
t.Fatal(err)
}
if len(policies) != 0 {
if !compareStates(state, sa) {
t.Fatalf("unexpected state returned")
}
if err = XfrmStateDel(state); err != nil {
t.Fatal(err)
}
states, err = XfrmStateList(FAMILY_ALL)
if err != nil {
t.Fatal(err)
}
if len(states) != 0 {
t.Fatal("State not removed properly")
}
}
func compareStates(a, b *XfrmState) bool {
return a.Src.Equal(b.Src) && a.Dst.Equal(b.Dst) &&
a.Mode == b.Mode && a.Spi == b.Spi && a.Proto == b.Proto &&
a.Auth.Name == b.Auth.Name && bytes.Equal(a.Auth.Key, b.Auth.Key) &&
a.Crypt.Name == b.Crypt.Name && bytes.Equal(a.Crypt.Key, b.Crypt.Key) &&
a.Mark.Value == b.Mark.Value && a.Mark.Mask == b.Mark.Mask
}