From 7501206f20b00743afe8df9925b56ba74f1af4f7 Mon Sep 17 00:00:00 2001 From: Antoni Zajko Date: Tue, 13 Feb 2024 14:15:13 +0100 Subject: [PATCH] Add support set encoding --- .../03_openml_clf_data_experiment_config.yaml | 2 +- experiments/03_openml_clf.py | 1 + liltab/data/dataloaders.py | 34 ++++- liltab/data/datasets.py | 2 +- .../model/heterogenous_attributes_network.py | 73 ++++++++- liltab/train/trainer.py | 33 +++- liltab/train/utils.py | 142 ++++++++++++++++-- test/liltab/data/test_dataloaders.py | 5 +- 8 files changed, 265 insertions(+), 27 deletions(-) diff --git a/config/03_openml_clf_data_experiment_config.yaml b/config/03_openml_clf_data_experiment_config.yaml index 5020416..21f8302 100644 --- a/config/03_openml_clf_data_experiment_config.yaml +++ b/config/03_openml_clf_data_experiment_config.yaml @@ -4,7 +4,7 @@ learning_rate: 0.001 weight_decay: 0 batch_size: 37 gradient_clipping: False -early_stopping_intervals: 100 +early_stopping_intervals: 20 support_size: 3 query_size: 29 diff --git a/experiments/03_openml_clf.py b/experiments/03_openml_clf.py index 349786a..baf2a72 100644 --- a/experiments/03_openml_clf.py +++ b/experiments/03_openml_clf.py @@ -66,6 +66,7 @@ def main(): hidden_size=config["hidden_size"], dropout_rate=config["dropout_rate"], is_classifier=config["is_classifier"], + inner_activation_function=nn.ReLU(), ) results_path = Path("results") / config["name"] diff --git a/liltab/data/dataloaders.py b/liltab/data/dataloaders.py index a2d6c08..c7e7449 100644 --- a/liltab/data/dataloaders.py +++ b/liltab/data/dataloaders.py @@ -172,6 +172,7 @@ def __init__( dataloaders: list[Iterable], batch_size: int = 32, num_batches: int = 1, + return_dataset_indicator: float = True, ): """ Args: @@ -184,6 +185,7 @@ def __init__( """ self.dataloaders = dataloaders self.batch_size = batch_size + self.return_dataset_indicator = return_dataset_indicator self.counter = 0 self.num_batches = num_batches @@ -198,21 +200,33 @@ def __next__(self): self.counter += 1 return [self._get_single_example() for _ in range(self.batch_size)] - def _get_single_example(self) -> tuple[Tensor, Tensor, Tensor, Tensor]: + def _get_single_example( + self, + ) -> Union[ + tuple[Tensor, Tensor, Tensor, Tensor], tuple[int, tuple[Tensor, Tensor, Tensor, Tensor]] + ]: """ Returns support and query sets from one DataLoaders from - randomly chosen from passed dataloaders. + randomly chosen from passed dataloaders. If return_dataset_indicator = True + then id of the dataset is also returned. Returns: - tuple[Tensor, Tensor, Tensor, Tensor]: - (X_support, y_support, X_query, y_query) + Union[ + tuple[Tensor, Tensor, Tensor, Tensor], + tuple[int, tuple[Tensor, Tensor, Tensor, Tensor]] + ] """ dataloader_has_next = False + dataloader_idx = None while not dataloader_has_next: dataloader_idx = np.random.choice(self.n_dataloaders, 1)[0] dataloader = self.dataloaders[dataloader_idx] dataloader_has_next = dataloader.has_next() - return next(dataloader) + + if self.return_dataset_indicator: + return dataloader_idx, next(dataloader) + else: + return next(dataloader) class RepeatableOutputComposedDataLoader: @@ -223,7 +237,9 @@ class RepeatableOutputComposedDataLoader: of data. Useful with test/validation datasets. """ - def __init__(self, dataloaders: list[Iterable], *args, **kwargs): + def __init__( + self, dataloaders: list[Iterable], return_dataset_indicator: bool = True, *args, **kwargs + ): """ Args: dataloaders (list[Iterable]): list of @@ -233,6 +249,7 @@ def __init__(self, dataloaders: list[Iterable], *args, **kwargs): self.batch_counter = 0 self.n_dataloaders = len(dataloaders) + self.return_dataset_indicator = return_dataset_indicator self.loaded = False self.cache = OrderedDict() @@ -246,4 +263,7 @@ def __next__(self): if self.loaded: raise StopIteration() self.loaded = True - return [sample for _, sample in self.cache.items()] + if self.return_dataset_indicator: + return list(self.cache.items()) + else: + return [sample for _, sample in self.cache.items()] diff --git a/liltab/data/datasets.py b/liltab/data/datasets.py index 4338ede..abfce4c 100644 --- a/liltab/data/datasets.py +++ b/liltab/data/datasets.py @@ -37,7 +37,7 @@ def __init__( self.data = data if type(data) in [str, PosixPath]: self.df = pd.read_csv(data) - elif type(data) == pd.DataFrame: + elif type(data) is pd.DataFrame: self.df = data else: raise ValueError( diff --git a/liltab/model/heterogenous_attributes_network.py b/liltab/model/heterogenous_attributes_network.py index 2c96a31..43a24b3 100644 --- a/liltab/model/heterogenous_attributes_network.py +++ b/liltab/model/heterogenous_attributes_network.py @@ -2,7 +2,7 @@ import torch.nn.functional as F from torch import nn, Tensor -from typing import Callable +from typing import Callable, Dict, Union from .utils import FeedForwardNetwork @@ -133,7 +133,35 @@ def __init__( ) self.is_classifier = is_classifier - def forward(self, X_support: Tensor, y_support: Tensor, X_query: Tensor) -> Tensor: + def forward( + self, X_support: Tensor, y_support: Tensor, X_query: Tensor, return_full_trace: bool = False + ) -> Union[Tensor, Dict[str, Tensor]]: + """ + Returns prediction of the network. If return_full_trace = True, then full trace of + prediction is returned. + Args: + X_support (Tensor): Support set attributes + with shape (n_support_observations, n_attributes) + y_support (Tensor): Support set responses + with shape (n_support_observations, n_responses) + X_query (Tensor): Query set attributes + with shape (n_query_observations, n_attributes) + return_full_trace (bool): if True, then instead of vanilla prediction, + the trace of prediction i.e. intermediate tensors are returned in + for of Dict[str, Tensor]. + Returns: + Union[Tensor, Dict[str, Tensor]]: Inferred query set responses + or full trace of predictions. + """ + prediction_trace = self._forward_with_full_trace(X_support, y_support, X_query) + if return_full_trace: + return prediction_trace + else: + return prediction_trace["prediction"] + + def _forward_with_full_trace( + self, X_support: Tensor, y_support: Tensor, X_query: Tensor + ) -> Dict[str, Tensor]: """ Inference function of network. Inference is done in following steps: 1. Calculate initial representation for all atrributes and responses @@ -146,6 +174,9 @@ def forward(self, X_support: Tensor, y_support: Tensor, X_query: Tensor) -> Tens 5. Make prediction based on representations of support set attributes and responses and query set attributes representations. All representations calculations are done using feed forward neural networks. + As a result it returns trace from inference i.e. all intermediate tensors in form + of a dictionary. + Args: X_support (Tensor): Support set attributes with shape (n_support_observations, n_attributes) @@ -154,7 +185,13 @@ def forward(self, X_support: Tensor, y_support: Tensor, X_query: Tensor) -> Tens X_query (Tensor): Query set attributes with shape (n_query_observations, n_attributes) Returns: - Tensor: Inferred query set responses shaped (n_query_obervations, n_responses) + Dict[str, Tensor]: Full trace of inference: + attributes_initial_representation - initial representation of attributes. + responses_initial_representation - initial representation of responses. + support_set_representation - representation of support set. + attributes_representation - final representation of attributes. + responses_representation - final representation of responses. + prediction - final prediction i.e. predicted y_query. """ attributes_initial_representation = self._calculate_initial_features_representation( self.initial_features_encoding_network, @@ -206,7 +243,35 @@ def forward(self, X_support: Tensor, y_support: Tensor, X_query: Tensor) -> Tens X_query, ) - return prediction + return { + "attributes_initial_representation": attributes_initial_representation, + "responses_initial_representation": responses_initial_representation, + "support_set_representation": support_set_representation, + "attributes_representation": attributes_representation, + "responses_representation": responses_representation, + "prediction": prediction, + } + + def encode_support_set(self, X_support: Tensor, y_support: Tensor) -> Tensor: + attributes_initial_representation = self._calculate_initial_features_representation( + self.initial_features_encoding_network, + self.initial_features_representation_network, + X_support, + ) + responses_initial_representation = self._calculate_initial_features_representation( + self.initial_features_encoding_network, + self.initial_features_representation_network, + y_support, + ) + support_set_representation = self._calculate_support_set_representation( + self.interaction_encoding_network, + self.interaction_representation_network, + X_support, + attributes_initial_representation, + y_support, + responses_initial_representation, + ) + return support_set_representation def _calculate_initial_features_representation( self, diff --git a/liltab/train/trainer.py b/liltab/train/trainer.py index 39716c2..829b133 100644 --- a/liltab/train/trainer.py +++ b/liltab/train/trainer.py @@ -11,7 +11,7 @@ ComposedDataLoader, RepeatableOutputComposedDataLoader, ) -from .utils import LightningWrapper +from .utils import LightningWrapper, LightningEncoderWrapper from .logger import TensorBoardLogger, FileLogger @@ -26,6 +26,7 @@ def __init__( gradient_clipping: bool, learning_rate: float, weight_decay: float, + representation_penalty_weight: float = 0, early_stopping_intervals: int = 100, check_val_every_n_epoch: int = 100, loss: Callable = nn.MSELoss(), @@ -112,6 +113,7 @@ def __init__( check_val_every_n_epoch=check_val_every_n_epoch, callbacks=callbacks, ) + self.representation_penalty_weight = representation_penalty_weight self.learning_rate = learning_rate self.weight_decay = weight_decay self.loss = loss @@ -143,12 +145,41 @@ def train_and_test( model, learning_rate=self.learning_rate, weight_decay=self.weight_decay, + representation_penalty_weight=self.representation_penalty_weight, loss=self.loss, ) self.trainer.fit(model_wrapper, train_loader, val_loader) test_results = self.trainer.test(model_wrapper, test_loader) return model_wrapper, test_results + def pretrain_encoder( + self, + model: HeterogenousAttributesNetwork, + train_loader: ComposedDataLoader | RepeatableOutputComposedDataLoader, + val_loader: ComposedDataLoader | RepeatableOutputComposedDataLoader, + ) -> tuple[LightningWrapper, list[dict[str, float]]]: + """ + Method used to pretrain encoder. + + Args: + model (HeterogenousAttributesNetwork): model to train + train_loader (ComposedDataLoader | RepeatableOutputComposedDataLoader): + loader withTrainingData + val_loader (ComposedDataLoader | RepeatableOutputComposedDataLoader): + loader with validation data + + Returns: + tuple[HeterogenousAttributesNetwork, list[dict[str, float]]]: + trained network with metrics on test set. + """ + encoder_wrapper = LightningEncoderWrapper( + model, + learning_rate=self.learning_rate, + weight_decay=self.weight_decay, + ) + self.trainer.fit(encoder_wrapper, train_loader, val_loader) + return encoder_wrapper + class LoggerCallback(Callback): def __init__( diff --git a/liltab/train/utils.py b/liltab/train/utils.py index 87f664c..581ad05 100644 --- a/liltab/train/utils.py +++ b/liltab/train/utils.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from torch import optim, Tensor -from typing import Any, Callable +from typing import Any, Callable, List from ..model.heterogenous_attributes_network import HeterogenousAttributesNetwork @@ -19,6 +19,7 @@ def __init__( model: HeterogenousAttributesNetwork, learning_rate: float, weight_decay: float, + representation_penalty_weight: float = 0, loss: Callable = F.mse_loss, ): """ @@ -33,53 +34,170 @@ def __init__( self.loss = loss self.learning_rate = learning_rate self.weight_decay = weight_decay + self.representation_penalty_weight = representation_penalty_weight self.metrics_history = dict() self.save_hyperparameters() def training_step(self, batch: list[tuple[Tensor, Tensor, Tensor, Tensor]], batch_idx) -> float: sum_loss_value = 0.0 + supports_representations = [None] * len(batch) + indices = [None] * len(batch) for i, example in enumerate(batch): - X_support, y_support, X_query, y_query = example - prediction = self.model(X_support, y_support, X_query) - loss = self.loss(prediction, y_query) + idx, (X_support, y_support, X_query, y_query) = example + full_trace = self.model(X_support, y_support, X_query, return_full_trace=True) + prediction = full_trace["prediction"] + + supports_representations[i] = full_trace["support_set_representation"] + indices[i] = idx + loss = torch.nn.CrossEntropyLoss()(prediction, y_query) if torch.isnan(loss): sum_loss_value = sum_loss_value * (i + 1) / i if i > 0 else 0 else: sum_loss_value += loss + if self.representation_penalty_weight != 0: + rep_loss = ( + self.representation_penalty_weight + * self._calculate_representation_penalty(supports_representations, indices) + / 4 + ) + sum_loss_value += rep_loss + return sum_loss_value def validation_step( self, batch: list[tuple[Tensor, Tensor, Tensor, Tensor]], batch_idx ) -> float: sum_loss_value = 0.0 + supports_representations = [None] * len(batch) + indices = [None] * len(batch) for i, example in enumerate(batch): - X_support, y_support, X_query, y_query = example - prediction = self.model(X_support, y_support, X_query) - loss = self.loss(prediction, y_query) + idx, (X_support, y_support, X_query, y_query) = example + full_trace = self.model(X_support, y_support, X_query, return_full_trace=True) + prediction = full_trace["prediction"] + + supports_representations[i] = full_trace["support_set_representation"] + indices[i] = idx + loss = torch.nn.CrossEntropyLoss()(prediction, y_query) if torch.isnan(loss): sum_loss_value = sum_loss_value * (i + 1) / i if i > 0 else 0 else: sum_loss_value += loss + if self.representation_penalty_weight != 0: + rep_loss = ( + self.representation_penalty_weight + * self._calculate_representation_penalty(supports_representations, indices) + / 4 + ) + sum_loss_value += rep_loss + return sum_loss_value def test_step(self, batch: list[tuple[Tensor, Tensor, Tensor, Tensor]], batch_idx) -> float: sum_loss_value = 0.0 + supports_representations = [None] * len(batch) + indices = [None] * len(batch) for i, example in enumerate(batch): - X_support, y_support, X_query, y_query = example - prediction = self.model(X_support, y_support, X_query) - loss = self.loss(prediction, y_query) + idx, (X_support, y_support, X_query, y_query) = example + full_trace = self.model(X_support, y_support, X_query, return_full_trace=True) + prediction = full_trace["prediction"] + + supports_representations[i] = full_trace["support_set_representation"] + indices[i] = idx + loss = torch.nn.CrossEntropyLoss()(prediction, y_query) if torch.isnan(loss): sum_loss_value = sum_loss_value * (i + 1) / i if i > 0 else 0 else: sum_loss_value += loss + if self.representation_penalty_weight != 0: + rep_loss = ( + self.representation_penalty_weight + * self._calculate_representation_penalty(supports_representations, indices) + / 4 + ) + sum_loss_value += rep_loss + return sum_loss_value + def _calculate_representation_penalty( + self, supports_representations: List[Tensor], dataset_indices: List[int] + ): + support_size = supports_representations[0].shape[0] + supports_representations_to_penalty = torch.concat(supports_representations, dim=0) + dist_matrix = torch.cdist( + supports_representations_to_penalty, supports_representations_to_penalty + ) + indices_to_mask = ( + torch.Tensor(dataset_indices).reshape(-1, 1).repeat((1, support_size)).reshape(-1, 1) + ) + mask = (-1) ** (torch.cdist(indices_to_mask, indices_to_mask) == 0) + return (dist_matrix * mask).sum() + def configure_optimizers(self) -> Any: return optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) - def calculate_model_weights_norm(self) -> Tensor: - pass + +class LightningEncoderWrapper(pl.LightningModule): + def __init__( + self, + model: HeterogenousAttributesNetwork, + learning_rate: float, + weight_decay: float, + ): + super().__init__() + self.model = model + self.learning_rate = learning_rate + self.weight_decay = weight_decay + self.metrics_history = dict() + + self.save_hyperparameters() + + def training_step(self, batch: list[tuple[Tensor, Tensor, Tensor, Tensor]], batch_idx) -> float: + supports_representations = [None] * len(batch) + indices = [None] * len(batch) + for i, example in enumerate(batch): + idx, (X_support, y_support, X_query, _) = example + supports_representations[i] = self.model.encode_support_set(X_support, y_support) + indices[i] = idx + return self._calculate_representation_penalty(supports_representations, indices) + + def validation_step( + self, batch: list[tuple[Tensor, Tensor, Tensor, Tensor]], batch_idx + ) -> float: + supports_representations = [None] * len(batch) + indices = [None] * len(batch) + for i, example in enumerate(batch): + idx, (X_support, y_support, X_query, _) = example + supports_representations[i] = self.model.encode_support_set(X_support, y_support) + indices[i] = idx + return self._calculate_representation_penalty(supports_representations, indices) + + def test_step(self, batch: list[tuple[Tensor, Tensor, Tensor, Tensor]], batch_idx) -> float: + supports_representations = [None] * len(batch) + indices = [None] * len(batch) + for i, example in enumerate(batch): + idx, (X_support, y_support, X_query, _) = example + supports_representations[i] = self.model.encode_support_set(X_support, y_support) + indices[i] = idx + return self._calculate_representation_penalty(supports_representations, indices) + + def _calculate_representation_penalty( + self, supports_representations: List[Tensor], dataset_indices: List[int] + ): + support_size = supports_representations[0].shape[0] + supports_representations_to_penalty = torch.concat(supports_representations, dim=0) + data_length = supports_representations_to_penalty.shape[0] + dist_matrix = torch.cdist( + supports_representations_to_penalty, supports_representations_to_penalty + ) + indices_to_mask = ( + torch.Tensor(dataset_indices).reshape(-1, 1).repeat((1, support_size)).reshape(-1, 1) + ) + mask = (-1) ** (torch.cdist(indices_to_mask, indices_to_mask) == 0) + return (dist_matrix * mask).sum() / (data_length * (data_length - 1)) + + def configure_optimizers(self) -> Any: + return optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) diff --git a/test/liltab/data/test_dataloaders.py b/test/liltab/data/test_dataloaders.py index fff1ece..7f2626e 100644 --- a/test/liltab/data/test_dataloaders.py +++ b/test/liltab/data/test_dataloaders.py @@ -179,6 +179,7 @@ def test_composed_data_loader_returns_from_all_loaders_properly( ), ], batch_size=n_episodes, + return_dataset_indicator=False, ) loaded_dataset = list(dataloader) batches_lens = np.array(list(map(lambda t: t[0][1].shape[0], loaded_dataset[0]))) @@ -208,6 +209,7 @@ def run_experiment(): ), ], batch_size=batch_size, + return_dataset_indicator=False, ) batch = list(dataloader)[0] batches_lens = np.array(list(map(lambda t: t[0][1].shape[0], batch))) @@ -232,7 +234,8 @@ def test_repeatable_output_composed_data_loader_repeat_samples(resources_path, u 6, n_episodes=100, ), - ] + ], + return_dataset_indicator=False, ) dataloader_1 = list(dataloader)[0]