diff --git a/nflog/nflog.go b/nflog/nflog.go index ca2924c8..5ae46517 100644 --- a/nflog/nflog.go +++ b/nflog/nflog.go @@ -37,6 +37,9 @@ import ( // ErrNotFound is returned for empty query results. var ErrNotFound = errors.New("not found") +// ErrInvalidState is returned if the state isn't valid. +var ErrInvalidState = fmt.Errorf("invalid state") + // query currently allows filtering by and/or receiver group key. // It is configured via QueryParameter functions. // @@ -239,6 +242,9 @@ func decodeState(r io.Reader) (state, error) { var e pb.MeshEntry _, err := pbutil.ReadDelimited(r, &e) if err == nil { + if e.Entry == nil || e.Entry.Receiver == nil { + return nil, ErrInvalidState + } st[stateKey(string(e.Entry.GroupKey), e.Entry.Receiver)] = &e continue } diff --git a/nflog/nflog_test.go b/nflog/nflog_test.go index b43b8c72..417af1a8 100644 --- a/nflog/nflog_test.go +++ b/nflog/nflog_test.go @@ -296,3 +296,14 @@ func TestQuery(t *testing.T) { require.EqualValues(t, firingAlerts, entry.FiringAlerts) require.EqualValues(t, resolvedAlerts, entry.ResolvedAlerts) } + +func TestStateDecodingError(t *testing.T) { + // Check whether decoding copes with erroneous data. + s := state{"": &pb.MeshEntry{}} + + msg, err := s.MarshalBinary() + require.NoError(t, err) + + _, err = decodeState(bytes.NewReader(msg)) + require.Equal(t, ErrInvalidState, err) +}