diff --git a/silence/silence.go b/silence/silence.go index 45268be6..26e3d344 100644 --- a/silence/silence.go +++ b/silence/silence.go @@ -40,6 +40,9 @@ import ( // ErrNotFound is returned if a silence was not found. var ErrNotFound = fmt.Errorf("not found") +// ErrInvalidState is returned if the state isn't valid. +var ErrInvalidState = fmt.Errorf("invalid state") + func utcNow() time.Time { return time.Now().UTC() } @@ -758,6 +761,9 @@ func decodeState(r io.Reader) (state, error) { var s pb.MeshSilence _, err := pbutil.ReadDelimited(r, &s) if err == nil { + if s.Silence == nil { + return nil, ErrInvalidState + } st[s.Silence.Id] = &s continue } diff --git a/silence/silence_test.go b/silence/silence_test.go index bde9afbe..febdaf0c 100644 --- a/silence/silence_test.go +++ b/silence/silence_test.go @@ -1080,3 +1080,14 @@ func TestStateCoding(t *testing.T) { require.Equal(t, in, out, "decoded data doesn't match encoded data") } } + +func TestStateDecodingError(t *testing.T) { + // Check whether decoding copes with erroneous data. + s := state{"": &pb.MeshSilence{}} + + msg, err := s.MarshalBinary() + require.NoError(t, err) + + _, err = decodeState(bytes.NewReader(msg)) + require.Equal(t, ErrInvalidState, err) +}