Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev 1st test #3

Merged
merged 14 commits into from
Dec 22, 2023
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ repos:
name: Security check (bandit)
entry: bandit
types: [python]
args: ["--recursive", "lib/"]
args: ["-x", "tests", --recursive, choice_learn]
language: system
- id: pytest-check
name: Tests (pytest)
Expand Down
1 change: 1 addition & 0 deletions choice_learn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Choice-Learn library for Python."""
1 change: 1 addition & 0 deletions choice_learn/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Data handling classes and functions."""
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import numpy as np
import pandas as pd
from choice_modeling.data.indexer import ChoiceDatasetIndexer
from choice_modeling.data.store import Store

from choice_learn.data.indexer import ChoiceDatasetIndexer
from choice_learn.data.store import Store


class ChoiceDataset(object):
Expand Down Expand Up @@ -686,7 +687,39 @@ def save(self):

def summary(self):
"""Method to display a summary of the dataset."""
raise NotImplementedError
print("Summary of the dataset:")
print("Number of items:", self.get_num_items())
print("Number of sessions:", self.get_num_sessions())
print(
"Number of choices:",
self.get_num_choices(),
"Averaging",
self.get_num_choices() / self.get_num_sessions(),
"choices per session",
)
if self.items_features is not None:
print(f"Items features: {self.items_features_names}")
if self.items_features is not None:
print(f"{sum([f.shape[1] for f in self.items_features])} items features")
else:
print("No items features registered")

if self.sessions_features is not None:
print(f"Sessions features: {self.sessions_features_names}")
if self.sessions_features is not None:
print(f"{sum([f.shape[1] for f in self.sessions_features])} session features")
else:
print("No sessions features registered")

if self.sessions_featuresitems_features is not None:
print(f"Session Items features: {self.sessions_items_features_names}")
if self.sessions_items_features is not None:
print(
f"{sum([f.shape[2] for f in self.sessions_items_features])} sessions \
items features"
)
else:
print("No sessions items features registered")

def get_choice_batch(self, choice_index):
"""Method to access data within the ListChoiceDataset from its index.
Expand Down Expand Up @@ -845,7 +878,7 @@ def __getitem__(self, session_indexes):
sessions_items_features_names=self.sessions_items_features_names,
)

def batch(self, batch_size=None, shuffle=None, sample_weight=None):
def old_batch(self, batch_size=None, shuffle=None, sample_weight=None):
"""Iterates over dataset return batches of length self.batch_size.

Parameters
Expand Down Expand Up @@ -892,6 +925,50 @@ def batch(self, batch_size=None, shuffle=None, sample_weight=None):
yielded_size += 2 * num_choices

@property
def iloc(self):
def batch(self):
"""Indexer."""
return self.indexer

def iter_batch(self, batch_size=None, shuffle=None, sample_weight=None):
"""Iterates over dataset return batches of length self.batch_size.

Newer version.

Parameters
----------
batch_size : int
batch size to set
shuffle: bool
Whether or not to shuffle the dataset
sample_weight : Iterable
list of weights to be returned with the right indexing during the shuffling
"""
if batch_size is None:
batch_size = self.batch_size
if shuffle is None:
shuffle = self.shuffle
if batch_size == -1:
batch_size = self.get_num_choices()

# Get indexes for each choice
num_choices = self.get_num_choices()
indexes = np.arange(num_choices)
# Shuffle indexes
if shuffle and not batch_size == -1:
indexes = np.random.permutation(indexes)

yielded_size = 0
while yielded_size < num_choices:
# Return sample_weight if not None, for index matching
if sample_weight is not None:
yield (
self.batch[indexes[yielded_size : yielded_size + batch_size].tolist()],
sample_weight[indexes[yielded_size : yielded_size + batch_size].tolist()],
)
else:
yield self.batch[indexes[yielded_size : yielded_size + batch_size].tolist()]
yielded_size += batch_size

# Special exit strategy for batch_size = -1
if batch_size == -1:
yielded_size += 2 * num_choices
2 changes: 1 addition & 1 deletion lib/data/indexer.py → choice_learn/data/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __getitem__(self, sequence_index):
if isinstance(sequence_index, slice):
return [
self.store.store[self.store.sequence[i]]
for i in range(*sequence_index.indices(len(self.sequence)))
for i in range(*sequence_index.indices(len(self.store.sequence)))
]
return self.store.store[self.store.sequence[sequence_index]]

Expand Down
7 changes: 4 additions & 3 deletions lib/data/store.py → choice_learn/data/store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Different classes to optimize RAM usage with repeated features over time."""
import numpy as np
from choice_modeling.data.indexer import OneHotStoreIndexer, StoreIndexer

from choice_learn.data.indexer import OneHotStoreIndexer, StoreIndexer


class Store(object):
Expand All @@ -22,7 +23,7 @@ def __init__(self, indexes=None, values=None, sequence=None, name=None, indexer=
name of the features store -- not used at the moment
"""
if indexes is None:
indexes = list(range(values))
indexes = list(range(len(values)))
self.store = {k: v for (k, v) in zip(indexes, values)}
self.sequence = np.array(sequence)
self.name = name
Expand Down Expand Up @@ -62,7 +63,7 @@ def __len__(self):
return len(self.sequence)

@property
def iloc(self):
def batch(self):
"""Indexing attribute."""
return self.indexer

Expand Down
1 change: 1 addition & 0 deletions choice_learn/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Models classes and functions."""
File renamed without changes.
3 changes: 2 additions & 1 deletion lib/models/rumnet.py → choice_learn/models/rumnet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Implementation of RUMnet for easy use."""
import tensorflow as tf
from choice_modeling.models.base_model import ChoiceModel

from choice_learn.models.base_model import ChoiceModel


class PaperRUMnet(ChoiceModel):
Expand Down
File renamed without changes.
Empty file removed lib/.gitkeep
Empty file.
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ requires-python = ">=3.8"
"Documentation" = "https://artefactory.github.io/choice-learn-private"

[tool.setuptools]
packages = ["lib", "config", "tests"]
packages = ["choice_learn", "config", "tests"]

[tool.ruff]
select = [
Expand Down Expand Up @@ -62,4 +62,7 @@ convention = "google"
quote-style = "double"

[tool.ruff.isort]
known-first-party = ["lib", "config", "tests"]
known-first-party = ["choice_learn", "config", "tests"]

[tool.bandit]
exclude_dirs = ["tests/"]
Empty file removed tests/unit_tests/.gitkeep
Empty file.
23 changes: 23 additions & 0 deletions tests/unit_tests/data/test_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Test the store module."""
from choice_learn.data.store import Store


def test_len_store():
"""Test the __len__ method of Store."""
store = Store(values=[1, 2, 3, 4], sequence=[0, 1, 2, 3, 0, 1, 2, 3])
assert len(store) == 8


def test_get_store_element():
"""Test the _get_store_element method of Store."""
store = Store(values=[1, 2, 3, 4], sequence=[0, 1, 2, 3, 0, 1, 2, 3])
assert store._get_store_element(0) == 1
assert store._get_store_element([0, 1, 2]) == [1, 2, 3]


def test_store_batch():
"""Test the batch method of Store."""
store = Store(values=[1, 2, 3, 4], sequence=[0, 1, 2, 3, 0, 1, 2, 3])
assert store.batch[1] == 2
assert store.batch[2:4] == [3, 4]
assert store.batch[[2, 3, 6, 7]] == [3, 4, 3, 4]
6 changes: 0 additions & 6 deletions tests/unit_tests/test_placeholder.py

This file was deleted.

Loading