diff --git a/pkg/config/duration.go b/pkg/config/duration.go index 13b4e263b0..d5f52646d1 100644 --- a/pkg/config/duration.go +++ b/pkg/config/duration.go @@ -27,12 +27,36 @@ func MustNewDuration(d time.Duration) *Duration { return &rv } +func ParseDuration(s string) (Duration, error) { + d, err := time.ParseDuration(s) + if err != nil { + return Duration{}, err + } + + return NewDuration(d) +} + func (d Duration) Duration() time.Duration { return d.d } +// Before returns the time d units before time t +func (d Duration) Before(t time.Time) time.Time { + return t.Add(-d.Duration()) +} + +// Shorter returns true if and only if d is shorter than od. +func (d Duration) Shorter(od Duration) bool { return d.d < od.d } + +// IsInstant is true if and only if d is of duration 0 +func (d Duration) IsInstant() bool { return d.d == 0 } + +// String returns a string representing the duration in the form "72h3m0.5s". +// Leading zero units are omitted. As a special case, durations less than one +// second format use a smaller unit (milli-, micro-, or nanoseconds) to ensure +// that the leading digit is non-zero. The zero duration formats as 0s. func (d Duration) String() string { - return d.d.String() + return d.Duration().String() } // MarshalJSON implements the json.Marshaler interface. @@ -58,6 +82,21 @@ func (d *Duration) UnmarshalJSON(input []byte) error { return nil } +func (d *Duration) Scan(v interface{}) (err error) { + switch tv := v.(type) { + case int64: + *d, err = NewDuration(time.Duration(tv)) + return err + default: + return errors.Errorf(`don't know how to parse "%s" of type %T as a `+ + `models.Duration`, tv, tv) + } +} + +func (d Duration) Value() (driver.Value, error) { + return int64(d.d), nil +} + // MarshalText implements the text.Marshaler interface. func (d Duration) MarshalText() ([]byte, error) { return []byte(d.d.String()), nil @@ -76,18 +115,3 @@ func (d *Duration) UnmarshalText(input []byte) error { *d = pd return nil } - -func (d *Duration) Scan(v interface{}) (err error) { - switch tv := v.(type) { - case int64: - *d, err = NewDuration(time.Duration(tv)) - return err - default: - return errors.Errorf(`don't know how to parse "%s" of type %T as a `+ - `models.Duration`, tv, tv) - } -} - -func (d Duration) Value() (driver.Value, error) { - return int64(d.d), nil -} diff --git a/pkg/config/duration_test.go b/pkg/config/duration_test.go new file mode 100644 index 0000000000..05ac8ac31a --- /dev/null +++ b/pkg/config/duration_test.go @@ -0,0 +1,74 @@ +package config + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDuration_MarshalJSON(t *testing.T) { + tests := []struct { + name string + input Duration + want string + }{ + {"zero", *MustNewDuration(0), `"0s"`}, + {"one second", *MustNewDuration(time.Second), `"1s"`}, + {"one minute", *MustNewDuration(time.Minute), `"1m0s"`}, + {"one hour", *MustNewDuration(time.Hour), `"1h0m0s"`}, + {"one hour thirty minutes", *MustNewDuration(time.Hour + 30*time.Minute), `"1h30m0s"`}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + b, err := json.Marshal(test.input) + assert.NoError(t, err) + assert.Equal(t, test.want, string(b)) + }) + } +} + +func TestDuration_Scan_Value(t *testing.T) { + t.Parallel() + + d := MustNewDuration(100) + require.NotNil(t, d) + + val, err := d.Value() + require.NoError(t, err) + + dNew := MustNewDuration(0) + err = dNew.Scan(val) + require.NoError(t, err) + + require.Equal(t, d, dNew) +} + +func TestDuration_MarshalJSON_UnmarshalJSON(t *testing.T) { + t.Parallel() + + d := MustNewDuration(100) + require.NotNil(t, d) + + json, err := d.MarshalJSON() + require.NoError(t, err) + + dNew := MustNewDuration(0) + err = dNew.UnmarshalJSON(json) + require.NoError(t, err) + + require.Equal(t, d, dNew) +} + +func TestDuration_MakeDurationFromString(t *testing.T) { + t.Parallel() + + d, err := ParseDuration("1s") + require.NoError(t, err) + require.Equal(t, 1*time.Second, d.Duration()) + + _, err = ParseDuration("xyz") + require.Error(t, err) +} diff --git a/pkg/config/url.go b/pkg/config/url.go index 9d694ea9ed..e4a1711685 100644 --- a/pkg/config/url.go +++ b/pkg/config/url.go @@ -21,8 +21,30 @@ func MustParseURL(s string) *URL { return u } +func (u *URL) String() string { + return (*url.URL)(u).String() +} + +// URL returns a copy of u as a *url.URL +func (u *URL) URL() *url.URL { + if u == nil { + return nil + } + // defensive copy + r := url.URL(*u) + if u.User != nil { + r.User = new(url.Userinfo) + *r.User = *u.User + } + return &r +} + +func (u *URL) IsZero() bool { + return (url.URL)(*u) == url.URL{} +} + func (u *URL) MarshalText() ([]byte, error) { - return []byte((*url.URL)(u).String()), nil + return []byte(u.String()), nil } func (u *URL) UnmarshalText(input []byte) error {