Skip to content

Commit

Permalink
Add support set encoding (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
azoz01 authored Feb 13, 2024
1 parent 91d8c18 commit d72b92d
Show file tree
Hide file tree
Showing 8 changed files with 265 additions and 27 deletions.
2 changes: 1 addition & 1 deletion config/03_openml_clf_data_experiment_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ learning_rate: 0.001
weight_decay: 0
batch_size: 37
gradient_clipping: False
early_stopping_intervals: 100
early_stopping_intervals: 20

support_size: 3
query_size: 29
Expand Down
1 change: 1 addition & 0 deletions experiments/03_openml_clf.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def main():
hidden_size=config["hidden_size"],
dropout_rate=config["dropout_rate"],
is_classifier=config["is_classifier"],
inner_activation_function=nn.ReLU(),
)

results_path = Path("results") / config["name"]
Expand Down
34 changes: 27 additions & 7 deletions liltab/data/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def __init__(
dataloaders: list[Iterable],
batch_size: int = 32,
num_batches: int = 1,
return_dataset_indicator: float = True,
):
"""
Args:
Expand All @@ -184,6 +185,7 @@ def __init__(
"""
self.dataloaders = dataloaders
self.batch_size = batch_size
self.return_dataset_indicator = return_dataset_indicator

self.counter = 0
self.num_batches = num_batches
Expand All @@ -198,21 +200,33 @@ def __next__(self):
self.counter += 1
return [self._get_single_example() for _ in range(self.batch_size)]

def _get_single_example(self) -> tuple[Tensor, Tensor, Tensor, Tensor]:
def _get_single_example(
self,
) -> Union[
tuple[Tensor, Tensor, Tensor, Tensor], tuple[int, tuple[Tensor, Tensor, Tensor, Tensor]]
]:
"""
Returns support and query sets from one DataLoaders from
randomly chosen from passed dataloaders.
randomly chosen from passed dataloaders. If return_dataset_indicator = True
then id of the dataset is also returned.
Returns:
tuple[Tensor, Tensor, Tensor, Tensor]:
(X_support, y_support, X_query, y_query)
Union[
tuple[Tensor, Tensor, Tensor, Tensor],
tuple[int, tuple[Tensor, Tensor, Tensor, Tensor]]
]
"""
dataloader_has_next = False
dataloader_idx = None
while not dataloader_has_next:
dataloader_idx = np.random.choice(self.n_dataloaders, 1)[0]
dataloader = self.dataloaders[dataloader_idx]
dataloader_has_next = dataloader.has_next()
return next(dataloader)

if self.return_dataset_indicator:
return dataloader_idx, next(dataloader)
else:
return next(dataloader)


class RepeatableOutputComposedDataLoader:
Expand All @@ -223,7 +237,9 @@ class RepeatableOutputComposedDataLoader:
of data. Useful with test/validation datasets.
"""

def __init__(self, dataloaders: list[Iterable], *args, **kwargs):
def __init__(
self, dataloaders: list[Iterable], return_dataset_indicator: bool = True, *args, **kwargs
):
"""
Args:
dataloaders (list[Iterable]): list of
Expand All @@ -233,6 +249,7 @@ def __init__(self, dataloaders: list[Iterable], *args, **kwargs):

self.batch_counter = 0
self.n_dataloaders = len(dataloaders)
self.return_dataset_indicator = return_dataset_indicator

self.loaded = False
self.cache = OrderedDict()
Expand All @@ -246,4 +263,7 @@ def __next__(self):
if self.loaded:
raise StopIteration()
self.loaded = True
return [sample for _, sample in self.cache.items()]
if self.return_dataset_indicator:
return list(self.cache.items())
else:
return [sample for _, sample in self.cache.items()]
2 changes: 1 addition & 1 deletion liltab/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
self.data = data
if type(data) in [str, PosixPath]:
self.df = pd.read_csv(data)
elif type(data) == pd.DataFrame:
elif type(data) is pd.DataFrame:
self.df = data
else:
raise ValueError(
Expand Down
73 changes: 69 additions & 4 deletions liltab/model/heterogenous_attributes_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn.functional as F

from torch import nn, Tensor
from typing import Callable
from typing import Callable, Dict, Union
from .utils import FeedForwardNetwork


Expand Down Expand Up @@ -133,7 +133,35 @@ def __init__(
)
self.is_classifier = is_classifier

def forward(self, X_support: Tensor, y_support: Tensor, X_query: Tensor) -> Tensor:
def forward(
self, X_support: Tensor, y_support: Tensor, X_query: Tensor, return_full_trace: bool = False
) -> Union[Tensor, Dict[str, Tensor]]:
"""
Returns prediction of the network. If return_full_trace = True, then full trace of
prediction is returned.
Args:
X_support (Tensor): Support set attributes
with shape (n_support_observations, n_attributes)
y_support (Tensor): Support set responses
with shape (n_support_observations, n_responses)
X_query (Tensor): Query set attributes
with shape (n_query_observations, n_attributes)
return_full_trace (bool): if True, then instead of vanilla prediction,
the trace of prediction i.e. intermediate tensors are returned in
for of Dict[str, Tensor].
Returns:
Union[Tensor, Dict[str, Tensor]]: Inferred query set responses
or full trace of predictions.
"""
prediction_trace = self._forward_with_full_trace(X_support, y_support, X_query)
if return_full_trace:
return prediction_trace
else:
return prediction_trace["prediction"]

def _forward_with_full_trace(
self, X_support: Tensor, y_support: Tensor, X_query: Tensor
) -> Dict[str, Tensor]:
"""
Inference function of network. Inference is done in following steps:
1. Calculate initial representation for all atrributes and responses
Expand All @@ -146,6 +174,9 @@ def forward(self, X_support: Tensor, y_support: Tensor, X_query: Tensor) -> Tens
5. Make prediction based on representations of support set attributes and
responses and query set attributes representations.
All representations calculations are done using feed forward neural networks.
As a result it returns trace from inference i.e. all intermediate tensors in form
of a dictionary.
Args:
X_support (Tensor): Support set attributes
with shape (n_support_observations, n_attributes)
Expand All @@ -154,7 +185,13 @@ def forward(self, X_support: Tensor, y_support: Tensor, X_query: Tensor) -> Tens
X_query (Tensor): Query set attributes
with shape (n_query_observations, n_attributes)
Returns:
Tensor: Inferred query set responses shaped (n_query_obervations, n_responses)
Dict[str, Tensor]: Full trace of inference:
attributes_initial_representation - initial representation of attributes.
responses_initial_representation - initial representation of responses.
support_set_representation - representation of support set.
attributes_representation - final representation of attributes.
responses_representation - final representation of responses.
prediction - final prediction i.e. predicted y_query.
"""
attributes_initial_representation = self._calculate_initial_features_representation(
self.initial_features_encoding_network,
Expand Down Expand Up @@ -206,7 +243,35 @@ def forward(self, X_support: Tensor, y_support: Tensor, X_query: Tensor) -> Tens
X_query,
)

return prediction
return {
"attributes_initial_representation": attributes_initial_representation,
"responses_initial_representation": responses_initial_representation,
"support_set_representation": support_set_representation,
"attributes_representation": attributes_representation,
"responses_representation": responses_representation,
"prediction": prediction,
}

def encode_support_set(self, X_support: Tensor, y_support: Tensor) -> Tensor:
attributes_initial_representation = self._calculate_initial_features_representation(
self.initial_features_encoding_network,
self.initial_features_representation_network,
X_support,
)
responses_initial_representation = self._calculate_initial_features_representation(
self.initial_features_encoding_network,
self.initial_features_representation_network,
y_support,
)
support_set_representation = self._calculate_support_set_representation(
self.interaction_encoding_network,
self.interaction_representation_network,
X_support,
attributes_initial_representation,
y_support,
responses_initial_representation,
)
return support_set_representation

def _calculate_initial_features_representation(
self,
Expand Down
33 changes: 32 additions & 1 deletion liltab/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
ComposedDataLoader,
RepeatableOutputComposedDataLoader,
)
from .utils import LightningWrapper
from .utils import LightningWrapper, LightningEncoderWrapper
from .logger import TensorBoardLogger, FileLogger


Expand All @@ -26,6 +26,7 @@ def __init__(
gradient_clipping: bool,
learning_rate: float,
weight_decay: float,
representation_penalty_weight: float = 0,
early_stopping_intervals: int = 100,
check_val_every_n_epoch: int = 100,
loss: Callable = nn.MSELoss(),
Expand Down Expand Up @@ -112,6 +113,7 @@ def __init__(
check_val_every_n_epoch=check_val_every_n_epoch,
callbacks=callbacks,
)
self.representation_penalty_weight = representation_penalty_weight
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.loss = loss
Expand Down Expand Up @@ -143,12 +145,41 @@ def train_and_test(
model,
learning_rate=self.learning_rate,
weight_decay=self.weight_decay,
representation_penalty_weight=self.representation_penalty_weight,
loss=self.loss,
)
self.trainer.fit(model_wrapper, train_loader, val_loader)
test_results = self.trainer.test(model_wrapper, test_loader)
return model_wrapper, test_results

def pretrain_encoder(
self,
model: HeterogenousAttributesNetwork,
train_loader: ComposedDataLoader | RepeatableOutputComposedDataLoader,
val_loader: ComposedDataLoader | RepeatableOutputComposedDataLoader,
) -> tuple[LightningWrapper, list[dict[str, float]]]:
"""
Method used to pretrain encoder.
Args:
model (HeterogenousAttributesNetwork): model to train
train_loader (ComposedDataLoader | RepeatableOutputComposedDataLoader):
loader withTrainingData
val_loader (ComposedDataLoader | RepeatableOutputComposedDataLoader):
loader with validation data
Returns:
tuple[HeterogenousAttributesNetwork, list[dict[str, float]]]:
trained network with metrics on test set.
"""
encoder_wrapper = LightningEncoderWrapper(
model,
learning_rate=self.learning_rate,
weight_decay=self.weight_decay,
)
self.trainer.fit(encoder_wrapper, train_loader, val_loader)
return encoder_wrapper


class LoggerCallback(Callback):
def __init__(
Expand Down
Loading

0 comments on commit d72b92d

Please sign in to comment.