diff --git a/nl/seg6local_linux.go b/nl/seg6local_linux.go index 1500177..8172b84 100644 --- a/nl/seg6local_linux.go +++ b/nl/seg6local_linux.go @@ -12,6 +12,7 @@ const ( SEG6_LOCAL_NH6 SEG6_LOCAL_IIF SEG6_LOCAL_OIF + SEG6_LOCAL_BPF __SEG6_LOCAL_MAX ) const ( @@ -34,6 +35,7 @@ const ( SEG6_LOCAL_ACTION_END_S // 12 SEG6_LOCAL_ACTION_END_AS // 13 SEG6_LOCAL_ACTION_END_AM // 14 + SEG6_LOCAL_ACTION_END_BPF // 15 __SEG6_LOCAL_ACTION_MAX ) const ( @@ -71,6 +73,8 @@ func SEG6LocalActionString(action int) string { return "End.AS" case SEG6_LOCAL_ACTION_END_AM: return "End.AM" + case SEG6_LOCAL_ACTION_END_BPF: + return "End.BPF" } return "unknown" } diff --git a/route_linux.go b/route_linux.go index 929738e..448ff8f 100644 --- a/route_linux.go +++ b/route_linux.go @@ -273,6 +273,16 @@ type SEG6LocalEncap struct { In6Addr net.IP Iif int Oif int + bpf bpfObj +} + +func (e *SEG6LocalEncap) SetProg(progFd int, progName string) error { + if progFd <= 0 { + return fmt.Errorf("seg6local bpf SetProg: invalid fd") + } + e.bpf.progFd = progFd + e.bpf.progName = progName + return nil } func (e *SEG6LocalEncap) Type() int { @@ -306,6 +316,22 @@ func (e *SEG6LocalEncap) Decode(buf []byte) error { case nl.SEG6_LOCAL_OIF: e.Oif = int(native.Uint32(attr.Value[0:4])) e.Flags[nl.SEG6_LOCAL_OIF] = true + case nl.SEG6_LOCAL_BPF: + var bpfAttrs []syscall.NetlinkRouteAttr + bpfAttrs, err = nl.ParseRouteAttr(attr.Value) + bpfobj := bpfObj{} + for _, bpfAttr := range bpfAttrs { + switch bpfAttr.Attr.Type { + case nl.LWT_BPF_PROG_FD: + bpfobj.progFd = int(native.Uint32(bpfAttr.Value)) + case nl.LWT_BPF_PROG_NAME: + bpfobj.progName = string(bpfAttr.Value) + default: + err = fmt.Errorf("seg6local bpf decode: unknown attribute: Type %d", bpfAttr.Attr) + } + } + e.bpf = bpfobj + e.Flags[nl.SEG6_LOCAL_BPF] = true } } return err @@ -367,6 +393,16 @@ func (e *SEG6LocalEncap) Encode() ([]byte, error) { native.PutUint32(attr[4:], uint32(e.Oif)) res = append(res, attr...) } + if e.Flags[nl.SEG6_LOCAL_BPF] { + attr := nl.NewRtAttr(nl.SEG6_LOCAL_BPF, []byte{}) + if e.bpf.progFd != 0 { + attr.AddRtAttr(nl.LWT_BPF_PROG_FD, nl.Uint32Attr(uint32(e.bpf.progFd))) + } + if e.bpf.progName != "" { + attr.AddRtAttr(nl.LWT_BPF_PROG_NAME, nl.ZeroTerminated(e.bpf.progName)) + } + res = append(res, attr.Serialize()...) + } return res, err } func (e *SEG6LocalEncap) String() string { @@ -406,6 +442,9 @@ func (e *SEG6LocalEncap) String() string { } strs = append(strs, fmt.Sprintf("segs %d [ %s ]", len(e.Segments), strings.Join(segs, " "))) } + if e.Flags[nl.SEG6_LOCAL_BPF] { + strs = append(strs, fmt.Sprintf("bpf %s[%d]", e.bpf.progName, e.bpf.progFd)) + } return strings.Join(strs, " ") } func (e *SEG6LocalEncap) Equal(x Encap) bool { @@ -437,7 +476,7 @@ func (e *SEG6LocalEncap) Equal(x Encap) bool { if !e.InAddr.Equal(o.InAddr) || !e.In6Addr.Equal(o.In6Addr) { return false } - if e.Action != o.Action || e.Table != o.Table || e.Iif != o.Iif || e.Oif != o.Oif { + if e.Action != o.Action || e.Table != o.Table || e.Iif != o.Iif || e.Oif != o.Oif || e.bpf != o.bpf { return false } return true diff --git a/route_test.go b/route_test.go index c73205f..7c39965 100644 --- a/route_test.go +++ b/route_test.go @@ -1767,6 +1767,9 @@ func TestSEG6LocalEqual(t *testing.T) { var flags_end_b6_encaps [nl.SEG6_LOCAL_MAX]bool flags_end_b6_encaps[nl.SEG6_LOCAL_ACTION] = true flags_end_b6_encaps[nl.SEG6_LOCAL_SRH] = true + var flags_end_bpf [nl.SEG6_LOCAL_MAX]bool + flags_end_bpf[nl.SEG6_LOCAL_ACTION] = true + flags_end_bpf[nl.SEG6_LOCAL_BPF] = true cases := []SEG6LocalEncap{ { @@ -1819,6 +1822,15 @@ func TestSEG6LocalEqual(t *testing.T) { Segments: segs, }, } + + // SEG6_LOCAL_ACTION_END_BPF + endBpf := SEG6LocalEncap{ + Flags: flags_end_bpf, + Action: nl.SEG6_LOCAL_ACTION_END_BPF, + } + _ = endBpf.SetProg(1, "firewall") + cases = append(cases, endBpf) + for i1 := range cases { for i2 := range cases { got := cases[i1].Equal(&cases[i2])