From 9ada19101fc5585d550e5cc0b43c28873214820a Mon Sep 17 00:00:00 2001 From: Wu Zongyong Date: Tue, 13 Apr 2021 16:01:12 +0800 Subject: [PATCH] filter: add support for police action This patch adds support for tc police action. And codes of fw filter have been refactored with the police action for reducing redundant codes. Signed-off-by: Wu Zongyong --- filter.go | 62 ++++++++++---- filter_linux.go | 182 ++++++++++++++++++++--------------------- filter_test.go | 212 ++++++++++++++++++++++++++++++++++++++++++++---- qdisc_linux.go | 4 + 4 files changed, 339 insertions(+), 121 deletions(-) diff --git a/filter.go b/filter.go index 2dc34b9..413abdb 100644 --- a/filter.go +++ b/filter.go @@ -260,6 +260,40 @@ func NewSkbEditAction() *SkbEditAction { } } +type PoliceAction struct { + ActionAttrs + Rate uint32 // in byte per second + Burst uint32 // in byte + RCellLog int + Mtu uint32 + Mpu uint16 // in byte + PeakRate uint32 // in byte per second + PCellLog int + AvRate uint32 // in byte per second + Overhead uint16 + LinkLayer int + ExceedAction TcPolAct + NotExceedAction TcPolAct +} + +func (action *PoliceAction) Type() string { + return "police" +} + +func (action *PoliceAction) Attrs() *ActionAttrs { + return &action.ActionAttrs +} + +func NewPoliceAction() *PoliceAction { + return &PoliceAction{ + RCellLog: -1, + PCellLog: -1, + LinkLayer: 1, // ETHERNET + ExceedAction: TC_POLICE_RECLASSIFY, + NotExceedAction: TC_POLICE_OK, + } +} + // MatchAll filters match all packets type MatchAll struct { FilterAttrs @@ -275,20 +309,20 @@ func (filter *MatchAll) Type() string { return "matchall" } -type FilterFwAttrs struct { - ClassId uint32 - InDev string - Mask uint32 - Index uint32 - Buffer uint32 - Mtu uint32 - Mpu uint16 - Rate uint32 - AvRate uint32 - PeakRate uint32 - Action TcPolAct - Overhead uint16 - LinkLayer int +type FwFilter struct { + FilterAttrs + ClassId uint32 + InDev string + Mask uint32 + Police *PoliceAction +} + +func (filter *FwFilter) Attrs() *FilterAttrs { + return &filter.FilterAttrs +} + +func (filter *FwFilter) Type() string { + return "fw" } type BpfFilter struct { diff --git a/filter_linux.go b/filter_linux.go index d5d45c8..3cfea45 100644 --- a/filter_linux.go +++ b/filter_linux.go @@ -51,76 +51,6 @@ func (filter *U32) Type() string { return "u32" } -// Fw filter filters on firewall marks -// NOTE: this is in filter_linux because it refers to nl.TcPolice which -// is defined in nl/tc_linux.go -type Fw struct { - FilterAttrs - ClassId uint32 - // TODO remove nl type from interface - Police nl.TcPolice - InDev string - // TODO Action - Mask uint32 - AvRate uint32 - Rtab [256]uint32 - Ptab [256]uint32 -} - -func NewFw(attrs FilterAttrs, fattrs FilterFwAttrs) (*Fw, error) { - var rtab [256]uint32 - var ptab [256]uint32 - rcellLog := -1 - pcellLog := -1 - avrate := fattrs.AvRate / 8 - police := nl.TcPolice{} - police.Rate.Rate = fattrs.Rate / 8 - police.PeakRate.Rate = fattrs.PeakRate / 8 - buffer := fattrs.Buffer - linklayer := nl.LINKLAYER_ETHERNET - - if fattrs.LinkLayer != nl.LINKLAYER_UNSPEC { - linklayer = fattrs.LinkLayer - } - - police.Action = int32(fattrs.Action) - if police.Rate.Rate != 0 { - police.Rate.Mpu = fattrs.Mpu - police.Rate.Overhead = fattrs.Overhead - if CalcRtable(&police.Rate, rtab[:], rcellLog, fattrs.Mtu, linklayer) < 0 { - return nil, errors.New("TBF: failed to calculate rate table") - } - police.Burst = Xmittime(uint64(police.Rate.Rate), uint32(buffer)) - } - police.Mtu = fattrs.Mtu - if police.PeakRate.Rate != 0 { - police.PeakRate.Mpu = fattrs.Mpu - police.PeakRate.Overhead = fattrs.Overhead - if CalcRtable(&police.PeakRate, ptab[:], pcellLog, fattrs.Mtu, linklayer) < 0 { - return nil, errors.New("POLICE: failed to calculate peak rate table") - } - } - - return &Fw{ - FilterAttrs: attrs, - ClassId: fattrs.ClassId, - InDev: fattrs.InDev, - Mask: fattrs.Mask, - Police: police, - AvRate: avrate, - Rtab: rtab, - Ptab: ptab, - }, nil -} - -func (filter *Fw) Attrs() *FilterAttrs { - return &filter.FilterAttrs -} - -func (filter *Fw) Type() string { - return "fw" -} - type Flower struct { FilterAttrs DestIP net.IP @@ -362,7 +292,7 @@ func (h *Handle) filterModify(filter Filter, flags int) error { if err := EncodeActions(actionsAttr, filter.Actions); err != nil { return err } - case *Fw: + case *FwFilter: if filter.Mask != 0 { b := make([]byte, 4) native.PutUint32(b, filter.Mask) @@ -371,17 +301,10 @@ func (h *Handle) filterModify(filter Filter, flags int) error { if filter.InDev != "" { options.AddRtAttr(nl.TCA_FW_INDEV, nl.ZeroTerminated(filter.InDev)) } - if (filter.Police != nl.TcPolice{}) { - + if filter.Police != nil { police := options.AddRtAttr(nl.TCA_FW_POLICE, nil) - police.AddRtAttr(nl.TCA_POLICE_TBF, filter.Police.Serialize()) - if (filter.Police.Rate != nl.TcRateSpec{}) { - payload := SerializeRtab(filter.Rtab) - police.AddRtAttr(nl.TCA_POLICE_RATE, payload) - } - if (filter.Police.PeakRate != nl.TcRateSpec{}) { - payload := SerializeRtab(filter.Ptab) - police.AddRtAttr(nl.TCA_POLICE_PEAKRATE, payload) + if err := encodePolice(police, filter.Police); err != nil { + return err } } if filter.ClassId != 0 { @@ -479,7 +402,7 @@ func (h *Handle) FilterList(link Link, parent uint32) ([]Filter, error) { case "u32": filter = &U32{} case "fw": - filter = &Fw{} + filter = &FwFilter{} case "bpf": filter = &BpfFilter{} case "matchall": @@ -551,6 +474,53 @@ func toAttrs(tcgen *nl.TcGen, attrs *ActionAttrs) { attrs.Bindcnt = int(tcgen.Bindcnt) } +func encodePolice(attr *nl.RtAttr, action *PoliceAction) error { + var rtab [256]uint32 + var ptab [256]uint32 + police := nl.TcPolice{} + police.Index = uint32(action.Attrs().Index) + police.Bindcnt = int32(action.Attrs().Bindcnt) + police.Capab = uint32(action.Attrs().Capab) + police.Refcnt = int32(action.Attrs().Refcnt) + police.Rate.Rate = action.Rate + police.PeakRate.Rate = action.PeakRate + police.Action = int32(action.ExceedAction) + + if police.Rate.Rate != 0 { + police.Rate.Mpu = action.Mpu + police.Rate.Overhead = action.Overhead + if CalcRtable(&police.Rate, rtab[:], action.RCellLog, action.Mtu, action.LinkLayer) < 0 { + return errors.New("TBF: failed to calculate rate table") + } + police.Burst = Xmittime(uint64(police.Rate.Rate), action.Burst) + } + + police.Mtu = action.Mtu + if police.PeakRate.Rate != 0 { + police.PeakRate.Mpu = action.Mpu + police.PeakRate.Overhead = action.Overhead + if CalcRtable(&police.PeakRate, ptab[:], action.PCellLog, action.Mtu, action.LinkLayer) < 0 { + return errors.New("POLICE: failed to calculate peak rate table") + } + } + + attr.AddRtAttr(nl.TCA_POLICE_TBF, police.Serialize()) + if police.Rate.Rate != 0 { + attr.AddRtAttr(nl.TCA_POLICE_RATE, SerializeRtab(rtab)) + } + if police.PeakRate.Rate != 0 { + attr.AddRtAttr(nl.TCA_POLICE_PEAKRATE, SerializeRtab(ptab)) + } + if action.AvRate != 0 { + attr.AddRtAttr(nl.TCA_POLICE_AVRATE, nl.Uint32Attr(action.AvRate)) + } + if action.NotExceedAction != 0 { + attr.AddRtAttr(nl.TCA_POLICE_RESULT, nl.Uint32Attr(uint32(action.NotExceedAction))) + } + + return nil +} + func EncodeActions(attr *nl.RtAttr, actions []Action) error { tabIndex := int(nl.TCA_ACT_TAB) @@ -558,6 +528,14 @@ func EncodeActions(attr *nl.RtAttr, actions []Action) error { switch action := action.(type) { default: return fmt.Errorf("unknown action type %s", action.Type()) + case *PoliceAction: + table := attr.AddRtAttr(tabIndex, nil) + tabIndex++ + table.AddRtAttr(nl.TCA_ACT_KIND, nl.ZeroTerminated("police")) + aopts := table.AddRtAttr(nl.TCA_ACT_OPTIONS, nil) + if err := encodePolice(aopts, action); err != nil { + return err + } case *MirredAction: table := attr.AddRtAttr(tabIndex, nil) tabIndex++ @@ -652,6 +630,29 @@ func EncodeActions(attr *nl.RtAttr, actions []Action) error { return nil } +func parsePolice(data syscall.NetlinkRouteAttr, police *PoliceAction) { + switch data.Attr.Type { + case nl.TCA_POLICE_RESULT: + police.NotExceedAction = TcPolAct(native.Uint32(data.Value[0:4])) + case nl.TCA_POLICE_AVRATE: + police.AvRate = native.Uint32(data.Value[0:4]) + case nl.TCA_POLICE_TBF: + p := *nl.DeserializeTcPolice(data.Value) + police.ActionAttrs = ActionAttrs{} + police.Attrs().Index = int(p.Index) + police.Attrs().Bindcnt = int(p.Bindcnt) + police.Attrs().Capab = int(p.Capab) + police.Attrs().Refcnt = int(p.Refcnt) + police.ExceedAction = TcPolAct(p.Action) + police.Rate = p.Rate.Rate + police.PeakRate = p.PeakRate.Rate + police.Burst = Xmitsize(uint64(p.Rate.Rate), p.Burst) + police.Mtu = p.Mtu + police.LinkLayer = int(p.Rate.Linklayer) & nl.TC_LINKLAYER_MASK + police.Overhead = p.Rate.Overhead + } +} + func parseActions(tables []syscall.NetlinkRouteAttr) ([]Action, error) { var actions []Action for _, table := range tables { @@ -680,6 +681,8 @@ func parseActions(tables []syscall.NetlinkRouteAttr) ([]Action, error) { action = &TunnelKeyAction{} case "skbedit": action = &SkbEditAction{} + case "police": + action = &PoliceAction{} default: break nextattr } @@ -758,6 +761,8 @@ func parseActions(tables []syscall.NetlinkRouteAttr) ([]Action, error) { gen := *nl.DeserializeTcGen(adatum.Value) toAttrs(&gen, action.Attrs()) } + case "police": + parsePolice(adatum, action.(*PoliceAction)) } } } @@ -813,7 +818,7 @@ func parseU32Data(filter Filter, data []syscall.NetlinkRouteAttr) (bool, error) } func parseFwData(filter Filter, data []syscall.NetlinkRouteAttr) (bool, error) { - fw := filter.(*Fw) + fw := filter.(*FwFilter) detailed := true for _, datum := range data { switch datum.Attr.Type { @@ -824,17 +829,12 @@ func parseFwData(filter Filter, data []syscall.NetlinkRouteAttr) (bool, error) { case nl.TCA_FW_INDEV: fw.InDev = string(datum.Value[:len(datum.Value)-1]) case nl.TCA_FW_POLICE: + var police PoliceAction adata, _ := nl.ParseRouteAttr(datum.Value) for _, aattr := range adata { - switch aattr.Attr.Type { - case nl.TCA_POLICE_TBF: - fw.Police = *nl.DeserializeTcPolice(aattr.Value) - case nl.TCA_POLICE_RATE: - fw.Rtab = DeserializeRtab(aattr.Value) - case nl.TCA_POLICE_PEAKRATE: - fw.Ptab = DeserializeRtab(aattr.Value) - } + parsePolice(aattr, &police) } + fw.Police = &police } } return detailed, nil diff --git a/filter_test.go b/filter_test.go index f960b80..8409680 100644 --- a/filter_test.go +++ b/filter_test.go @@ -7,6 +7,7 @@ import ( "reflect" "testing" + "github.com/vishvananda/netlink/nl" "golang.org/x/sys/unix" ) @@ -427,6 +428,12 @@ func TestFilterFwAddDel(t *testing.T) { t.Fatal("Failed to add class") } + police := NewPoliceAction() + police.Burst = 12345 + police.Rate = 1234 + police.PeakRate = 2345 + police.Action = TcAct(TC_POLICE_SHOT) + filterattrs := FilterAttrs{ LinkIndex: link.Attrs().Index, Parent: MakeHandle(0xffff, 0), @@ -434,20 +441,14 @@ func TestFilterFwAddDel(t *testing.T) { Priority: 1, Protocol: unix.ETH_P_IP, } - fwattrs := FilterFwAttrs{ - Buffer: 12345, - Rate: 1234, - PeakRate: 2345, - Action: TC_POLICE_SHOT, - ClassId: MakeHandle(0xffff, 2), + + filter := FwFilter{ + FilterAttrs: filterattrs, + ClassId: MakeHandle(0xffff, 2), + Police: police, } - filter, err := NewFw(filterattrs, fwattrs) - if err != nil { - t.Fatal(err) - } - - if err := FilterAdd(filter); err != nil { + if err := FilterAdd(&filter); err != nil { t.Fatal(err) } @@ -458,11 +459,11 @@ func TestFilterFwAddDel(t *testing.T) { if len(filters) != 1 { t.Fatal("Failed to add filter") } - fw, ok := filters[0].(*Fw) + fw, ok := filters[0].(*FwFilter) if !ok { t.Fatal("Filter is the wrong type") } - if fw.Police.Rate.Rate != filter.Police.Rate.Rate { + if fw.Police.Rate != filter.Police.Rate { t.Fatal("Police Rate doesn't match") } if fw.ClassId != filter.ClassId { @@ -471,11 +472,11 @@ func TestFilterFwAddDel(t *testing.T) { if fw.InDev != filter.InDev { t.Fatal("InDev doesn't match") } - if fw.AvRate != filter.AvRate { + if fw.Police.AvRate != filter.Police.AvRate { t.Fatal("AvRate doesn't match") } - if err := FilterDel(filter); err != nil { + if err := FilterDel(&filter); err != nil { t.Fatal(err) } filters, err = FilterList(link, MakeHandle(0xffff, 0)) @@ -1625,3 +1626,182 @@ func TestFilterFlowerAddDel(t *testing.T) { t.Fatal("Failed to remove qdisc") } } + +func TestFilterU32PoliceAddDel(t *testing.T) { + tearDown := setUpNetlinkTest(t) + defer tearDown() + if err := LinkAdd(&Ifb{LinkAttrs{Name: "foo"}}); err != nil { + t.Fatal(err) + } + if err := LinkAdd(&Ifb{LinkAttrs{Name: "bar"}}); err != nil { + t.Fatal(err) + } + link, err := LinkByName("foo") + if err != nil { + t.Fatal(err) + } + if err := LinkSetUp(link); err != nil { + t.Fatal(err) + } + redir, err := LinkByName("bar") + if err != nil { + t.Fatal(err) + } + if err := LinkSetUp(redir); err != nil { + t.Fatal(err) + } + + qdisc := &Ingress{ + QdiscAttrs: QdiscAttrs{ + LinkIndex: link.Attrs().Index, + Handle: MakeHandle(0xffff, 0), + Parent: HANDLE_INGRESS, + }, + } + if err := QdiscAdd(qdisc); err != nil { + t.Fatal(err) + } + qdiscs, err := SafeQdiscList(link) + if err != nil { + t.Fatal(err) + } + + found := false + for _, v := range qdiscs { + if _, ok := v.(*Ingress); ok { + found = true + break + } + } + if !found { + t.Fatal("Qdisc is the wrong type") + } + + const ( + policeRate = 0x40000000 // 1 Gbps + policeBurst = 0x19000 // 100 KB + policePeakRate = 0x4000 // 16 Kbps + ) + + police := NewPoliceAction() + police.Rate = policeRate + police.PeakRate = policePeakRate + police.Burst = policeBurst + police.ExceedAction = TC_POLICE_SHOT + police.NotExceedAction = TC_POLICE_UNSPEC + + classId := MakeHandle(1, 1) + filter := &U32{ + FilterAttrs: FilterAttrs{ + LinkIndex: link.Attrs().Index, + Parent: MakeHandle(0xffff, 0), + Priority: 1, + Protocol: unix.ETH_P_ALL, + }, + ClassId: classId, + Actions: []Action{ + police, + &MirredAction{ + ActionAttrs: ActionAttrs{ + Action: TC_ACT_STOLEN, + }, + MirredAction: TCA_EGRESS_REDIR, + Ifindex: redir.Attrs().Index, + }, + }, + } + + if err := FilterAdd(filter); err != nil { + t.Fatal(err) + } + + filters, err := FilterList(link, MakeHandle(0xffff, 0)) + if err != nil { + t.Fatal(err) + } + if len(filters) != 1 { + t.Fatal("Failed to add filter") + } + u32, ok := filters[0].(*U32) + if !ok { + t.Fatal("Filter is the wrong type") + } + + if len(u32.Actions) != 2 { + t.Fatalf("Too few Actions in filter") + } + if u32.ClassId != classId { + t.Fatalf("ClassId of the filter is the wrong value") + } + + // actions can be returned in reverse order + p, ok := u32.Actions[0].(*PoliceAction) + if !ok { + p, ok = u32.Actions[1].(*PoliceAction) + if !ok { + t.Fatal("Unable to find police action") + } + } + + if p.ExceedAction != TC_POLICE_SHOT { + t.Fatal("Police ExceedAction isn't TC_POLICE_SHOT") + } + + if p.NotExceedAction != TC_POLICE_UNSPEC { + t.Fatal("Police NotExceedAction isn't TC_POLICE_UNSPEC") + } + + if p.Rate != policeRate { + t.Fatal("Action Rate doesn't match") + } + + if p.PeakRate != policePeakRate { + t.Fatal("Action PeakRate doesn't match") + } + + if p.LinkLayer != nl.LINKLAYER_ETHERNET { + t.Fatal("Action LinkLayer doesn't match") + } + + mia, ok := u32.Actions[0].(*MirredAction) + if !ok { + mia, ok = u32.Actions[1].(*MirredAction) + if !ok { + t.Fatal("Unable to find mirred action") + } + } + + if mia.Attrs().Action != TC_ACT_STOLEN { + t.Fatal("Mirred action isn't TC_ACT_STOLEN") + } + + if err := FilterDel(filter); err != nil { + t.Fatal(err) + } + filters, err = FilterList(link, MakeHandle(0xffff, 0)) + if err != nil { + t.Fatal(err) + } + if len(filters) != 0 { + t.Fatal("Failed to remove filter") + } + + if err := QdiscDel(qdisc); err != nil { + t.Fatal(err) + } + qdiscs, err = SafeQdiscList(link) + if err != nil { + t.Fatal(err) + } + + found = false + for _, v := range qdiscs { + if _, ok := v.(*Ingress); ok { + found = true + break + } + } + if found { + t.Fatal("Failed to remove qdisc") + } +} diff --git a/qdisc_linux.go b/qdisc_linux.go index 9fc4b3d..e182e1c 100644 --- a/qdisc_linux.go +++ b/qdisc_linux.go @@ -706,3 +706,7 @@ func Xmittime(rate uint64, size uint32) uint32 { // https://git.kernel.org/pub/scm/network/iproute2/iproute2.git/tree/tc/tc_core.c#n62 return time2Tick(uint32(TIME_UNITS_PER_SEC * (float64(size) / float64(rate)))) } + +func Xmitsize(rate uint64, ticks uint32) uint32 { + return uint32((float64(rate) * float64(tick2Time(ticks))) / TIME_UNITS_PER_SEC) +}