package netlink import ( "bytes" "encoding/hex" "net" "testing" "time" ) func TestXfrmStateAddGetDel(t *testing.T) { for _, s := range []*XfrmState{ getBaseState(), getAeadState(), getBaseStateV6oV4(), getBaseStateV4oV6(), } { testXfrmStateAddGetDel(t, s) } } func testXfrmStateAddGetDel(t *testing.T, state *XfrmState) { tearDown := setUpNetlinkTest(t) defer tearDown() if err := XfrmStateAdd(state); err != nil { t.Fatal(err) } states, err := XfrmStateList(FAMILY_ALL) if err != nil { t.Fatal(err) } if len(states) != 1 { t.Fatal("State not added properly") } if !compareStates(state, &states[0]) { t.Fatalf("unexpected states returned") } // Get specific state sa, err := XfrmStateGet(state) if err != nil { t.Fatal(err) } if !compareStates(state, sa) { t.Fatalf("unexpected state returned") } if err = XfrmStateDel(state); err != nil { t.Fatal(err) } states, err = XfrmStateList(FAMILY_ALL) if err != nil { t.Fatal(err) } if len(states) != 0 { t.Fatal("State not removed properly") } if _, err := XfrmStateGet(state); err == nil { t.Fatalf("Unexpected success") } } func TestXfrmStateAllocSpi(t *testing.T) { defer setUpNetlinkTest(t)() state := getBaseState() state.Spi = 0 state.Auth = nil state.Crypt = nil rstate, err := XfrmStateAllocSpi(state) if err != nil { t.Fatal(err) } if rstate.Spi == 0 { t.Fatalf("SPI is not allocated") } rstate.Spi = 0 if !compareStates(state, rstate) { t.Fatalf("State not properly allocated") } } func TestXfrmStateFlush(t *testing.T) { defer setUpNetlinkTest(t)() state1 := getBaseState() state2 := getBaseState() state2.Src = net.ParseIP("127.1.0.1") state2.Dst = net.ParseIP("127.1.0.2") state2.Proto = XFRM_PROTO_AH state2.Mode = XFRM_MODE_TUNNEL state2.Spi = 20 state2.Mark = nil state2.Crypt = nil if err := XfrmStateAdd(state1); err != nil { t.Fatal(err) } if err := XfrmStateAdd(state2); err != nil { t.Fatal(err) } // flushing proto for which no state is present should return silently if err := XfrmStateFlush(XFRM_PROTO_COMP); err != nil { t.Fatal(err) } if err := XfrmStateFlush(XFRM_PROTO_AH); err != nil { t.Fatal(err) } if _, err := XfrmStateGet(state2); err == nil { t.Fatalf("Unexpected success") } if err := XfrmStateAdd(state2); err != nil { t.Fatal(err) } if err := XfrmStateFlush(0); err != nil { t.Fatal(err) } states, err := XfrmStateList(FAMILY_ALL) if err != nil { t.Fatal(err) } if len(states) != 0 { t.Fatal("State not flushed properly") } } func TestXfrmStateUpdateLimits(t *testing.T) { defer setUpNetlinkTest(t)() // Program state with limits state := getBaseState() state.Limits.TimeHard = 3600 state.Limits.TimeSoft = 60 state.Limits.PacketHard = 1000 state.Limits.PacketSoft = 50 state.Limits.ByteHard = 1000000 state.Limits.ByteSoft = 50000 state.Limits.TimeUseHard = 3000 state.Limits.TimeUseSoft = 1500 if err := XfrmStateAdd(state); err != nil { t.Fatal(err) } // Verify limits s, err := XfrmStateGet(state) if err != nil { t.Fatal(err) } if !compareLimits(state, s) { t.Fatalf("Incorrect time hard/soft retrieved: %s", s.Print(true)) } // Update limits state.Limits.TimeHard = 1800 state.Limits.TimeSoft = 30 state.Limits.PacketHard = 500 state.Limits.PacketSoft = 25 state.Limits.ByteHard = 500000 state.Limits.ByteSoft = 25000 state.Limits.TimeUseHard = 2000 state.Limits.TimeUseSoft = 1000 if err := XfrmStateUpdate(state); err != nil { t.Fatal(err) } // Verify new limits s, err = XfrmStateGet(state) if err != nil { t.Fatal(err) } if s.Limits.TimeHard != 1800 || s.Limits.TimeSoft != 30 { t.Fatalf("Incorrect time hard retrieved: (%d, %d)", s.Limits.TimeHard, s.Limits.TimeSoft) } } func TestXfrmStateStats(t *testing.T) { defer setUpNetlinkTest(t)() // Program state and record time state := getBaseState() now := time.Now() if err := XfrmStateAdd(state); err != nil { t.Fatal(err) } // Retrieve state s, err := XfrmStateGet(state) if err != nil { t.Fatal(err) } // Verify stats: We expect zero counters, same second add time and unset use time if s.Statistics.Bytes != 0 || s.Statistics.Packets != 0 || s.Statistics.AddTime != uint64(now.Unix()) || s.Statistics.UseTime != 0 { t.Fatalf("Unexpected statistics (addTime: %s) for state:\n%s", now.Format(time.UnixDate), s.Print(true)) } } func TestXfrmStateWithIfid(t *testing.T) { minKernelRequired(t, 4, 19) defer setUpNetlinkTest(t)() state := getBaseState() state.Ifid = 54321 if err := XfrmStateAdd(state); err != nil { t.Fatal(err) } s, err := XfrmStateGet(state) if err != nil { t.Fatal(err) } if !compareStates(state, s) { t.Fatalf("unexpected state returned.\nExpected: %v.\nGot %v", state, s) } if err = XfrmStateDel(s); err != nil { t.Fatal(err) } } func TestXfrmStateWithOutputMark(t *testing.T) { minKernelRequired(t, 4, 14) defer setUpNetlinkTest(t)() state := getBaseState() state.OutputMark = &XfrmMark{ Value: 0x0000000a, } if err := XfrmStateAdd(state); err != nil { t.Fatal(err) } s, err := XfrmStateGet(state) if err != nil { t.Fatal(err) } if !compareStates(state, s) { t.Fatalf("unexpected state returned.\nExpected: %v.\nGot %v", state, s) } if err = XfrmStateDel(s); err != nil { t.Fatal(err) } } func TestXfrmStateWithOutputMarkAndMask(t *testing.T) { minKernelRequired(t, 4, 19) defer setUpNetlinkTest(t)() state := getBaseState() state.OutputMark = &XfrmMark{ Value: 0x0000000a, Mask: 0x0000000f, } if err := XfrmStateAdd(state); err != nil { t.Fatal(err) } s, err := XfrmStateGet(state) if err != nil { t.Fatal(err) } if !compareStates(state, s) { t.Fatalf("unexpected state returned.\nExpected: %v.\nGot %v", state, s) } if err = XfrmStateDel(s); err != nil { t.Fatal(err) } } func genStateSelectorForV6Payload() *XfrmPolicy { _, wildcardV6Net, _ := net.ParseCIDR("::/0") return &XfrmPolicy{ Src: wildcardV6Net, Dst: wildcardV6Net, } } func genStateSelectorForV4Payload() *XfrmPolicy { _, wildcardV4Net, _ := net.ParseCIDR("0.0.0.0/0") return &XfrmPolicy{ Src: wildcardV4Net, Dst: wildcardV4Net, } } func getBaseState() *XfrmState { return &XfrmState{ // Force 4 byte notation for the IPv4 addresses Src: net.ParseIP("127.0.0.1").To4(), Dst: net.ParseIP("127.0.0.2").To4(), Proto: XFRM_PROTO_ESP, Mode: XFRM_MODE_TUNNEL, Spi: 1, Auth: &XfrmStateAlgo{ Name: "hmac(sha256)", Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"), }, Crypt: &XfrmStateAlgo{ Name: "cbc(aes)", Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"), }, Mark: &XfrmMark{ Value: 0x12340000, Mask: 0xffff0000, }, } } func getBaseStateV4oV6() *XfrmState { return &XfrmState{ // Force 4 byte notation for the IPv4 addressesd Src: net.ParseIP("2001:dead::1").To16(), Dst: net.ParseIP("2001:beef::1").To16(), Proto: XFRM_PROTO_ESP, Mode: XFRM_MODE_TUNNEL, Spi: 1, Auth: &XfrmStateAlgo{ Name: "hmac(sha256)", Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"), }, Crypt: &XfrmStateAlgo{ Name: "cbc(aes)", Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"), }, Mark: &XfrmMark{ Value: 0x12340000, Mask: 0xffff0000, }, Selector: genStateSelectorForV4Payload(), } } func getBaseStateV6oV4() *XfrmState { return &XfrmState{ // Force 4 byte notation for the IPv4 addressesd Src: net.ParseIP("192.168.1.1").To4(), Dst: net.ParseIP("192.168.2.2").To4(), Proto: XFRM_PROTO_ESP, Mode: XFRM_MODE_TUNNEL, Spi: 1, Auth: &XfrmStateAlgo{ Name: "hmac(sha256)", Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"), }, Crypt: &XfrmStateAlgo{ Name: "cbc(aes)", Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"), }, Mark: &XfrmMark{ Value: 0x12340000, Mask: 0xffff0000, }, Selector: genStateSelectorForV6Payload(), } } func getAeadState() *XfrmState { // 128 key bits + 32 salt bits k, _ := hex.DecodeString("d0562776bf0e75830ba3f7f8eb6c09b555aa1177") return &XfrmState{ // Leave IPv4 addresses in Ipv4 in IPv6 notation Src: net.ParseIP("192.168.1.1"), Dst: net.ParseIP("192.168.2.2"), Proto: XFRM_PROTO_ESP, Mode: XFRM_MODE_TUNNEL, Spi: 2, Aead: &XfrmStateAlgo{ Name: "rfc4106(gcm(aes))", Key: k, ICVLen: 64, }, } } func compareSelector(a, b *XfrmPolicy) bool { return a.Src.String() == b.Src.String() && a.Dst.String() == b.Dst.String() && a.Proto == b.Proto && a.DstPort == b.DstPort && a.SrcPort == b.SrcPort && a.Ifindex == b.Ifindex } func compareStates(a, b *XfrmState) bool { if a == b { return true } if a == nil || b == nil { return false } if a.Selector != nil && b.Selector != nil { if !compareSelector(a.Selector, b.Selector) { return false } } return a.Src.Equal(b.Src) && a.Dst.Equal(b.Dst) && a.Mode == b.Mode && a.Spi == b.Spi && a.Proto == b.Proto && a.Ifid == b.Ifid && compareAlgo(a.Auth, b.Auth) && compareAlgo(a.Crypt, b.Crypt) && compareAlgo(a.Aead, b.Aead) && compareMarks(a.Mark, b.Mark) && compareMarks(a.OutputMark, b.OutputMark) } func compareLimits(a, b *XfrmState) bool { return a.Limits.TimeHard == b.Limits.TimeHard && a.Limits.TimeSoft == b.Limits.TimeSoft && a.Limits.PacketHard == b.Limits.PacketHard && a.Limits.PacketSoft == b.Limits.PacketSoft && a.Limits.ByteHard == b.Limits.ByteHard && a.Limits.ByteSoft == b.Limits.ByteSoft && a.Limits.TimeUseHard == b.Limits.TimeUseHard && a.Limits.TimeUseSoft == b.Limits.TimeUseSoft } func compareAlgo(a, b *XfrmStateAlgo) bool { if a == b { return true } if a == nil || b == nil { return false } return a.Name == b.Name && bytes.Equal(a.Key, b.Key) && (a.TruncateLen == 0 || a.TruncateLen == b.TruncateLen) && (a.ICVLen == 0 || a.ICVLen == b.ICVLen) } func compareMarks(a, b *XfrmMark) bool { if a == b { return true } if a == nil || b == nil { return false } return a.Value == b.Value && a.Mask == b.Mask }