Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support set encoding #25

Merged
merged 1 commit into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion config/03_openml_clf_data_experiment_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ learning_rate: 0.001
weight_decay: 0
batch_size: 37
gradient_clipping: False
early_stopping_intervals: 100
early_stopping_intervals: 20

support_size: 3
query_size: 29
Expand Down
1 change: 1 addition & 0 deletions experiments/03_openml_clf.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def main():
hidden_size=config["hidden_size"],
dropout_rate=config["dropout_rate"],
is_classifier=config["is_classifier"],
inner_activation_function=nn.ReLU(),
)

results_path = Path("results") / config["name"]
Expand Down
34 changes: 27 additions & 7 deletions liltab/data/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def __init__(
dataloaders: list[Iterable],
batch_size: int = 32,
num_batches: int = 1,
return_dataset_indicator: float = True,
):
"""
Args:
Expand All @@ -184,6 +185,7 @@ def __init__(
"""
self.dataloaders = dataloaders
self.batch_size = batch_size
self.return_dataset_indicator = return_dataset_indicator

self.counter = 0
self.num_batches = num_batches
Expand All @@ -198,21 +200,33 @@ def __next__(self):
self.counter += 1
return [self._get_single_example() for _ in range(self.batch_size)]

def _get_single_example(self) -> tuple[Tensor, Tensor, Tensor, Tensor]:
def _get_single_example(
self,
) -> Union[
tuple[Tensor, Tensor, Tensor, Tensor], tuple[int, tuple[Tensor, Tensor, Tensor, Tensor]]
]:
"""
Returns support and query sets from one DataLoaders from
randomly chosen from passed dataloaders.
randomly chosen from passed dataloaders. If return_dataset_indicator = True
then id of the dataset is also returned.

Returns:
tuple[Tensor, Tensor, Tensor, Tensor]:
(X_support, y_support, X_query, y_query)
Union[
tuple[Tensor, Tensor, Tensor, Tensor],
tuple[int, tuple[Tensor, Tensor, Tensor, Tensor]]
]
"""
dataloader_has_next = False
dataloader_idx = None
while not dataloader_has_next:
dataloader_idx = np.random.choice(self.n_dataloaders, 1)[0]
dataloader = self.dataloaders[dataloader_idx]
dataloader_has_next = dataloader.has_next()
return next(dataloader)

if self.return_dataset_indicator:
return dataloader_idx, next(dataloader)
else:
return next(dataloader)


class RepeatableOutputComposedDataLoader:
Expand All @@ -223,7 +237,9 @@ class RepeatableOutputComposedDataLoader:
of data. Useful with test/validation datasets.
"""

def __init__(self, dataloaders: list[Iterable], *args, **kwargs):
def __init__(
self, dataloaders: list[Iterable], return_dataset_indicator: bool = True, *args, **kwargs
):
"""
Args:
dataloaders (list[Iterable]): list of
Expand All @@ -233,6 +249,7 @@ def __init__(self, dataloaders: list[Iterable], *args, **kwargs):

self.batch_counter = 0
self.n_dataloaders = len(dataloaders)
self.return_dataset_indicator = return_dataset_indicator

self.loaded = False
self.cache = OrderedDict()
Expand All @@ -246,4 +263,7 @@ def __next__(self):
if self.loaded:
raise StopIteration()
self.loaded = True
return [sample for _, sample in self.cache.items()]
if self.return_dataset_indicator:
return list(self.cache.items())
else:
return [sample for _, sample in self.cache.items()]
2 changes: 1 addition & 1 deletion liltab/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
self.data = data
if type(data) in [str, PosixPath]:
self.df = pd.read_csv(data)
elif type(data) == pd.DataFrame:
elif type(data) is pd.DataFrame:
self.df = data
else:
raise ValueError(
Expand Down
73 changes: 69 additions & 4 deletions liltab/model/heterogenous_attributes_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn.functional as F

from torch import nn, Tensor
from typing import Callable
from typing import Callable, Dict, Union
from .utils import FeedForwardNetwork


Expand Down Expand Up @@ -133,7 +133,35 @@ def __init__(
)
self.is_classifier = is_classifier

def forward(self, X_support: Tensor, y_support: Tensor, X_query: Tensor) -> Tensor:
def forward(
self, X_support: Tensor, y_support: Tensor, X_query: Tensor, return_full_trace: bool = False
) -> Union[Tensor, Dict[str, Tensor]]:
"""
Returns prediction of the network. If return_full_trace = True, then full trace of
prediction is returned.
Args:
X_support (Tensor): Support set attributes
with shape (n_support_observations, n_attributes)
y_support (Tensor): Support set responses
with shape (n_support_observations, n_responses)
X_query (Tensor): Query set attributes
with shape (n_query_observations, n_attributes)
return_full_trace (bool): if True, then instead of vanilla prediction,
the trace of prediction i.e. intermediate tensors are returned in
for of Dict[str, Tensor].
Returns:
Union[Tensor, Dict[str, Tensor]]: Inferred query set responses
or full trace of predictions.
"""
prediction_trace = self._forward_with_full_trace(X_support, y_support, X_query)
if return_full_trace:
return prediction_trace
else:
return prediction_trace["prediction"]

def _forward_with_full_trace(
self, X_support: Tensor, y_support: Tensor, X_query: Tensor
) -> Dict[str, Tensor]:
"""
Inference function of network. Inference is done in following steps:
1. Calculate initial representation for all atrributes and responses
Expand All @@ -146,6 +174,9 @@ def forward(self, X_support: Tensor, y_support: Tensor, X_query: Tensor) -> Tens
5. Make prediction based on representations of support set attributes and
responses and query set attributes representations.
All representations calculations are done using feed forward neural networks.
As a result it returns trace from inference i.e. all intermediate tensors in form
of a dictionary.

Args:
X_support (Tensor): Support set attributes
with shape (n_support_observations, n_attributes)
Expand All @@ -154,7 +185,13 @@ def forward(self, X_support: Tensor, y_support: Tensor, X_query: Tensor) -> Tens
X_query (Tensor): Query set attributes
with shape (n_query_observations, n_attributes)
Returns:
Tensor: Inferred query set responses shaped (n_query_obervations, n_responses)
Dict[str, Tensor]: Full trace of inference:
attributes_initial_representation - initial representation of attributes.
responses_initial_representation - initial representation of responses.
support_set_representation - representation of support set.
attributes_representation - final representation of attributes.
responses_representation - final representation of responses.
prediction - final prediction i.e. predicted y_query.
"""
attributes_initial_representation = self._calculate_initial_features_representation(
self.initial_features_encoding_network,
Expand Down Expand Up @@ -206,7 +243,35 @@ def forward(self, X_support: Tensor, y_support: Tensor, X_query: Tensor) -> Tens
X_query,
)

return prediction
return {
"attributes_initial_representation": attributes_initial_representation,
"responses_initial_representation": responses_initial_representation,
"support_set_representation": support_set_representation,
"attributes_representation": attributes_representation,
"responses_representation": responses_representation,
"prediction": prediction,
}

def encode_support_set(self, X_support: Tensor, y_support: Tensor) -> Tensor:
attributes_initial_representation = self._calculate_initial_features_representation(
self.initial_features_encoding_network,
self.initial_features_representation_network,
X_support,
)
responses_initial_representation = self._calculate_initial_features_representation(
self.initial_features_encoding_network,
self.initial_features_representation_network,
y_support,
)
support_set_representation = self._calculate_support_set_representation(
self.interaction_encoding_network,
self.interaction_representation_network,
X_support,
attributes_initial_representation,
y_support,
responses_initial_representation,
)
return support_set_representation

def _calculate_initial_features_representation(
self,
Expand Down
33 changes: 32 additions & 1 deletion liltab/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
ComposedDataLoader,
RepeatableOutputComposedDataLoader,
)
from .utils import LightningWrapper
from .utils import LightningWrapper, LightningEncoderWrapper
from .logger import TensorBoardLogger, FileLogger


Expand All @@ -26,6 +26,7 @@ def __init__(
gradient_clipping: bool,
learning_rate: float,
weight_decay: float,
representation_penalty_weight: float = 0,
early_stopping_intervals: int = 100,
check_val_every_n_epoch: int = 100,
loss: Callable = nn.MSELoss(),
Expand Down Expand Up @@ -112,6 +113,7 @@ def __init__(
check_val_every_n_epoch=check_val_every_n_epoch,
callbacks=callbacks,
)
self.representation_penalty_weight = representation_penalty_weight
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.loss = loss
Expand Down Expand Up @@ -143,12 +145,41 @@ def train_and_test(
model,
learning_rate=self.learning_rate,
weight_decay=self.weight_decay,
representation_penalty_weight=self.representation_penalty_weight,
loss=self.loss,
)
self.trainer.fit(model_wrapper, train_loader, val_loader)
test_results = self.trainer.test(model_wrapper, test_loader)
return model_wrapper, test_results

def pretrain_encoder(
self,
model: HeterogenousAttributesNetwork,
train_loader: ComposedDataLoader | RepeatableOutputComposedDataLoader,
val_loader: ComposedDataLoader | RepeatableOutputComposedDataLoader,
) -> tuple[LightningWrapper, list[dict[str, float]]]:
"""
Method used to pretrain encoder.

Args:
model (HeterogenousAttributesNetwork): model to train
train_loader (ComposedDataLoader | RepeatableOutputComposedDataLoader):
loader withTrainingData
val_loader (ComposedDataLoader | RepeatableOutputComposedDataLoader):
loader with validation data

Returns:
tuple[HeterogenousAttributesNetwork, list[dict[str, float]]]:
trained network with metrics on test set.
"""
encoder_wrapper = LightningEncoderWrapper(
model,
learning_rate=self.learning_rate,
weight_decay=self.weight_decay,
)
self.trainer.fit(encoder_wrapper, train_loader, val_loader)
return encoder_wrapper


class LoggerCallback(Callback):
def __init__(
Expand Down
Loading
Loading