diff --git a/cmd/ipset-test/main.go b/cmd/ipset-test/main.go index 84c2a27..368f891 100644 --- a/cmd/ipset-test/main.go +++ b/cmd/ipset-test/main.go @@ -1,3 +1,4 @@ +//go:build linux // +build linux package main @@ -28,6 +29,7 @@ var ( "listall": {cmdListAll, "list all ipsets", 0}, "add": {cmdAddDel(netlink.IpsetAdd), "add entry", 2}, "del": {cmdAddDel(netlink.IpsetDel), "delete entry", 2}, + "test": {cmdTest, "test whether an entry is in a set or not", 2}, } timeoutVal *uint32 @@ -140,6 +142,21 @@ func cmdAddDel(f func(string, *netlink.IPSetEntry) error) func([]string) { } } +func cmdTest(args []string) { + setName := args[0] + element := args[1] + ip := net.ParseIP(element) + entry := &netlink.IPSetEntry{ + Timeout: timeoutVal, + IP: ip, + Comment: *comment, + Replace: *replace, + } + exist, err := netlink.IpsetTest(setName, entry) + check(err) + log.Printf("existence: %t\n", exist) +} + // panic on error func check(err error) { if err != nil { diff --git a/ipset_linux.go b/ipset_linux.go index cde2be2..f4c0522 100644 --- a/ipset_linux.go +++ b/ipset_linux.go @@ -121,6 +121,11 @@ func IpsetDel(setname string, entry *IPSetEntry) error { return pkgHandle.IpsetDel(setname, entry) } +// IpsetTest tests whether an entry is in a set or not. +func IpsetTest(setname string, entry *IPSetEntry) (bool, error) { + return pkgHandle.IpsetTest(setname, entry) +} + func (h *Handle) IpsetProtocol() (protocol uint8, minVersion uint8, err error) { req := h.newIpsetRequest(nl.IPSET_CMD_PROTOCOL) msgs, err := req.Execute(unix.NETLINK_NETFILTER, 0) @@ -270,16 +275,9 @@ func encodeIP(ip net.IP) (*nl.RtAttr, error) { return nl.NewRtAttr(typ, ip), nil } -func (h *Handle) ipsetAddDel(nlCmd int, setname string, entry *IPSetEntry) error { - req := h.newIpsetRequest(nlCmd) - req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_SETNAME, nl.ZeroTerminated(setname))) - +func buildEntryData(entry *IPSetEntry) (*nl.RtAttr, error) { data := nl.NewRtAttr(nl.IPSET_ATTR_DATA|int(nl.NLA_F_NESTED), nil) - if !entry.Replace { - req.Flags |= unix.NLM_F_EXCL - } - if entry.Comment != "" { data.AddChild(nl.NewRtAttr(nl.IPSET_ATTR_COMMENT, nl.ZeroTerminated(entry.Comment))) } @@ -291,7 +289,7 @@ func (h *Handle) ipsetAddDel(nlCmd int, setname string, entry *IPSetEntry) error if entry.IP != nil { nestedData, err := encodeIP(entry.IP) if err != nil { - return err + return nil, err } data.AddChild(nl.NewRtAttr(nl.IPSET_ATTR_IP|int(nl.NLA_F_NESTED), nestedData.Serialize())) } @@ -307,7 +305,7 @@ func (h *Handle) ipsetAddDel(nlCmd int, setname string, entry *IPSetEntry) error if entry.IP2 != nil { nestedData, err := encodeIP(entry.IP2) if err != nil { - return err + return nil, err } data.AddChild(nl.NewRtAttr(nl.IPSET_ATTR_IP2|int(nl.NLA_F_NESTED), nestedData.Serialize())) } @@ -335,14 +333,53 @@ func (h *Handle) ipsetAddDel(nlCmd int, setname string, entry *IPSetEntry) error if entry.Mark != nil { data.AddChild(&nl.Uint32Attribute{Type: nl.IPSET_ATTR_MARK | nl.NLA_F_NET_BYTEORDER, Value: *entry.Mark}) } + return data, nil +} +func (h *Handle) ipsetAddDel(nlCmd int, setname string, entry *IPSetEntry) error { + req := h.newIpsetRequest(nlCmd) + req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_SETNAME, nl.ZeroTerminated(setname))) + + if !entry.Replace { + req.Flags |= unix.NLM_F_EXCL + } + + data, err := buildEntryData(entry) + if err != nil { + return err + } data.AddChild(&nl.Uint32Attribute{Type: nl.IPSET_ATTR_LINENO | nl.NLA_F_NET_BYTEORDER, Value: 0}) req.AddData(data) - _, err := ipsetExecute(req) + _, err = ipsetExecute(req) return err } +func (h *Handle) IpsetTest(setname string, entry *IPSetEntry) (bool, error) { + req := h.newIpsetRequest(nl.IPSET_CMD_TEST) + req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_SETNAME, nl.ZeroTerminated(setname))) + + if !entry.Replace { + req.Flags |= unix.NLM_F_EXCL + } + + data, err := buildEntryData(entry) + if err != nil { + return false, err + } + req.AddData(data) + + _, err = ipsetExecute(req) + if err != nil { + if err == nl.IPSetError(nl.IPSET_ERR_EXIST) { + // not exist + return false, nil + } + return false, err + } + return true, nil +} + func (h *Handle) newIpsetRequest(cmd int) *nl.NetlinkRequest { req := h.newNetlinkRequest(cmd|(unix.NFNL_SUBSYS_IPSET<<8), nl.GetIpsetFlags(cmd)) diff --git a/ipset_linux_test.go b/ipset_linux_test.go index bd2a2ff..298c3a3 100644 --- a/ipset_linux_test.go +++ b/ipset_linux_test.go @@ -143,9 +143,20 @@ func TestIpsetCreateListAddDelDestroy(t *testing.T) { t.Errorf("expected timeout to be 3, but got '%d'", *results[0].Timeout) } + ip := net.ParseIP("10.99.99.99") + exist, err := IpsetTest("my-test-ipset-1", &IPSetEntry{ + IP: ip, + }) + if err != nil { + t.Fatal(err) + } + if exist { + t.Errorf("entry should not exist before being added: %s", ip.String()) + } + err = IpsetAdd("my-test-ipset-1", &IPSetEntry{ Comment: "test comment", - IP: net.ParseIP("10.99.99.99"), + IP: ip, Replace: false, }) @@ -153,6 +164,16 @@ func TestIpsetCreateListAddDelDestroy(t *testing.T) { t.Fatal(err) } + exist, err = IpsetTest("my-test-ipset-1", &IPSetEntry{ + IP: ip, + }) + if err != nil { + t.Fatal(err) + } + if !exist { + t.Errorf("entry should exist after being added: %s", ip.String()) + } + result, err := IpsetList("my-test-ipset-1") if err != nil { @@ -455,6 +476,14 @@ func TestIpsetCreateListAddDelDestroyWithTestCases(t *testing.T) { t.Fatal(err) } + exist, err := IpsetTest(tC.setname, tC.entry) + if err != nil { + t.Fatal(err) + } + if !exist { + t.Errorf("entry should exist, but 'test' got false, case: %s", tC.desc) + } + result, err = IpsetList(tC.setname) if err != nil { @@ -526,6 +555,14 @@ func TestIpsetCreateListAddDelDestroyWithTestCases(t *testing.T) { t.Fatal(err) } + exist, err = IpsetTest(tC.setname, tC.entry) + if err != nil { + t.Fatal(err) + } + if exist { + t.Errorf("entry should be deleted, but 'test' got true, case: %s", tC.desc) + } + result, err = IpsetList(tC.setname) if err != nil { t.Fatal(err)