diff --git a/pkg/kadm/errors.go b/pkg/kadm/errors.go index 822bf614..878e62af 100644 --- a/pkg/kadm/errors.go +++ b/pkg/kadm/errors.go @@ -121,3 +121,14 @@ func (e *ShardErrors) Error() string { } return fmt.Sprintf("request %s has %d separate shard errors, first: %s", e.Name, len(e.Errs), e.Errs[0].Err) } + +// Unwrap returns the underlying errors. +func (e *ShardErrors) Unwrap() []error { + unwrapped := make([]error, 0, len(e.Errs)) + + for _, shardErr := range e.Errs { + unwrapped = append(unwrapped, shardErr.Err) + } + + return unwrapped +} diff --git a/pkg/kadm/errors_test.go b/pkg/kadm/errors_test.go new file mode 100644 index 00000000..4b219080 --- /dev/null +++ b/pkg/kadm/errors_test.go @@ -0,0 +1,23 @@ +package kadm + +import ( + "context" + "errors" + "testing" +) + +func TestShardErrors_Unwrap(t *testing.T) { + err1 := errors.New("test error 1") + err2 := errors.New("test error 2") + + errs := &ShardErrors{Errs: []ShardError{{Err: err1}, {Err: context.Canceled}}} + if !errors.Is(errs, err1) { + t.Errorf("ShardErrors does not match error %v", err1) + } + if !errors.Is(errs, context.Canceled) { + t.Errorf("ShardErrors does not match error %v", context.Canceled) + } + if errors.Is(errs, err2) { + t.Errorf("ShardErrors matches error %v but it not expected to match it", err2) + } +}