mirror of
https://github.com/vishvananda/netlink
synced 2025-01-14 02:51:12 +00:00
- fixes vishvananda/netlink#815
- changes: 1. set userTmpl.Family to correct family 2. add Selector to XfrmState which is corresponding to XfrmUsersaInfo.Sel - update *XfrmAddress.ToIPNet method to support 0.0.0.0/0 and ::/0 correctly - update xfrmStateFromXfrmUsersaInfo to get XfrmState.Selector - extend TestXfrmStateAddGetDel for v6ov4 and v4ov6 cases
This commit is contained in:
parent
d3b7a6fadd
commit
d3c0a2caa5
@ -131,7 +131,15 @@ func (x *XfrmAddress) ToIP() net.IP {
|
|||||||
return ip
|
return ip
|
||||||
}
|
}
|
||||||
|
|
||||||
func (x *XfrmAddress) ToIPNet(prefixlen uint8) *net.IPNet {
|
// family is only used when x and prefixlen are both 0
|
||||||
|
func (x *XfrmAddress) ToIPNet(prefixlen uint8, family uint16) *net.IPNet {
|
||||||
|
empty := [SizeofXfrmAddress]byte{}
|
||||||
|
if bytes.Equal(x[:], empty[:]) && prefixlen == 0 {
|
||||||
|
if family == FAMILY_V6 {
|
||||||
|
return &net.IPNet{IP: net.ParseIP("::"), Mask: net.CIDRMask(int(prefixlen), 128)}
|
||||||
|
}
|
||||||
|
return &net.IPNet{IP: net.ParseIP("0.0.0.0"), Mask: net.CIDRMask(int(prefixlen), 32)}
|
||||||
|
}
|
||||||
ip := x.ToIP()
|
ip := x.ToIP()
|
||||||
if GetIPFamily(ip) == FAMILY_V4 {
|
if GetIPFamily(ip) == FAMILY_V4 {
|
||||||
return &net.IPNet{IP: ip, Mask: net.CIDRMask(int(prefixlen), 32)}
|
return &net.IPNet{IP: ip, Mask: net.CIDRMask(int(prefixlen), 32)}
|
||||||
|
@ -75,6 +75,7 @@ func (h *Handle) xfrmPolicyAddOrUpdate(policy *XfrmPolicy, nlProto int) error {
|
|||||||
userTmpl := nl.DeserializeXfrmUserTmpl(tmplData[start : start+nl.SizeofXfrmUserTmpl])
|
userTmpl := nl.DeserializeXfrmUserTmpl(tmplData[start : start+nl.SizeofXfrmUserTmpl])
|
||||||
userTmpl.XfrmId.Daddr.FromIP(tmpl.Dst)
|
userTmpl.XfrmId.Daddr.FromIP(tmpl.Dst)
|
||||||
userTmpl.Saddr.FromIP(tmpl.Src)
|
userTmpl.Saddr.FromIP(tmpl.Src)
|
||||||
|
userTmpl.Family = uint16(nl.GetIPFamily(tmpl.Dst))
|
||||||
userTmpl.XfrmId.Proto = uint8(tmpl.Proto)
|
userTmpl.XfrmId.Proto = uint8(tmpl.Proto)
|
||||||
userTmpl.XfrmId.Spi = nl.Swap32(uint32(tmpl.Spi))
|
userTmpl.XfrmId.Spi = nl.Swap32(uint32(tmpl.Spi))
|
||||||
userTmpl.Mode = uint8(tmpl.Mode)
|
userTmpl.Mode = uint8(tmpl.Mode)
|
||||||
@ -223,8 +224,8 @@ func parseXfrmPolicy(m []byte, family int) (*XfrmPolicy, error) {
|
|||||||
|
|
||||||
var policy XfrmPolicy
|
var policy XfrmPolicy
|
||||||
|
|
||||||
policy.Dst = msg.Sel.Daddr.ToIPNet(msg.Sel.PrefixlenD)
|
policy.Dst = msg.Sel.Daddr.ToIPNet(msg.Sel.PrefixlenD, uint16(family))
|
||||||
policy.Src = msg.Sel.Saddr.ToIPNet(msg.Sel.PrefixlenS)
|
policy.Src = msg.Sel.Saddr.ToIPNet(msg.Sel.PrefixlenS, uint16(family))
|
||||||
policy.Proto = Proto(msg.Sel.Proto)
|
policy.Proto = Proto(msg.Sel.Proto)
|
||||||
policy.DstPort = int(nl.Swap16(msg.Sel.Dport))
|
policy.DstPort = int(nl.Swap16(msg.Sel.Dport))
|
||||||
policy.SrcPort = int(nl.Swap16(msg.Sel.Sport))
|
policy.SrcPort = int(nl.Swap16(msg.Sel.Sport))
|
||||||
|
@ -117,6 +117,7 @@ type XfrmState struct {
|
|||||||
DontEncapDSCP bool
|
DontEncapDSCP bool
|
||||||
OSeqMayWrap bool
|
OSeqMayWrap bool
|
||||||
Replay *XfrmReplayState
|
Replay *XfrmReplayState
|
||||||
|
Selector *XfrmPolicy
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sa XfrmState) String() string {
|
func (sa XfrmState) String() string {
|
||||||
|
@ -209,7 +209,6 @@ func (h *Handle) xfrmStateAllocSpi(state *XfrmState) (*XfrmState, error) {
|
|||||||
msg.Min = 0x100
|
msg.Min = 0x100
|
||||||
msg.Max = 0xffffffff
|
msg.Max = 0xffffffff
|
||||||
req.AddData(msg)
|
req.AddData(msg)
|
||||||
|
|
||||||
if state.Mark != nil {
|
if state.Mark != nil {
|
||||||
out := nl.NewRtAttr(nl.XFRMA_MARK, writeMark(state.Mark))
|
out := nl.NewRtAttr(nl.XFRMA_MARK, writeMark(state.Mark))
|
||||||
req.AddData(out)
|
req.AddData(out)
|
||||||
@ -337,7 +336,6 @@ var familyError = fmt.Errorf("family error")
|
|||||||
|
|
||||||
func xfrmStateFromXfrmUsersaInfo(msg *nl.XfrmUsersaInfo) *XfrmState {
|
func xfrmStateFromXfrmUsersaInfo(msg *nl.XfrmUsersaInfo) *XfrmState {
|
||||||
var state XfrmState
|
var state XfrmState
|
||||||
|
|
||||||
state.Dst = msg.Id.Daddr.ToIP()
|
state.Dst = msg.Id.Daddr.ToIP()
|
||||||
state.Src = msg.Saddr.ToIP()
|
state.Src = msg.Saddr.ToIP()
|
||||||
state.Proto = Proto(msg.Id.Proto)
|
state.Proto = Proto(msg.Id.Proto)
|
||||||
@ -347,20 +345,25 @@ func xfrmStateFromXfrmUsersaInfo(msg *nl.XfrmUsersaInfo) *XfrmState {
|
|||||||
state.ReplayWindow = int(msg.ReplayWindow)
|
state.ReplayWindow = int(msg.ReplayWindow)
|
||||||
lftToLimits(&msg.Lft, &state.Limits)
|
lftToLimits(&msg.Lft, &state.Limits)
|
||||||
curToStats(&msg.Curlft, &msg.Stats, &state.Statistics)
|
curToStats(&msg.Curlft, &msg.Stats, &state.Statistics)
|
||||||
|
state.Selector = &XfrmPolicy{
|
||||||
|
Dst: msg.Sel.Daddr.ToIPNet(msg.Sel.PrefixlenD, msg.Sel.Family),
|
||||||
|
Src: msg.Sel.Saddr.ToIPNet(msg.Sel.PrefixlenS, msg.Sel.Family),
|
||||||
|
Proto: Proto(msg.Sel.Proto),
|
||||||
|
DstPort: int(nl.Swap16(msg.Sel.Dport)),
|
||||||
|
SrcPort: int(nl.Swap16(msg.Sel.Sport)),
|
||||||
|
Ifindex: int(msg.Sel.Ifindex),
|
||||||
|
}
|
||||||
|
|
||||||
return &state
|
return &state
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseXfrmState(m []byte, family int) (*XfrmState, error) {
|
func parseXfrmState(m []byte, family int) (*XfrmState, error) {
|
||||||
msg := nl.DeserializeXfrmUsersaInfo(m)
|
msg := nl.DeserializeXfrmUsersaInfo(m)
|
||||||
|
|
||||||
// This is mainly for the state dump
|
// This is mainly for the state dump
|
||||||
if family != FAMILY_ALL && family != int(msg.Family) {
|
if family != FAMILY_ALL && family != int(msg.Family) {
|
||||||
return nil, familyError
|
return nil, familyError
|
||||||
}
|
}
|
||||||
|
|
||||||
state := xfrmStateFromXfrmUsersaInfo(msg)
|
state := xfrmStateFromXfrmUsersaInfo(msg)
|
||||||
|
|
||||||
attrs, err := nl.ParseRouteAttr(m[nl.SizeofXfrmUsersaInfo:])
|
attrs, err := nl.ParseRouteAttr(m[nl.SizeofXfrmUsersaInfo:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -515,6 +518,9 @@ func xfrmUsersaInfoFromXfrmState(state *XfrmState) *nl.XfrmUsersaInfo {
|
|||||||
msg.Id.Spi = nl.Swap32(uint32(state.Spi))
|
msg.Id.Spi = nl.Swap32(uint32(state.Spi))
|
||||||
msg.Reqid = uint32(state.Reqid)
|
msg.Reqid = uint32(state.Reqid)
|
||||||
msg.ReplayWindow = uint8(state.ReplayWindow)
|
msg.ReplayWindow = uint8(state.ReplayWindow)
|
||||||
|
msg.Sel = nl.XfrmSelector{}
|
||||||
|
if state.Selector != nil {
|
||||||
|
selFromPolicy(&msg.Sel, state.Selector)
|
||||||
|
}
|
||||||
return msg
|
return msg
|
||||||
}
|
}
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
//go:build linux
|
||||||
// +build linux
|
// +build linux
|
||||||
|
|
||||||
package netlink
|
package netlink
|
||||||
@ -11,7 +12,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestXfrmStateAddGetDel(t *testing.T) {
|
func TestXfrmStateAddGetDel(t *testing.T) {
|
||||||
for _, s := range []*XfrmState{getBaseState(), getAeadState()} {
|
for _, s := range []*XfrmState{
|
||||||
|
getBaseState(),
|
||||||
|
getAeadState(),
|
||||||
|
getBaseStateV6oV4(),
|
||||||
|
getBaseStateV4oV6(),
|
||||||
|
} {
|
||||||
testXfrmStateAddGetDel(t, s)
|
testXfrmStateAddGetDel(t, s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -26,7 +32,6 @@ func testXfrmStateAddGetDel(t *testing.T, state *XfrmState) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(states) != 1 {
|
if len(states) != 1 {
|
||||||
t.Fatal("State not added properly")
|
t.Fatal("State not added properly")
|
||||||
}
|
}
|
||||||
@ -77,6 +82,7 @@ func TestXfrmStateAllocSpi(t *testing.T) {
|
|||||||
t.Fatalf("SPI is not allocated")
|
t.Fatalf("SPI is not allocated")
|
||||||
}
|
}
|
||||||
rstate.Spi = 0
|
rstate.Spi = 0
|
||||||
|
|
||||||
if !compareStates(state, rstate) {
|
if !compareStates(state, rstate) {
|
||||||
t.Fatalf("State not properly allocated")
|
t.Fatalf("State not properly allocated")
|
||||||
}
|
}
|
||||||
@ -268,6 +274,21 @@ func TestXfrmStateWithOutputMarkAndMask(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
func genStateSelectorForV6Payload() *XfrmPolicy {
|
||||||
|
_, wildcardV6Net, _ := net.ParseCIDR("::/0")
|
||||||
|
return &XfrmPolicy{
|
||||||
|
Src: wildcardV6Net,
|
||||||
|
Dst: wildcardV6Net,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func genStateSelectorForV4Payload() *XfrmPolicy {
|
||||||
|
_, wildcardV4Net, _ := net.ParseCIDR("0.0.0.0/0")
|
||||||
|
return &XfrmPolicy{
|
||||||
|
Src: wildcardV4Net,
|
||||||
|
Dst: wildcardV4Net,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func getBaseState() *XfrmState {
|
func getBaseState() *XfrmState {
|
||||||
return &XfrmState{
|
return &XfrmState{
|
||||||
@ -292,6 +313,54 @@ func getBaseState() *XfrmState {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getBaseStateV4oV6() *XfrmState {
|
||||||
|
return &XfrmState{
|
||||||
|
// Force 4 byte notation for the IPv4 addressesd
|
||||||
|
Src: net.ParseIP("2001:dead::1").To16(),
|
||||||
|
Dst: net.ParseIP("2001:beef::1").To16(),
|
||||||
|
Proto: XFRM_PROTO_ESP,
|
||||||
|
Mode: XFRM_MODE_TUNNEL,
|
||||||
|
Spi: 1,
|
||||||
|
Auth: &XfrmStateAlgo{
|
||||||
|
Name: "hmac(sha256)",
|
||||||
|
Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"),
|
||||||
|
},
|
||||||
|
Crypt: &XfrmStateAlgo{
|
||||||
|
Name: "cbc(aes)",
|
||||||
|
Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"),
|
||||||
|
},
|
||||||
|
Mark: &XfrmMark{
|
||||||
|
Value: 0x12340000,
|
||||||
|
Mask: 0xffff0000,
|
||||||
|
},
|
||||||
|
Selector: genStateSelectorForV4Payload(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getBaseStateV6oV4() *XfrmState {
|
||||||
|
return &XfrmState{
|
||||||
|
// Force 4 byte notation for the IPv4 addressesd
|
||||||
|
Src: net.ParseIP("192.168.1.1").To4(),
|
||||||
|
Dst: net.ParseIP("192.168.2.2").To4(),
|
||||||
|
Proto: XFRM_PROTO_ESP,
|
||||||
|
Mode: XFRM_MODE_TUNNEL,
|
||||||
|
Spi: 1,
|
||||||
|
Auth: &XfrmStateAlgo{
|
||||||
|
Name: "hmac(sha256)",
|
||||||
|
Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"),
|
||||||
|
},
|
||||||
|
Crypt: &XfrmStateAlgo{
|
||||||
|
Name: "cbc(aes)",
|
||||||
|
Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"),
|
||||||
|
},
|
||||||
|
Mark: &XfrmMark{
|
||||||
|
Value: 0x12340000,
|
||||||
|
Mask: 0xffff0000,
|
||||||
|
},
|
||||||
|
Selector: genStateSelectorForV6Payload(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func getAeadState() *XfrmState {
|
func getAeadState() *XfrmState {
|
||||||
// 128 key bits + 32 salt bits
|
// 128 key bits + 32 salt bits
|
||||||
k, _ := hex.DecodeString("d0562776bf0e75830ba3f7f8eb6c09b555aa1177")
|
k, _ := hex.DecodeString("d0562776bf0e75830ba3f7f8eb6c09b555aa1177")
|
||||||
@ -309,6 +378,14 @@ func getAeadState() *XfrmState {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
func compareSelector(a, b *XfrmPolicy) bool {
|
||||||
|
return a.Src.String() == b.Src.String() &&
|
||||||
|
a.Dst.String() == b.Dst.String() &&
|
||||||
|
a.Proto == b.Proto &&
|
||||||
|
a.DstPort == b.DstPort &&
|
||||||
|
a.SrcPort == b.SrcPort &&
|
||||||
|
a.Ifindex == b.Ifindex
|
||||||
|
}
|
||||||
|
|
||||||
func compareStates(a, b *XfrmState) bool {
|
func compareStates(a, b *XfrmState) bool {
|
||||||
if a == b {
|
if a == b {
|
||||||
@ -317,6 +394,12 @@ func compareStates(a, b *XfrmState) bool {
|
|||||||
if a == nil || b == nil {
|
if a == nil || b == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
if a.Selector != nil && b.Selector != nil {
|
||||||
|
if !compareSelector(a.Selector, b.Selector) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return a.Src.Equal(b.Src) && a.Dst.Equal(b.Dst) &&
|
return a.Src.Equal(b.Src) && a.Dst.Equal(b.Dst) &&
|
||||||
a.Mode == b.Mode && a.Spi == b.Spi && a.Proto == b.Proto &&
|
a.Mode == b.Mode && a.Spi == b.Spi && a.Proto == b.Proto &&
|
||||||
a.Ifid == b.Ifid &&
|
a.Ifid == b.Ifid &&
|
||||||
@ -325,6 +408,7 @@ func compareStates(a, b *XfrmState) bool {
|
|||||||
compareAlgo(a.Aead, b.Aead) &&
|
compareAlgo(a.Aead, b.Aead) &&
|
||||||
compareMarks(a.Mark, b.Mark) &&
|
compareMarks(a.Mark, b.Mark) &&
|
||||||
compareMarks(a.OutputMark, b.OutputMark)
|
compareMarks(a.OutputMark, b.OutputMark)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func compareLimits(a, b *XfrmState) bool {
|
func compareLimits(a, b *XfrmState) bool {
|
||||||
|
Loading…
Reference in New Issue
Block a user