Skip to content

Commit

Permalink
custom flag with user-defined parser
Browse files Browse the repository at this point in the history
Signed-off-by: Artem Bortnikov <[email protected]>
  • Loading branch information
BROngineer committed Jun 15, 2024
1 parent c9e94e9 commit d264430
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 8 deletions.
6 changes: 6 additions & 0 deletions errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ const (
noValueProvidedMessage = "no value provided for flag"
parseErrorMessage = "failed to parse flag"
typeMismatchMessage = "value type mismatch"
noParserDefinedMessage = "no input parser defined for flag"
)

var (
Expand All @@ -21,6 +22,7 @@ var (
ErrNoValueProvided = errors.New(noValueProvidedMessage)
ErrParseFailed = errors.New(parseErrorMessage)
ErrTypeMismatch = errors.New(typeMismatchMessage)
ErrNoParserDefined = errors.New(noParserDefinedMessage)
)

func UnknownFlag(flagName string) error {
Expand All @@ -46,3 +48,7 @@ func ParseError(flagName string, err error) error {
func TypeMismatch(actual, expected any) error {
return fmt.Errorf("expected %T, got %T: %w", expected, actual, ErrTypeMismatch)
}

func NoParserDefined(flagName string) error {
return fmt.Errorf("%s: %w", flagName, ErrNoParserDefined)
}
41 changes: 41 additions & 0 deletions flag/custom.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package flag

import "github.com/brongineer/helium/errors"

type Custom[T any] struct {
*flag[T]
}

func (f *Custom[T]) FromCommandLine(input string) error {
if f.CommandLineParser() == nil {
return errors.NoParserDefined(f.Name())
}

val, err := customParse[T](f.CommandLineParser(), input, f.Name())
if err != nil {
return err
}
f.value = &val
f.visited = true
return nil
}

func (f *Custom[T]) FromEnvVariable(input string) error {
if f.EnvVariableParser() == nil {
return errors.NoParserDefined(f.Name())
}

val, err := customParse[T](f.EnvVariableParser(), input, f.Name())
if err != nil {
return err
}
f.value = &val
f.visited = true
return nil
}

func NewCustom[T any](name string, opts ...Option) *Custom[T] {
f := newFlag[T](name)
applyForFlag(f, opts...)
return &Custom[T]{f}
}
95 changes: 95 additions & 0 deletions flag/flag_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package flag

import (
"reflect"
"strconv"
"testing"
"time"

Expand All @@ -17,6 +19,8 @@ type flagPropertyGetter interface {
IsVisited() bool
FromCommandLine(string) error
FromEnvVariable(string) error
CommandLineParser() func(string) (any, error)
EnvVariableParser() func(string) (any, error)
}

type expected struct {
Expand All @@ -26,6 +30,8 @@ type expected struct {
defaultValue any
required bool
shared bool
cmdParser func(string) (any, error)
envParser func(string) (any, error)
}

func (e *expected) Description() string {
Expand Down Expand Up @@ -55,6 +61,14 @@ func (e *expected) Shared() bool {
return e.shared
}

func (e *expected) CommandLineParser() func(string) (any, error) {
return e.cmdParser
}

func (e *expected) EnvVariableParser() func(string) (any, error) {
return e.envParser
}

type result[T any] struct {
some *T
err bool
Expand Down Expand Up @@ -88,6 +102,16 @@ func assertFlag[T any](t *testing.T, f flagPropertyGetter, tt flagTest) {
assert.Equal(t, tt.expected.DefaultValue(), *actual)
}
assert.Equal(t, tt.expected.Shared(), f.IsShared())
if tt.expected.CommandLineParser() == nil {
assert.Nil(t, f.CommandLineParser())
} else {
assert.Equal(t, reflect.ValueOf(tt.expected.CommandLineParser()).Pointer(), reflect.ValueOf(f.CommandLineParser()).Pointer())
}
if tt.expected.EnvVariableParser() == nil {
assert.Nil(t, f.EnvVariableParser())
} else {
assert.Equal(t, reflect.ValueOf(tt.expected.EnvVariableParser()).Pointer(), reflect.ValueOf(f.EnvVariableParser()).Pointer())
}
}

func assertGetFlag[T any](t *testing.T, f flagPropertyGetter, tt getFlagTest[T]) {
Expand All @@ -104,6 +128,77 @@ func assertGetFlag[T any](t *testing.T, f flagPropertyGetter, tt getFlagTest[T])
assert.Equal(t, tt.wanted.some, PtrOrDie[T](f.Value()))
}

type custom struct {
field int
}

func parser(input string) (any, error) {
v, err := strconv.Atoi(input)
if err != nil {
return nil, err
}
return &custom{field: v}, nil
}

func TestFlag_Custom(t *testing.T) {
t.Parallel()
tests := []flagTest{
{
"sample",
[]Option{},
expected{},
},
{
"sample",
[]Option{CommandLineParser(parser)},
expected{cmdParser: parser},
},
}
for _, tc := range tests {
tt := tc
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
f := NewCustom[custom](tt.name, tt.opts...)
assertFlag[custom](t, f, tt)
})
}
}

func TestFlag_GetCustom(t *testing.T) {
t.Parallel()
tests := []getFlagTest[custom]{
{
"sample",
[]Option{},
ptrTo("test"),
result[custom]{err: true},
},
{
"sample",
[]Option{CommandLineParser(parser)},
ptrTo("10"),
result[custom]{
ptrTo(custom{field: 10}),
false,
},
},
{
"sample",
[]Option{CommandLineParser(parser)},
ptrTo("invalid"),
result[custom]{err: true},
},
}
for _, tc := range tests {
tt := tc
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
f := NewCustom[custom](tt.name, tt.opts...)
assertGetFlag[custom](t, f, tt)
})
}
}

func TestFlag_String(t *testing.T) {
t.Parallel()
tests := []flagTest{
Expand Down
55 changes: 47 additions & 8 deletions flag/generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,23 @@ package flag
import (
"fmt"
"os"

"github.com/brongineer/helium/errors"
)

const defaultSliceSeparator = ","

type flag[T any] struct {
name string
description string
shorthand string
shared bool
visited bool
defaultValue *T
value *T
separator string
name string
description string
shorthand string
shared bool
visited bool
defaultValue *T
value *T
separator string
commandLineParser func(string) (any, error)
envVariableParser func(string) (any, error)
}

func (f *flag[T]) Value() any {
Expand Down Expand Up @@ -52,6 +56,14 @@ func (f *flag[T]) IsVisited() bool {
return f.visited
}

func (f *flag[T]) CommandLineParser() func(string) (any, error) {
return f.commandLineParser
}

func (f *flag[T]) EnvVariableParser() func(string) (any, error) {
return f.envVariableParser
}

func newFlag[T any](name string) *flag[T] {
return &flag[T]{name: name}
}
Expand Down Expand Up @@ -83,3 +95,30 @@ func (f *flag[T]) setDefaultValue(value any) {
func (f *flag[T]) setSeparator(separator string) {
f.separator = separator
}

func (f *flag[T]) setCommandLineParser(parser func(string) (any, error)) {
f.commandLineParser = parser
}

func (f *flag[T]) setEnvVariableParser(parser func(string) (any, error)) {
f.envVariableParser = parser
}

func customParse[T any](parser func(string) (any, error), input, flag string) (T, error) {
var (
val T
parsed any
valPtr *T
err error
)
parsed, err = parser(input)
if err != nil {
return val, errors.ParseError(flag, err)
}
valPtr, err = value[T](parsed)
if err != nil {
return val, errors.ParseError(flag, err)
}
val = *valPtr
return val, nil
}
22 changes: 22 additions & 0 deletions flag/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ type flagPropertySetter interface {
setShared()
setDefaultValue(any)
setSeparator(string)
setCommandLineParser(func(string) (any, error))
setEnvVariableParser(func(string) (any, error))
}

type Option interface {
Expand Down Expand Up @@ -70,6 +72,26 @@ func Separator(value string) Option {
return separator{value}
}

type commandLineParser func(string) (any, error)

func (o commandLineParser) apply(f flagPropertySetter) {
f.setCommandLineParser(o)
}

func CommandLineParser(parser func(string) (any, error)) Option {
return commandLineParser(parser)
}

type envVariableParser func(string) (any, error)

func (o envVariableParser) apply(f flagPropertySetter) {
f.setEnvVariableParser(o)
}

func EnvVariableParser(parser func(string) (any, error)) Option {
return envVariableParser(parser)
}

func applyForFlag(f flagPropertySetter, opts ...Option) {
for _, opt := range opts {
opt.apply(f)
Expand Down

0 comments on commit d264430

Please sign in to comment.