package netlink

import (
	"bytes"
	"net"
	"testing"
)

const zeroCIDR = "0.0.0.0/0"

func TestXfrmPolicyAddUpdateDel(t *testing.T) {
	tearDown := setUpNetlinkTest(t)
	defer tearDown()

	policy := getPolicy()
	if err := XfrmPolicyAdd(policy); err != nil {
		t.Fatal(err)
	}
	policies, err := XfrmPolicyList(FAMILY_ALL)
	if err != nil {
		t.Fatal(err)
	}

	if len(policies) != 1 {
		t.Fatal("Policy not added properly")
	}

	if !comparePolicies(policy, &policies[0]) {
		t.Fatalf("unexpected policy returned.\nExpected: %v.\nGot %v", policy, policies[0])
	}

	if policies[0].Ifindex != 0 {
		t.Fatalf("default policy has a non-zero interface index.\nGot %d", policies[0].Ifindex)
	}

	if policies[0].Ifid != 0 {
		t.Fatalf("default policy has non-zero if_id.\nGot %d", policies[0].Ifid)
	}

	if policies[0].Action != XFRM_POLICY_ALLOW {
		t.Fatalf("default policy has non-allow action.\nGot %s", policies[0].Action)
	}

	// Look for a specific policy
	sp, err := XfrmPolicyGet(policy)
	if err != nil {
		t.Fatal(err)
	}

	if !comparePolicies(policy, sp) {
		t.Fatalf("unexpected policy returned")
	}

	// Modify the policy
	policy.Priority = 100
	if err := XfrmPolicyUpdate(policy); err != nil {
		t.Fatal(err)
	}
	sp, err = XfrmPolicyGet(policy)
	if err != nil {
		t.Fatal(err)
	}
	if sp.Priority != 100 {
		t.Fatalf("failed to modify the policy")
	}

	if err = XfrmPolicyDel(policy); err != nil {
		t.Fatal(err)
	}

	policies, err = XfrmPolicyList(FAMILY_ALL)
	if err != nil {
		t.Fatal(err)
	}
	if len(policies) != 0 {
		t.Fatal("Policy not removed properly")
	}

	// Src and dst are not mandatory field. Creation should succeed
	policy.Src = nil
	policy.Dst = nil
	if err = XfrmPolicyAdd(policy); err != nil {
		t.Fatal(err)
	}

	sp, err = XfrmPolicyGet(policy)
	if err != nil {
		t.Fatal(err)
	}

	if !comparePolicies(policy, sp) {
		t.Fatalf("unexpected policy returned")
	}

	if err = XfrmPolicyDel(policy); err != nil {
		t.Fatal(err)
	}

	if _, err := XfrmPolicyGet(policy); err == nil {
		t.Fatalf("Unexpected success")
	}
}

func TestXfrmPolicyFlush(t *testing.T) {
	defer setUpNetlinkTest(t)()

	p1 := getPolicy()
	if err := XfrmPolicyAdd(p1); err != nil {
		t.Fatal(err)
	}

	p1.Dir = XFRM_DIR_IN
	s := p1.Src
	p1.Src = p1.Dst
	p1.Dst = s
	if err := XfrmPolicyAdd(p1); err != nil {
		t.Fatal(err)
	}

	policies, err := XfrmPolicyList(FAMILY_ALL)
	if err != nil {
		t.Fatal(err)
	}
	if len(policies) != 2 {
		t.Fatalf("unexpected number of policies: %d", len(policies))
	}

	if err := XfrmPolicyFlush(); err != nil {
		t.Fatal(err)
	}

	policies, err = XfrmPolicyList(FAMILY_ALL)
	if err != nil {
		t.Fatal(err)
	}
	if len(policies) != 0 {
		t.Fatalf("unexpected number of policies: %d", len(policies))
	}

}

func TestXfrmPolicyBlockWithIfindex(t *testing.T) {
	defer setUpNetlinkTest(t)()

	pBlock := getPolicy()
	pBlock.Action = XFRM_POLICY_BLOCK
	pBlock.Ifindex = 1 // loopback interface
	if err := XfrmPolicyAdd(pBlock); err != nil {
		t.Fatal(err)
	}
	policies, err := XfrmPolicyList(FAMILY_ALL)
	if err != nil {
		t.Fatal(err)
	}
	if len(policies) != 1 {
		t.Fatalf("unexpected number of policies: %d", len(policies))
	}
	if !comparePolicies(pBlock, &policies[0]) {
		t.Fatalf("unexpected policy returned.\nExpected: %v.\nGot %v", pBlock, policies[0])
	}
	if err = XfrmPolicyDel(pBlock); err != nil {
		t.Fatal(err)
	}
}

func TestXfrmPolicyWithIfid(t *testing.T) {
	minKernelRequired(t, 4, 19)
	defer setUpNetlinkTest(t)()

	pol := getPolicy()
	pol.Ifid = 54321

	if err := XfrmPolicyAdd(pol); err != nil {
		t.Fatal(err)
	}
	policies, err := XfrmPolicyList(FAMILY_ALL)
	if err != nil {
		t.Fatal(err)
	}
	if len(policies) != 1 {
		t.Fatalf("unexpected number of policies: %d", len(policies))
	}
	if !comparePolicies(pol, &policies[0]) {
		t.Fatalf("unexpected policy returned.\nExpected: %v.\nGot %v", pol, policies[0])
	}
	if err = XfrmPolicyDel(&policies[0]); err != nil {
		t.Fatal(err)
	}
}

func TestXfrmPolicyWithOptional(t *testing.T) {
	minKernelRequired(t, 4, 19)
	defer setUpNetlinkTest(t)()

	pol := getPolicy()
	pol.Dir = XFRM_DIR_IN
	pol.Tmpls[0].Optional = 1

	if err := XfrmPolicyAdd(pol); err != nil {
		t.Fatal(err)
	}
	policies, err := XfrmPolicyList(FAMILY_ALL)
	if err != nil {
		t.Fatal(err)
	}
	if len(policies) != 1 {
		t.Fatalf("unexpected number of policies: %d", len(policies))
	}
	if !comparePolicies(pol, &policies[0]) {
		t.Fatalf("unexpected policy returned.\nExpected: %v.\nGot %v", pol, policies[0])
	}
	if err = XfrmPolicyDel(&policies[0]); err != nil {
		t.Fatal(err)
	}
}

func comparePolicies(a, b *XfrmPolicy) bool {
	if a == b {
		return true
	}
	if a == nil || b == nil {
		return false
	}
	// Do not check Index which is assigned by kernel
	return a.Dir == b.Dir && a.Priority == b.Priority &&
		compareIPNet(a.Src, b.Src) && compareIPNet(a.Dst, b.Dst) &&
		a.Action == b.Action && a.Ifindex == b.Ifindex &&
		a.Mark.Value == b.Mark.Value && a.Mark.Mask == b.Mark.Mask &&
		a.Ifid == b.Ifid && compareTemplates(a.Tmpls, b.Tmpls)
}

func compareTemplates(a, b []XfrmPolicyTmpl) bool {
	if len(a) != len(b) {
		return false
	}
	for i, ta := range a {
		tb := b[i]
		if !ta.Dst.Equal(tb.Dst) || !ta.Src.Equal(tb.Src) || ta.Spi != tb.Spi ||
			ta.Mode != tb.Mode || ta.Reqid != tb.Reqid || ta.Proto != tb.Proto ||
			ta.Optional != tb.Optional {
			return false
		}
	}
	return true
}

func compareIPNet(a, b *net.IPNet) bool {
	if a == b {
		return true
	}
	// For unspecified src/dst parseXfrmPolicy would set the zero address cidr
	if (a == nil && b.String() == zeroCIDR) || (b == nil && a.String() == zeroCIDR) {
		return true
	}
	if a == nil || b == nil {
		return false
	}
	return a.IP.Equal(b.IP) && bytes.Equal(a.Mask, b.Mask)
}

func getPolicy() *XfrmPolicy {
	src, _ := ParseIPNet("127.1.1.1/32")
	dst, _ := ParseIPNet("127.1.1.2/32")
	policy := &XfrmPolicy{
		Src:     src,
		Dst:     dst,
		Proto:   17,
		DstPort: 1234,
		SrcPort: 5678,
		Dir:     XFRM_DIR_OUT,
		Mark: &XfrmMark{
			Value: 0xabff22,
			Mask:  0xffffffff,
		},
		Priority: 10,
	}
	tmpl := XfrmPolicyTmpl{
		Src:   net.ParseIP("127.0.0.1"),
		Dst:   net.ParseIP("127.0.0.2"),
		Proto: XFRM_PROTO_ESP,
		Mode:  XFRM_MODE_TUNNEL,
		Spi:   0x1bcdef99,
	}
	policy.Tmpls = append(policy.Tmpls, tmpl)
	return policy
}