Skip to content

Commit

Permalink
added some tests, remove call to previous reset_optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
AlbinSou committed Mar 8, 2024
1 parent 068fae5 commit bf3a3ad
Show file tree
Hide file tree
Showing 4 changed files with 326 additions and 205 deletions.
35 changes: 20 additions & 15 deletions avalanche/models/dynamic_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,18 @@ def map_optimized_params(optimizer, parameters, old_params=None):
"""
Establishes a mapping between a list of named parameters and the parameters
that are in the optimizer, additionally,
returns the list of
returns the lists of:
changed_parameters
new_parameters
removed_parameters: List of indexes of optimizer parameters that are not found in the new parameters
returns:
new_parameters: Names of new parameters in the provided "parameters" argument
changed_parameters: Names and indexes of parameters that have changed (grown, shrink)
removed_parameters: List of indexes of optimizer parameters that are not found in the new parameters
"""

if old_params is None:
old_params = {}

group_mapping = defaultdict(dict)
new_parameters = []

Expand All @@ -51,7 +55,9 @@ def map_optimized_params(optimizer, parameters, old_params=None):
found_indexes.append(np.zeros(len(params)))

for n, p in parameters.items():
g = None
gidx = None
pidx = None

# Find param in optimizer
found = False

Expand All @@ -64,21 +70,27 @@ def map_optimized_params(optimizer, parameters, old_params=None):
params = group["params"]
for param_idx, po in enumerate(params):
if id(po) == search_id:
g = group_idx
gidx = group_idx
pidx = param_idx
found = True
# Update found indexes
assert found_indexes[group_idx][param_idx] == 0
found_indexes[group_idx][param_idx] = 1
break
if found:
break

if not found:
new_parameters.append(n)

if search_id != id(p):
if found:
changed_parameters.append((n, group_idx, param_idx))
changed_parameters.append((n, gidx, pidx))

group_mapping[n] = g
if len(optimizer.param_groups) > 1:
group_mapping[n] = gidx
else:
group_mapping[n] = 0

not_found_in_parameters = [np.where(arr == 0)[0] for arr in found_indexes]

Expand Down Expand Up @@ -214,13 +226,6 @@ def single_group(self):
return list(self.groups)[0]


def compare_keys(old_dict, new_dict):
not_in_new = list(set(old_dict.keys()) - set(new_dict.keys()))
in_both = list(set(old_dict.keys()) & set(new_dict.keys()))
not_in_old = list(set(new_dict.keys()) - set(old_dict.keys()))
return not_in_new, in_both, not_in_old


def reset_optimizer(optimizer, model):
"""Reset the optimizer to update the list of learnable parameters.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,12 @@ def make_optimizer(self, reset_optimizer_state=False, **kwargs):
optimizer, regardless of what parameters are
initially put in the optimizer.
"""
if self.optimized_param_id is None:
self.optimized_param_id = reset_optimizer(self.optimizer, self.model)
else:
self.optimized_param_id = update_optimizer(
self.optimizer,
dict(self.model.named_parameters()),
self.optimized_param_id,
reset_state=reset_optimizer_state,
)
self.optimized_param_id = update_optimizer(
self.optimizer,
dict(self.model.named_parameters()),
self.optimized_param_id,
reset_state=reset_optimizer_state,
)

def check_model_and_optimizer(self, reset_optimizer_state=False, **kwargs):
# If strategy has access to the task boundaries, and the current
Expand Down
181 changes: 0 additions & 181 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,187 +83,6 @@ def test_get_model(self):
self.assertIsInstance(model, pytorchcv.models.resnet.ResNet)


class DynamicOptimizersTests(unittest.TestCase):
if "USE_GPU" in os.environ:
use_gpu = os.environ["USE_GPU"].lower() in ["true"]
else:
use_gpu = False

print("Test on GPU:", use_gpu)

if use_gpu:
device = "cuda"
else:
device = "cpu"

def setUp(self):
common_setups()

def _iterate_optimizers(self, model, *optimizers):
for opt_class in optimizers:
if opt_class == "SGDmom":
yield torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
if opt_class == "SGD":
yield torch.optim.SGD(model.parameters(), lr=0.1)
if opt_class == "Adam":
yield torch.optim.Adam(model.parameters(), lr=0.001)
if opt_class == "AdamW":
yield torch.optim.AdamW(model.parameters(), lr=0.001)

def _is_param_in_optimizer(self, param, optimizer):
for group in optimizer.param_groups:
for curr_p in group["params"]:
if hash(curr_p) == hash(param):
return True
return False

def load_benchmark(self, use_task_labels=False):
"""
Returns a NC benchmark from a fake dataset of 10 classes, 5 experiences,
2 classes per experience.
:param fast_test: if True loads fake data, MNIST otherwise.
"""
return get_fast_benchmark(use_task_labels=use_task_labels)

def init_scenario(self, multi_task=False):
model = self.get_model(multi_task=multi_task)
criterion = CrossEntropyLoss()
benchmark = self.load_benchmark(use_task_labels=multi_task)
return model, criterion, benchmark

def test_optimizer_update(self):
model = SimpleMLP()
optimizer = SGD(model.parameters(), lr=1e-3)
strategy = Naive(model, optimizer)

# check add_param_group
p = torch.nn.Parameter(torch.zeros(10, 10))
add_new_params_to_optimizer(optimizer, p)
assert self._is_param_in_optimizer(p, strategy.optimizer)

# check new_param is in optimizer
# check old_param is NOT in optimizer
p_new = torch.nn.Parameter(torch.zeros(10, 10))

# Here we cannot know what parameter group but there is only one so it should work
new_parameters = {"new_param": p_new}
new_parameters.update(dict(model.named_parameters()))
optimized = update_optimizer(optimizer, new_parameters, {"old_param": p})
self.assertTrue("new_param" in optimized)
self.assertFalse("old_param" in optimized)
self.assertTrue(self._is_param_in_optimizer(p_new, strategy.optimizer))
self.assertFalse(self._is_param_in_optimizer(p, strategy.optimizer))

def test_optimizers(self):
# SIT scenario
model, criterion, benchmark = self.init_scenario(multi_task=True)
for optimizer in self._iterate_optimizers(
model, "SGDmom", "Adam", "SGD", "AdamW"
):
strategy = Naive(
model=model,
optimizer=optimizer,
criterion=criterion,
train_mb_size=64,
device=self.device,
eval_mb_size=50,
train_epochs=2,
)
self._test_optimizer(strategy)

# Needs torch 2.0 ?
def test_checkpointing(self):
model, criterion, benchmark = self.init_scenario(multi_task=True)
optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9)
strategy = Naive(
model=model,
optimizer=optimizer,
criterion=criterion,
train_mb_size=64,
device=self.device,
eval_mb_size=50,
train_epochs=2,
)
experience_0 = benchmark.train_stream[0]
strategy.train(experience_0)
old_state = copy.deepcopy(strategy.optimizer.state)
save_checkpoint(strategy, "./checkpoint.pt")

del strategy

model, criterion, benchmark = self.init_scenario(multi_task=True)
optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9)
strategy = Naive(
model=model,
optimizer=optimizer,
criterion=criterion,
train_mb_size=64,
device=self.device,
eval_mb_size=50,
train_epochs=2,
)
strategy, exp_counter = maybe_load_checkpoint(
strategy, "./checkpoint.pt", strategy.device
)

# Check that the state has been well serialized
self.assertEqual(len(strategy.optimizer.state), len(old_state))
for (key_new, value_new_dict), (key_old, value_old_dict) in zip(
strategy.optimizer.state.items(), old_state.items()
):
self.assertTrue(torch.equal(key_new, key_old))

value_new = value_new_dict["momentum_buffer"]
value_old = value_old_dict["momentum_buffer"]

# Empty state
if len(value_new) == 0 or len(value_old) == 0:
self.assertTrue(len(value_new) == len(value_old))
else:
self.assertTrue(torch.equal(value_new, value_old))

experience_1 = benchmark.train_stream[1]
strategy.train(experience_1)
os.remove("./checkpoint.pt")

def test_mh_classifier(self):
model, criterion, benchmark = self.init_scenario(multi_task=True)
optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9)
strategy = Naive(
model=model,
optimizer=optimizer,
criterion=criterion,
train_mb_size=64,
device=self.device,
eval_mb_size=50,
train_epochs=2,
)
strategy.train(benchmark.train_stream)

def _test_optimizer(self, strategy):
# Add a parameter
module = torch.nn.Linear(10, 10)
param1 = list(module.parameters())[0]
strategy.make_optimizer()
self.assertFalse(self._is_param_in_optimizer(param1, strategy.optimizer))
strategy.model.add_module("new_module", module)
strategy.make_optimizer()
self.assertTrue(self._is_param_in_optimizer(param1, strategy.optimizer))
# Remove a parameter
del strategy.model.new_module

strategy.make_optimizer()
self.assertFalse(self._is_param_in_optimizer(param1, strategy.optimizer))

def get_model(self, multi_task=False):
if multi_task:
model = MTSimpleMLP(input_size=6, hidden_size=10)
else:
model = SimpleMLP(input_size=6, hidden_size=10)
return model


class DynamicModelsTests(unittest.TestCase):
def setUp(self):
common_setups()
Expand Down
Loading

0 comments on commit bf3a3ad

Please sign in to comment.