diff --git a/mnist_test.go b/mnist_test.go index 5faf456..9e6a287 100644 --- a/mnist_test.go +++ b/mnist_test.go @@ -47,3 +47,16 @@ func TestLoad(t *testing.T) { } println(train.Count(), test.Count()) } + +func TestSweeperNext(t *testing.T) { + train, _, err := Load("./data") + if err != nil { + t.Fatalf("load (%s)", err) + } + sweeper := train.Sweep() + var currentIndex = sweeper.i + sweeper.Next() + if currentIndex == sweeper.i { + t.Errorf("Next does not increase index") + } +} diff --git a/util.go b/util.go index 1a015ed..1af888d 100644 --- a/util.go +++ b/util.go @@ -57,10 +57,12 @@ type Sweeper struct { // Next returns the next image and its label in the data set. // If the end is reached, present is set to false. func (sw *Sweeper) Next() (image RawImage, label Label, present bool) { - if sw.i >= len(sw.set.Images) { + var prevIndex = sw.i + sw.i += 1 + if prevIndex >= len(sw.set.Images) { return nil, 0, false } - return sw.set.Images[sw.i], sw.set.Labels[sw.i], true + return sw.set.Images[prevIndex], sw.set.Labels[prevIndex], true } // Sweep creates a new sweep iterator over the data set