Skip to content

Commit

Permalink
FIX: indexer tests
Browse files Browse the repository at this point in the history
  • Loading branch information
VincentAuriau committed Apr 26, 2024
1 parent 9fbcc25 commit 8f50e4b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ This repository contains a private version of the package.

## Table of Contents

- [choice-learn-private](#choice-learn-private)
[Choice-Learn](#choice-learn-private)
- [Introduction - Discrete Choice Modelling](#introduction---discrete-choice-modelling)
- [What's in there ?](#whats-in-there)
- [Getting Started](#getting-started---fast-track)
Expand Down
24 changes: 12 additions & 12 deletions tests/integration_tests/data/test_dataset_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,20 @@ def test_batch():
assert (batch[0][1] == np.array([1, 0, 0, 0])).all()

assert (batch[1][0] == np.array([[2, 2, 2, 2], [2, 2, 2, 3]])).all()
assert (batch[1][0] == np.array([[1, 0, 0, 0], [0, 1, 0, 0]])).all()
assert (batch[1][1] == np.array([[1, 0, 0, 0], [0, 1, 0, 0]])).all()

assert (batch[3] == np.array([1.0, 1.0])).all()
assert batch[4] == 0
assert (batch[2] == np.array([1.0, 1.0])).all()
assert batch[3] == 0

batch = dataset.batch[0]
assert (batch[0][0] == np.array([2, 1])).all()
assert (batch[0][1] == np.array([1, 0, 0, 0])).all()

assert (batch[1][0] == np.array([[2, 2, 2, 2], [2, 2, 2, 3]])).all()
assert (batch[1][0] == np.array([[1, 0, 0, 0], [0, 1, 0, 0]])).all()
assert (batch[1][1] == np.array([[1, 0, 0, 0], [0, 1, 0, 0]])).all()

assert (batch[3] == np.array([1.0, 1.0])).all()
assert batch[4] == 0
assert (batch[2] == np.array([1.0, 1.0])).all()
assert batch[3] == 0

batch = dataset.get_choices_batch([1, 2])
assert (batch[0][0] == np.array([[3, 4], [9, 4]])).all()
Expand All @@ -74,11 +74,11 @@ def test_batch():
batch[1][0] == np.array([[[2, 2, 3, 2], [3, 2, 2, 2]], [[3, 2, 2, 2], [2, 3, 2, 2]]])
).all()
assert (
batch[1][0] == np.array([[[0, 0, 0, 1], [0, 0, 1, 0]], [[1, 0, 0, 0], [0, 1, 0, 0]]])
batch[1][1] == np.array([[[0, 0, 0, 1], [0, 0, 1, 0]], [[1, 0, 0, 0], [0, 1, 0, 0]]])
).all()

assert (batch[3] == np.array([[1.0, 1.0], [1.0, 1.0]])).all()
assert (batch[4] == np.array([1, 1])).all()
assert (batch[2] == np.array([[1.0, 1.0], [1.0, 1.0]])).all()
assert (batch[3] == np.array([1, 1])).all()

batch = dataset.batch[[1, 2]]
assert (batch[0][0] == np.array([[3, 4], [9, 4]])).all()
Expand All @@ -88,8 +88,8 @@ def test_batch():
batch[1][0] == np.array([[[2, 2, 3, 2], [3, 2, 2, 2]], [[3, 2, 2, 2], [2, 3, 2, 2]]])
).all()
assert (
batch[1][0] == np.array([[[0, 0, 0, 1], [0, 0, 1, 0]], [[1, 0, 0, 0], [0, 1, 0, 0]]])
batch[1][1] == np.array([[[0, 0, 0, 1], [0, 0, 1, 0]], [[1, 0, 0, 0], [0, 1, 0, 0]]])
).all()

assert (batch[3] == np.array([[1.0, 1.0], [1.0, 1.0]])).all()
assert (batch[4] == np.array([1, 1])).all()
assert (batch[2] == np.array([[1.0, 1.0], [1.0, 1.0]])).all()
assert (batch[3] == np.array([1, 1])).all()

0 comments on commit 8f50e4b

Please sign in to comment.