Merge pull request #10893 from prymitive/unwrap_errors

Implement Unwrap() on errors returned from rulefmt
This commit is contained in:
Julien Pivotto 2022-06-30 11:30:21 +02:00 committed by GitHub
commit c637705403
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 0 deletions

View File

@ -53,6 +53,11 @@ func (err *Error) Error() string {
return fmt.Sprintf("group %q, rule %d, %q: %v", err.Group, err.Rule, err.RuleName, err.Err.err) return fmt.Sprintf("group %q, rule %d, %q: %v", err.Group, err.Rule, err.RuleName, err.Err.err)
} }
// Unwrap unpacks wrapped error for use in errors.Is & errors.As.
func (err *Error) Unwrap() error {
return &err.Err
}
// WrappedError wraps error with the yaml node which can be used to represent // WrappedError wraps error with the yaml node which can be used to represent
// the line and column numbers of the error. // the line and column numbers of the error.
type WrappedError struct { type WrappedError struct {
@ -75,6 +80,11 @@ func (we *WrappedError) Error() string {
return we.err.Error() return we.err.Error()
} }
// Unwrap unpacks wrapped error for use in errors.Is & errors.As.
func (we *WrappedError) Unwrap() error {
return we.err
}
// RuleGroups is a set of rule groups that are typically exposed in a file. // RuleGroups is a set of rule groups that are typically exposed in a file.
type RuleGroups struct { type RuleGroups struct {
Groups []RuleGroup `yaml:"groups"` Groups []RuleGroup `yaml:"groups"`

View File

@ -15,6 +15,7 @@ package rulefmt
import ( import (
"errors" "errors"
"io"
"path/filepath" "path/filepath"
"testing" "testing"
@ -303,3 +304,27 @@ func TestWrappedError(t *testing.T) {
}) })
} }
} }
func TestErrorUnwrap(t *testing.T) {
err1 := errors.New("test error")
tests := []struct {
wrappedError *Error
unwrappedError error
}{
{
wrappedError: &Error{Err: WrappedError{err: err1}},
unwrappedError: err1,
},
{
wrappedError: &Error{Err: WrappedError{err: io.ErrClosedPipe}},
unwrappedError: io.ErrClosedPipe,
},
}
for _, tt := range tests {
t.Run(tt.wrappedError.Error(), func(t *testing.T) {
require.ErrorIs(t, tt.wrappedError, tt.unwrappedError)
})
}
}