From 8f50e4bbda191429d1bdbacbf555be50b26bbd78 Mon Sep 17 00:00:00 2001 From: VincentAuriau Date: Fri, 26 Apr 2024 10:22:44 +0200 Subject: [PATCH] FIX: indexer tests --- README.md | 2 +- .../data/test_dataset_indexer.py | 24 +++++++++---------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 571375d4..021e158b 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ This repository contains a private version of the package. ## Table of Contents -- [choice-learn-private](#choice-learn-private) +[Choice-Learn](#choice-learn-private) - [Introduction - Discrete Choice Modelling](#introduction---discrete-choice-modelling) - [What's in there ?](#whats-in-there) - [Getting Started](#getting-started---fast-track) diff --git a/tests/integration_tests/data/test_dataset_indexer.py b/tests/integration_tests/data/test_dataset_indexer.py index 82f66c5d..c1bf1f9b 100644 --- a/tests/integration_tests/data/test_dataset_indexer.py +++ b/tests/integration_tests/data/test_dataset_indexer.py @@ -51,20 +51,20 @@ def test_batch(): assert (batch[0][1] == np.array([1, 0, 0, 0])).all() assert (batch[1][0] == np.array([[2, 2, 2, 2], [2, 2, 2, 3]])).all() - assert (batch[1][0] == np.array([[1, 0, 0, 0], [0, 1, 0, 0]])).all() + assert (batch[1][1] == np.array([[1, 0, 0, 0], [0, 1, 0, 0]])).all() - assert (batch[3] == np.array([1.0, 1.0])).all() - assert batch[4] == 0 + assert (batch[2] == np.array([1.0, 1.0])).all() + assert batch[3] == 0 batch = dataset.batch[0] assert (batch[0][0] == np.array([2, 1])).all() assert (batch[0][1] == np.array([1, 0, 0, 0])).all() assert (batch[1][0] == np.array([[2, 2, 2, 2], [2, 2, 2, 3]])).all() - assert (batch[1][0] == np.array([[1, 0, 0, 0], [0, 1, 0, 0]])).all() + assert (batch[1][1] == np.array([[1, 0, 0, 0], [0, 1, 0, 0]])).all() - assert (batch[3] == np.array([1.0, 1.0])).all() - assert batch[4] == 0 + assert (batch[2] == np.array([1.0, 1.0])).all() + assert batch[3] == 0 batch = dataset.get_choices_batch([1, 2]) assert (batch[0][0] == np.array([[3, 4], [9, 4]])).all() @@ -74,11 +74,11 @@ def test_batch(): batch[1][0] == np.array([[[2, 2, 3, 2], [3, 2, 2, 2]], [[3, 2, 2, 2], [2, 3, 2, 2]]]) ).all() assert ( - batch[1][0] == np.array([[[0, 0, 0, 1], [0, 0, 1, 0]], [[1, 0, 0, 0], [0, 1, 0, 0]]]) + batch[1][1] == np.array([[[0, 0, 0, 1], [0, 0, 1, 0]], [[1, 0, 0, 0], [0, 1, 0, 0]]]) ).all() - assert (batch[3] == np.array([[1.0, 1.0], [1.0, 1.0]])).all() - assert (batch[4] == np.array([1, 1])).all() + assert (batch[2] == np.array([[1.0, 1.0], [1.0, 1.0]])).all() + assert (batch[3] == np.array([1, 1])).all() batch = dataset.batch[[1, 2]] assert (batch[0][0] == np.array([[3, 4], [9, 4]])).all() @@ -88,8 +88,8 @@ def test_batch(): batch[1][0] == np.array([[[2, 2, 3, 2], [3, 2, 2, 2]], [[3, 2, 2, 2], [2, 3, 2, 2]]]) ).all() assert ( - batch[1][0] == np.array([[[0, 0, 0, 1], [0, 0, 1, 0]], [[1, 0, 0, 0], [0, 1, 0, 0]]]) + batch[1][1] == np.array([[[0, 0, 0, 1], [0, 0, 1, 0]], [[1, 0, 0, 0], [0, 1, 0, 0]]]) ).all() - assert (batch[3] == np.array([[1.0, 1.0], [1.0, 1.0]])).all() - assert (batch[4] == np.array([1, 1])).all() + assert (batch[2] == np.array([[1.0, 1.0], [1.0, 1.0]])).all() + assert (batch[3] == np.array([1, 1])).all()