mirror of
https://github.com/vishvananda/netlink
synced 2024-12-27 00:52:11 +00:00
- fixes vishvananda/netlink#815
- changes: 1. set userTmpl.Family to correct family 2. add Selector to XfrmState which is corresponding to XfrmUsersaInfo.Sel - update *XfrmAddress.ToIPNet method to support 0.0.0.0/0 and ::/0 correctly - update xfrmStateFromXfrmUsersaInfo to get XfrmState.Selector - extend TestXfrmStateAddGetDel for v6ov4 and v4ov6 cases
This commit is contained in:
parent
d3b7a6fadd
commit
d3c0a2caa5
@ -131,7 +131,15 @@ func (x *XfrmAddress) ToIP() net.IP {
|
||||
return ip
|
||||
}
|
||||
|
||||
func (x *XfrmAddress) ToIPNet(prefixlen uint8) *net.IPNet {
|
||||
// family is only used when x and prefixlen are both 0
|
||||
func (x *XfrmAddress) ToIPNet(prefixlen uint8, family uint16) *net.IPNet {
|
||||
empty := [SizeofXfrmAddress]byte{}
|
||||
if bytes.Equal(x[:], empty[:]) && prefixlen == 0 {
|
||||
if family == FAMILY_V6 {
|
||||
return &net.IPNet{IP: net.ParseIP("::"), Mask: net.CIDRMask(int(prefixlen), 128)}
|
||||
}
|
||||
return &net.IPNet{IP: net.ParseIP("0.0.0.0"), Mask: net.CIDRMask(int(prefixlen), 32)}
|
||||
}
|
||||
ip := x.ToIP()
|
||||
if GetIPFamily(ip) == FAMILY_V4 {
|
||||
return &net.IPNet{IP: ip, Mask: net.CIDRMask(int(prefixlen), 32)}
|
||||
|
@ -75,6 +75,7 @@ func (h *Handle) xfrmPolicyAddOrUpdate(policy *XfrmPolicy, nlProto int) error {
|
||||
userTmpl := nl.DeserializeXfrmUserTmpl(tmplData[start : start+nl.SizeofXfrmUserTmpl])
|
||||
userTmpl.XfrmId.Daddr.FromIP(tmpl.Dst)
|
||||
userTmpl.Saddr.FromIP(tmpl.Src)
|
||||
userTmpl.Family = uint16(nl.GetIPFamily(tmpl.Dst))
|
||||
userTmpl.XfrmId.Proto = uint8(tmpl.Proto)
|
||||
userTmpl.XfrmId.Spi = nl.Swap32(uint32(tmpl.Spi))
|
||||
userTmpl.Mode = uint8(tmpl.Mode)
|
||||
@ -223,8 +224,8 @@ func parseXfrmPolicy(m []byte, family int) (*XfrmPolicy, error) {
|
||||
|
||||
var policy XfrmPolicy
|
||||
|
||||
policy.Dst = msg.Sel.Daddr.ToIPNet(msg.Sel.PrefixlenD)
|
||||
policy.Src = msg.Sel.Saddr.ToIPNet(msg.Sel.PrefixlenS)
|
||||
policy.Dst = msg.Sel.Daddr.ToIPNet(msg.Sel.PrefixlenD, uint16(family))
|
||||
policy.Src = msg.Sel.Saddr.ToIPNet(msg.Sel.PrefixlenS, uint16(family))
|
||||
policy.Proto = Proto(msg.Sel.Proto)
|
||||
policy.DstPort = int(nl.Swap16(msg.Sel.Dport))
|
||||
policy.SrcPort = int(nl.Swap16(msg.Sel.Sport))
|
||||
|
@ -117,6 +117,7 @@ type XfrmState struct {
|
||||
DontEncapDSCP bool
|
||||
OSeqMayWrap bool
|
||||
Replay *XfrmReplayState
|
||||
Selector *XfrmPolicy
|
||||
}
|
||||
|
||||
func (sa XfrmState) String() string {
|
||||
|
@ -209,7 +209,6 @@ func (h *Handle) xfrmStateAllocSpi(state *XfrmState) (*XfrmState, error) {
|
||||
msg.Min = 0x100
|
||||
msg.Max = 0xffffffff
|
||||
req.AddData(msg)
|
||||
|
||||
if state.Mark != nil {
|
||||
out := nl.NewRtAttr(nl.XFRMA_MARK, writeMark(state.Mark))
|
||||
req.AddData(out)
|
||||
@ -337,7 +336,6 @@ var familyError = fmt.Errorf("family error")
|
||||
|
||||
func xfrmStateFromXfrmUsersaInfo(msg *nl.XfrmUsersaInfo) *XfrmState {
|
||||
var state XfrmState
|
||||
|
||||
state.Dst = msg.Id.Daddr.ToIP()
|
||||
state.Src = msg.Saddr.ToIP()
|
||||
state.Proto = Proto(msg.Id.Proto)
|
||||
@ -347,20 +345,25 @@ func xfrmStateFromXfrmUsersaInfo(msg *nl.XfrmUsersaInfo) *XfrmState {
|
||||
state.ReplayWindow = int(msg.ReplayWindow)
|
||||
lftToLimits(&msg.Lft, &state.Limits)
|
||||
curToStats(&msg.Curlft, &msg.Stats, &state.Statistics)
|
||||
state.Selector = &XfrmPolicy{
|
||||
Dst: msg.Sel.Daddr.ToIPNet(msg.Sel.PrefixlenD, msg.Sel.Family),
|
||||
Src: msg.Sel.Saddr.ToIPNet(msg.Sel.PrefixlenS, msg.Sel.Family),
|
||||
Proto: Proto(msg.Sel.Proto),
|
||||
DstPort: int(nl.Swap16(msg.Sel.Dport)),
|
||||
SrcPort: int(nl.Swap16(msg.Sel.Sport)),
|
||||
Ifindex: int(msg.Sel.Ifindex),
|
||||
}
|
||||
|
||||
return &state
|
||||
}
|
||||
|
||||
func parseXfrmState(m []byte, family int) (*XfrmState, error) {
|
||||
msg := nl.DeserializeXfrmUsersaInfo(m)
|
||||
|
||||
// This is mainly for the state dump
|
||||
if family != FAMILY_ALL && family != int(msg.Family) {
|
||||
return nil, familyError
|
||||
}
|
||||
|
||||
state := xfrmStateFromXfrmUsersaInfo(msg)
|
||||
|
||||
attrs, err := nl.ParseRouteAttr(m[nl.SizeofXfrmUsersaInfo:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -515,6 +518,9 @@ func xfrmUsersaInfoFromXfrmState(state *XfrmState) *nl.XfrmUsersaInfo {
|
||||
msg.Id.Spi = nl.Swap32(uint32(state.Spi))
|
||||
msg.Reqid = uint32(state.Reqid)
|
||||
msg.ReplayWindow = uint8(state.ReplayWindow)
|
||||
|
||||
msg.Sel = nl.XfrmSelector{}
|
||||
if state.Selector != nil {
|
||||
selFromPolicy(&msg.Sel, state.Selector)
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
@ -1,3 +1,4 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package netlink
|
||||
@ -11,7 +12,12 @@ import (
|
||||
)
|
||||
|
||||
func TestXfrmStateAddGetDel(t *testing.T) {
|
||||
for _, s := range []*XfrmState{getBaseState(), getAeadState()} {
|
||||
for _, s := range []*XfrmState{
|
||||
getBaseState(),
|
||||
getAeadState(),
|
||||
getBaseStateV6oV4(),
|
||||
getBaseStateV4oV6(),
|
||||
} {
|
||||
testXfrmStateAddGetDel(t, s)
|
||||
}
|
||||
}
|
||||
@ -26,7 +32,6 @@ func testXfrmStateAddGetDel(t *testing.T, state *XfrmState) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(states) != 1 {
|
||||
t.Fatal("State not added properly")
|
||||
}
|
||||
@ -77,6 +82,7 @@ func TestXfrmStateAllocSpi(t *testing.T) {
|
||||
t.Fatalf("SPI is not allocated")
|
||||
}
|
||||
rstate.Spi = 0
|
||||
|
||||
if !compareStates(state, rstate) {
|
||||
t.Fatalf("State not properly allocated")
|
||||
}
|
||||
@ -268,6 +274,21 @@ func TestXfrmStateWithOutputMarkAndMask(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
func genStateSelectorForV6Payload() *XfrmPolicy {
|
||||
_, wildcardV6Net, _ := net.ParseCIDR("::/0")
|
||||
return &XfrmPolicy{
|
||||
Src: wildcardV6Net,
|
||||
Dst: wildcardV6Net,
|
||||
}
|
||||
}
|
||||
|
||||
func genStateSelectorForV4Payload() *XfrmPolicy {
|
||||
_, wildcardV4Net, _ := net.ParseCIDR("0.0.0.0/0")
|
||||
return &XfrmPolicy{
|
||||
Src: wildcardV4Net,
|
||||
Dst: wildcardV4Net,
|
||||
}
|
||||
}
|
||||
|
||||
func getBaseState() *XfrmState {
|
||||
return &XfrmState{
|
||||
@ -292,6 +313,54 @@ func getBaseState() *XfrmState {
|
||||
}
|
||||
}
|
||||
|
||||
func getBaseStateV4oV6() *XfrmState {
|
||||
return &XfrmState{
|
||||
// Force 4 byte notation for the IPv4 addressesd
|
||||
Src: net.ParseIP("2001:dead::1").To16(),
|
||||
Dst: net.ParseIP("2001:beef::1").To16(),
|
||||
Proto: XFRM_PROTO_ESP,
|
||||
Mode: XFRM_MODE_TUNNEL,
|
||||
Spi: 1,
|
||||
Auth: &XfrmStateAlgo{
|
||||
Name: "hmac(sha256)",
|
||||
Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"),
|
||||
},
|
||||
Crypt: &XfrmStateAlgo{
|
||||
Name: "cbc(aes)",
|
||||
Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"),
|
||||
},
|
||||
Mark: &XfrmMark{
|
||||
Value: 0x12340000,
|
||||
Mask: 0xffff0000,
|
||||
},
|
||||
Selector: genStateSelectorForV4Payload(),
|
||||
}
|
||||
}
|
||||
|
||||
func getBaseStateV6oV4() *XfrmState {
|
||||
return &XfrmState{
|
||||
// Force 4 byte notation for the IPv4 addressesd
|
||||
Src: net.ParseIP("192.168.1.1").To4(),
|
||||
Dst: net.ParseIP("192.168.2.2").To4(),
|
||||
Proto: XFRM_PROTO_ESP,
|
||||
Mode: XFRM_MODE_TUNNEL,
|
||||
Spi: 1,
|
||||
Auth: &XfrmStateAlgo{
|
||||
Name: "hmac(sha256)",
|
||||
Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"),
|
||||
},
|
||||
Crypt: &XfrmStateAlgo{
|
||||
Name: "cbc(aes)",
|
||||
Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"),
|
||||
},
|
||||
Mark: &XfrmMark{
|
||||
Value: 0x12340000,
|
||||
Mask: 0xffff0000,
|
||||
},
|
||||
Selector: genStateSelectorForV6Payload(),
|
||||
}
|
||||
}
|
||||
|
||||
func getAeadState() *XfrmState {
|
||||
// 128 key bits + 32 salt bits
|
||||
k, _ := hex.DecodeString("d0562776bf0e75830ba3f7f8eb6c09b555aa1177")
|
||||
@ -309,6 +378,14 @@ func getAeadState() *XfrmState {
|
||||
},
|
||||
}
|
||||
}
|
||||
func compareSelector(a, b *XfrmPolicy) bool {
|
||||
return a.Src.String() == b.Src.String() &&
|
||||
a.Dst.String() == b.Dst.String() &&
|
||||
a.Proto == b.Proto &&
|
||||
a.DstPort == b.DstPort &&
|
||||
a.SrcPort == b.SrcPort &&
|
||||
a.Ifindex == b.Ifindex
|
||||
}
|
||||
|
||||
func compareStates(a, b *XfrmState) bool {
|
||||
if a == b {
|
||||
@ -317,6 +394,12 @@ func compareStates(a, b *XfrmState) bool {
|
||||
if a == nil || b == nil {
|
||||
return false
|
||||
}
|
||||
if a.Selector != nil && b.Selector != nil {
|
||||
if !compareSelector(a.Selector, b.Selector) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return a.Src.Equal(b.Src) && a.Dst.Equal(b.Dst) &&
|
||||
a.Mode == b.Mode && a.Spi == b.Spi && a.Proto == b.Proto &&
|
||||
a.Ifid == b.Ifid &&
|
||||
@ -325,6 +408,7 @@ func compareStates(a, b *XfrmState) bool {
|
||||
compareAlgo(a.Aead, b.Aead) &&
|
||||
compareMarks(a.Mark, b.Mark) &&
|
||||
compareMarks(a.OutputMark, b.OutputMark)
|
||||
|
||||
}
|
||||
|
||||
func compareLimits(a, b *XfrmState) bool {
|
||||
|
Loading…
Reference in New Issue
Block a user