From 648b12d8c563fa761f5643bdc790a424e2d78a4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=81ukasz=20Mierzwa?= Date: Mon, 20 Jun 2022 16:57:34 +0100 Subject: [PATCH] Implement Unwrap() on errors returned from rulefmt MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I'd like to unwrap errors returned from rulefmt but both Error and WrappedError types are missing Unwrap() method. Signed-off-by: Ɓukasz Mierzwa --- model/rulefmt/rulefmt.go | 10 ++++++++++ model/rulefmt/rulefmt_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/model/rulefmt/rulefmt.go b/model/rulefmt/rulefmt.go index 46e008007..6c8b5978d 100644 --- a/model/rulefmt/rulefmt.go +++ b/model/rulefmt/rulefmt.go @@ -48,6 +48,11 @@ func (err *Error) Error() string { return errors.Wrapf(err.Err.err, "group %q, rule %d, %q", err.Group, err.Rule, err.RuleName).Error() } +// Unwrap unpacks wrapped error for use in errors.Is & errors.As. +func (err *Error) Unwrap() error { + return &err.Err +} + // WrappedError wraps error with the yaml node which can be used to represent // the line and column numbers of the error. type WrappedError struct { @@ -66,6 +71,11 @@ func (we *WrappedError) Error() string { return we.err.Error() } +// Unwrap unpacks wrapped error for use in errors.Is & errors.As. +func (we *WrappedError) Unwrap() error { + return we.err +} + // RuleGroups is a set of rule groups that are typically exposed in a file. type RuleGroups struct { Groups []RuleGroup `yaml:"groups"` diff --git a/model/rulefmt/rulefmt_test.go b/model/rulefmt/rulefmt_test.go index 21afb0b46..f8e6869f0 100644 --- a/model/rulefmt/rulefmt_test.go +++ b/model/rulefmt/rulefmt_test.go @@ -15,6 +15,7 @@ package rulefmt import ( "errors" + "io" "path/filepath" "testing" @@ -299,3 +300,27 @@ func TestWrappedError(t *testing.T) { }) } } + +func TestErrorUnwrap(t *testing.T) { + err1 := errors.New("test error") + + tests := []struct { + wrappedError *Error + unwrappedError error + }{ + { + wrappedError: &Error{Err: WrappedError{err: err1}}, + unwrappedError: err1, + }, + { + wrappedError: &Error{Err: WrappedError{err: io.ErrClosedPipe}}, + unwrappedError: io.ErrClosedPipe, + }, + } + + for _, tt := range tests { + t.Run(tt.wrappedError.Error(), func(t *testing.T) { + require.ErrorIs(t, tt.wrappedError, tt.unwrappedError) + }) + } +}