From 8715fe718dfdf487a919acb6df7da109346bbfd6 Mon Sep 17 00:00:00 2001
From: Damien Tournoud <damien@platform.sh>
Date: Thu, 21 Jul 2022 13:44:19 -0700
Subject: [PATCH] ipset: Expose MaxElements to IpsetCreate

---
 ipset_linux.go      | 17 ++++++++++------
 ipset_linux_test.go | 49 +++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 60 insertions(+), 6 deletions(-)

diff --git a/ipset_linux.go b/ipset_linux.go
index 94177e3..30ae878 100644
--- a/ipset_linux.go
+++ b/ipset_linux.go
@@ -67,12 +67,13 @@ type IpsetCreateOptions struct {
 	Comments bool
 	Skbinfo  bool
 
-	Family   uint8
-	Revision uint8
-	IPFrom   net.IP
-	IPTo     net.IP
-	PortFrom uint16
-	PortTo   uint16
+	Family      uint8
+	Revision    uint8
+	IPFrom      net.IP
+	IPTo        net.IP
+	PortFrom    uint16
+	PortTo      uint16
+	MaxElements uint32
 }
 
 // IpsetProtocol returns the ipset protocol version from the kernel
@@ -167,6 +168,10 @@ func (h *Handle) IpsetCreate(setname, typename string, options IpsetCreateOption
 
 	req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_FAMILY, nl.Uint8Attr(family)))
 
+	if options.MaxElements != 0 {
+		data.AddChild(&nl.Uint32Attribute{Type: nl.IPSET_ATTR_MAXELEM | nl.NLA_F_NET_BYTEORDER, Value: options.MaxElements})
+	}
+
 	if timeout := options.Timeout; timeout != nil {
 		data.AddChild(&nl.Uint32Attribute{Type: nl.IPSET_ATTR_TIMEOUT | nl.NLA_F_NET_BYTEORDER, Value: *timeout})
 	}
diff --git a/ipset_linux_test.go b/ipset_linux_test.go
index fa9877b..27c2e90 100644
--- a/ipset_linux_test.go
+++ b/ipset_linux_test.go
@@ -673,3 +673,52 @@ func TestIpsetSwap(t *testing.T) {
 	assertIsEmpty(ipset1)
 	assertHasOneEntry(ipset2)
 }
+
+func nextIP(ip net.IP) {
+	for j := len(ip) - 1; j >= 0; j-- {
+		ip[j]++
+		if ip[j] > 0 {
+			break
+		}
+	}
+}
+
+// TestIpsetMaxElements tests that we can create an ipset containing
+// 128k elements, which is double the default size (64k elements).
+func TestIpsetMaxElements(t *testing.T) {
+	tearDown := setUpNetlinkTest(t)
+	defer tearDown()
+
+	ipsetName := "my-test-ipset-max"
+	maxElements := uint32(128 << 10)
+
+	err := IpsetCreate(ipsetName, "hash:ip", IpsetCreateOptions{
+		Replace:     true,
+		MaxElements: maxElements,
+	})
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer func() {
+		_ = IpsetDestroy(ipsetName)
+	}()
+
+	ip := net.ParseIP("10.0.0.0")
+	for i := uint32(0); i < maxElements; i++ {
+		err = IpsetAdd(ipsetName, &IPSetEntry{
+			IP: ip,
+		})
+		if err != nil {
+			t.Fatal(err)
+		}
+		nextIP(ip)
+	}
+
+	result, err := IpsetList(ipsetName)
+	if err != nil {
+		t.Fatal(err)
+	}
+	if len(result.Entries) != int(maxElements) {
+		t.Fatalf("expected '%d' entry be created, got '%d'", maxElements, len(result.Entries))
+	}
+}