diff --git a/nl/xfrm_linux.go b/nl/xfrm_linux.go index dce9073..cdb318b 100644 --- a/nl/xfrm_linux.go +++ b/nl/xfrm_linux.go @@ -131,7 +131,15 @@ func (x *XfrmAddress) ToIP() net.IP { return ip } -func (x *XfrmAddress) ToIPNet(prefixlen uint8) *net.IPNet { +// family is only used when x and prefixlen are both 0 +func (x *XfrmAddress) ToIPNet(prefixlen uint8, family uint16) *net.IPNet { + empty := [SizeofXfrmAddress]byte{} + if bytes.Equal(x[:], empty[:]) && prefixlen == 0 { + if family == FAMILY_V6 { + return &net.IPNet{IP: net.ParseIP("::"), Mask: net.CIDRMask(int(prefixlen), 128)} + } + return &net.IPNet{IP: net.ParseIP("0.0.0.0"), Mask: net.CIDRMask(int(prefixlen), 32)} + } ip := x.ToIP() if GetIPFamily(ip) == FAMILY_V4 { return &net.IPNet{IP: ip, Mask: net.CIDRMask(int(prefixlen), 32)} diff --git a/xfrm_policy_linux.go b/xfrm_policy_linux.go index 3584968..cca5260 100644 --- a/xfrm_policy_linux.go +++ b/xfrm_policy_linux.go @@ -75,6 +75,7 @@ func (h *Handle) xfrmPolicyAddOrUpdate(policy *XfrmPolicy, nlProto int) error { userTmpl := nl.DeserializeXfrmUserTmpl(tmplData[start : start+nl.SizeofXfrmUserTmpl]) userTmpl.XfrmId.Daddr.FromIP(tmpl.Dst) userTmpl.Saddr.FromIP(tmpl.Src) + userTmpl.Family = uint16(nl.GetIPFamily(tmpl.Dst)) userTmpl.XfrmId.Proto = uint8(tmpl.Proto) userTmpl.XfrmId.Spi = nl.Swap32(uint32(tmpl.Spi)) userTmpl.Mode = uint8(tmpl.Mode) @@ -223,8 +224,8 @@ func parseXfrmPolicy(m []byte, family int) (*XfrmPolicy, error) { var policy XfrmPolicy - policy.Dst = msg.Sel.Daddr.ToIPNet(msg.Sel.PrefixlenD) - policy.Src = msg.Sel.Saddr.ToIPNet(msg.Sel.PrefixlenS) + policy.Dst = msg.Sel.Daddr.ToIPNet(msg.Sel.PrefixlenD, uint16(family)) + policy.Src = msg.Sel.Saddr.ToIPNet(msg.Sel.PrefixlenS, uint16(family)) policy.Proto = Proto(msg.Sel.Proto) policy.DstPort = int(nl.Swap16(msg.Sel.Dport)) policy.SrcPort = int(nl.Swap16(msg.Sel.Sport)) diff --git a/xfrm_state.go b/xfrm_state.go index 3cd01cb..0095832 100644 --- a/xfrm_state.go +++ b/xfrm_state.go @@ -117,6 +117,7 @@ type XfrmState struct { DontEncapDSCP bool OSeqMayWrap bool Replay *XfrmReplayState + Selector *XfrmPolicy } func (sa XfrmState) String() string { diff --git a/xfrm_state_linux.go b/xfrm_state_linux.go index 5c14120..6b7bb8e 100644 --- a/xfrm_state_linux.go +++ b/xfrm_state_linux.go @@ -209,7 +209,6 @@ func (h *Handle) xfrmStateAllocSpi(state *XfrmState) (*XfrmState, error) { msg.Min = 0x100 msg.Max = 0xffffffff req.AddData(msg) - if state.Mark != nil { out := nl.NewRtAttr(nl.XFRMA_MARK, writeMark(state.Mark)) req.AddData(out) @@ -337,7 +336,6 @@ var familyError = fmt.Errorf("family error") func xfrmStateFromXfrmUsersaInfo(msg *nl.XfrmUsersaInfo) *XfrmState { var state XfrmState - state.Dst = msg.Id.Daddr.ToIP() state.Src = msg.Saddr.ToIP() state.Proto = Proto(msg.Id.Proto) @@ -347,20 +345,25 @@ func xfrmStateFromXfrmUsersaInfo(msg *nl.XfrmUsersaInfo) *XfrmState { state.ReplayWindow = int(msg.ReplayWindow) lftToLimits(&msg.Lft, &state.Limits) curToStats(&msg.Curlft, &msg.Stats, &state.Statistics) + state.Selector = &XfrmPolicy{ + Dst: msg.Sel.Daddr.ToIPNet(msg.Sel.PrefixlenD, msg.Sel.Family), + Src: msg.Sel.Saddr.ToIPNet(msg.Sel.PrefixlenS, msg.Sel.Family), + Proto: Proto(msg.Sel.Proto), + DstPort: int(nl.Swap16(msg.Sel.Dport)), + SrcPort: int(nl.Swap16(msg.Sel.Sport)), + Ifindex: int(msg.Sel.Ifindex), + } return &state } func parseXfrmState(m []byte, family int) (*XfrmState, error) { msg := nl.DeserializeXfrmUsersaInfo(m) - // This is mainly for the state dump if family != FAMILY_ALL && family != int(msg.Family) { return nil, familyError } - state := xfrmStateFromXfrmUsersaInfo(msg) - attrs, err := nl.ParseRouteAttr(m[nl.SizeofXfrmUsersaInfo:]) if err != nil { return nil, err @@ -515,6 +518,9 @@ func xfrmUsersaInfoFromXfrmState(state *XfrmState) *nl.XfrmUsersaInfo { msg.Id.Spi = nl.Swap32(uint32(state.Spi)) msg.Reqid = uint32(state.Reqid) msg.ReplayWindow = uint8(state.ReplayWindow) - + msg.Sel = nl.XfrmSelector{} + if state.Selector != nil { + selFromPolicy(&msg.Sel, state.Selector) + } return msg } diff --git a/xfrm_state_test.go b/xfrm_state_test.go index bf15d1e..2c59b9d 100644 --- a/xfrm_state_test.go +++ b/xfrm_state_test.go @@ -1,3 +1,4 @@ +//go:build linux // +build linux package netlink @@ -11,7 +12,12 @@ import ( ) func TestXfrmStateAddGetDel(t *testing.T) { - for _, s := range []*XfrmState{getBaseState(), getAeadState()} { + for _, s := range []*XfrmState{ + getBaseState(), + getAeadState(), + getBaseStateV6oV4(), + getBaseStateV4oV6(), + } { testXfrmStateAddGetDel(t, s) } } @@ -26,7 +32,6 @@ func testXfrmStateAddGetDel(t *testing.T, state *XfrmState) { if err != nil { t.Fatal(err) } - if len(states) != 1 { t.Fatal("State not added properly") } @@ -77,6 +82,7 @@ func TestXfrmStateAllocSpi(t *testing.T) { t.Fatalf("SPI is not allocated") } rstate.Spi = 0 + if !compareStates(state, rstate) { t.Fatalf("State not properly allocated") } @@ -268,6 +274,21 @@ func TestXfrmStateWithOutputMarkAndMask(t *testing.T) { t.Fatal(err) } } +func genStateSelectorForV6Payload() *XfrmPolicy { + _, wildcardV6Net, _ := net.ParseCIDR("::/0") + return &XfrmPolicy{ + Src: wildcardV6Net, + Dst: wildcardV6Net, + } +} + +func genStateSelectorForV4Payload() *XfrmPolicy { + _, wildcardV4Net, _ := net.ParseCIDR("0.0.0.0/0") + return &XfrmPolicy{ + Src: wildcardV4Net, + Dst: wildcardV4Net, + } +} func getBaseState() *XfrmState { return &XfrmState{ @@ -292,6 +313,54 @@ func getBaseState() *XfrmState { } } +func getBaseStateV4oV6() *XfrmState { + return &XfrmState{ + // Force 4 byte notation for the IPv4 addressesd + Src: net.ParseIP("2001:dead::1").To16(), + Dst: net.ParseIP("2001:beef::1").To16(), + Proto: XFRM_PROTO_ESP, + Mode: XFRM_MODE_TUNNEL, + Spi: 1, + Auth: &XfrmStateAlgo{ + Name: "hmac(sha256)", + Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"), + }, + Crypt: &XfrmStateAlgo{ + Name: "cbc(aes)", + Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"), + }, + Mark: &XfrmMark{ + Value: 0x12340000, + Mask: 0xffff0000, + }, + Selector: genStateSelectorForV4Payload(), + } +} + +func getBaseStateV6oV4() *XfrmState { + return &XfrmState{ + // Force 4 byte notation for the IPv4 addressesd + Src: net.ParseIP("192.168.1.1").To4(), + Dst: net.ParseIP("192.168.2.2").To4(), + Proto: XFRM_PROTO_ESP, + Mode: XFRM_MODE_TUNNEL, + Spi: 1, + Auth: &XfrmStateAlgo{ + Name: "hmac(sha256)", + Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"), + }, + Crypt: &XfrmStateAlgo{ + Name: "cbc(aes)", + Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"), + }, + Mark: &XfrmMark{ + Value: 0x12340000, + Mask: 0xffff0000, + }, + Selector: genStateSelectorForV6Payload(), + } +} + func getAeadState() *XfrmState { // 128 key bits + 32 salt bits k, _ := hex.DecodeString("d0562776bf0e75830ba3f7f8eb6c09b555aa1177") @@ -309,6 +378,14 @@ func getAeadState() *XfrmState { }, } } +func compareSelector(a, b *XfrmPolicy) bool { + return a.Src.String() == b.Src.String() && + a.Dst.String() == b.Dst.String() && + a.Proto == b.Proto && + a.DstPort == b.DstPort && + a.SrcPort == b.SrcPort && + a.Ifindex == b.Ifindex +} func compareStates(a, b *XfrmState) bool { if a == b { @@ -317,6 +394,12 @@ func compareStates(a, b *XfrmState) bool { if a == nil || b == nil { return false } + if a.Selector != nil && b.Selector != nil { + if !compareSelector(a.Selector, b.Selector) { + return false + } + } + return a.Src.Equal(b.Src) && a.Dst.Equal(b.Dst) && a.Mode == b.Mode && a.Spi == b.Spi && a.Proto == b.Proto && a.Ifid == b.Ifid && @@ -325,6 +408,7 @@ func compareStates(a, b *XfrmState) bool { compareAlgo(a.Aead, b.Aead) && compareMarks(a.Mark, b.Mark) && compareMarks(a.OutputMark, b.OutputMark) + } func compareLimits(a, b *XfrmState) bool {