Skip to content

Commit

Permalink
replace generic flag type constraint with any
Browse files Browse the repository at this point in the history
Signed-off-by: Artem Bortnikov <[email protected]>
  • Loading branch information
BROngineer committed Jun 15, 2024
1 parent 798d696 commit c9e94e9
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 44 deletions.
10 changes: 5 additions & 5 deletions flag/flag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ func (e *expected) Shared() bool {
return e.shared
}

type result[T allowed] struct {
type result[T any] struct {
some *T
err bool
}

func ptrTo[T allowed](v T) *T {
func ptrTo[T any](v T) *T {
return &v
}

Expand All @@ -70,14 +70,14 @@ type flagTest struct {
expected expected
}

type getFlagTest[T allowed] struct {
type getFlagTest[T any] struct {
name string
opts []Option
input *string
wanted result[T]
}

func assertFlag[T allowed](t *testing.T, f flagPropertyGetter, tt flagTest) {
func assertFlag[T any](t *testing.T, f flagPropertyGetter, tt flagTest) {
assert.NotNil(t, f)
assert.Equal(t, tt.expected.Description(), f.Description())
assert.Equal(t, tt.expected.Shorthand(), f.Shorthand())
Expand All @@ -90,7 +90,7 @@ func assertFlag[T allowed](t *testing.T, f flagPropertyGetter, tt flagTest) {
assert.Equal(t, tt.expected.Shared(), f.IsShared())
}

func assertGetFlag[T allowed](t *testing.T, f flagPropertyGetter, tt getFlagTest[T]) {
func assertGetFlag[T any](t *testing.T, f flagPropertyGetter, tt getFlagTest[T]) {
assert.NotNil(t, f)
if tt.input != nil {
if tt.wanted.err {
Expand Down
37 changes: 2 additions & 35 deletions flag/generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,11 @@ package flag
import (
"fmt"
"os"
"time"
)

const defaultSliceSeparator = ","

type intFlag interface {
~int | ~int8 | ~int16 | ~int32 | ~int64
}

type uintFlag interface {
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64
}

type floatFlag interface {
~float32 | ~float64
}

type intSliceFlag interface {
~[]int | ~[]int8 | ~[]int16 | ~[]int32 | ~[]int64
}

type uintSliceFlag interface {
~[]uint | ~[]uint8 | ~[]uint16 | ~[]uint32 | ~[]uint64
}

type floatSliceFlag interface {
~[]float32 | ~[]float64
}

type sliceFlag interface {
~[]string | ~[]bool | ~[]time.Duration | intSliceFlag | uintSliceFlag | floatSliceFlag
}

type allowed interface {
~string | ~bool | intFlag | uintFlag | floatFlag | sliceFlag | time.Duration // | counter
}

type flag[T allowed] struct {
type flag[T any] struct {
name string
description string
shorthand string
Expand Down Expand Up @@ -85,7 +52,7 @@ func (f *flag[T]) IsVisited() bool {
return f.visited
}

func newFlag[T allowed](name string) *flag[T] {
func newFlag[T any](name string) *flag[T] {
return &flag[T]{name: name}
}

Expand Down
7 changes: 3 additions & 4 deletions flag/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ import (
// If the flag does not exist, it returns an error with the message "unknown flag".
// If the value is nil, it returns an error with the message "value is nil".
// If the type of the value does not match the specified type, it returns an error with the message "wrong type".
// The function expects the allowed type to be one of the types defined in the flags.generic.allowed interface.
func value[T allowed](v any) (*T, error) {
func value[T any](v any) (*T, error) {
var (
valuePtr *T
ok bool
Expand All @@ -28,7 +27,7 @@ func value[T allowed](v any) (*T, error) {
// DerefOrDie dereferences a pointer and checks for errors. If the error is not nil,
// it prints the error message to stderr and exits the program with code 1. If the pointer is nil,
// it prints an error message to stderr and exits the program with code 1. It returns the dereferenced value.
func DerefOrDie[T allowed](v any) T {
func DerefOrDie[T any](v any) T {
p, err := value[T](v)
if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "Error: %v\n", err.Error())
Expand All @@ -44,7 +43,7 @@ func DerefOrDie[T allowed](v any) T {
// PtrOrDie returns the pointer value `p` and exits the program if there is an error `err`.
// If `err` is not nil, an error message is printed to stderr and the program exits with code 1.
// The function is used to simplify error handling in flag retrieval functions.
func PtrOrDie[T allowed](v any) *T {
func PtrOrDie[T any](v any) *T {
p, err := value[T](v)
if err != nil {
_, _ = fmt.Fprintf(os.Stderr, "Error: %v\n", err.Error())
Expand Down

0 comments on commit c9e94e9

Please sign in to comment.