Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support equivalent to golang flag.TextVar(), also fixes the test failure as described in #368 #418

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions flag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1134,7 +1134,6 @@ func TestMultipleNormalizeFlagNameInvocations(t *testing.T) {
}
}

//
func TestHiddenFlagInUsage(t *testing.T) {
f := NewFlagSet("bob", ContinueOnError)
f.Bool("secretFlag", true, "shhh")
Expand All @@ -1149,7 +1148,6 @@ func TestHiddenFlagInUsage(t *testing.T) {
}
}

//
func TestHiddenFlagUsage(t *testing.T) {
f := NewFlagSet("bob", ContinueOnError)
f.Bool("secretFlag", true, "shhh")
Expand Down Expand Up @@ -1238,8 +1236,8 @@ func TestPrintDefaults(t *testing.T) {
fs.PrintDefaults()
got := buf.String()
if got != defaultOutput {
fmt.Println("\n" + got)
fmt.Println("\n" + defaultOutput)
fmt.Print("\n" + got + "\n")
fmt.Print("\n" + defaultOutput + "\n")
t.Errorf("got %q want %q\n", got, defaultOutput)
}
}
Expand Down
81 changes: 81 additions & 0 deletions text.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package pflag

import (
"encoding"
"fmt"
"reflect"
)

// following is copied from go 1.23.4 flag.go
type textValue struct{ p encoding.TextUnmarshaler }

func newTextValue(val encoding.TextMarshaler, p encoding.TextUnmarshaler) textValue {
ptrVal := reflect.ValueOf(p)
if ptrVal.Kind() != reflect.Ptr {
panic("variable value type must be a pointer")
}
defVal := reflect.ValueOf(val)
if defVal.Kind() == reflect.Ptr {
defVal = defVal.Elem()
}
if defVal.Type() != ptrVal.Type().Elem() {
panic(fmt.Sprintf("default type does not match variable type: %v != %v", defVal.Type(), ptrVal.Type().Elem()))
}
ptrVal.Elem().Set(defVal)
return textValue{p}
}

func (v textValue) Set(s string) error {
return v.p.UnmarshalText([]byte(s))
}

func (v textValue) Get() interface{} {
return v.p
}

func (v textValue) String() string {
if m, ok := v.p.(encoding.TextMarshaler); ok {
if b, err := m.MarshalText(); err == nil {
return string(b)
}
}
return ""
}

//end of copy

func (v textValue) Type() string {
return reflect.ValueOf(v.p).Type().Name()
}

// GetText set out, which implements encoding.UnmarshalText, to the value of a flag with given name
func (f *FlagSet) GetText(name string, out encoding.TextUnmarshaler) error {
flag := f.Lookup(name)
if flag == nil {
return fmt.Errorf("flag accessed but not defined: %s", name)
}
if flag.Value.Type() != reflect.TypeOf(out).Name() {
fmt.Errorf("trying to get %s value of flag of type %s", reflect.TypeOf(out).Name(), flag.Value.Type())
}
return out.UnmarshalText([]byte(flag.Value.String()))
}

// TextVar defines a flag with a specified name, default value, and usage string. The argument p must be a pointer to a variable that will hold the value of the flag, and p must implement encoding.TextUnmarshaler. If the flag is used, the flag value will be passed to p's UnmarshalText method. The type of the default value must be the same as the type of p.
func (f *FlagSet) TextVar(p encoding.TextUnmarshaler, name string, value encoding.TextMarshaler, usage string) {
f.VarP(newTextValue(value, p), name, "", usage)
}

// TextVarP is like TextVar, but accepts a shorthand letter that can be used after a single dash.
func (f *FlagSet) TextVarP(p encoding.TextUnmarshaler, name, shorthand string, value encoding.TextMarshaler, usage string) {
f.VarP(newTextValue(value, p), name, shorthand, usage)
}

// TextVar defines a flag with a specified name, default value, and usage string. The argument p must be a pointer to a variable that will hold the value of the flag, and p must implement encoding.TextUnmarshaler. If the flag is used, the flag value will be passed to p's UnmarshalText method. The type of the default value must be the same as the type of p.
func TextVar(p encoding.TextUnmarshaler, name string, value encoding.TextMarshaler, usage string) {
CommandLine.VarP(newTextValue(value, p), name, "", usage)
}

// TextVarP is like TextVar, but accepts a shorthand letter that can be used after a single dash.
func TextVarP(p encoding.TextUnmarshaler, name, shorthand string, value encoding.TextMarshaler, usage string) {
CommandLine.VarP(newTextValue(value, p), name, shorthand, usage)
}
53 changes: 53 additions & 0 deletions text_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package pflag

import (
"fmt"
"os"
"testing"
"time"
)

func setUpTime(t *time.Time) *FlagSet {
f := NewFlagSet("test", ContinueOnError)
f.TextVar(t, "time", time.Now(), "time stamp")
return f
}

func TestText(t *testing.T) {
testCases := []struct {
input string
success bool
expected time.Time
}{
{"2003-01-02T15:04:05Z", true, time.Date(2003, 1, 2, 15, 04, 05, 0, time.UTC)},
{"2003-01-02 15:05:01", false, time.Date(2002, 1, 2, 15, 05, 05, 07, time.UTC)},
{"2024-11-22T03:01:02Z", true, time.Date(2024, 11, 22, 3, 1, 02, 0, time.UTC)},
{"2006-01-02T15:04:05+07:00", true, time.Date(2006, 1, 2, 15, 4, 5, 0, time.FixedZone("UTC+7", 7*60*60))},
}

devnull, _ := os.Open(os.DevNull)
os.Stderr = devnull
for i := range testCases {
var ts time.Time
f := setUpTime(&ts)
tc := &testCases[i]
arg := fmt.Sprintf("--time=%s", tc.input)
err := f.Parse([]string{arg})
if err != nil && tc.success == true {
t.Errorf("expected success, got %q", err)
continue
} else if err == nil && tc.success == false {
t.Errorf("expected failure, but succeeded")
continue
} else if tc.success {
parsedT := new(time.Time)
err := f.GetText("time", parsedT)
if err != nil {
t.Errorf("Got error trying to fetch the time flag: %v", err)
}
if !parsedT.Equal(tc.expected) {
t.Errorf("expected %q, got %q", tc.expected, parsedT)
}
}
}
}