Skip to content

Commit

Permalink
chore: modernize tests (#15244)
Browse files Browse the repository at this point in the history
Signed-off-by: Manik Rana <[email protected]>
Signed-off-by: Manik Rana <[email protected]>
  • Loading branch information
Maniktherana authored Feb 22, 2024
1 parent ba3531f commit 6fec119
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 234 deletions.
34 changes: 10 additions & 24 deletions go/cache/lru_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package cache

import (
"testing"

"github.com/stretchr/testify/assert"
)

type CacheValue struct {
Expand All @@ -27,24 +29,12 @@ type CacheValue struct {
func TestInitialState(t *testing.T) {
cache := NewLRUCache[*CacheValue](5)
l, sz, c, e, h, m := cache.Len(), cache.UsedCapacity(), cache.MaxCapacity(), cache.Evictions(), cache.Hits(), cache.Misses()
if l != 0 {
t.Errorf("length = %v, want 0", l)
}
if sz != 0 {
t.Errorf("size = %v, want 0", sz)
}
if c != 5 {
t.Errorf("capacity = %v, want 5", c)
}
if e != 0 {
t.Errorf("evictions = %v, want 0", c)
}
if h != 0 {
t.Errorf("hits = %v, want 0", c)
}
if m != 0 {
t.Errorf("misses = %v, want 0", c)
}
assert.Zero(t, l)
assert.EqualValues(t, 0, sz)
assert.EqualValues(t, 5, c)
assert.EqualValues(t, 0, e)
assert.EqualValues(t, 0, h)
assert.EqualValues(t, 0, m)
}

func TestSetInsertsValue(t *testing.T) {
Expand Down Expand Up @@ -137,12 +127,8 @@ func TestCapacityIsObeyed(t *testing.T) {
// Insert one more; something should be evicted to make room.
cache.Set("key4", value)
sz, evictions := cache.UsedCapacity(), cache.Evictions()
if sz != size {
t.Errorf("post-evict cache.UsedCapacity() = %v, expected %v", sz, size)
}
if evictions != 1 {
t.Errorf("post-evict cache.Evictions() = %v, expected 1", evictions)
}
assert.Equal(t, size, sz)
assert.EqualValues(t, 1, evictions)

// Check various other stats
if l := cache.Len(); int64(l) != size {
Expand Down
14 changes: 8 additions & 6 deletions go/cache/theine/singleflight_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ func TestDo(t *testing.T) {
return "bar", nil
})

assert.Equal(t, "bar (string)", fmt.Sprintf("%v (%T)", v, v), "incorrect Do value")
assert.NoError(t, err, "got Do error")
assert.Equal(t, "bar (string)", fmt.Sprintf("%v (%T)", v, v))
assert.NoError(t, err)
}

func TestDoErr(t *testing.T) {
Expand Down Expand Up @@ -85,11 +85,11 @@ func TestDoDupSuppress(t *testing.T) {
defer wg2.Done()
wg1.Done()
v, err, _ := g.Do("key", fn)
if !assert.NoError(t, err, "unexpected Do error") {
if !assert.NoError(t, err) {
return
}

assert.Equal(t, "bar", v, "unexpected Do value")
assert.Equal(t, "bar", v)
}()
}
wg1.Wait()
Expand All @@ -98,7 +98,8 @@ func TestDoDupSuppress(t *testing.T) {
c <- "bar"
wg2.Wait()
got := atomic.LoadInt32(&calls)
assert.True(t, got > 0 && got < n, "number of calls not between 0 and %d", n)
assert.Greater(t, got, int32(0))
assert.Less(t, got, int32(n))
}

// Test singleflight behaves correctly after Do panic.
Expand Down Expand Up @@ -131,7 +132,7 @@ func TestPanicDo(t *testing.T) {

select {
case <-done:
assert.Equal(t, int32(n), panicCount, "unexpected number of panics")
assert.EqualValues(t, n, panicCount)
case <-time.After(time.Second):
require.Fail(t, "Do hangs")
}
Expand All @@ -152,6 +153,7 @@ func TestGoexitDo(t *testing.T) {
var err error
defer func() {
assert.NoError(t, err)

if atomic.AddInt32(&waited, -1) == 0 {
close(done)
}
Expand Down
40 changes: 10 additions & 30 deletions go/event/event_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"reflect"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

type testInterface1 interface {
Expand Down Expand Up @@ -56,10 +58,7 @@ func TestStaticListener(t *testing.T) {
AddListener(func(testEvent1) { triggered = true })
AddListener(func(testEvent2) { t.Errorf("wrong listener type triggered") })
Dispatch(testEvent1{})

if !triggered {
t.Errorf("static listener failed to trigger")
}
assert.True(t, triggered, "static listener failed to trigger")
}

func TestPointerListener(t *testing.T) {
Expand All @@ -69,10 +68,7 @@ func TestPointerListener(t *testing.T) {
AddListener(func(ev *testEvent2) { ev.triggered = true })
AddListener(func(testEvent2) { t.Errorf("non-pointer listener triggered on pointer type") })
Dispatch(testEvent)

if !testEvent.triggered {
t.Errorf("pointer listener failed to trigger")
}
assert.True(t, testEvent.triggered, "pointer listener failed to trigger")
}

func TestInterfaceListener(t *testing.T) {
Expand All @@ -82,10 +78,7 @@ func TestInterfaceListener(t *testing.T) {
AddListener(func(testInterface1) { triggered = true })
AddListener(func(testInterface2) { t.Errorf("interface listener triggered on non-matching type") })
Dispatch(testEvent1{})

if !triggered {
t.Errorf("interface listener failed to trigger")
}
assert.True(t, triggered, "interface listener failed to trigger")
}

func TestEmptyInterfaceListener(t *testing.T) {
Expand All @@ -94,10 +87,7 @@ func TestEmptyInterfaceListener(t *testing.T) {
triggered := false
AddListener(func(any) { triggered = true })
Dispatch("this should match any")

if !triggered {
t.Errorf("any listener failed to trigger")
}
assert.True(t, triggered, "empty listener failed to trigger")
}

func TestMultipleListeners(t *testing.T) {
Expand Down Expand Up @@ -144,7 +134,6 @@ func TestBadListenerWrongType(t *testing.T) {

defer func() {
err := recover()

if err == nil {
t.Errorf("bad listener type (not a func) failed to trigger panic")
}
Expand Down Expand Up @@ -186,10 +175,8 @@ func TestDispatchPointerToValueInterfaceListener(t *testing.T) {
triggered = true
})
Dispatch(&testEvent1{})
assert.True(t, triggered, "Dispatch by pointer failed to trigger interface listener")

if !triggered {
t.Errorf("Dispatch by pointer failed to trigger interface listener")
}
}

func TestDispatchValueToValueInterfaceListener(t *testing.T) {
Expand All @@ -200,10 +187,7 @@ func TestDispatchValueToValueInterfaceListener(t *testing.T) {
triggered = true
})
Dispatch(testEvent1{})

if !triggered {
t.Errorf("Dispatch by value failed to trigger interface listener")
}
assert.True(t, triggered, "Dispatch by value failed to trigger interface listener")
}

func TestDispatchPointerToPointerInterfaceListener(t *testing.T) {
Expand All @@ -212,10 +196,8 @@ func TestDispatchPointerToPointerInterfaceListener(t *testing.T) {
triggered := false
AddListener(func(testInterface2) { triggered = true })
Dispatch(&testEvent2{})
assert.True(t, triggered, "interface listener failed to trigger for pointer")

if !triggered {
t.Errorf("interface listener failed to trigger for pointer")
}
}

func TestDispatchValueToPointerInterfaceListener(t *testing.T) {
Expand Down Expand Up @@ -245,10 +227,8 @@ func TestDispatchUpdate(t *testing.T) {

ev := &testUpdateEvent{}
DispatchUpdate(ev, "hello")
assert.True(t, triggered, "listener failed to trigger on DispatchUpdate()")

if !triggered {
t.Errorf("listener failed to trigger on DispatchUpdate()")
}
want := "hello"
if got := ev.update.(string); got != want {
t.Errorf("ev.update = %#v, want %#v", got, want)
Expand Down
63 changes: 21 additions & 42 deletions go/event/syslogger/syslogger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import (
"testing"

"vitess.io/vitess/go/event"

"github.com/stretchr/testify/assert"
)

type TestEvent struct {
Expand Down Expand Up @@ -70,10 +72,8 @@ func TestSyslog(t *testing.T) {

ev := new(TestEvent)
event.Dispatch(ev)
assert.True(t, ev.triggered)

if !ev.triggered {
t.Errorf("Syslog() was not called on event that implements Syslogger")
}
}

// TestBadWriter verifies we are still triggering (to normal logs) if
Expand All @@ -87,55 +87,40 @@ func TestBadWriter(t *testing.T) {
wantLevel := "ERROR"
ev := &TestEvent{priority: syslog.LOG_ALERT, message: wantMsg}
event.Dispatch(ev)
if !strings.Contains(tl.getLog().msg, wantMsg) {
t.Errorf("error log msg [%s], want msg [%s]", tl.getLog().msg, wantMsg)
}
if !strings.Contains(tl.getLog().level, wantLevel) {
t.Errorf("error log level [%s], want level [%s]", tl.getLog().level, wantLevel)
}
assert.True(t, strings.Contains(tl.getLog().msg, wantMsg))
assert.True(t, strings.Contains(tl.getLog().level, wantLevel))

ev = &TestEvent{priority: syslog.LOG_CRIT, message: wantMsg}
event.Dispatch(ev)
if !strings.Contains(tl.getLog().level, wantLevel) {
t.Errorf("error log level [%s], want level [%s]", tl.getLog().level, wantLevel)
}
assert.True(t, strings.Contains(tl.getLog().level, wantLevel))

ev = &TestEvent{priority: syslog.LOG_ERR, message: wantMsg}
event.Dispatch(ev)
if !strings.Contains(tl.getLog().level, wantLevel) {
t.Errorf("error log level [%s], want level [%s]", tl.getLog().level, wantLevel)
}
assert.True(t, strings.Contains(tl.getLog().level, wantLevel))

ev = &TestEvent{priority: syslog.LOG_EMERG, message: wantMsg}
event.Dispatch(ev)
if !strings.Contains(tl.getLog().level, wantLevel) {
t.Errorf("error log level [%s], want level [%s]", tl.getLog().level, wantLevel)
}
assert.True(t, strings.Contains(tl.getLog().level, wantLevel))

wantLevel = "WARNING"
ev = &TestEvent{priority: syslog.LOG_WARNING, message: wantMsg}
event.Dispatch(ev)
if !strings.Contains(tl.getLog().level, wantLevel) {
t.Errorf("error log level [%s], want level [%s]", tl.getLog().level, wantLevel)
}
assert.True(t, strings.Contains(tl.getLog().level, wantLevel))

wantLevel = "INFO"
ev = &TestEvent{priority: syslog.LOG_INFO, message: wantMsg}
event.Dispatch(ev)
if !strings.Contains(tl.getLog().level, wantLevel) {
t.Errorf("error log level [%s], want level [%s]", tl.getLog().level, wantLevel)
}
assert.True(t, strings.Contains(tl.getLog().level, wantLevel))

ev = &TestEvent{priority: syslog.LOG_NOTICE, message: wantMsg}
event.Dispatch(ev)
if !strings.Contains(tl.getLog().level, wantLevel) {
t.Errorf("error log level [%s], want level [%s]", tl.getLog().level, wantLevel)
}
assert.True(t, strings.Contains(tl.getLog().level, wantLevel))

ev = &TestEvent{priority: syslog.LOG_DEBUG, message: wantMsg}
event.Dispatch(ev)
if !strings.Contains(tl.getLog().level, wantLevel) {
t.Errorf("error log level [%s], want level [%s]", tl.getLog().level, wantLevel)
}
assert.True(t, strings.Contains(tl.getLog().level, wantLevel))
assert.True(t, ev.triggered)

if !ev.triggered {
t.Errorf("passed nil writer to client")
}
}

// TestWriteError checks that we don't panic on a write error.
Expand All @@ -150,24 +135,18 @@ func TestInvalidSeverity(t *testing.T) {
writer = fw

event.Dispatch(&TestEvent{priority: syslog.Priority(123), message: "log me"})
assert.NotEqual(t, "log me", fw.message)

if fw.message == "log me" {
t.Errorf("message was logged despite invalid severity")
}
}

func testSeverity(sev syslog.Priority, t *testing.T) {
fw := &fakeWriter{}
writer = fw

event.Dispatch(&TestEvent{priority: sev, message: "log me"})
assert.Equal(t, sev, fw.priority)
assert.Equal(t, "log me", fw.message)

if fw.priority != sev {
t.Errorf("wrong priority: got %v, want %v", fw.priority, sev)
}
if fw.message != "log me" {
t.Errorf(`wrong message: got "%v", want "%v"`, fw.message, "log me")
}
}

func TestEmerg(t *testing.T) {
Expand Down
10 changes: 4 additions & 6 deletions go/flagutil/flagutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"testing"

"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

Expand All @@ -38,12 +39,9 @@ func TestStringList(t *testing.T) {
t.Errorf("v.Set(%v): %v", in, err)
continue
}
if strings.Join(p, ".") != out {
t.Errorf("want %#v, got %#v", strings.Split(out, "."), p)
}
if p.String() != in {
t.Errorf("v.String(): want %#v, got %#v", in, p.String())
}
assert.Equal(t, out, strings.Join(p, "."))
assert.Equal(t, in, p.String())

}
}

Expand Down
12 changes: 6 additions & 6 deletions go/history/history_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package history

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestHistory(t *testing.T) {
Expand All @@ -33,9 +35,8 @@ func TestHistory(t *testing.T) {
t.Errorf("len(records): want %v, got %v. records: %+v", want, got, q)
}
for i, record := range records {
if record != want[i] {
t.Errorf("record doesn't match: want %v, got %v", want[i], record)
}
assert.Equal(t, want[i], record)

}

for ; i < 6; i++ {
Expand All @@ -48,9 +49,8 @@ func TestHistory(t *testing.T) {
t.Errorf("len(records): want %v, got %v. records: %+v", want, got, q)
}
for i, record := range records {
if record != want[i] {
t.Errorf("record doesn't match: want %v, got %v", want[i], record)
}
assert.Equal(t, want[i], record)

}
}

Expand Down
Loading

0 comments on commit 6fec119

Please sign in to comment.