Skip to content

Commit

Permalink
Docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
azoz01 committed Dec 5, 2023
1 parent b995727 commit 49a5d23
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 18 deletions.
2 changes: 1 addition & 1 deletion experiments/02_openml.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def main():
train_loader = ComposedDataLoaderFactory.create_composed_dataloader_from_path(
Path(config["train_data_path"]),
RandomFeaturesPandasDataset,
{},
{"total_random_feature_sampling": True},
FewShotDataLoader,
{"support_size": config["support_size"], "query_size": config["query_size"]},
ComposedDataLoader,
Expand Down
41 changes: 24 additions & 17 deletions liltab/data/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,23 +68,30 @@ def __init__(
self.class_values_idx[val] = np.where(self.y == val)[0]

if sample_classes_equally:
self.samples_per_class_support = {
class_value: self.support_size // len(self.class_values)
for class_value in self.class_values
}
self.samples_per_class_query = {
class_value: self.query_size // len(self.class_values)
for class_value in self.class_values
}
if self.sample_classes_stratified:
self.samples_per_class_support = {
class_value: int(self.support_size * (self.y == class_value).sum() / len(self.y))
for class_value in self.class_values
}
self.samples_per_class_query = {
class_value: int(self.query_size * (self.y == class_value).sum() / len(self.y))
for class_value in self.class_values
}
self._init_samples_per_class_equal()

if sample_classes_stratified:
self._init_samples_per_class_stratified()

def _init_samples_per_class_equal(self):
self.samples_per_class_support = {
class_value: self.support_size // len(self.class_values)
for class_value in self.class_values
}
self.samples_per_class_query = {
class_value: self.query_size // len(self.class_values)
for class_value in self.class_values
}

def _init_samples_per_class_stratified(self):
self.samples_per_class_support = {
class_value: int(self.support_size * (self.y == class_value).sum() / len(self.y))
for class_value in self.class_values
}
self.samples_per_class_query = {
class_value: int(self.query_size * (self.y == class_value).sum() / len(self.y))
for class_value in self.class_values
}

def __iter__(self):
return deepcopy(self)
Expand Down
13 changes: 13 additions & 0 deletions liltab/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@


class Dataset(ABC):
"""
Abstract class for Datasets. It reads and stores data as Pandas
DataFrame. __getitem__ method is to be implemented with custom
indexing strategy.
"""

def __init__(
self,
data_path: str,
Expand Down Expand Up @@ -45,6 +51,10 @@ def __init__(
self.y = self.df[self.response_columns].values

def _preprocess_data(self):
"""
Standardizes data using z-score method. If encode_categorical_target = True
then response variable isn't scaled.
"""
self.preprocessing_pipeline = get_preprocessing_pipeline()
if self.encode_categorical_target:
self.df.loc[:, self.attribute_columns] = self.preprocessing_pipeline.fit_transform(
Expand All @@ -56,6 +66,9 @@ def _preprocess_data(self):
)

def _encode_categorical_target(self):
"""
Encodes categorical response using one-hot encoding.
"""
self.one_hot_encoder = OneHotEncoder(sparse=False)
self.raw_y = self.df[self.response_columns]
self.y = self.one_hot_encoder.fit_transform((self.df[self.response_columns]))
Expand Down

0 comments on commit 49a5d23

Please sign in to comment.