From 39cdda10e705a55f23c328e8fde732e323ca1f11 Mon Sep 17 00:00:00 2001 From: lthoang Date: Thu, 7 Dec 2023 02:52:56 +0800 Subject: [PATCH] Add test case reading basket data --- tests/cornac/data/test_reader.py | 61 ++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 26 deletions(-) diff --git a/tests/cornac/data/test_reader.py b/tests/cornac/data/test_reader.py index fec565a76..f8c7a1eb1 100644 --- a/tests/cornac/data/test_reader.py +++ b/tests/cornac/data/test_reader.py @@ -13,31 +13,29 @@ # limitations under the License. # ============================================================================ -import unittest - from cornac.data import Reader from cornac.data.reader import read_text class TestReader(unittest.TestCase): - def setUp(self): - self.data_file = './tests/data.txt' + self.data_file = "./tests/data.txt" + self.basket_file = "./tests/basket.txt" self.reader = Reader() def test_raise(self): try: - self.reader.read(self.data_file, fmt='bla bla') + self.reader.read(self.data_file, fmt="bla bla") except ValueError: assert True def test_read_ui(self): - triplets = self.reader.read(self.data_file, fmt='UI') + triplets = self.reader.read(self.data_file, fmt="UI") self.assertEqual(len(triplets), 30) - self.assertEqual(triplets[0][1], '93') + self.assertEqual(triplets[0][1], "93") self.assertEqual(triplets[1][2], 1.0) - triplets = self.reader.read(self.data_file, fmt='UI', id_inline=True) + triplets = self.reader.read(self.data_file, fmt="UI", id_inline=True) self.assertEqual(len(triplets), 40) def test_read_uir(self): @@ -45,32 +43,32 @@ def test_read_uir(self): self.assertEqual(len(triplet_data), 10) self.assertEqual(triplet_data[4][2], 3) - self.assertEqual(triplet_data[6][1], '478') - self.assertEqual(triplet_data[8][0], '543') + self.assertEqual(triplet_data[6][1], "478") + self.assertEqual(triplet_data[8][0], "543") def test_read_uirt(self): - data = self.reader.read(self.data_file, fmt='UIRT') + data = self.reader.read(self.data_file, fmt="UIRT") self.assertEqual(len(data), 10) self.assertEqual(data[4][3], 891656347) self.assertEqual(data[4][2], 3) - self.assertEqual(data[4][1], '705') - self.assertEqual(data[4][0], '329') + self.assertEqual(data[4][1], "705") + self.assertEqual(data[4][0], "329") self.assertEqual(data[9][3], 879451804) def test_read_tup(self): - tup_data = self.reader.read(self.data_file, fmt='UITup') + tup_data = self.reader.read(self.data_file, fmt="UITup") self.assertEqual(len(tup_data), 10) - self.assertEqual(tup_data[4][2], [('3',), ('891656347',)]) - self.assertEqual(tup_data[6][1], '478') - self.assertEqual(tup_data[8][0], '543') + self.assertEqual(tup_data[4][2], [("3",), ("891656347",)]) + self.assertEqual(tup_data[6][1], "478") + self.assertEqual(tup_data[8][0], "543") def test_read_review(self): - review_data = self.reader.read('./tests/review.txt', fmt='UIReview') + review_data = self.reader.read("./tests/review.txt", fmt="UIReview") self.assertEqual(len(review_data), 5) - self.assertEqual(review_data[0][2], 'Sample text 1') - self.assertEqual(review_data[1][1], '257') - self.assertEqual(review_data[4][0], '329') + self.assertEqual(review_data[0][2], "Sample text 1") + self.assertEqual(review_data[1][1], "257") + self.assertEqual(review_data[4][0], "329") def test_filter(self): reader = Reader(bin_threshold=4.0) @@ -84,19 +82,30 @@ def test_filter(self): reader = Reader(min_item_freq=2) self.assertEqual(len(reader.read(self.data_file)), 0) - reader = Reader(user_set=['76'], item_set=['93']) + reader = Reader(user_set=["76"], item_set=["93"]) self.assertEqual(len(reader.read(self.data_file)), 1) - reader = Reader(user_set=['76', '768']) + reader = Reader(user_set=["76", "768"]) self.assertEqual(len(reader.read(self.data_file)), 2) - reader = Reader(item_set=['93', '257', '795']) + reader = Reader(item_set=["93", "257", "795"]) self.assertEqual(len(reader.read(self.data_file)), 3) def test_read_text(self): self.assertEqual(len(read_text(self.data_file, sep=None)), 10) - self.assertEqual(read_text(self.data_file, sep='\t')[1][0], '76') + self.assertEqual(read_text(self.data_file, sep="\t")[1][0], "76") + + def test_read_basket(self): + self.assertEqual( + len(self.reader.read(self.basket_file, sep="\t", fmt="UBI")), 50 + ) + self.assertEqual( + len(self.reader.read(self.basket_file, sep="\t", fmt="UBIT")), 50 + ) + self.assertEqual( + len(self.reader.read(self.basket_file, sep="\t", fmt="UBITJson")), 50 + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main()