diff --git a/go/flagutil/flagutil_test.go b/go/flagutil/flagutil_test.go index 2502213ab73..1ddbf693e27 100644 --- a/go/flagutil/flagutil_test.go +++ b/go/flagutil/flagutil_test.go @@ -35,10 +35,9 @@ func TestStringList(t *testing.T) { "3ala,": "3ala.", } for in, out := range wanted { - if err := p.Set(in); err != nil { - t.Errorf("v.Set(%v): %v", in, err) - continue - } + err := p.Set(in) + assert.NoError(t, err) + assert.Equal(t, out, strings.Join(p, ".")) assert.Equal(t, in, p.String()) @@ -49,12 +48,10 @@ func TestStringList(t *testing.T) { func TestEmptyStringList(t *testing.T) { var p StringListValue var _ pflag.Value = &p - if err := p.Set(""); err != nil { - t.Fatalf("p.Set(\"\"): %v", err) - } - if len(p) != 0 { - t.Fatalf("len(p) != 0: got %v", len(p)) - } + + err := p.Set("") + require.NoError(t, err) + require.Len(t, p, 0) } type pair struct { @@ -81,28 +78,15 @@ func TestStringMap(t *testing.T) { }, } for _, want := range wanted { - if err := v.Set(want.in); err != want.err { - t.Errorf("v.Set(%v): %v", want.in, want.err) - continue - } - if want.err != nil { - continue - } + err := v.Set(want.in) + assert.ErrorIs(t, err, want.err) - if len(want.out) != len(v) { - t.Errorf("want %#v, got %#v", want.out, v) + if want.err != nil { continue } - for key, value := range want.out { - if v[key] != value { - t.Errorf("want %#v, got %#v", want.out, v) - continue - } - } - if vs := v.String(); vs != want.in { - t.Errorf("v.String(): want %#v, got %#v", want.in, vs) - } + assert.EqualValues(t, want.out, v) + assert.Equal(t, want.in, v.String()) } } @@ -119,3 +103,179 @@ func TestStringMapValue(t *testing.T) { require.Equal(t, "StringMap", strMapVal.Type()) require.Equal(t, map[string]string(map[string]string{"key": "val"}), strMapVal.Get()) } + +func TestDualFormatStringListVar(t *testing.T) { + testFlagSet := pflag.NewFlagSet("testFlagSet", pflag.ExitOnError) + + testFlagName := "test-flag_name" + var flagVal []string + testValue := []string{"testValue1", "testValue2", "testValue3"} + + DualFormatStringListVar(testFlagSet, &flagVal, testFlagName, testValue, "usage string") + assert.Equal(t, testValue, flagVal) + + want := "testValue1,testValue2,testValue3" + f := testFlagSet.Lookup("test-flag-name") + assert.NotNil(t, f) + assert.Equal(t, want, f.Value.String()) + + f = testFlagSet.Lookup("test_flag_name") + assert.NotNil(t, f) + assert.Equal(t, want, f.Value.String()) + + newVal := "newValue1,newValue2" + err := testFlagSet.Set("test-flag-name", newVal) + assert.NoError(t, err) + + assert.Equal(t, newVal, f.Value.String()) + assert.Equal(t, []string{"newValue1", "newValue2"}, flagVal) +} + +func TestDualFormatStringVar(t *testing.T) { + testFlagSet := pflag.NewFlagSet("testFlagSet", pflag.ExitOnError) + + testFlagName := "test-flag_name" + var flagVal string + testValue := "testValue" + + DualFormatStringVar(testFlagSet, &flagVal, testFlagName, testValue, "usage string") + assert.Equal(t, testValue, flagVal) + + f := testFlagSet.Lookup("test-flag-name") + assert.NotNil(t, f) + assert.Equal(t, testValue, f.Value.String()) + + f = testFlagSet.Lookup("test_flag_name") + assert.NotNil(t, f) + assert.Equal(t, testValue, f.Value.String()) + + newVal := "newValue" + err := testFlagSet.Set("test-flag-name", newVal) + assert.NoError(t, err) + + assert.Equal(t, newVal, f.Value.String()) + assert.Equal(t, newVal, flagVal) +} + +func TestDualFormatBoolVar(t *testing.T) { + testFlagSet := pflag.NewFlagSet("testFlagSet", pflag.ExitOnError) + + testFlagName := "test-flag_name" + var flagVal bool + + DualFormatBoolVar(testFlagSet, &flagVal, testFlagName, true, "usage string") + assert.True(t, flagVal) + + f := testFlagSet.Lookup("test-flag-name") + assert.NotNil(t, f) + assert.Equal(t, "true", f.Value.String()) + + f = testFlagSet.Lookup("test_flag_name") + assert.NotNil(t, f) + assert.Equal(t, "true", f.Value.String()) + + err := testFlagSet.Set("test-flag-name", "false") + assert.NoError(t, err) + + assert.Equal(t, "false", f.Value.String()) + assert.False(t, flagVal) +} + +func TestDualFormatInt64Var(t *testing.T) { + testFlagSet := pflag.NewFlagSet("testFlagSet", pflag.ExitOnError) + + testFlagName := "test-flag_name" + var flagVal int64 + + DualFormatInt64Var(testFlagSet, &flagVal, testFlagName, int64(256), "usage string") + assert.Equal(t, int64(256), flagVal) + + f := testFlagSet.Lookup("test-flag-name") + assert.NotNil(t, f) + assert.Equal(t, "256", f.Value.String()) + + f = testFlagSet.Lookup("test_flag_name") + assert.NotNil(t, f) + assert.Equal(t, "256", f.Value.String()) + + newVal := "128" + err := testFlagSet.Set("test-flag-name", newVal) + assert.NoError(t, err) + + assert.Equal(t, newVal, f.Value.String()) + assert.Equal(t, int64(128), flagVal) +} + +func TestDualFormatIntVar(t *testing.T) { + testFlagSet := pflag.NewFlagSet("testFlagSet", pflag.ExitOnError) + + testFlagName := "test-flag_name" + var flagVal int + + DualFormatIntVar(testFlagSet, &flagVal, testFlagName, 128, "usage string") + assert.Equal(t, 128, flagVal) + + f := testFlagSet.Lookup("test-flag-name") + assert.NotNil(t, f) + assert.Equal(t, "128", f.Value.String()) + + f = testFlagSet.Lookup("test_flag_name") + assert.NotNil(t, f) + assert.Equal(t, "128", f.Value.String()) + + newVal := "256" + err := testFlagSet.Set("test-flag-name", newVal) + assert.NoError(t, err) + + assert.Equal(t, newVal, f.Value.String()) + assert.Equal(t, 256, flagVal) +} + +type MockValue struct { + val *bool +} + +func (b MockValue) Set(s string) error { + if s == "true" { + *b.val = true + } else { + *b.val = false + } + return nil +} + +func (b MockValue) String() string { + if *b.val { + return "true" + } + return "false" +} + +func (b MockValue) Type() string { + return "bool" +} + +func TestDualFormatVar(t *testing.T) { + testFlagSet := pflag.NewFlagSet("testFlagSet", pflag.ExitOnError) + + testFlagName := "test-flag_name" + flagVal := true + value := MockValue{val: &flagVal} + + DualFormatVar(testFlagSet, value, testFlagName, "usage string") + + f := testFlagSet.Lookup("test-flag-name") + assert.NotNil(t, f) + assert.Equal(t, "true", f.Value.String()) + + f = testFlagSet.Lookup("test_flag_name") + assert.NotNil(t, f) + assert.Equal(t, "true", f.Value.String()) + + newVal := "false" + err := testFlagSet.Set("test-flag-name", newVal) + assert.NoError(t, err) + + assert.Equal(t, newVal, f.Value.String()) + assert.False(t, flagVal) +}