Skip to content

Commit

Permalink
Dev 1st test (#3)
Browse files Browse the repository at this point in the history
* ADD: 1st test

* ADD: first working tests organization

* FIX: lib renaming -> choice_learn apply changes
  • Loading branch information
VincentAuriau authored Dec 22, 2023
1 parent 3744dc0 commit efa9d69
Show file tree
Hide file tree
Showing 16 changed files with 121 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ repos:
name: Security check (bandit)
entry: bandit
types: [python]
args: ["--recursive", "lib/"]
args: ["-x", "tests", --recursive, choice_learn]
language: system
- id: pytest-check
name: Tests (pytest)
Expand Down
1 change: 1 addition & 0 deletions choice_learn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Choice-Learn library for Python."""
1 change: 1 addition & 0 deletions choice_learn/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Data handling classes and functions."""
87 changes: 82 additions & 5 deletions lib/data/choice_dataset.py → choice_learn/data/choice_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import numpy as np
import pandas as pd
from choice_modeling.data.indexer import ChoiceDatasetIndexer
from choice_modeling.data.store import Store

from choice_learn.data.indexer import ChoiceDatasetIndexer
from choice_learn.data.store import Store


class ChoiceDataset(object):
Expand Down Expand Up @@ -686,7 +687,39 @@ def save(self):

def summary(self):
"""Method to display a summary of the dataset."""
raise NotImplementedError
print("Summary of the dataset:")
print("Number of items:", self.get_num_items())
print("Number of sessions:", self.get_num_sessions())
print(
"Number of choices:",
self.get_num_choices(),
"Averaging",
self.get_num_choices() / self.get_num_sessions(),
"choices per session",
)
if self.items_features is not None:
print(f"Items features: {self.items_features_names}")
if self.items_features is not None:
print(f"{sum([f.shape[1] for f in self.items_features])} items features")
else:
print("No items features registered")

if self.sessions_features is not None:
print(f"Sessions features: {self.sessions_features_names}")
if self.sessions_features is not None:
print(f"{sum([f.shape[1] for f in self.sessions_features])} session features")
else:
print("No sessions features registered")

if self.sessions_featuresitems_features is not None:
print(f"Session Items features: {self.sessions_items_features_names}")
if self.sessions_items_features is not None:
print(
f"{sum([f.shape[2] for f in self.sessions_items_features])} sessions \
items features"
)
else:
print("No sessions items features registered")

def get_choice_batch(self, choice_index):
"""Method to access data within the ListChoiceDataset from its index.
Expand Down Expand Up @@ -845,7 +878,7 @@ def __getitem__(self, session_indexes):
sessions_items_features_names=self.sessions_items_features_names,
)

def batch(self, batch_size=None, shuffle=None, sample_weight=None):
def old_batch(self, batch_size=None, shuffle=None, sample_weight=None):
"""Iterates over dataset return batches of length self.batch_size.
Parameters
Expand Down Expand Up @@ -892,6 +925,50 @@ def batch(self, batch_size=None, shuffle=None, sample_weight=None):
yielded_size += 2 * num_choices

@property
def iloc(self):
def batch(self):
"""Indexer."""
return self.indexer

def iter_batch(self, batch_size=None, shuffle=None, sample_weight=None):
"""Iterates over dataset return batches of length self.batch_size.
Newer version.
Parameters
----------
batch_size : int
batch size to set
shuffle: bool
Whether or not to shuffle the dataset
sample_weight : Iterable
list of weights to be returned with the right indexing during the shuffling
"""
if batch_size is None:
batch_size = self.batch_size
if shuffle is None:
shuffle = self.shuffle
if batch_size == -1:
batch_size = self.get_num_choices()

# Get indexes for each choice
num_choices = self.get_num_choices()
indexes = np.arange(num_choices)
# Shuffle indexes
if shuffle and not batch_size == -1:
indexes = np.random.permutation(indexes)

yielded_size = 0
while yielded_size < num_choices:
# Return sample_weight if not None, for index matching
if sample_weight is not None:
yield (
self.batch[indexes[yielded_size : yielded_size + batch_size].tolist()],
sample_weight[indexes[yielded_size : yielded_size + batch_size].tolist()],
)
else:
yield self.batch[indexes[yielded_size : yielded_size + batch_size].tolist()]
yielded_size += batch_size

# Special exit strategy for batch_size = -1
if batch_size == -1:
yielded_size += 2 * num_choices
2 changes: 1 addition & 1 deletion lib/data/indexer.py → choice_learn/data/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __getitem__(self, sequence_index):
if isinstance(sequence_index, slice):
return [
self.store.store[self.store.sequence[i]]
for i in range(*sequence_index.indices(len(self.sequence)))
for i in range(*sequence_index.indices(len(self.store.sequence)))
]
return self.store.store[self.store.sequence[sequence_index]]

Expand Down
7 changes: 4 additions & 3 deletions lib/data/store.py → choice_learn/data/store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Different classes to optimize RAM usage with repeated features over time."""
import numpy as np
from choice_modeling.data.indexer import OneHotStoreIndexer, StoreIndexer

from choice_learn.data.indexer import OneHotStoreIndexer, StoreIndexer


class Store(object):
Expand All @@ -22,7 +23,7 @@ def __init__(self, indexes=None, values=None, sequence=None, name=None, indexer=
name of the features store -- not used at the moment
"""
if indexes is None:
indexes = list(range(values))
indexes = list(range(len(values)))
self.store = {k: v for (k, v) in zip(indexes, values)}
self.sequence = np.array(sequence)
self.name = name
Expand Down Expand Up @@ -62,7 +63,7 @@ def __len__(self):
return len(self.sequence)

@property
def iloc(self):
def batch(self):
"""Indexing attribute."""
return self.indexer

Expand Down
1 change: 1 addition & 0 deletions choice_learn/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Models classes and functions."""
File renamed without changes.
File renamed without changes.
3 changes: 2 additions & 1 deletion lib/models/rumnet.py → choice_learn/models/rumnet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Implementation of RUMnet for easy use."""
import tensorflow as tf
from choice_modeling.models.base_model import ChoiceModel

from choice_learn.models.base_model import ChoiceModel


class PaperRUMnet(ChoiceModel):
Expand Down
File renamed without changes.
Empty file removed lib/.gitkeep
Empty file.
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ requires-python = ">=3.8"
"Documentation" = "https://artefactory.github.io/choice-learn-private"

[tool.setuptools]
packages = ["lib", "config", "tests"]
packages = ["choice_learn", "config", "tests"]

[tool.ruff]
select = [
Expand Down Expand Up @@ -62,4 +62,7 @@ convention = "google"
quote-style = "double"

[tool.ruff.isort]
known-first-party = ["lib", "config", "tests"]
known-first-party = ["choice_learn", "config", "tests"]

[tool.bandit]
exclude_dirs = ["tests/"]
Empty file removed tests/unit_tests/.gitkeep
Empty file.
23 changes: 23 additions & 0 deletions tests/unit_tests/data/test_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Test the store module."""
from choice_learn.data.store import Store


def test_len_store():
"""Test the __len__ method of Store."""
store = Store(values=[1, 2, 3, 4], sequence=[0, 1, 2, 3, 0, 1, 2, 3])
assert len(store) == 8


def test_get_store_element():
"""Test the _get_store_element method of Store."""
store = Store(values=[1, 2, 3, 4], sequence=[0, 1, 2, 3, 0, 1, 2, 3])
assert store._get_store_element(0) == 1
assert store._get_store_element([0, 1, 2]) == [1, 2, 3]


def test_store_batch():
"""Test the batch method of Store."""
store = Store(values=[1, 2, 3, 4], sequence=[0, 1, 2, 3, 0, 1, 2, 3])
assert store.batch[1] == 2
assert store.batch[2:4] == [3, 4]
assert store.batch[[2, 3, 6, 7]] == [3, 4, 3, 4]
6 changes: 0 additions & 6 deletions tests/unit_tests/test_placeholder.py

This file was deleted.

0 comments on commit efa9d69

Please sign in to comment.