- changes: 1. set userTmpl.Family to correct family 2. add Selector to XfrmState which is corresponding to XfrmUsersaInfo.Sel

- update *XfrmAddress.ToIPNet method to support 0.0.0.0/0 and ::/0 correctly
- update xfrmStateFromXfrmUsersaInfo to get XfrmState.Selector
- extend TestXfrmStateAddGetDel for v6ov4 and v4ov6 cases
This commit is contained in:
Hu Jun 2022-10-04 20:23:15 -07:00 committed by Alessandro Boch
parent d3b7a6fadd
commit d3c0a2caa5
5 changed files with 111 additions and 11 deletions

View File

@ -131,7 +131,15 @@ func (x *XfrmAddress) ToIP() net.IP {
return ip return ip
} }
func (x *XfrmAddress) ToIPNet(prefixlen uint8) *net.IPNet { // family is only used when x and prefixlen are both 0
func (x *XfrmAddress) ToIPNet(prefixlen uint8, family uint16) *net.IPNet {
empty := [SizeofXfrmAddress]byte{}
if bytes.Equal(x[:], empty[:]) && prefixlen == 0 {
if family == FAMILY_V6 {
return &net.IPNet{IP: net.ParseIP("::"), Mask: net.CIDRMask(int(prefixlen), 128)}
}
return &net.IPNet{IP: net.ParseIP("0.0.0.0"), Mask: net.CIDRMask(int(prefixlen), 32)}
}
ip := x.ToIP() ip := x.ToIP()
if GetIPFamily(ip) == FAMILY_V4 { if GetIPFamily(ip) == FAMILY_V4 {
return &net.IPNet{IP: ip, Mask: net.CIDRMask(int(prefixlen), 32)} return &net.IPNet{IP: ip, Mask: net.CIDRMask(int(prefixlen), 32)}

View File

@ -75,6 +75,7 @@ func (h *Handle) xfrmPolicyAddOrUpdate(policy *XfrmPolicy, nlProto int) error {
userTmpl := nl.DeserializeXfrmUserTmpl(tmplData[start : start+nl.SizeofXfrmUserTmpl]) userTmpl := nl.DeserializeXfrmUserTmpl(tmplData[start : start+nl.SizeofXfrmUserTmpl])
userTmpl.XfrmId.Daddr.FromIP(tmpl.Dst) userTmpl.XfrmId.Daddr.FromIP(tmpl.Dst)
userTmpl.Saddr.FromIP(tmpl.Src) userTmpl.Saddr.FromIP(tmpl.Src)
userTmpl.Family = uint16(nl.GetIPFamily(tmpl.Dst))
userTmpl.XfrmId.Proto = uint8(tmpl.Proto) userTmpl.XfrmId.Proto = uint8(tmpl.Proto)
userTmpl.XfrmId.Spi = nl.Swap32(uint32(tmpl.Spi)) userTmpl.XfrmId.Spi = nl.Swap32(uint32(tmpl.Spi))
userTmpl.Mode = uint8(tmpl.Mode) userTmpl.Mode = uint8(tmpl.Mode)
@ -223,8 +224,8 @@ func parseXfrmPolicy(m []byte, family int) (*XfrmPolicy, error) {
var policy XfrmPolicy var policy XfrmPolicy
policy.Dst = msg.Sel.Daddr.ToIPNet(msg.Sel.PrefixlenD) policy.Dst = msg.Sel.Daddr.ToIPNet(msg.Sel.PrefixlenD, uint16(family))
policy.Src = msg.Sel.Saddr.ToIPNet(msg.Sel.PrefixlenS) policy.Src = msg.Sel.Saddr.ToIPNet(msg.Sel.PrefixlenS, uint16(family))
policy.Proto = Proto(msg.Sel.Proto) policy.Proto = Proto(msg.Sel.Proto)
policy.DstPort = int(nl.Swap16(msg.Sel.Dport)) policy.DstPort = int(nl.Swap16(msg.Sel.Dport))
policy.SrcPort = int(nl.Swap16(msg.Sel.Sport)) policy.SrcPort = int(nl.Swap16(msg.Sel.Sport))

View File

@ -117,6 +117,7 @@ type XfrmState struct {
DontEncapDSCP bool DontEncapDSCP bool
OSeqMayWrap bool OSeqMayWrap bool
Replay *XfrmReplayState Replay *XfrmReplayState
Selector *XfrmPolicy
} }
func (sa XfrmState) String() string { func (sa XfrmState) String() string {

View File

@ -209,7 +209,6 @@ func (h *Handle) xfrmStateAllocSpi(state *XfrmState) (*XfrmState, error) {
msg.Min = 0x100 msg.Min = 0x100
msg.Max = 0xffffffff msg.Max = 0xffffffff
req.AddData(msg) req.AddData(msg)
if state.Mark != nil { if state.Mark != nil {
out := nl.NewRtAttr(nl.XFRMA_MARK, writeMark(state.Mark)) out := nl.NewRtAttr(nl.XFRMA_MARK, writeMark(state.Mark))
req.AddData(out) req.AddData(out)
@ -337,7 +336,6 @@ var familyError = fmt.Errorf("family error")
func xfrmStateFromXfrmUsersaInfo(msg *nl.XfrmUsersaInfo) *XfrmState { func xfrmStateFromXfrmUsersaInfo(msg *nl.XfrmUsersaInfo) *XfrmState {
var state XfrmState var state XfrmState
state.Dst = msg.Id.Daddr.ToIP() state.Dst = msg.Id.Daddr.ToIP()
state.Src = msg.Saddr.ToIP() state.Src = msg.Saddr.ToIP()
state.Proto = Proto(msg.Id.Proto) state.Proto = Proto(msg.Id.Proto)
@ -347,20 +345,25 @@ func xfrmStateFromXfrmUsersaInfo(msg *nl.XfrmUsersaInfo) *XfrmState {
state.ReplayWindow = int(msg.ReplayWindow) state.ReplayWindow = int(msg.ReplayWindow)
lftToLimits(&msg.Lft, &state.Limits) lftToLimits(&msg.Lft, &state.Limits)
curToStats(&msg.Curlft, &msg.Stats, &state.Statistics) curToStats(&msg.Curlft, &msg.Stats, &state.Statistics)
state.Selector = &XfrmPolicy{
Dst: msg.Sel.Daddr.ToIPNet(msg.Sel.PrefixlenD, msg.Sel.Family),
Src: msg.Sel.Saddr.ToIPNet(msg.Sel.PrefixlenS, msg.Sel.Family),
Proto: Proto(msg.Sel.Proto),
DstPort: int(nl.Swap16(msg.Sel.Dport)),
SrcPort: int(nl.Swap16(msg.Sel.Sport)),
Ifindex: int(msg.Sel.Ifindex),
}
return &state return &state
} }
func parseXfrmState(m []byte, family int) (*XfrmState, error) { func parseXfrmState(m []byte, family int) (*XfrmState, error) {
msg := nl.DeserializeXfrmUsersaInfo(m) msg := nl.DeserializeXfrmUsersaInfo(m)
// This is mainly for the state dump // This is mainly for the state dump
if family != FAMILY_ALL && family != int(msg.Family) { if family != FAMILY_ALL && family != int(msg.Family) {
return nil, familyError return nil, familyError
} }
state := xfrmStateFromXfrmUsersaInfo(msg) state := xfrmStateFromXfrmUsersaInfo(msg)
attrs, err := nl.ParseRouteAttr(m[nl.SizeofXfrmUsersaInfo:]) attrs, err := nl.ParseRouteAttr(m[nl.SizeofXfrmUsersaInfo:])
if err != nil { if err != nil {
return nil, err return nil, err
@ -515,6 +518,9 @@ func xfrmUsersaInfoFromXfrmState(state *XfrmState) *nl.XfrmUsersaInfo {
msg.Id.Spi = nl.Swap32(uint32(state.Spi)) msg.Id.Spi = nl.Swap32(uint32(state.Spi))
msg.Reqid = uint32(state.Reqid) msg.Reqid = uint32(state.Reqid)
msg.ReplayWindow = uint8(state.ReplayWindow) msg.ReplayWindow = uint8(state.ReplayWindow)
msg.Sel = nl.XfrmSelector{}
if state.Selector != nil {
selFromPolicy(&msg.Sel, state.Selector)
}
return msg return msg
} }

View File

@ -1,3 +1,4 @@
//go:build linux
// +build linux // +build linux
package netlink package netlink
@ -11,7 +12,12 @@ import (
) )
func TestXfrmStateAddGetDel(t *testing.T) { func TestXfrmStateAddGetDel(t *testing.T) {
for _, s := range []*XfrmState{getBaseState(), getAeadState()} { for _, s := range []*XfrmState{
getBaseState(),
getAeadState(),
getBaseStateV6oV4(),
getBaseStateV4oV6(),
} {
testXfrmStateAddGetDel(t, s) testXfrmStateAddGetDel(t, s)
} }
} }
@ -26,7 +32,6 @@ func testXfrmStateAddGetDel(t *testing.T, state *XfrmState) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if len(states) != 1 { if len(states) != 1 {
t.Fatal("State not added properly") t.Fatal("State not added properly")
} }
@ -77,6 +82,7 @@ func TestXfrmStateAllocSpi(t *testing.T) {
t.Fatalf("SPI is not allocated") t.Fatalf("SPI is not allocated")
} }
rstate.Spi = 0 rstate.Spi = 0
if !compareStates(state, rstate) { if !compareStates(state, rstate) {
t.Fatalf("State not properly allocated") t.Fatalf("State not properly allocated")
} }
@ -268,6 +274,21 @@ func TestXfrmStateWithOutputMarkAndMask(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
} }
func genStateSelectorForV6Payload() *XfrmPolicy {
_, wildcardV6Net, _ := net.ParseCIDR("::/0")
return &XfrmPolicy{
Src: wildcardV6Net,
Dst: wildcardV6Net,
}
}
func genStateSelectorForV4Payload() *XfrmPolicy {
_, wildcardV4Net, _ := net.ParseCIDR("0.0.0.0/0")
return &XfrmPolicy{
Src: wildcardV4Net,
Dst: wildcardV4Net,
}
}
func getBaseState() *XfrmState { func getBaseState() *XfrmState {
return &XfrmState{ return &XfrmState{
@ -292,6 +313,54 @@ func getBaseState() *XfrmState {
} }
} }
func getBaseStateV4oV6() *XfrmState {
return &XfrmState{
// Force 4 byte notation for the IPv4 addressesd
Src: net.ParseIP("2001:dead::1").To16(),
Dst: net.ParseIP("2001:beef::1").To16(),
Proto: XFRM_PROTO_ESP,
Mode: XFRM_MODE_TUNNEL,
Spi: 1,
Auth: &XfrmStateAlgo{
Name: "hmac(sha256)",
Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"),
},
Crypt: &XfrmStateAlgo{
Name: "cbc(aes)",
Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"),
},
Mark: &XfrmMark{
Value: 0x12340000,
Mask: 0xffff0000,
},
Selector: genStateSelectorForV4Payload(),
}
}
func getBaseStateV6oV4() *XfrmState {
return &XfrmState{
// Force 4 byte notation for the IPv4 addressesd
Src: net.ParseIP("192.168.1.1").To4(),
Dst: net.ParseIP("192.168.2.2").To4(),
Proto: XFRM_PROTO_ESP,
Mode: XFRM_MODE_TUNNEL,
Spi: 1,
Auth: &XfrmStateAlgo{
Name: "hmac(sha256)",
Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"),
},
Crypt: &XfrmStateAlgo{
Name: "cbc(aes)",
Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"),
},
Mark: &XfrmMark{
Value: 0x12340000,
Mask: 0xffff0000,
},
Selector: genStateSelectorForV6Payload(),
}
}
func getAeadState() *XfrmState { func getAeadState() *XfrmState {
// 128 key bits + 32 salt bits // 128 key bits + 32 salt bits
k, _ := hex.DecodeString("d0562776bf0e75830ba3f7f8eb6c09b555aa1177") k, _ := hex.DecodeString("d0562776bf0e75830ba3f7f8eb6c09b555aa1177")
@ -309,6 +378,14 @@ func getAeadState() *XfrmState {
}, },
} }
} }
func compareSelector(a, b *XfrmPolicy) bool {
return a.Src.String() == b.Src.String() &&
a.Dst.String() == b.Dst.String() &&
a.Proto == b.Proto &&
a.DstPort == b.DstPort &&
a.SrcPort == b.SrcPort &&
a.Ifindex == b.Ifindex
}
func compareStates(a, b *XfrmState) bool { func compareStates(a, b *XfrmState) bool {
if a == b { if a == b {
@ -317,6 +394,12 @@ func compareStates(a, b *XfrmState) bool {
if a == nil || b == nil { if a == nil || b == nil {
return false return false
} }
if a.Selector != nil && b.Selector != nil {
if !compareSelector(a.Selector, b.Selector) {
return false
}
}
return a.Src.Equal(b.Src) && a.Dst.Equal(b.Dst) && return a.Src.Equal(b.Src) && a.Dst.Equal(b.Dst) &&
a.Mode == b.Mode && a.Spi == b.Spi && a.Proto == b.Proto && a.Mode == b.Mode && a.Spi == b.Spi && a.Proto == b.Proto &&
a.Ifid == b.Ifid && a.Ifid == b.Ifid &&
@ -325,6 +408,7 @@ func compareStates(a, b *XfrmState) bool {
compareAlgo(a.Aead, b.Aead) && compareAlgo(a.Aead, b.Aead) &&
compareMarks(a.Mark, b.Mark) && compareMarks(a.Mark, b.Mark) &&
compareMarks(a.OutputMark, b.OutputMark) compareMarks(a.OutputMark, b.OutputMark)
} }
func compareLimits(a, b *XfrmState) bool { func compareLimits(a, b *XfrmState) bool {