Support ipset test entry existence

This commit is contained in:
major1201 2023-04-04 00:30:26 +08:00 committed by Alessandro Boch
parent ced5aaba43
commit 7350a0539f
3 changed files with 103 additions and 12 deletions

View File

@ -1,3 +1,4 @@
//go:build linux
// +build linux
package main
@ -28,6 +29,7 @@ var (
"listall": {cmdListAll, "list all ipsets", 0},
"add": {cmdAddDel(netlink.IpsetAdd), "add entry", 2},
"del": {cmdAddDel(netlink.IpsetDel), "delete entry", 2},
"test": {cmdTest, "test whether an entry is in a set or not", 2},
}
timeoutVal *uint32
@ -140,6 +142,21 @@ func cmdAddDel(f func(string, *netlink.IPSetEntry) error) func([]string) {
}
}
func cmdTest(args []string) {
setName := args[0]
element := args[1]
ip := net.ParseIP(element)
entry := &netlink.IPSetEntry{
Timeout: timeoutVal,
IP: ip,
Comment: *comment,
Replace: *replace,
}
exist, err := netlink.IpsetTest(setName, entry)
check(err)
log.Printf("existence: %t\n", exist)
}
// panic on error
func check(err error) {
if err != nil {

View File

@ -121,6 +121,11 @@ func IpsetDel(setname string, entry *IPSetEntry) error {
return pkgHandle.IpsetDel(setname, entry)
}
// IpsetTest tests whether an entry is in a set or not.
func IpsetTest(setname string, entry *IPSetEntry) (bool, error) {
return pkgHandle.IpsetTest(setname, entry)
}
func (h *Handle) IpsetProtocol() (protocol uint8, minVersion uint8, err error) {
req := h.newIpsetRequest(nl.IPSET_CMD_PROTOCOL)
msgs, err := req.Execute(unix.NETLINK_NETFILTER, 0)
@ -270,16 +275,9 @@ func encodeIP(ip net.IP) (*nl.RtAttr, error) {
return nl.NewRtAttr(typ, ip), nil
}
func (h *Handle) ipsetAddDel(nlCmd int, setname string, entry *IPSetEntry) error {
req := h.newIpsetRequest(nlCmd)
req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_SETNAME, nl.ZeroTerminated(setname)))
func buildEntryData(entry *IPSetEntry) (*nl.RtAttr, error) {
data := nl.NewRtAttr(nl.IPSET_ATTR_DATA|int(nl.NLA_F_NESTED), nil)
if !entry.Replace {
req.Flags |= unix.NLM_F_EXCL
}
if entry.Comment != "" {
data.AddChild(nl.NewRtAttr(nl.IPSET_ATTR_COMMENT, nl.ZeroTerminated(entry.Comment)))
}
@ -291,7 +289,7 @@ func (h *Handle) ipsetAddDel(nlCmd int, setname string, entry *IPSetEntry) error
if entry.IP != nil {
nestedData, err := encodeIP(entry.IP)
if err != nil {
return err
return nil, err
}
data.AddChild(nl.NewRtAttr(nl.IPSET_ATTR_IP|int(nl.NLA_F_NESTED), nestedData.Serialize()))
}
@ -307,7 +305,7 @@ func (h *Handle) ipsetAddDel(nlCmd int, setname string, entry *IPSetEntry) error
if entry.IP2 != nil {
nestedData, err := encodeIP(entry.IP2)
if err != nil {
return err
return nil, err
}
data.AddChild(nl.NewRtAttr(nl.IPSET_ATTR_IP2|int(nl.NLA_F_NESTED), nestedData.Serialize()))
}
@ -335,14 +333,53 @@ func (h *Handle) ipsetAddDel(nlCmd int, setname string, entry *IPSetEntry) error
if entry.Mark != nil {
data.AddChild(&nl.Uint32Attribute{Type: nl.IPSET_ATTR_MARK | nl.NLA_F_NET_BYTEORDER, Value: *entry.Mark})
}
return data, nil
}
func (h *Handle) ipsetAddDel(nlCmd int, setname string, entry *IPSetEntry) error {
req := h.newIpsetRequest(nlCmd)
req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_SETNAME, nl.ZeroTerminated(setname)))
if !entry.Replace {
req.Flags |= unix.NLM_F_EXCL
}
data, err := buildEntryData(entry)
if err != nil {
return err
}
data.AddChild(&nl.Uint32Attribute{Type: nl.IPSET_ATTR_LINENO | nl.NLA_F_NET_BYTEORDER, Value: 0})
req.AddData(data)
_, err := ipsetExecute(req)
_, err = ipsetExecute(req)
return err
}
func (h *Handle) IpsetTest(setname string, entry *IPSetEntry) (bool, error) {
req := h.newIpsetRequest(nl.IPSET_CMD_TEST)
req.AddData(nl.NewRtAttr(nl.IPSET_ATTR_SETNAME, nl.ZeroTerminated(setname)))
if !entry.Replace {
req.Flags |= unix.NLM_F_EXCL
}
data, err := buildEntryData(entry)
if err != nil {
return false, err
}
req.AddData(data)
_, err = ipsetExecute(req)
if err != nil {
if err == nl.IPSetError(nl.IPSET_ERR_EXIST) {
// not exist
return false, nil
}
return false, err
}
return true, nil
}
func (h *Handle) newIpsetRequest(cmd int) *nl.NetlinkRequest {
req := h.newNetlinkRequest(cmd|(unix.NFNL_SUBSYS_IPSET<<8), nl.GetIpsetFlags(cmd))

View File

@ -143,9 +143,20 @@ func TestIpsetCreateListAddDelDestroy(t *testing.T) {
t.Errorf("expected timeout to be 3, but got '%d'", *results[0].Timeout)
}
ip := net.ParseIP("10.99.99.99")
exist, err := IpsetTest("my-test-ipset-1", &IPSetEntry{
IP: ip,
})
if err != nil {
t.Fatal(err)
}
if exist {
t.Errorf("entry should not exist before being added: %s", ip.String())
}
err = IpsetAdd("my-test-ipset-1", &IPSetEntry{
Comment: "test comment",
IP: net.ParseIP("10.99.99.99"),
IP: ip,
Replace: false,
})
@ -153,6 +164,16 @@ func TestIpsetCreateListAddDelDestroy(t *testing.T) {
t.Fatal(err)
}
exist, err = IpsetTest("my-test-ipset-1", &IPSetEntry{
IP: ip,
})
if err != nil {
t.Fatal(err)
}
if !exist {
t.Errorf("entry should exist after being added: %s", ip.String())
}
result, err := IpsetList("my-test-ipset-1")
if err != nil {
@ -455,6 +476,14 @@ func TestIpsetCreateListAddDelDestroyWithTestCases(t *testing.T) {
t.Fatal(err)
}
exist, err := IpsetTest(tC.setname, tC.entry)
if err != nil {
t.Fatal(err)
}
if !exist {
t.Errorf("entry should exist, but 'test' got false, case: %s", tC.desc)
}
result, err = IpsetList(tC.setname)
if err != nil {
@ -526,6 +555,14 @@ func TestIpsetCreateListAddDelDestroyWithTestCases(t *testing.T) {
t.Fatal(err)
}
exist, err = IpsetTest(tC.setname, tC.entry)
if err != nil {
t.Fatal(err)
}
if exist {
t.Errorf("entry should be deleted, but 'test' got true, case: %s", tC.desc)
}
result, err = IpsetList(tC.setname)
if err != nil {
t.Fatal(err)