Skip to content

Commit

Permalink
bug fixes in FroSSL
Browse files Browse the repository at this point in the history
  • Loading branch information
OFSkean committed Aug 4, 2024
1 parent b67c7eb commit 92b28b5
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 20 deletions.
2 changes: 1 addition & 1 deletion scripts/pretrain/imagenet-100/frossl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ checkpoint:
dir: "trained_models"
frequency: 1
auto_resume:
enabled: True
enabled: False

# overwrite PL stuff
max_epochs: 400
Expand Down
73 changes: 58 additions & 15 deletions solo/losses/frossl.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,65 @@
# Copyright 2024 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
# Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all copies
# or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

from typing import Any, List, Sequence, Dict
import torch
import torch.distributed as dist
import torch.nn.functional as F

def calculate_frobenius_regularization_term(z: torch.Tensor) -> torch.Tensor:
V, N, D = z.shape

if N > D:
cov = torch.matmul(z.transpose(1, 2), z) # V x D x D

Check warning on line 29 in solo/losses/frossl.py

View check run for this annotation

Codecov / codecov/patch

solo/losses/frossl.py#L29

Added line #L29 was not covered by tests
else:
cov = torch.matmul(z, z.transpose(1, 2)) # V x N x N

# divide each view covariance by its trace
trace = torch.diagonal(cov, dim1=1, dim2=2) # V x D
trace = torch.sum(trace, dim=1) # V x 1
cov = cov / trace.unsqueeze(-1).unsqueeze(-1)

# REGULARIZATION TERM - sum the log-frobenius norm of each view covariance matrix
fro_norm_per_view = torch.linalg.norm(cov, dim=(1,2), ord='fro') # V x 1
regularization_term = -torch.sum( 2*torch.log(fro_norm_per_view) ) # we bring frobenius square outside log

return regularization_term

def calculate_invariance_term(z: torch.Tensor) -> torch.Tensor:
V, N, D = z.shape

# INVARIANCE - align each view to the average view
average_z = torch.mean(z, dim=0) # N x D, samples are averaged across views
average_z = average_z.repeat(V, 1, 1) # V x N x D
invariance_loss_term = F.mse_loss(z, average_z)

return invariance_loss_term

def frossl_loss_func(
z: torch.Tensor, invariance_weight=1
z: torch.Tensor, invariance_weight=1, logger=None
) -> torch.Tensor:
"""Computes FroSSL's loss given batch of projected features z
from num_crops different views.
Args:
z (torch.Tensor): views x N x D Tensor containing projected features from the views.
Every Nth sample is a different view of the same image.
z (torch.Tensor): V x N x D Tensor containing projected features from the views.
Every N-th sample is a different view of the same image.
invariance_weight (float): weight for the invariance loss term. default is 1.
Return:
Expand All @@ -21,19 +69,14 @@ def frossl_loss_func(

z = F.normalize(z, dim=-1) # V x N x D

if N > D:
cov = view_embeddings.T @ view_embeddings # V x D x D
else:
cov = view_embeddings @ view_embeddings.T # V x N x N
cov = cov / torch.trace(cov)
regularization_term = calculate_frobenius_regularization_term(z)

# sum the log-frobenius norm of each view covariance matrix
fro_norm_per_view = torch.linalg.norm(cov, ord='fro') # V x 1
regularization_term = torch.sum( -2*torch.log(fro_norm) ) # bring frobenius square outside log
invariance_term = calculate_invariance_term(z)
invariance_term = D * invariance_weight * invariance_term

# align each view to the average view
average_z = torch.mean(z, dim=0) # N x D, samples are averaged across views
invariance_loss_term = F.mse_loss(z, average_z)
if logger is not None:
logger("frossl_regularization_loss", regularization_term, sync_dist=True)
logger("frossl_invariance_loss", invariance_term, sync_dist=True)

total_loss = regularization_term + invariance_weight*invariance_loss_term
total_loss = regularization_term + invariance_term
return total_loss
43 changes: 39 additions & 4 deletions solo/methods/frossl.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,22 @@
# Copyright 2024 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
# Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all copies
# or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

from typing import Any, List, Sequence, Dict

import omegaconf
Expand All @@ -10,7 +29,7 @@
class FroSSL(BaseMethod):
def __init__(self, cfg: omegaconf.DictConfig):
"""Implements FroSSL (https://arxiv.org/pdf/2310.02903)
Heavily adapted from https://github.com/OFSkean/FroSSL
Extra cfg settings:
method_kwargs:
Expand Down Expand Up @@ -82,6 +101,22 @@ def forward(self, X):
out.update({"z": z})
return out

def multicrop_forward(self, X: torch.tensor) -> Dict[str, Any]:
"""Performs the forward pass for the multicrop views.
Args:
X (torch.Tensor): batch of images in tensor format.
Returns:
Dict[]: a dict containing the outputs of the parent
and the projected features.
"""

out = super().multicrop_forward(X)
z = self.projector(out["feats"])
out.update({"z": z})
return out

def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
"""Training step for FroSSL reusing BaseMethod training step.
Expand All @@ -98,9 +133,9 @@ def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
class_loss = out["loss"]

z = torch.stack(out["z"], dim=0) # V x N_per_gpu x D
z = torch.gather(z, dim=1) # V x N_total x D
z = gather(z, dim=1) # V x N_total x D

frossl_loss = frossl_loss_func(z, invariance_weight=self.invariance_weight)
self.log("train_frossl_loss", frossl_loss, on_epoch=True, sync_dist=True)
frossl_loss = frossl_loss_func(z, invariance_weight=self.invariance_weight, logger=self.log)
self.log("train_frossl_loss", frossl_loss, sync_dist=True)

return frossl_loss + class_loss
131 changes: 131 additions & 0 deletions tests/methods/test_frossl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright 2023 solo-learn development team.

# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to use,
# copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
# Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all copies
# or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

import torch
from solo.methods import FroSSL

from .utils import gen_base_cfg, gen_batch, gen_trainer, prepare_dummy_dataloaders


def test_frossl():
method_kwargs = {
"proj_hidden_dim": 2048,
"proj_output_dim": 256,
"invariance_weight": 1.4,
}

cfg = gen_base_cfg("frossl", batch_size=2, num_classes=100, momentum=True)
cfg.method_kwargs = method_kwargs
model = FroSSL(cfg)

# test arguments
model.add_and_assert_specific_cfg(cfg)

# test parameters
assert model.learnable_params is not None

# test forward
batch, _ = gen_batch(cfg.optimizer.batch_size, cfg.data.num_classes, "imagenet100")
out = model(batch[1][0])
assert (
"logits" in out
and isinstance(out["logits"], torch.Tensor)
and out["logits"].size() == (cfg.optimizer.batch_size, cfg.data.num_classes)
)
assert (
"feats" in out
and isinstance(out["feats"], torch.Tensor)
and out["feats"].size() == (cfg.optimizer.batch_size, model.features_dim)
)
assert (
"z" in out
and isinstance(out["z"], torch.Tensor)
and out["z"].size() == (cfg.optimizer.batch_size, method_kwargs["proj_output_dim"])
)
print('here')

multicrop_out = model.multicrop_forward(batch[1][0])
assert (
"feats" in multicrop_out
and isinstance(multicrop_out["feats"], torch.Tensor)
and multicrop_out["feats"].size() == (cfg.optimizer.batch_size, model.features_dim)
)
assert (
"z" in multicrop_out
and isinstance(multicrop_out["z"], torch.Tensor)
and multicrop_out["z"].size()
== (cfg.optimizer.batch_size, method_kwargs["proj_output_dim"])
)

# imagenet
model = FroSSL(cfg)

trainer = gen_trainer(cfg)
train_dl, val_dl = prepare_dummy_dataloaders(
"imagenet100",
num_large_crops=cfg.data.num_large_crops,
num_small_crops=0,
num_classes=cfg.data.num_classes,
batch_size=cfg.optimizer.batch_size,
)
trainer.fit(model, train_dl, val_dl)

# cifar
cfg.data.dataset = "cifar10"
cfg.data.num_classes = 10
model = FroSSL(cfg)

trainer = gen_trainer(cfg)
train_dl, val_dl = prepare_dummy_dataloaders(
"cifar10",
num_large_crops=cfg.data.num_large_crops,
num_small_crops=0,
num_classes=cfg.data.num_classes,
batch_size=cfg.optimizer.batch_size,
)
trainer.fit(model, train_dl, val_dl)

# multicrop
cfg.data.num_small_crops = 6
model = FroSSL(cfg)

trainer = gen_trainer(cfg)
train_dl, val_dl = prepare_dummy_dataloaders(
"imagenet100",
num_large_crops=cfg.data.num_large_crops,
num_small_crops=cfg.data.num_small_crops,
num_classes=cfg.data.num_classes,
batch_size=cfg.optimizer.batch_size,
)
trainer.fit(model, train_dl, val_dl)

# 8 large views
cfg.data.num_small_crops = 8
cfg.data.num_small_crops = 0
model = FroSSL(cfg)

trainer = gen_trainer(cfg)
train_dl, val_dl = prepare_dummy_dataloaders(
"imagenet100",
num_large_crops=cfg.data.num_large_crops,
num_small_crops=cfg.data.num_small_crops,
num_classes=cfg.data.num_classes,
batch_size=cfg.optimizer.batch_size,
)
trainer.fit(model, train_dl, val_dl)

0 comments on commit 92b28b5

Please sign in to comment.