diff --git a/adversarial_analysis.py b/adversarial_analysis.py deleted file mode 100644 index 716d6d5f..00000000 --- a/adversarial_analysis.py +++ /dev/null @@ -1,121 +0,0 @@ -import os -import sys - -ROOT_DIR = os.path.dirname(os.getcwd()) -if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) - -import numpy as np -import proplot as plot -import torch - -from DeepSparseCoding.utils.file_utils import Logger -import DeepSparseCoding.utils.loaders as loaders -import DeepSparseCoding.utils.run_utils as run_utils -import DeepSparseCoding.utils.dataset_utils as dataset_utils -import DeepSparseCoding.utils.run_utils as ru -import DeepSparseCoding.utils.plot_functions as pf - -import eagerpy as ep -from foolbox import PyTorchModel, accuracy, samples -import foolbox.attacks as fa - - -log_files = [ - os.path.join(*[ROOT_DIR, 'Torch_projects', 'mlp_768_mnist', 'logfiles', 'mlp_768_mnist_v0.log']), - os.path.join(*[ROOT_DIR, 'Torch_projects', 'lca_768_mlp_mnist', 'logfiles', 'lca_768_mlp_mnist_v0.log']) - ] - -cp_latest_filenames = [ - os.path.join(*[ROOT_DIR,'Torch_projects', 'mlp_768_mnist', 'checkpoints', 'mlp_768_mnist_latest_checkpoint_v0.pt']), - os.path.join(*[ROOT_DIR, 'Torch_projects', 'lca_768_mlp_mnist', 'checkpoints', 'lca_768_mlp_mnist_latest_checkpoint_v0.pt']) - ] - -attack_params = { - 'linfPGD': { - 'abs_stepsize':0.01, - 'steps':5000 - } -} - -attacks = [ - #fa.FGSM(), - fa.LinfPGD(**attack_params['linfPGD']), - #fa.LinfBasicIterativeAttack(), - #fa.LinfAdditiveUniformNoiseAttack(), - #fa.LinfDeepFoolAttack(), -] - -epsilons = [ # allowed perturbation size - 0.0, - 0.05, - 0.1, - 0.15, - 0.2, - 0.25, - 0.3, - 0.35, - #0.4, - 0.5, - #0.8, - 1.0 -] - -num_models = len(log_files) -for model_index in range(num_models): - logger = Logger(log_files[model_index], overwrite=False) - log_text = logger.load_file() - params = logger.read_params(log_text)[-1] - params.cp_latest_filename = cp_latest_filenames[model_index] - train_loader, val_loader, test_loader, data_params = dataset_utils.load_dataset(params) - for key, value in data_params.items(): - setattr(params, key, value) - model = loaders.load_model(params.model_type) - model.setup(params, logger) - model.params.analysis_out_dir = os.path.join( - *[model.params.model_out_dir, 'analysis', model.params.version]) - model.params.analysis_save_dir = os.path.join(model.params.analysis_out_dir, 'savefiles') - if not os.path.exists(model.params.analysis_save_dir): - os.makedirs(model.params.analysis_save_dir) - model.to(params.device) - model.load_checkpoint() - fmodel = PyTorchModel(model.eval(), bounds=(0, 1)) - print('\n', '~' * 79) - num_batches = len(test_loader.dataset) // model.params.batch_size - attack_success = np.zeros( - (len(attacks), len(epsilons), num_batches, model.params.batch_size), dtype=np.bool) - for batch_index, (data, target) in enumerate(test_loader): - data = model.preprocess_data(data.to(model.params.device)) - target = target.to(model.params.device) - images, labels = ep.astensors(*(data, target)) - del data; del target - print(f'Model type: {model.params.model_type} [{model_index+1} out of {len(log_files)}]') - print(f'Batch {batch_index+1} out of {num_batches}') - print(f'accuracy {accuracy(fmodel, images, labels)}') - for attack_index, attack in enumerate(attacks): - advs, inputs, success = attack(fmodel, images, labels, epsilons=epsilons) - assert success.shape == (len(epsilons), len(images)) - success_ = success.numpy() - assert success_.dtype == np.bool - attack_success[attack_index, :, batch_index, :] = success_ - print('\n', attack) - print(' ', 1.0 - success_.mean(axis=-1).round(2)) - np.savez('tmp_perturbations.npz', data=advs[0].numpy()) - np.savez('tmp_images.npz', data=images.numpy()) - np.savez('tmp_inputs.npz', data=inputs[0].numpy()) - import IPython; IPython.embed(); raise SystemExit - robust_accuracy = 1.0 - attack_success[:, :, batch_index, :].max(axis=0).mean(axis=-1) - print('\n', '-' * 79, '\n') - print('worst case (best attack per-sample)') - print(' ', robust_accuracy.round(2)) - print('-' * 79) - attack_success = attack_success.reshape( - (len(attacks), len(epsilons), num_batches*model.params.batch_size)) - attack_types = [str(type(attack)).split('.')[-1][:-2] for attack in attacks] - output_filename = os.path.join(model.params.analysis_save_dir, - f'linf_adversarial_analysis.npz') - out_dict = { - 'adversarial_analysis':attack_success, - 'attack_types':attack_types, - 'epsilons':epsilons, - 'attack_params':attack_params} - np.savez(output_filename, data=out_dict) diff --git a/datasets/synthetic.py b/datasets/synthetic.py index db82e635..3af48bc6 100644 --- a/datasets/synthetic.py +++ b/datasets/synthetic.py @@ -1,5 +1,9 @@ import os import sys +from os.path import dirname as up + +ROOT_DIR = up(up(up(os.path.realpath(__file__)))) +if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np from scipy.stats import norm @@ -7,9 +11,6 @@ import torch import torchvision -ROOT_DIR = os.path.dirname(os.getcwd()) -if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) - import DeepSparseCoding.utils.data_processing as dp class SyntheticImages(torchvision.datasets.vision.VisionDataset): diff --git a/models/base_model.py b/models/base_model.py index c1e383fa..904cb7ea 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -1,16 +1,20 @@ import os +import subprocess +import pprint import numpy as np import torch +from DeepSparseCoding.utils.file_utils import summary_string from DeepSparseCoding.utils.file_utils import Logger +from DeepSparseCoding.utils.run_utils import compute_conv_output_shape +import DeepSparseCoding.utils.loaders as loaders class BaseModel(object): def setup(self, params, logger=None): """ Setup required model components - #TODO: log system info, including git commit hash """ self.load_params(params) self.check_params() @@ -18,6 +22,7 @@ def setup(self, params, logger=None): if logger is None: self.init_logging() self.log_params() + self.logger.log_info(self.get_env_details()) else: self.logger = logger @@ -92,24 +97,106 @@ def log_params(self, params=None): dump_obj = self.params.__dict__ self.logger.log_params(dump_obj) - def log_info(self, string): - """Log input string""" - self.logger.log_info(string) + def get_train_stats(self, batch_step=None): + """ + Get default statistics about current training run + + Keyword arguments: + batch_step: [int] current batch iteration. The default assumes that training has finished. + """ + if batch_step is None: + batch_step = self.params.num_batches + epoch = batch_step / self.params.batches_per_epoch + stat_dict = { + 'epoch':int(epoch), + 'batch_step':batch_step, + 'train_progress':np.round(batch_step/self.params.num_batches, 3), + } + return stat_dict - def write_checkpoint(self): - """Write checkpoints""" - torch.save(self.state_dict(), self.params.cp_latest_filename) - self.log_info('Full model saved in file %s'%self.params.cp_latest_filename) + def get_env_details(self): + env = {} + for k in ['SYSTEMROOT', 'PATH']: + v = os.environ.get(k) + if v is not None: + env[k] = v + commit_cmd = ['git', 'rev-parse', 'HEAD'] + commit = subprocess.Popen(commit_cmd, stdout=subprocess.PIPE, env=env).communicate()[0] + commit = commit.strip().decode('ascii') + branch_cmd = ['git', 'rev-parse', '--abbrev-ref', 'HEAD'] + branch = subprocess.Popen(branch_cmd, stdout=subprocess.PIPE, env=env).communicate()[0] + branch = branch.strip().decode('ascii') + system_details = os.uname() + out_dict = { + 'current_branch':branch, + 'current_commit_hash':commit, + 'sysname':system_details.sysname, + 'release':system_details.release, + 'machine':system_details.machine + } + if torch.cuda.is_available(): + out_dict['gpu_device'] = torch.cuda.get_device_name(0) + return out_dict - def load_checkpoint(self, cp_file=None): + def log_architecture_details(self): + """ + Log model architecture with computed output sizes and number of parameters for each layer + """ + architecture_string = '\n'+summary_string( + self, + input_size=tuple(self.params.data_shape), + batch_size=self.params.batch_size, + device=self.params.device, + dtype=torch.FloatTensor + )[0] + architecture_string += '\n' + self.logger.log_string(architecture_string) + + def write_checkpoint(self, batch_step=None): + """ + Write checkpoints + + Keyword arguments: + batch_step: [int] current batch iteration. The default assumes that training has finished. + """ + output_dict = {} + if(self.params.model_type.lower() == 'ensemble'): + for module in self: + module_name = module.params.submodule_name + output_dict[module_name+'_module_state_dict'] = module.state_dict() + output_dict[module_name+'_optimizer_state_dict'] = module.optimizer.state_dict() + else: + output_dict['model_state_dict'] = self.state_dict() + module_state_dict_name = 'optimizer_state_dict' + output_dict[module_state_dict_name] = self.optimizer.state_dict(), + ## TODO: Save scheduler state dict as well + training_stats = self.get_train_stats(batch_step) + output_dict.update(training_stats) + torch.save(output_dict, self.params.cp_latest_filename) + self.logger.log_string('Full model saved in file %s'%self.params.cp_latest_filename) + + def get_checkpoint_from_log(self, logfile): + model_params = loaders.load_params_from_log(logfile) + checkpoint = torch.load(model_params.cp_latest_filename) + return checkpoint + + def load_checkpoint(self, cp_file=None, load_optimizer=False): """ Load checkpoint - Inputs: - model_dir: String specifying the path to the checkpoint + Keyword arguments: + model_dir: [str] specifying the path to the checkpoint """ if cp_file is None: cp_file = self.params.cp_latest_filename - return self.load_state_dict(torch.load(cp_file)) + checkpoint = torch.load(cp_file) + if load_optimizer: + self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + self.load_state_dict(checkpoint['model_state_dict']) + _ = checkpoint.pop('optimizer_state_dict', None) + _ = checkpoint.pop('model_state_dict', None) + training_status = pprint.pformat(checkpoint, compact=True)#, sort_dicts=True #TODO: Python 3.8 adds the sort_dicts parameter + out_str = f'Loaded checkpoint from {cp_file} with the following stats:\n{training_status}' + return out_str def get_optimizer(self, optimizer_params, trainable_variables): optimizer_name = optimizer_params.optimizer.name @@ -129,8 +216,8 @@ def get_optimizer(self, optimizer_params, trainable_variables): def setup_optimizer(self): self.optimizer = self.get_optimizer( - optimizer_params=self.params, - trainable_variables=self.parameters()) + optimizer_params=self.params, + trainable_variables=self.parameters()) self.scheduler = torch.optim.lr_scheduler.MultiStepLR( self.optimizer, milestones=self.params.optimizer.milestones, @@ -143,21 +230,18 @@ def print_update(self, input_data, input_labels=None, batch_step=0): input_data: data object containing the current image batch input_labels: data object containing the current label batch batch_step: current batch number within the schedule - NOTE: For the analysis code to parse update statistics, the self.js_dumpstring() call - must receive a dict object. Additionally, the self.js_dumpstring() output must be - logged with tags. - For example: logging.info(''+self.js_dumpstring(output_dictionary)+'') + NOTE: For the analysis code to parse update statistics, + the logger.log_stats() function must be used """ update_dict = self.generate_update_dict(input_data, input_labels, batch_step) - js_str = self.js_dumpstring(update_dict) - self.log_info(''+js_str+'') + self.logger.log_stats(update_dict) def generate_update_dict(self, input_data, input_labels=None, batch_step=0, update_dict=None): """ Generates a dictionary to be logged in the print_update function """ if update_dict is None: - update_dict = dict() + update_dict = self.get_train_stats(batch_step) for param_name, param_var in self.named_parameters(): grad = param_var.grad update_dict[param_name+'_grad_max_mean_min'] = [ diff --git a/models/ensemble_model.py b/models/ensemble_model.py index 60bbb343..d1b334fd 100644 --- a/models/ensemble_model.py +++ b/models/ensemble_model.py @@ -1,3 +1,5 @@ +import pprint + import torch import DeepSparseCoding.utils.loaders as loaders @@ -15,28 +17,96 @@ def setup(self, params, logger=None): self.setup_optimizer() def setup_module(self, params): - for subparams in params.ensemble_params: + layer_names = [] # TODO: Make this submodule_name=model_type+layer_name is unique, not layer_name is unique + for sub_index, subparams in enumerate(params.ensemble_params): + layer_names.append(subparams.layer_name) + assert len(set(layer_names)) == len(layer_names), ( + 'The "layer_name" parameter must be unique for each module in the ensemble.') + subparams.submodule_name = subparams.model_type + '_' + subparams.layer_name subparams.epoch_size = params.epoch_size subparams.batches_per_epoch = params.batches_per_epoch subparams.num_batches = params.num_batches #subparams.num_val_images = params.num_val_images #subparams.num_test_images = params.num_test_images - subparams.data_shape = params.data_shape + if not hasattr(subparams, 'data_shape'): # TODO: This is a workaround for a dependency on data_shape in lca module + subparams.data_shape = params.data_shape super(EnsembleModel, self).setup_ensemble_module(params) self.submodel_classes = [] - for submodel_params in self.params.ensemble_params: - self.submodel_classes.append(loaders.load_model_class(submodel_params.model_type)) + for ensemble_index, subparams in enumerate(self.params.ensemble_params): + submodule_class = loaders.load_model_class(subparams.model_type) + self.submodel_classes.append(submodule_class) + if subparams.checkpoint_boot_log != '': + checkpoint = self.get_checkpoint_from_log(subparams.checkpoint_boot_log) + submodule = self.__getitem__(ensemble_index) + module_state_dict_name = subparams.submodule_name+'_module_state_dict' + if module_state_dict_name in checkpoint.keys(): # It was already in an ensemble + submodule.load_state_dict(checkpoint[module_state_dict_name]) + else: # it was trained on its own + if 'model_state_dict' in checkpoint.keys(): + submodule.load_state_dict(checkpoint['model_state_dict']) + else: + assert False, ( + f'subparams {subparams} has checkpoint_boot_log set to ' + +f'{subparams.checkpoint_boot_log}, but that log does not have the ' + +f'appropriate key. The key "{module_state_dict_name}" must be in ' + +f'checkpoint.keys() = {checkpoint.keys}') def setup_optimizer(self): for module in self: module.optimizer = self.get_optimizer( optimizer_params=module.params, trainable_variables=module.parameters()) + if module.params.checkpoint_boot_log != '': + checkpoint = self.get_checkpoint_from_log(module.params.checkpoint_boot_log) + module_state_dict_name = module.params.submodule_name+'_optimizer_state_dict' + if module_state_dict_name in checkpoint.keys(): # It was already in an ensemble + module.optimizer.load_state_dict(checkpoint[module_state_dict_name]) + else: # it was trained on its own + module.optimizer.load_state_dict(checkpoint['optimizer_state_dict'][0]) #TODO: For some reason this is a tuple of size 1 containing the dictionary. It should just be the dictionary + for group in module.optimizer.param_groups: # overwrite learning rates + group['lr'] = module.params.weight_lr + group['initial_lr'] = module.params.weight_lr + ## TODO: load scheduler state dict with checkpoint, set last_epoch correctly + ## https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html module.scheduler = torch.optim.lr_scheduler.MultiStepLR( module.optimizer, milestones=module.params.optimizer.milestones, gamma=module.params.optimizer.lr_decay_rate) + def load_checkpoint(self, cp_file=None, load_optimizer=False): + """ + Load checkpoint + Keyword arguments: + model_dir: [str] specifying the path to the checkpoint + """ + if cp_file is None: + cp_file = self.params.cp_latest_filename + checkpoint = torch.load(cp_file) + for module in self: + module_state_dict_name = module.params.submodule_name+'_module_state_dict' + if module_state_dict_name in checkpoint.keys(): # It was already in an ensemble + module.load_state_dict(checkpoint[module_state_dict_name]) + _ = checkpoint.pop(module_state_dict_name, None) + else: # it was trained on its own + module.load_state_dict(checkpoint['model_state_dict']) + _ = checkpoint.pop('optimizer_state_dict', None) + if load_optimizer: + module_state_dict_name = module.params.submodule_name+'_optimizer_state_dict' + if module_state_dict_name in checkpoint.keys(): # It was already in an ensemble + module.optimizer.load_state_dict(checkpoint[module_state_dict_name]) + _ = checkpoint.pop(module_state_dict_name, None) + else: + module.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + _ = checkpoint.pop('optimizer_state_dict', None) + for group in module.optimizer.param_groups: # overwrite learning rates + group['lr'] = module.params.weight_lr + group['initial_lr'] = module.params.weight_lr + ## TODO: Load scheduler state dict as well + _ = checkpoint.pop('model_state_dict', None) + training_status = pprint.pformat(checkpoint, compact=True)#, sort_dicts=True #TODO: Python 3.8 adds the sort_dicts parameter + out_str = f'Loaded checkpoint from {cp_file} with the following stats:\n{training_status}' + return out_str + def preprocess_data(self, data): """ We assume that only the first submodel will be preprocessing the input data @@ -59,7 +129,7 @@ def generate_update_dict(self, input_data, input_labels=None, batch_step=0): input_labels, batch_step, update_dict=dict()) for key, value in submodel_update_dict.items(): if key not in ['epoch', 'batch_step']: - key = submodule.params.model_type+'_'+key + key = submodule.params.submodule_name + '_' + key update_dict[key] = value x = submodule.get_encodings(x) return update_dict diff --git a/models/lca_model.py b/models/lca_model.py index ec13c014..ad78e3a8 100644 --- a/models/lca_model.py +++ b/models/lca_model.py @@ -11,6 +11,10 @@ def setup(self, params, logger=None): super(LcaModel, self).setup(params, logger) self.setup_module(params) self.setup_optimizer() + if params.checkpoint_boot_log != '': + checkpoint = self.get_checkpoint_from_log(params.checkpoint_boot_log) + self.module.load_state_dict(checkpoint['model_state_dict']) + self.module.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) def get_total_loss(self, input_tuple): input_tensor, input_labels = input_tuple @@ -24,16 +28,12 @@ def get_total_loss(self, input_tuple): def generate_update_dict(self, input_data, input_labels=None, batch_step=0, update_dict=None): if update_dict is None: update_dict = super(LcaModel, self).generate_update_dict(input_data, input_labels, batch_step) - epoch = batch_step / self.params.batches_per_epoch - stat_dict = { - 'epoch':int(epoch), - 'batch_step':batch_step, - 'train_progress':np.round(batch_step/self.params.num_batches, 3), - 'weight_lr':self.scheduler.get_lr()[0]} + stat_dict = dict() latents = self.get_encodings(input_data) recon = self.get_recon_from_latents(latents) recon_loss = losses.half_squared_l2(input_data, recon).item() sparse_loss = self.params.sparse_mult * losses.l1_norm(latents).item() + stat_dict['weight_lr'] = self.scheduler.get_last_lr()[0] stat_dict['loss_recon'] = recon_loss stat_dict['loss_sparse'] = sparse_loss stat_dict['loss_total'] = recon_loss + sparse_loss @@ -41,7 +41,20 @@ def generate_update_dict(self, input_data, input_labels=None, batch_step=0, upda input_data.max().item(), input_data.mean().item(), input_data.min().item()] stat_dict['recon_max_mean_min'] = [ recon.max().item(), recon.mean().item(), recon.min().item()] - latent_nnz = torch.sum(latents != 0).item() # TODO: github issue 23907 requests torch.count_nonzero - stat_dict['latents_fraction_active'] = latent_nnz / latents.numel() + def count_nonzero(array, dim): + # TODO: github issue 23907 requests torch.count_nonzero, integrated in torch 1.7 + return torch.sum(array !=0, dim=dim, dtype=torch.float) + latent_dims = tuple([i for i in range(len(latents.shape))]) + latent_nnz = count_nonzero(latents, dim=latent_dims).item() + stat_dict['fraction_active_all_latents'] = latent_nnz / latents.numel() + if self.params.layer_types[0] == 'conv': + latent_map_dims = latent_dims[2:] + latent_map_size = np.prod(list(latents.shape[2:])) + latent_channel_nnz = count_nonzero(latents, dim=latent_map_dims)/latent_map_size + latent_channel_mean_nnz = torch.mean(latent_channel_nnz).item() + stat_dict['fraction_active_latents_per_channel'] = latent_channel_mean_nnz + num_channels = latents.shape[1] + latent_patch_mean_nnz = torch.mean(count_nonzero(latents, dim=1)/num_channels).item() + stat_dict['fraction_active_latents_per_patch'] = latent_patch_mean_nnz update_dict.update(stat_dict) return update_dict diff --git a/models/mlp_model.py b/models/mlp_model.py index bf755d12..28d0f261 100644 --- a/models/mlp_model.py +++ b/models/mlp_model.py @@ -10,6 +10,10 @@ def setup(self, params, logger=None): super(MlpModel, self).setup(params, logger) self.setup_module(params) self.setup_optimizer() + if params.checkpoint_boot_log != '': + checkpoint = self.get_checkpoint_from_log(params.checkpoint_boot_log) + self.module.load_state_dict(checkpoint['model_state_dict']) + self.module.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) def get_total_loss(self, input_tuple): input_tensor, input_label = input_tuple @@ -20,16 +24,12 @@ def get_total_loss(self, input_tuple): def generate_update_dict(self, input_data, input_labels=None, batch_step=0, update_dict=None): if update_dict is None: update_dict = super(MlpModel, self).generate_update_dict(input_data, input_labels, batch_step) - epoch = batch_step / self.params.batches_per_epoch - stat_dict = { - 'epoch':int(epoch), - 'batch_step':batch_step, - 'train_progress':np.round(batch_step/self.params.num_batches, 3)} + stat_dict = dict() pred = self.forward(input_data) - #total_loss = F.nll_loss(pred, input_labels) total_loss = self.loss_fn(pred, input_labels) pred = pred.max(1, keepdim=True)[1] correct = pred.eq(input_labels.view_as(pred)).sum().item() + stat_dict['weight_lr'] = self.scheduler.get_last_lr()[0] # one LR for all parameters stat_dict['loss'] = total_loss.item() stat_dict['train_accuracy'] = 100. * correct / self.params.batch_size update_dict.update(stat_dict) diff --git a/models/pooling_model.py b/models/pooling_model.py new file mode 100644 index 00000000..3f5caf1a --- /dev/null +++ b/models/pooling_model.py @@ -0,0 +1,46 @@ +import torch + +import DeepSparseCoding.modules.losses as losses + +from DeepSparseCoding.models.base_model import BaseModel +from DeepSparseCoding.modules.pooling_module import PoolingModule + +class PoolingModel(BaseModel, PoolingModule): + """ + TODO: rename pool_ksize and pool_stride to just kernel_size and stride + """ + def setup(self, params, logger=None): + self.setup_module(params) + self.setup_optimizer() + if params.checkpoint_boot_log != '': + checkpoint = self.get_checkpoint_from_log(params.checkpoint_boot_log) + self.module.load_state_dict(checkpoint['model_state_dict']) + self.module.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + + def get_total_loss(self, input_tuple): + def loss_fn(model_output): + output_loss = losses.trace_covariance(model_output) + w_stride = self.params.pool_stride + weight_loss = losses.weight_orthogonality(self.weight, stride=w_stride, padding=0) + return output_loss + weight_loss + input_tensor, input_label = input_tuple + layer_output = self.forward(input_tensor) + self.loss_fn = loss_fn + return self.loss_fn(layer_output) + + def generate_update_dict(self, input_data, input_labels=None, batch_step=0, update_dict=None): + if update_dict is None: + update_dict = super(PoolinModel, self).generate_update_dict(input_data, input_labels, batch_step) + stat_dict = dict() + rep = self.forward(input_data) + def count_nonzero(array, dim): + # TODO: github issue 23907 requests torch.count_nonzero, integrated in torch 1.7 + return torch.sum(array !=0, dim=dim, dtype=torch.float) + rep_dims = tuple([i for i in range(len(rep.shape))]) + rep_nnz = count_nonzero(rep, dim=rep_dims).item() + stat_dict['fraction_active_all_latents'] = rep_nnz / rep.numel() + total_loss = self.loss_fn(rep) + stat_dict['weight_lr'] = self.scheduler.get_last_lr()[0] + stat_dict['loss'] = total_loss.item() + update_dict.update(stat_dict) + return update_dict diff --git a/modules/activations.py b/modules/activations.py index 98446536..56a502a2 100644 --- a/modules/activations.py +++ b/modules/activations.py @@ -1,17 +1,5 @@ import torch import torch.nn as nn -import torch.nn.functional as F - -def activation_picker(activation_function): - if activation_function == 'identity': - return lambda x: x - if activation_function == 'relu': - return F.relu - if activation_function == 'lrelu' or activation_function == 'leaky_relu': - return F.leaky_relu - if activation_function == 'lca_threshold': - return lca_threshold - assert False, (f'Activation function {activation_function} is not supported.') def lca_threshold(u_in, thresh_type, rectify, sparse_threshold): u_zeros = torch.zeros_like(u_in) @@ -40,3 +28,12 @@ def lca_threshold(u_in, thresh_type, rectify, sparse_threshold): else: assert False, (f'Parameter thresh_type must be "soft" or "hard", not {thresh_type}') return a_out + +def activation_picker(activation_function): + if activation_function == 'identity': + return nn.Identity() + if activation_function == 'relu': + return nn.ReLU() + if activation_function == 'lrelu' or activation_function == 'leaky_relu': + return nn.LeakyReLU() + assert False, (f'Activation function {activation_function} is not supported.') diff --git a/modules/ensemble_module.py b/modules/ensemble_module.py index 9e7b7932..27e60f77 100644 --- a/modules/ensemble_module.py +++ b/modules/ensemble_module.py @@ -4,21 +4,19 @@ class EnsembleModule(nn.Sequential): - def __init__(self): # do not do Sequential's init - super(nn.Sequential, self).__init__() - def setup_ensemble_module(self, params): self.params = params for subparams in params.ensemble_params: submodule = loaders.load_module(subparams.model_type) submodule.setup_module(subparams) - self.add_module(subparams.model_type, submodule) + self.add_module(subparams.layer_name, submodule) def forward(self, x): - self.layer_list = [x] for module in self: - self.layer_list.append(module.get_encodings(self.layer_list[-1])) # latent encodings - return self.layer_list[-1] + if module.params.layer_types[0] == 'fc': + x = x.view(x.size(0), -1) #flat + x = module(x) + return x def get_encodings(self, x): return self.forward(x) diff --git a/modules/lca_module.py b/modules/lca_module.py index 9444a5d6..b8446aae 100644 --- a/modules/lca_module.py +++ b/modules/lca_module.py @@ -3,62 +3,132 @@ import torch.nn.functional as F from DeepSparseCoding.modules.activations import lca_threshold +from DeepSparseCoding.utils.run_utils import compute_conv_output_shape +import DeepSparseCoding.utils.data_processing as dp class LcaModule(nn.Module): + """ + Keyword arguments: + params: [dict] with keys: + data_shape [list of int] of shape [elements, channels, height, width]; Assumes h = w (i.e. square inputs) + The remaining keys are only used layer_types[0] is "conv": + kernel_size: [int] edge size of the square convolving kernel + stride: [int] vertical and horizontal stride of the convolution + padding: [int] zero-padding added to both sides of the input + """ def setup_module(self, params): self.params = params - self.w = nn.Parameter( - F.normalize( - torch.randn(self.params.num_pixels, self.params.num_latent), - p=2, dim=0), - requires_grad=True) + if self.params.layer_types[0] == 'fc': + self.w_shape = self.params.layer_channels[::-1] #[outputs, inputs] + self.layer_output_shape = [self.params.layer_channels[-1]] + else: + assert (self.params.data_shape[-1] % self.params.stride == 0), ( + f'Stride = {self.params.stride} must divide evenly into input edge size = {self.params.data_shape[-1]}') + self.w_shape = [ + self.params.layer_channels[1], + self.params.layer_channels[0], + self.params.kernel_size, + self.params.kernel_size + ] + output_height = compute_conv_output_shape( + self.params.data_shape[1], + self.params.kernel_size, + self.params.stride, + self.params.padding, + dilation=1) + output_width = compute_conv_output_shape( + self.params.data_shape[2], + self.params.kernel_size, + self.params.stride, + self.params.padding, + dilation=1) + self.layer_output_shape = [self.params.layer_channels[1], output_height, output_width] + w_init = torch.randn(self.w_shape) + w_init_normed = dp.l2_normalize_weights(w_init, eps=self.params.eps) + self.weight = nn.Parameter(w_init_normed, requires_grad=True) def preprocess_data(self, input_tensor): - input_tensor = input_tensor.view(-1, self.params.num_pixels) + if self.params.layer_types[0] == 'fc': + input_tensor = input_tensor.view(self.params.batch_size, -1) return input_tensor - def compute_excitatory_current(self, input_tensor): - return torch.matmul(input_tensor, self.w) + def compute_excitatory_current(self, input_tensor, a_in): + if self.params.layer_types[0] == 'fc': + excitatory_current = torch.mm(input_tensor, self.weight.T) + else: + recon = self.get_recon_from_latents(a_in) + recon_error = input_tensor - recon + error_injection = F.conv2d( + input=recon_error, + weight=self.weight, + bias=None, + stride=self.params.stride, + padding=self.params.padding + ) + excitatory_current = error_injection + a_in + return excitatory_current def compute_inhibitory_connectivity(self): - lca_g = torch.matmul(torch.transpose(self.w, dim0=0, dim1=1), - self.w) - torch.eye(self.params.num_latent, + identity = torch.eye(self.params.layer_channels[1], requires_grad=True, device=self.params.device) - return lca_g + if self.params.layer_types[0] == 'fc': + inhibitory_connectivity = torch.mm(self.weight, self.weight.T) - identity + else: + conv_kernels = self.weight.view(1, -1) + inhibitory_connectivity = torch.mm(conv_kernels, conv_kernels.T) - identity + return inhibitory_connectivity def threshold_units(self, u_in): a_out = lca_threshold(u_in, self.params.thresh_type, self.params.rectify_a, self.params.sparse_mult) return a_out - def step_inference(self, u_in, a_in, b, g, step): - lca_explain_away = torch.matmul(a_in, g) - du = b - lca_explain_away - u_in + def step_inference(self, u_in, a_in, excitatory_current, inhibitory_connectivity, step): + if self.params.layer_types[0] == 'fc': + lca_explain_away = torch.mm(a_in, inhibitory_connectivity) + else: + lca_explain_away = 0 # already computed in excitatory_current + du = excitatory_current - lca_explain_away - u_in u_out = u_in + self.params.step_size * du return u_out, lca_explain_away def infer_coefficients(self, input_tensor): - lca_b = self.compute_excitatory_current(input_tensor) - lca_g = self.compute_inhibitory_connectivity() - u_list = [torch.zeros([input_tensor.shape[0], self.params.num_latent], - device=self.params.device)] + output_shape = [input_tensor.shape[0]] + self.layer_output_shape + u_list = [torch.zeros(output_shape, device=self.params.device)] a_list = [self.threshold_units(u_list[0])] - # TODO: look into redoing this with a register_buffer that gets updated? look up simple RNN code... + excitatory_current = self.compute_excitatory_current(input_tensor, a_list[-1]) + inhibitory_connectivity = self.compute_inhibitory_connectivity() for step in range(self.params.num_steps-1): - u = self.step_inference(u_list[step], a_list[step], lca_b, lca_g, step)[0] + u = self.step_inference( + u_list[step], + a_list[step], + excitatory_current, + inhibitory_connectivity, + step + )[0] u_list.append(u) a_list.append(self.threshold_units(u)) + if self.params.layer_types[0] == 'conv': + excitatory_current = self.compute_excitatory_current(input_tensor, a_list[-1]) return (u_list, a_list) - def get_recon_from_latents(self, latents): - return torch.matmul(latents, torch.transpose(self.w, dim0=0, dim1=1)) + def get_recon_from_latents(self, a_in): + if self.params.layer_types[0] == 'fc': + recon = torch.mm(a_in, self.weight) + else: + recon = F.conv_transpose2d( + input=a_in, + weight=self.weight, + bias=None, + stride=self.params.stride, + padding=self.params.padding + ) + return recon - def get_encodings(self, input_tensor): + def forward(self, input_tensor): u_list, a_list = self.infer_coefficients(input_tensor) return a_list[-1] - def forward(self, input_tensor): - latents = self.get_encodings(input_tensor) - reconstruction = self.get_recon_from_latents(latents) - return reconstruction + def get_encodings(self, input_tensor): + return self.forward(input_tensor) diff --git a/modules/losses.py b/modules/losses.py index d9ea61b7..4176a10e 100644 --- a/modules/losses.py +++ b/modules/losses.py @@ -1,15 +1,24 @@ +import numpy as np import torch import DeepSparseCoding.utils.data_processing as dp +#def l2_flatness(z1, z2, z3, weight): +# """ +# Minimized when a straight line can be drawn through [z1, z2, z3]. +# Extended from equations 8 and 12 in +# Chen, Paiton, Olshausen (2018) - The Sparse Manifold Transform +# """ +# z_mat = + def half_squared_l2(x1, x2): """ Computes the standard reconstruction loss. It will average over batch dimensions. - Args: + Keyword arguments: x1: Tensor with original input image x2: Tensor with reconstructed image for comparison - Returns: + Outputs: recon_loss: Tensor representing the squared l2 distance between the inputs, averaged over batch """ dp.check_all_same_shape([x1, x2]) @@ -22,15 +31,15 @@ def half_squared_l2(x1, x2): def half_weight_norm_squared(weight_list): """ Computes a loss that encourages each weight in the list of weights to have unit l2 norm. - Args: + Keyword arguments: weight_list: List of torch variables - Returns: - w_norm_loss: 0.5 * sum of (1 - l2_norm(w))^2 for each w in weight_list + Outputs: + w_norm_loss: 0.5 * sum of (1 - l2_norm(weight))^2 for each weight in weight_list """ w_norm_list = [] - for w in weight_list: - reduc_dim = list(range(1, len(w.shape))) - w_norm = torch.sum(torch.pow(1 - torch.sqrt(torch.sum(tf.pow(w, 2.), axis=reduc_dim)), 2.)) + for weight in weight_list: + reduc_dim = list(range(1, len(weight.shape))) + w_norm = torch.sum(torch.pow(1 - torch.sqrt(torch.sum(tf.pow(weight, 2.), axis=reduc_dim)), 2.)) w_norm_list.append(w_norm) norm_loss = 0.5 * torch.sum(w_norm_list) return norm_loss @@ -39,12 +48,12 @@ def half_weight_norm_squared(weight_list): def weight_decay(weight_list): """ Computes typical weight decay loss - Args: + Keyword arguments: weight_list: List of torch variables - Returns: - decay_loss: 0.5 * sum of w^2 for each w in weight_list + Outputs: + decay_loss: 0.5 * sum of weight^2 for each weight in weight_list """ - decay_loss = 0.5 * torch.sum([torch.sum(torch.pow(w, 2.)) for w in weight_list]) + decay_loss = 0.5 * torch.sum([torch.sum(torch.pow(weight, 2.)) for weight in weight_list]) return decay_loss @@ -52,11 +61,66 @@ def l1_norm(latents): """ Computes the L1 norm of for a batch of input vector This is the sparsity loss for a Laplacian prior - Args: + Keyword arguments: latents: torch tensor of any shape, but where first index is always batch - Returns: + Outputs: sparse_loss: sum of abs of latents, averaged over the batch """ reduc_dim = list(range(1, len(latents.shape))) sparse_loss = torch.mean(torch.sum(torch.abs(latents), dim=reduc_dim, keepdim=False)) return sparse_loss + + +def trace_covariance(latents): + """ + Returns loss that is the trace of the covariance matrix of the latents + + Keyword arguments: + latents: torch tensor of shape [num_batch, num_latents] or [num_batch, num_channels, latents_h, latents_w] + Outputs: + loss + """ + covariance = dp.covariance(latents) # [num_channels, num_channels] + if latents.ndim == 4: + num_batch, num_channels, latents_h, latents_w = latents.shape + covariance = covariance / (latents_h * latents_w - 1.0) + trace = torch.trace(covariance) + target = torch.trace(torch.eye(covariance.size(0), device=trace.device)) # should = trace.size[0] + return torch.norm(trace - target, p='fro') + + +def weight_orthogonality(weight, stride=1, padding=0): + """ + Returns l2 loss that is minimized when the weight are orthogonal + + Keyword arguments: + weight [torch tensor] layer weight, either fully connected or 2d convolutional + stride [int] layer stride for convolutional layers + padding [int] layer padding for convolutional layers + + Outputs: + loss + + Note: + Convolutional orthogonalization loss is based on + Orthogonal Convolutional Neural Networks + https://arxiv.org/abs/1911.12207 + https://github.com/samaonline/Orthogonal-Convolutional-Neural-Networks + """ + w_shape = weight.shape + if weight.ndim == 2: # fully-connected, [inputs, outputs] + loss = torch.norm(torch.mm(weight.T, weight) - torch.eye(w_shape[1], device=weight.device)) + elif weight.ndim == 4: # convolutional, [output_channels, input_channels, height, width] + out_channels, in_channels, in_height, in_width = w_shape + output = torch.conv2d(weight, weight, stride=stride, padding=padding) + out_height = output.shape[-2] + out_width = output.shape[-1] + target = torch.zeros((out_channels, out_channels, out_height, out_width), + device=weight.device) + center_h = int(np.floor(out_height / 2)) + center_w = int(np.floor(out_width / 2)) + target[:, :, center_h, center_w] = torch.eye(out_channels, device=weight.device) + loss = torch.norm(output - target, p='fro') + else: + assert False, (f'weight ndim must be 2 or 4, not {weight.ndim}') + return loss diff --git a/modules/mlp_module.py b/modules/mlp_module.py index 4877d8eb..cebda2c3 100644 --- a/modules/mlp_module.py +++ b/modules/mlp_module.py @@ -1,7 +1,11 @@ +from collections import OrderedDict + +import numpy as np import torch.nn as nn import torch.nn.functional as F from DeepSparseCoding.modules.activations import activation_picker +from DeepSparseCoding.utils.run_utils import compute_conv_output_shape class MlpModule(nn.Module): @@ -9,28 +13,102 @@ def setup_module(self, params): self.params = params self.act_funcs = [activation_picker(act_func_str) for act_func_str in self.params.activation_functions] + self.layer_output_shapes = [self.params.data_shape] # [channels, height, width] self.layers = [] + self.pooling = [] self.dropout = [] for layer_index, layer_type in enumerate(self.params.layer_types): if layer_type == 'fc': + if(layer_index > 0 and self.params.layer_types[layer_index-1] == 'conv'): + in_features = np.prod(self.layer_output_shapes[-1]).astype(np.int) + else: + in_features = self.params.layer_channels[layer_index] layer = nn.Linear( - in_features = self.params.layer_channels[layer_index], - out_features = self.params.layer_channels[layer_index+1], - bias = True) + in_features=in_features, + out_features=self.params.layer_channels[layer_index + 1], + bias=True) self.register_parameter('fc'+str(layer_index)+'_w', layer.weight) self.register_parameter('fc'+str(layer_index)+'_b', layer.bias) self.layers.append(layer) + self.layer_output_shapes.append(self.params.layer_channels[layer_index + 1]) + elif layer_type == 'conv': + layer = nn.Conv2d( + in_channels=self.params.layer_channels[layer_index], + out_channels=self.params.layer_channels[layer_index + 1], + kernel_size=self.params.kernel_sizes[layer_index], + stride=self.params.strides[layer_index], + padding=0, + dilation=1, + bias=True) + self.register_parameter('conv'+str(layer_index)+'_w', layer.weight) + self.register_parameter('conv'+str(layer_index)+'_b', layer.bias) + self.layers.append(layer) + output_channels = self.params.layer_channels[layer_index + 1] + output_height = compute_conv_output_shape( + self.layer_output_shapes[-1][1], + self.params.kernel_sizes[layer_index], + self.params.strides[layer_index], + padding=0, + dilation=1) + output_width = compute_conv_output_shape( + self.layer_output_shapes[-1][2], + self.params.kernel_sizes[layer_index], + self.params.strides[layer_index], + padding=0, + dilation=1) + self.layer_output_shapes.append([output_channels, output_height, output_width]) + else: + assert False, ('layer_type parameter must be "fc" or "conv", not %g'%(layer_type)) + if(self.params.max_pool[layer_index] and layer_type == 'conv'): + self.pooling.append(nn.MaxPool2d( + kernel_size=self.params.pool_ksizes[layer_index], + stride=self.params.pool_strides[layer_index], + padding=0, + dilation=1)) + output_channels = self.params.layer_channels[layer_index + 1] + output_height = compute_conv_output_shape( + self.layer_output_shapes[-1][1], + self.params.pool_ksizes[layer_index], + self.params.pool_strides[layer_index], + padding=0, + dilation=1) + output_width = compute_conv_output_shape( + self.layer_output_shapes[-1][2], + self.params.pool_ksizes[layer_index], + self.params.pool_strides[layer_index], + padding=0, + dilation=1) + self.layer_output_shapes.append([output_channels, output_height, output_width]) else: - assert False, ('layer_type parameter must be "fc", not %g'%(layer_type)) + self.pooling.append(nn.Identity()) # do nothing self.dropout.append(nn.Dropout(p=self.params.dropout_rate[layer_index])) + conv_module_dict = OrderedDict() + fc_module_dict = OrderedDict() + layer_zip = zip(self.params.layer_types, self.layers, self.act_funcs, self.pooling, + self.dropout) + for layer_idx, full_layer in enumerate(layer_zip): + for component_idx, layer_component in enumerate(full_layer[1:]): + component_id = f'{layer_idx:02}-{component_idx:02}' + if full_layer[0] == 'fc': + fc_module_dict[full_layer[0] + component_id] = layer_component + else: + conv_module_dict[full_layer[0] + component_id] = layer_component + self.conv_sequential = lambda x: x # identity by default + self.fc_sequential = lambda x: x # identity by default + if len(conv_module_dict) > 0: + self.conv_sequential = nn.Sequential(conv_module_dict) + if len(fc_module_dict) > 0: + self.fc_sequential = nn.Sequential(fc_module_dict) def preprocess_data(self, input_tensor): - input_tensor = input_tensor.view(-1, self.params.layer_channels[0]) + if self.params.layer_types[0] == 'fc': + input_tensor = input_tensor.view(input_tensor.size(0), -1) #flat return input_tensor def forward(self, x): - for dropout, act_func, layer in zip(self.dropout, self.act_funcs, self.layers): - x = dropout(act_func(layer(x))) + x = self.conv_sequential(x) + x = x.view(x.size(0), -1) #flat + x = self.fc_sequential(x) return x def get_encodings(self, input_tensor): diff --git a/modules/pooling_module.py b/modules/pooling_module.py new file mode 100644 index 00000000..80f2bca2 --- /dev/null +++ b/modules/pooling_module.py @@ -0,0 +1,37 @@ +import torch +import torch.nn as nn + + +class PoolingModule(nn.Module): + def setup_module(self, params): + params.weight_decay = 0 # used by base model; pooling layer never has weight decay + self.params = params + if self.params.layer_types[0] == 'fc': + self.layer = nn.Linear( + in_features=self.params.layer_channels[0], + out_features=self.params.layer_channels[1], + bias=False) + self.weight = self.layer.weight # [outputs, inputs] + + elif self.params.layer_types[0] == 'conv': + self.layer = nn.Conv2d( + in_channels=self.params.layer_channels[0], + out_channels=self.params.layer_channels[1], + kernel_size=self.params.pool_ksize, + stride=self.params.pool_stride, + padding=0, + dilation=1, + bias=False) + nn.init.orthogonal_(self.layer.weight) # initialize to orthogonal matrix + self.weight = self.layer.weight + + else: + assert False, ('layer_types[0] parameter must be "fc", "conv", not %g'%(layer_types[0])) + + def forward(self, x): + if self.params.layer_types[0] == 'fc': + x = x.view(x.shape[0], -1) # flat + return self.layer(x) + + def get_encodings(self, input_tensor): + return self.forward(input_tensor) diff --git a/notebooks/monitor_training.ipynb b/notebooks/monitor_training.ipynb index 2c3ecad2..3796865e 100644 --- a/notebooks/monitor_training.ipynb +++ b/notebooks/monitor_training.ipynb @@ -23,7 +23,8 @@ "outputs": [], "source": [ "workspace_dir = os.path.expanduser('~')+'/Work/'\n", - "log_file = workspace_dir+'/Torch_projects/lca_768_mlp_mnist/logfiles/lca_768_mlp_mnist_v0.log'\n", + "model_name = 'conv_lca_mnist'\n", + "log_file = workspace_dir+'/Torch_projects/{}/logfiles/{}_v0.log'.format(model_name, model_name)\n", "logger = Logger(log_file, overwrite=False)\n", "\n", "log_text = logger.load_file()\n", @@ -41,9 +42,9 @@ "outputs": [], "source": [ "x_key = 'epoch'\n", - "y_keys = ['lca_loss_recon', 'lca_loss_sparse', 'lca_loss_total', 'mlp_loss', 'mlp_train_accuracy']\n", - "y_labels = ['Recon loss', 'Sparse loss', 'Total LCA loss', 'Total MLP loss', 'MLP train accuracy']\n", - "stats_fig = pf.plot_stats(model_stats, x_key=x_key, y_keys=y_keys, y_labels=y_labels, start_index=0)" + "#y_keys = ['lca_loss_recon', 'lca_loss_sparse', 'lca_loss_total', 'mlp_loss', 'mlp_train_accuracy']\n", + "#y_labels = ['Recon loss', 'Sparse loss', 'Total LCA loss', 'Total MLP loss', 'MLP train accuracy']\n", + "stats_fig = pf.plot_stats(model_stats, x_key=x_key)#, y_keys=y_keys, y_labels=y_labels, start_index=0)" ] }, { @@ -70,7 +71,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.7" + "version": "3.6.8" }, "varInspector": { "cols": { diff --git a/notebooks/smt_analysis.ipynb b/notebooks/smt_analysis.ipynb new file mode 100644 index 00000000..c79282e9 --- /dev/null +++ b/notebooks/smt_analysis.ipynb @@ -0,0 +1,1470 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "\n", + "ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd()))\n", + "if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)\n", + "\n", + "import scipy\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "import proplot as plot\n", + "import matplotlib\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.gridspec as gridspec\n", + "from matplotlib.colors import LinearSegmentedColormap\n", + "\n", + "from DeepSparseCoding.utils.file_utils import Logger\n", + "import DeepSparseCoding.utils.run_utils as run_utils\n", + "import DeepSparseCoding.utils.dataset_utils as dataset_utils\n", + "import DeepSparseCoding.utils.loaders as loaders\n", + "import DeepSparseCoding.utils.plot_functions as pf\n", + "import DeepSparseCoding.utils.data_processing as dp\n", + "from DeepSparseCoding.params.lca_cifar10_params import params as LcaParams" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "workspace_dir = '/mnt/qb/bethge/dpaiton/'\n", + "model_name = 'smt_cifar10'\n", + "model_version = 'lplp'\n", + "log_file = workspace_dir + os.path.join(*['Projects', model_name, 'logfiles', f'{model_name}_v{model_version}.log'])\n", + "logger = Logger(log_file, overwrite=False)\n", + "log_text = logger.load_file()\n", + "params = logger.read_params(log_text)[-1]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "model_stats = logger.read_stats(log_text)\n", + "x_key = \"epoch\"\n", + "y_keys = [key for key in list(model_stats.keys()) if 'test_' not in key]\n", + "stats_fig = pf.plot_stats(model_stats, x_key, y_keys=y_keys)\n", + "\n", + "if 'test_epoch' in list(model_stats.keys()):\n", + " x_key = \"test_epoch\"\n", + " y_keys = [key for key in list(model_stats.keys()) if 'test_' in key]\n", + " test_stats_fig = pf.plot_stats(model_stats, x_key, y_keys=y_keys)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = loaders.load_model(params.model_type)\n", + "model.setup(params, logger)\n", + "model.to(params.device)\n", + "model_state_str = model.load_checkpoint()\n", + "print(model_state_str)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "lca_weights = model.lca_1.weight.detach().cpu().numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "def normalize_data_with_max(data):\n", + " \"\"\"\n", + " Normalize data by dividing by abs(max(data))\n", + " If abs(max(data)) is zero, then output is zero\n", + " Inputs:\n", + " data: [np.ndarray] data to be normalized\n", + " Outputs:\n", + " norm_data: [np.ndarray] normalized data\n", + " data_max: [float] max that was divided out\n", + " \"\"\"\n", + " data_max = np.max(np.abs(data), axis=(1,2), keepdims=True)\n", + " norm_data = np.divide(data, data_max, out=np.zeros_like(data), where=data_max!=0)\n", + " return norm_data, data_max\n", + "\n", + "def pad_matrix_to_image(matrix, pad_size=0, pad_value=0, normalize=False):\n", + " if normalize:\n", + " #matrix = normalize_data_with_max(matrix)[0]\n", + " matrix = dp.rescale_data_to_one(torch.from_numpy(matrix), eps=1e-10, samplewise=True)[0].numpy()\n", + " num_weights, img_c, img_h, img_w = matrix.shape\n", + " #if img_c == 1:\n", + " # matrix = matrix.squeeze()\n", + " #else:\n", + " # # TODO: separate channels, pad each individual one, then recombine.\n", + " # assert False, (f'Multiple color channels are not currently supported') \n", + " num_extra_images = int(np.ceil(np.sqrt(num_weights))**2 - num_weights)\n", + " matrices = []\n", + " for channel_idx in range(img_c):\n", + " channel_matrix = matrix[:, channel_idx, ...].copy()\n", + " if num_extra_images > 0:\n", + " channel_matrix = np.concatenate(\n", + " [channel_matrix, np.zeros((num_extra_images, img_h, img_w))], axis=0)\n", + " channel_matrix = np.pad(channel_matrix,\n", + " pad_width=((0,0), (num_pad_pix, num_pad_pix), (num_pad_pix, num_pad_pix)),\n", + " mode='constant', constant_values=pad_value)\n", + " padded_img_h, padded_img_w = channel_matrix.shape[1:]\n", + " num_edge_tiles = int(np.sqrt(channel_matrix.shape[0]))\n", + " tiles = channel_matrix.reshape(num_edge_tiles, num_edge_tiles, padded_img_h, padded_img_w)\n", + " tiles = tiles.swapaxes(1, 2)\n", + " matrices.append(tiles.reshape(num_edge_tiles * padded_img_h, num_edge_tiles * padded_img_w))\n", + " padded_matrix = np.stack(matrices, axis=0) # channel dim first\n", + " return padded_matrix\n", + " \n", + "def plot_matrix(matrix, title='', cmap=None):\n", + " fig, ax = plot.subplots(figsize=(10,10))\n", + " ax = pf.clear_axis(ax)\n", + " ax.imshow(matrix, cmap=cmap)#, vmin=0.0, vmax=1.0)#, cmap='greys_r')\n", + " ax.format(title=title)\n", + " plot.show()\n", + " return fig\n", + "\n", + "pad_value = 0.5\n", + "num_pad_pix = 1\n", + "padded_matrix = pad_matrix_to_image(lca_weights, num_pad_pix, pad_value, normalize=True)\n", + "fig = plot_matrix(np.transpose(padded_matrix, axes=[1, 2, 0]), title=f'{model.params.model_name} weights')\n", + "fig.savefig(\n", + " f'{model.params.disp_dir}/weights_plot_matrix.png',\n", + " transparent=False,\n", + " bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def rgb_to_gray(rgb):\n", + " num, chan, height, width = rgb.shape\n", + " gray = np.zeros((num, 1, height, width))\n", + " for neuron_idx in range(num):\n", + " gray[neuron_idx, ...] = 0.2125 * rgb[neuron_idx, 0, ...]\n", + " gray[neuron_idx, ...] += 0.7154 * rgb[neuron_idx, 1, ...]\n", + " gray[neuron_idx, ...] += 0.0721 * rgb[neuron_idx, 2, ...]\n", + " return gray\n", + "\n", + "gray_lca_weights = rgb_to_gray(lca_weights)\n", + "pad_value = 0.5\n", + "num_pad_pix = 1\n", + "padded_matrix = pad_matrix_to_image(gray_lca_weights, num_pad_pix, pad_value, normalize=True)\n", + "fig = plot_matrix(np.squeeze(np.transpose(padded_matrix, axes=[1, 2, 0])), title=f'{model.params.model_name} weights', cmap='grays_r')\n", + "fig.savefig(\n", + " f'{model.params.disp_dir}/weights_grayscale_plot_matrix.png',\n", + " transparent=False,\n", + " bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def generate_gaussian(shape, mean, cov):\n", + " \"\"\"\n", + " Generate a Gaussian PDF from given mean & cov\n", + " Inputs:\n", + " shape: [tuple] specifying (num_rows, num_cols)\n", + " mean: [np.ndarray] of shape (2,) specifying the 2-D Gaussian center\n", + " cov: [np.ndarray] of shape (2,2) specifying the 2-D Gaussian covariance matrix\n", + " Outputs:\n", + " tuple containing (Gaussian PDF, grid_points used to generate PDF)\n", + " grid_points are specified as a tuple of (y,x) points\n", + " \"\"\"\n", + " (y_size, x_size) = shape\n", + " y = np.linspace(0, y_size, np.int32(np.floor(y_size)))\n", + " x = np.linspace(0, x_size, np.int32(np.floor(x_size)))\n", + " y, x = np.meshgrid(y, x)\n", + " pos = np.empty(x.shape + (2,)) #x.shape == y.shape\n", + " pos[:, :, 0] = y; pos[:, :, 1] = x\n", + " gauss = scipy.stats.multivariate_normal(mean, cov)\n", + " return (gauss.pdf(pos), (y,x))\n", + "\n", + "\n", + "def gaussian_fit(pyx):\n", + " \"\"\"\n", + " Compute the expected mean & covariance matrix for a 2-D gaussian fit of input distribution\n", + " Inputs:\n", + " pyx: [np.ndarray] of shape [num_rows, num_cols] that indicates the probability function to fit\n", + " Outputs:\n", + " mean: [np.ndarray] of shape (2,) specifying the 2-D Gaussian center\n", + " cov: [np.ndarray] of shape (2,2) specifying the 2-D Gaussian covariance matrix\n", + " \"\"\"\n", + " assert pyx.ndim == 2, (\n", + " \"Input must have 2 dimensions specifying [num_rows, num_cols]\")\n", + " mean = np.zeros((1,2), dtype=np.float32) # [mu_y, mu_x]\n", + " for idx in np.ndindex(pyx.shape): # [y, x] ticks columns (x) first, then rows (y)\n", + " mean += np.asarray([pyx[idx]*idx[0], pyx[idx]*idx[1]])[None,:]\n", + " cov = np.zeros((2,2), dtype=np.float32)\n", + " for idx in np.ndindex(pyx.shape): # ticks columns first, then rows\n", + " cov += np.dot((idx-mean).T, (idx-mean))*pyx[idx] # typically an outer-product\n", + " return (np.squeeze(mean), cov)\n", + "\n", + "\n", + "def get_gauss_fit(prob_map, num_attempts=1, perc_mean=0.33):\n", + " \"\"\"\n", + " Returns a gaussian fit for a given probability map\n", + " Fitting is done via robust regression, where a fit is\n", + " continuously refined by deleting outliers num_attempts times\n", + " Inputs:\n", + " prob_map: 2-D probability map to be fit\n", + " num_attempts: Number of times to fit & remove outliers\n", + " perc_mean: All probability values below perc_mean*mean(gauss_fit) will be\n", + " considered outliers for repeated attempts\n", + " Outputs:\n", + " gauss_fit: [np.ndarray] specifying the 2-D Gaussian PDF\n", + " grid: [tuple] containing (y,x) points with which the Gaussian PDF can be plotted\n", + " gauss_mean: [np.ndarray] of shape (2,) specifying the 2-D Gaussian center\n", + " gauss_cov: [np.ndarray] of shape (2,2) specifying the 2-D Gaussian covariance matrix\n", + " \"\"\"\n", + " assert prob_map.ndim==2, (\n", + " \"get_gauss_fit: Input prob_map must have 2 dimension specifying [num_rows, num_cols\")\n", + " if num_attempts < 1:\n", + " num_attempts = 1\n", + " orig_prob_map = prob_map.copy()\n", + " gauss_success = False\n", + " while not gauss_success:\n", + " prob_map = orig_prob_map.copy()\n", + " try:\n", + " for i in range(num_attempts):\n", + " map_min = np.min(prob_map)\n", + " prob_map -= map_min\n", + " map_sum = np.sum(prob_map)\n", + " if map_sum != 1.0:\n", + " prob_map /= map_sum\n", + " gauss_mean, gauss_cov = gaussian_fit(prob_map)\n", + " gauss_fit, grid = generate_gaussian(prob_map.shape, gauss_mean, gauss_cov)\n", + " gauss_fit = (gauss_fit * map_sum) + map_min\n", + " if i < num_attempts-1:\n", + " gauss_mask = gauss_fit.copy().T\n", + " mask_slice = np.where(gauss_mask0)] = 1\n", + " prob_map *= gauss_mask\n", + " gauss_success = True\n", + " except np.linalg.LinAlgError: # Usually means cov matrix is singular\n", + " print(\"get_gauss_fit: Failed to fit Gaussian at attempt \",i,\", trying again.\"+\n", + " \"\\n To avoid this try decreasing perc_mean.\")\n", + " num_attempts = i-1\n", + " if num_attempts <= 0:\n", + " assert False, (\"get_gauss_fit: np.linalg.LinAlgError - Unable to fit gaussian.\")\n", + " return (gauss_fit, grid, gauss_mean, gauss_cov)\n", + "\n", + "\n", + "def hilbert_amplitude(weights, padding=None):\n", + " \"\"\"\n", + " Compute Hilbert amplitude envelope of weight matrix\n", + " Inputs:\n", + " weights: [np.ndarray] of shape [num_inputs, num_outputs]\n", + " num_inputs must have an even square root\n", + " padding: [int] specifying how much 0-padding to use for FFT\n", + " default is the closest power of 2 of sqrt(num_inputs)\n", + " Outputs:\n", + " env: [np.ndarray] of shape [num_outputs, num_inputs]\n", + " Hilbert envelope\n", + " bff_filt: [np.ndarray] of shape [num_outputs, padded_num_inputs]\n", + " Filtered Fourier transform of basis function\n", + " hil_filt: [np.ndarray] of shape [num_outputs, sqrt(num_inputs), sqrt(num_inputs)]\n", + " Hilbert filter to be applied in Fourier space\n", + " bffs: [np.ndarray] of shape [num_outputs, padded_num_inputs, padded_num_inputs]\n", + " Fourier transform of input weights\n", + " \"\"\"\n", + " cart2pol = lambda x,y: (np.arctan2(y,x), np.hypot(x, y))\n", + " num_inputs, num_outputs = weights.shape\n", + " assert np.sqrt(num_inputs) == np.floor(np.sqrt(num_inputs)), (\n", + " \"weights.shape[0] must have an even square root.\")\n", + " patch_edge_size = int(np.sqrt(num_inputs))\n", + " if padding is None or padding <= patch_edge_size:\n", + " # Amount of zero padding for fft2 (closest power of 2)\n", + " N = np.int(2**(np.ceil(np.log2(patch_edge_size))))\n", + " else:\n", + " N = np.int(padding)\n", + " # Analytic signal envelope for weights\n", + " # (Hilbet transform of each basis function)\n", + " env = np.zeros((num_outputs, num_inputs), dtype=complex)\n", + " # Fourier transform of weights\n", + " bffs = np.zeros((num_outputs, N, N), dtype=complex)\n", + " # Filtered Fourier transform of weights\n", + " bff_filt = np.zeros((num_outputs, N**2), dtype=complex)\n", + " # Hilbert filters\n", + " hil_filt = np.zeros((num_outputs, N, N))\n", + " # Grid for creating filter\n", + " f = (2/N) * np.pi * np.arange(-N/2.0, N/2.0)\n", + " (fx, fy) = np.meshgrid(f, f)\n", + " (theta, r) = cart2pol(fx, fy)\n", + " for neuron_idx in range(num_outputs):\n", + " # Grab single basis function, reshape to a square image\n", + " bf = weights[:, neuron_idx].reshape(patch_edge_size, patch_edge_size)\n", + " # Convert basis function into DC-centered Fourier domain\n", + " bff = np.fft.fftshift(np.fft.fft2(bf-np.mean(bf), [N, N]))\n", + " bffs[neuron_idx, ...] = bff\n", + " # Find indices of the peak amplitude\n", + " max_ys = np.abs(bff).argmax(axis=0) # Returns row index for each col\n", + " max_x = np.argmax(np.abs(bff).max(axis=0))\n", + " # Convert peak amplitude location into angle in freq domain\n", + " fx_ang = f[max_x]\n", + " fy_ang = f[max_ys[max_x]]\n", + " theta_max = np.arctan2(fy_ang, fx_ang)\n", + " # Define the half-plane with respect to the maximum\n", + " ang_diff = np.abs(theta-theta_max)\n", + " idx = (ang_diff>np.pi).nonzero()\n", + " ang_diff[idx] = 2.0 * np.pi - ang_diff[idx]\n", + " hil_filt[neuron_idx, ...] = (ang_diff < np.pi/2.0).astype(int)\n", + " # Create analytic signal from the inverse FT of the half-plane filtered bf\n", + " abf = np.fft.ifft2(np.fft.fftshift(hil_filt[neuron_idx, ...]*bff))\n", + " env[neuron_idx, ...] = abf[0:patch_edge_size, 0:patch_edge_size].reshape(num_inputs)\n", + " bff_filt[neuron_idx, ...] = (hil_filt[neuron_idx, ...]*bff).reshape(N**2)\n", + " return (env, bff_filt, hil_filt, bffs)\n", + "\n", + "\n", + "def get_dictionary_stats(weights, padding=None, num_gauss_fits=20, gauss_thresh=0.2):\n", + " \"\"\"\n", + " Compute summary statistics on dictionary elements using Hilbert amplitude envelope\n", + " Inputs:\n", + " weights: [np.ndarray] of shape [num_inputs, num_outputs]\n", + " padding: [int] total image size to pad out to in the FFT computation\n", + " num_gauss_fits: [int] total number of attempts to make when fitting the BFs\n", + " gauss_thresh: All probability values below gauss_thresh*mean(gauss_fit) will be\n", + " considered outliers for repeated fits\n", + " Outputs:\n", + " The function output is a dictionary containing the keys for each type of analysis\n", + " Each key dereferences a list of len num_outputs (i.e. one entry for each weight vector)\n", + " The keys and their list entries are as follows:\n", + " basis_functions: [np.ndarray] of shape [patch_edge_size, patch_edge_size]\n", + " envelopes: [np.ndarray] of shape [N, N], where N is the amount of padding\n", + " for the hilbert_amplitude function\n", + " envelope_centers: [tuples of ints] indicating the (y, x) position of the\n", + " center of the Hilbert envelope\n", + " gauss_fits: [list of np.ndarrays] containing (gaussian_fit, grid) where gaussian_fit\n", + " is returned from get_gauss_fit and specifies the 2D Gaussian PDF fit to the Hilbert\n", + " envelope and grid is a tuple containing (y,x) points with which the Gaussian PDF\n", + " can be plotted\n", + " gauss_centers: [list of ints] containing the (y,x) position of the center of\n", + " the Gaussian fit\n", + " gauss_orientations: [list of np.ndarrays] containing the (eigenvalues, eigenvectors) of\n", + " the covariance matrix for the Gaussian fit of the Hilbert amplitude envelope. They are\n", + " both sorted according to the highest to lowest Eigenvalue.\n", + " fourier_centers: [list of ints] containing the (y,x) position of the center (max) of\n", + " the Fourier amplitude map\n", + " num_inputs: [int] dim[0] of input weights\n", + " num_outputs: [int] dim[1] of input weights\n", + " patch_edge_size: [int] int(floor(sqrt(num_inputs)))\n", + " areas: [list of floats] area of enclosed ellipse\n", + " spatial_frequncies: [list of floats] dominant spatial frequency for basis function\n", + " \"\"\"\n", + " envelope, bff_filt, hil_filter, bffs = hilbert_amplitude(weights, padding)\n", + " num_inputs, num_outputs = weights.shape\n", + " patch_edge_size = np.int(np.floor(np.sqrt(num_inputs)))\n", + " basis_funcs = [None]*num_outputs\n", + " envelopes = [None]*num_outputs\n", + " gauss_fits = [None]*num_outputs\n", + " gauss_centers = [None]*num_outputs\n", + " diameters = [None]*num_outputs\n", + " gauss_orientations = [None]*num_outputs\n", + " envelope_centers = [None]*num_outputs\n", + " fourier_centers = [None]*num_outputs\n", + " ellipse_orientations = [None]*num_outputs\n", + " fourier_maps = [None]*num_outputs\n", + " spatial_frequencies = [None]*num_outputs\n", + " areas = [None]*num_outputs\n", + " phases = [None]*num_outputs\n", + " for bf_idx in range(num_outputs):\n", + " # Reformatted individual basis function\n", + " basis_funcs[bf_idx] = weights.T[bf_idx,...].reshape((patch_edge_size, patch_edge_size))\n", + " # Reformatted individual envelope filter\n", + " envelopes[bf_idx] = np.abs(envelope[bf_idx,...]).reshape((patch_edge_size, patch_edge_size))\n", + " # Basis function center\n", + " max_ys = envelopes[bf_idx].argmax(axis=0) # Returns row index for each col\n", + " max_x = np.argmax(envelopes[bf_idx].max(axis=0))\n", + " y_cen = max_ys[max_x]\n", + " x_cen = max_x\n", + " envelope_centers[bf_idx] = (y_cen, x_cen)\n", + " # Gaussian fit to Hilbet amplitude envelope\n", + " gauss_fit, grid, gauss_mean, gauss_cov = get_gauss_fit(envelopes[bf_idx],\n", + " num_gauss_fits, gauss_thresh)\n", + " gauss_fits[bf_idx] = (gauss_fit, grid)\n", + " gauss_centers[bf_idx] = gauss_mean\n", + " evals, evecs = np.linalg.eigh(gauss_cov)\n", + " sort_indices = np.argsort(evals)[::-1]\n", + " gauss_orientations[bf_idx] = (evals[sort_indices], evecs[:,sort_indices])\n", + " width, height = evals[sort_indices] # Width & height are relative to orientation\n", + " diameters[bf_idx] = np.sqrt(width**2+height**2)\n", + " # Fourier function center, spatial frequency, orientation\n", + " fourier_map = np.sqrt(np.real(bffs[bf_idx, ...])**2+np.imag(bffs[bf_idx, ...])**2)\n", + " fourier_maps[bf_idx] = fourier_map\n", + " N = fourier_map.shape[0]\n", + " center_freq = int(np.floor(N/2))\n", + " fourier_map[center_freq, center_freq] = 0 # remove DC component\n", + " max_fys = fourier_map.argmax(axis=0)\n", + " max_fx = np.argmax(fourier_map.max(axis=0))\n", + " fy_cen = (max_fys[max_fx] - (N/2)) * (patch_edge_size/N)\n", + " fx_cen = (max_fx - (N/2)) * (patch_edge_size/N)\n", + " fourier_centers[bf_idx] = [fy_cen, fx_cen]\n", + " # NOTE: we flip fourier_centers because fx_cen is the peak of the x frequency,\n", + " # which would be a y coordinate\n", + " ellipse_orientations[bf_idx] = np.arctan2(*fourier_centers[bf_idx][::-1])\n", + " spatial_frequencies[bf_idx] = np.sqrt(fy_cen**2 + fx_cen**2)\n", + " areas[bf_idx] = np.pi * np.prod(evals)\n", + " phases[bf_idx] = np.angle(bffs[bf_idx])[y_cen, x_cen]\n", + " output = {\"basis_functions\":basis_funcs, \"envelopes\":envelopes, \"gauss_fits\":gauss_fits,\n", + " \"gauss_centers\":gauss_centers, \"gauss_orientations\":gauss_orientations, \"areas\":areas,\n", + " \"fourier_centers\":fourier_centers, \"fourier_maps\":fourier_maps, \"num_inputs\":num_inputs,\n", + " \"spatial_frequencies\":spatial_frequencies, \"envelope_centers\":envelope_centers,\n", + " \"num_outputs\":num_outputs, \"patch_edge_size\":patch_edge_size, \"phases\":phases,\n", + " \"ellipse_orientations\":ellipse_orientations, \"diameters\":diameters}\n", + " return output" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "bf_stats = get_dictionary_stats(\n", + " gray_lca_weights.reshape(gray_lca_weights.shape[0], -1).T,\n", + " padding=32,\n", + " num_gauss_fits=20,\n", + " gauss_thresh=0.2)\n", + "\n", + "np.savez(\n", + " model.params.save_dir+'bf_summary_stats.npz',\n", + " data={'bf_stats':bf_stats})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def clear_axis(ax, spines=\"none\"):\n", + " for ax_loc in [\"top\", \"bottom\", \"left\", \"right\"]:\n", + " ax.spines[ax_loc].set_color(spines)\n", + " ax.set_yticklabels([])\n", + " ax.set_xticklabels([])\n", + " ax.get_xaxis().set_visible(False)\n", + " ax.get_yaxis().set_visible(False)\n", + " ax.tick_params(axis=\"both\", bottom=False, top=False, left=False, right=False)\n", + " return ax\n", + "\n", + "def plot_ellipse(axis, center, shape, angle, color_val=\"auto\", alpha=1.0, lines=False,\n", + " fill_ellipse=False):\n", + " \"\"\"\n", + " Add an ellipse to given axis\n", + " Inputs:\n", + " axis [matplotlib.axes._subplots.AxesSubplot] axis on which ellipse should be drawn\n", + " center [tuple or list] specifying [y, x] center coordinates\n", + " shape [tuple or list] specifying [width, height] shape of ellipse\n", + " angle [float] specifying angle of ellipse\n", + " color_val [matplotlib color spec] specifying the color of the edge & face of the ellipse\n", + " alpha [float] specifying the transparency of the ellipse\n", + " lines [bool] if true, output will be a line, where the secondary axis of the ellipse\n", + " is collapsed\n", + " fill_ellipse [bool] if true and lines is false then a filled ellipse will be plotted\n", + " Outputs:\n", + " ellipse [matplotlib.patches.ellipse] ellipse object\n", + " \"\"\"\n", + " if fill_ellipse:\n", + " face_color_val = \"none\" if color_val==\"auto\" else color_val\n", + " else:\n", + " face_color_val = \"none\"\n", + " y_cen, x_cen = center\n", + " width, height = shape\n", + " if lines:\n", + " min_length = 0.1\n", + " if width < height:\n", + " width = min_length\n", + " elif width > height:\n", + " height = min_length\n", + " ellipse = matplotlib.patches.Ellipse(xy=[x_cen, y_cen], width=width,\n", + " height=height, angle=angle, edgecolor=color_val, facecolor=face_color_val,\n", + " alpha=alpha, fill=True)\n", + " axis.add_artist(ellipse)\n", + " ellipse.set_clip_box(axis.bbox)\n", + " return ellipse\n", + "\n", + "def plot_ellipse_summaries(bf_stats, num_bf=-1, lines=False, rand_bf=False):\n", + " \"\"\"\n", + " Plot basis functions with summary ellipses drawn over them\n", + " Inputs:\n", + " bf_stats [dict] output of dp.get_dictionary_stats()\n", + " num_bf [int] number of basis functions to plot (<=0 is all; >total is all)\n", + " lines [bool] If true, will plot lines instead of ellipses\n", + " rand_bf [bool] If true, will choose a random set of basis functions\n", + " \"\"\"\n", + " tot_num_bf = len(bf_stats[\"basis_functions\"])\n", + " if num_bf <= 0 or num_bf > tot_num_bf:\n", + " num_bf = tot_num_bf\n", + " SFs = np.asarray([np.sqrt(fcent[0]**2 + fcent[1]**2)\n", + " for fcent in bf_stats[\"fourier_centers\"]], dtype=np.float32)\n", + " sf_sort_indices = np.argsort(SFs)\n", + " if rand_bf:\n", + " bf_range = np.random.choice([i for i in range(tot_num_bf)], num_bf, replace=False)\n", + " num_plots_y = int(np.ceil(np.sqrt(num_bf)))\n", + " num_plots_x = int(np.ceil(np.sqrt(num_bf)))\n", + " gs = gridspec.GridSpec(num_plots_y, num_plots_x)\n", + " fig = plt.figure(figsize=(17,17))\n", + " filter_idx = 0\n", + " for plot_id in np.ndindex((num_plots_y, num_plots_x)):\n", + " ax = clear_axis(fig.add_subplot(gs[plot_id]))\n", + " if filter_idx < tot_num_bf and filter_idx < num_bf:\n", + " if rand_bf:\n", + " bf_idx = bf_range[filter_idx]\n", + " else:\n", + " bf_idx = filter_idx\n", + " bf = bf_stats[\"basis_functions\"][bf_idx]\n", + " ax.imshow(bf, interpolation=\"Nearest\", cmap=\"grays_r\")\n", + " ax.set_title(str(bf_idx), fontsize=\"8\")\n", + " center = bf_stats[\"gauss_centers\"][bf_idx]\n", + " evals, evecs = bf_stats[\"gauss_orientations\"][bf_idx]\n", + " orientations = bf_stats[\"fourier_centers\"][bf_idx]\n", + " angle = np.rad2deg(np.pi/2 + np.arctan2(*orientations))\n", + " alpha = 1.0\n", + " ellipse = plot_ellipse(ax, center, evals, angle, color_val=\"b\", alpha=alpha, lines=lines)\n", + " filter_idx += 1\n", + " ax.set_aspect(\"equal\")\n", + " plt.show()\n", + " return fig" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig = plot_ellipse_summaries(bf_stats, lines=False)\n", + "\n", + "fig.savefig(\n", + " f'{model.params.disp_dir}/basis_function_fits.png',\n", + " transparent=False,\n", + " bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def bgr_colormap():\n", + " \"\"\"\n", + " In cdict, the first column is interpolated between 0.0 & 1.0 - this indicates the value to be plotted\n", + " the second column specifies how interpolation should be done from below\n", + " the third column specifies how interpolation should be done from above\n", + " if the second column does not equal the third, then there will be a break in the colors\n", + " \"\"\"\n", + " darkness = 0.85 #0 is black, 1 is white\n", + " cdict = {\n", + " 'red': ((0.0, 0.0, 0.0),\n", + " (0.5, darkness, darkness),\n", + " (1.0, 1.0, 1.0)),\n", + " 'green': ((0.0, 0.0, 0.0),\n", + " (0.5, darkness, darkness),\n", + " (1.0, 0.0, 0.0)),\n", + " 'blue': ((0.0, 1.0, 1.0),\n", + " (0.5, darkness, darkness),\n", + " (1.0, 0.0, 0.0))\n", + " }\n", + " return LinearSegmentedColormap(\"bgr\", cdict)\n", + "\n", + "def plot_pooling_centers(bf_stats, pooling_filters, num_pooling_filters, num_connected_weights,\n", + " spot_size=10, figsize=None):\n", + " \"\"\"\n", + " Plot 2nd layer (fully-connected) weights in terms of spatial/frequency centers of\n", + " 1st layer weights\n", + " Inputs:\n", + " bf_stats [dict] Output of dp.get_dictionary_stats() which was run on the 1st layer weights\n", + " pooling_filters [np.ndarray] 2nd layer weights\n", + " should be shape [num_1st_layer_neurons, num_2nd_layer_neurons]\n", + " num_pooling_filters [int] How many 2nd layer neurons to plot\n", + " figsize [tuple] Containing the (width, height) of the figure, in inches\n", + " spot_size [int] How big to make the points\n", + " \"\"\"\n", + " num_filters_y = int(np.ceil(np.sqrt(num_pooling_filters)))\n", + " num_filters_x = int(np.ceil(np.sqrt(num_pooling_filters)))\n", + " tot_pooling_filters = pooling_filters.shape[1]\n", + " #filter_indices = np.random.choice(tot_pooling_filters, num_pooling_filters, replace=False)\n", + " filter_indices = np.arange(tot_pooling_filters, dtype=np.int32)\n", + " cmap = plt.get_cmap(bgr_colormap())# Could also use \"nipy_spectral\", \"coolwarm\", \"bwr\"\n", + " cNorm = matplotlib.colors.SymLogNorm(linthresh=0.03, linscale=0.01, vmin=-1.0, vmax=1.0)\n", + " scalarMap = matplotlib.cm.ScalarMappable(norm=cNorm, cmap=cmap)\n", + " x_p_cent = [x for (y,x) in bf_stats[\"gauss_centers\"]]# Get raw points\n", + " y_p_cent = [y for (y,x) in bf_stats[\"gauss_centers\"]]\n", + " x_f_cent = [x for (y,x) in bf_stats[\"fourier_centers\"]]\n", + " y_f_cent = [y for (y,x) in bf_stats[\"fourier_centers\"]]\n", + " max_sf = np.max(np.abs(x_f_cent+y_f_cent))\n", + " pair_w_gap = 0.01\n", + " group_w_gap = 0.03\n", + " h_gap = 0.03\n", + " plt_w = (num_filters_x/num_pooling_filters)\n", + " plt_h = plt_w\n", + " if figsize is None:\n", + " fig = plt.figure()\n", + " figsize = (fig.get_figwidth(), fig.get_figheight())\n", + " else:\n", + " fig = plt.figure(figsize=figsize) #figsize is (w,h)\n", + " axes = []\n", + " filter_id = 0\n", + " for plot_id in np.ndindex((num_filters_y, num_filters_x)):\n", + " if all(pid == 0 for pid in plot_id):\n", + " axes.append(clear_axis(fig.add_axes([0, plt_h+h_gap, 2*plt_w, plt_h])))\n", + " scalarMap._A = []\n", + " cbar = fig.colorbar(scalarMap, ax=axes[-1], ticks=[-1, 0, 1], aspect=10, location=\"bottom\")\n", + " cbar.ax.set_xticklabels([\"-1\", \"0\", \"1\"])\n", + " cbar.ax.xaxis.set_ticks_position('top')\n", + " cbar.ax.xaxis.set_label_position('top')\n", + " for label in cbar.ax.xaxis.get_ticklabels():\n", + " label.set_weight(\"bold\")\n", + " label.set_fontsize(10+figsize[0])\n", + " if (filter_id < num_pooling_filters):\n", + " example_filter = pooling_filters[:, filter_indices[filter_id]]\n", + " top_indices = np.argsort(np.abs(example_filter))[::-1] #descending\n", + " selected_indices = top_indices[:num_connected_weights][::-1] #select top, plot weakest first\n", + " filter_norm = np.max(np.abs(example_filter))\n", + " connection_colors = [scalarMap.to_rgba(example_filter[bf_idx]/filter_norm)\n", + " for bf_idx in range(bf_stats[\"num_outputs\"])]\n", + " if num_connected_weights < top_indices.size:\n", + " black_indices = top_indices[num_connected_weights:][::-1]\n", + " xp = [x_p_cent[i] for i in black_indices]+[x_p_cent[i] for i in selected_indices]\n", + " yp = [y_p_cent[i] for i in black_indices]+[y_p_cent[i] for i in selected_indices]\n", + " xf = [x_f_cent[i] for i in black_indices]+[x_f_cent[i] for i in selected_indices]\n", + " yf = [y_f_cent[i] for i in black_indices]+[y_f_cent[i] for i in selected_indices]\n", + " c = [(0.1,0.1,0.1,1.0) for i in black_indices]+[connection_colors[i] for i in selected_indices]\n", + " else:\n", + " xp = [x_p_cent[i] for i in selected_indices]\n", + " yp = [y_p_cent[i] for i in selected_indices]\n", + " xf = [x_f_cent[i] for i in selected_indices]\n", + " yf = [y_f_cent[i] for i in selected_indices]\n", + " c = [connection_colors[i] for i in selected_indices]\n", + " (y_id, x_id) = plot_id\n", + " if x_id == 0:\n", + " ax_l = 0\n", + " ax_b = - y_id * (plt_h+h_gap)\n", + " else:\n", + " bbox = axes[-1].get_position().get_points()[0]#bbox is [[x0,y0],[x1,y1]]\n", + " prev_l = bbox[0]\n", + " prev_b = bbox[1]\n", + " ax_l = prev_l + plt_w + group_w_gap\n", + " ax_b = prev_b\n", + " ax_w = plt_w\n", + " ax_h = plt_h\n", + " axes.append(clear_axis(fig.add_axes([ax_l, ax_b, ax_w, ax_h])))\n", + " axes[-1].invert_yaxis()\n", + " axes[-1].scatter(xp, yp, c=c, s=spot_size, alpha=0.8)\n", + " axes[-1].set_xlim(0, bf_stats[\"patch_edge_size\"]-1)\n", + " axes[-1].set_ylim(bf_stats[\"patch_edge_size\"]-1, 0)\n", + " axes[-1].set_aspect(\"equal\")\n", + " axes[-1].set_facecolor(\"w\")\n", + " axes.append(clear_axis(fig.add_axes([ax_l+ax_w+pair_w_gap, ax_b, ax_w, ax_h])))\n", + " axes[-1].scatter(xf, yf, c=c, s=spot_size, alpha=0.8)\n", + " axes[-1].set_xlim([-max_sf, max_sf])\n", + " axes[-1].set_ylim([-max_sf, max_sf])\n", + " axes[-1].xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True))\n", + " axes[-1].yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True))\n", + " axes[-1].set_aspect(\"equal\")\n", + " axes[-1].set_facecolor(\"w\")\n", + " filter_id += 1\n", + " plt.show()\n", + " return fig" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "kernel_pos = 0\n", + "pool_weights = model.pool_1.layer.weight.detach().cpu().numpy()\n", + "outputs, inputs, kernel_h, kernel_w = pool_weights.shape\n", + "\n", + "fig = plot_pooling_centers(\n", + " bf_stats,\n", + " pool_weights[:, :, kernel_pos, kernel_pos].T,\n", + " num_pooling_filters=outputs,\n", + " num_connected_weights=inputs,\n", + " spot_size=3,\n", + " figsize=(5, 5))\n", + "\n", + "fig.savefig(\n", + " f'{model.params.disp_dir}/pooling_spots.png',\n", + " transparent=False,\n", + " bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_pooling_summaries(bf_stats, pooling_filters, num_pooling_filters,\n", + " num_connected_weights, lines=False, figsize=None):\n", + " \"\"\"\n", + " Plot 2nd layer (fully-connected) weights in terms of connection strengths to 1st layer weights\n", + " Inputs:\n", + " bf_stats [dict] output of dp.get_dictionary_stats() which was run on the 1st layer weights\n", + " pooling_filters [np.ndarray] 2nd layer weights\n", + " should be shape [num_1st_layer_neurons, num_2nd_layer_neurons]\n", + " num_pooling_filters [int] How many 2nd layer neurons to plot\n", + " num_connected_weights [int] How many 1st layer weight summaries to include\n", + " for a given 2nd layer neuron\n", + " lines [bool] if True, 1st layer weight summaries will appear as lines instead of ellipses\n", + " \"\"\"\n", + " num_inputs = bf_stats[\"num_inputs\"]\n", + " num_outputs = bf_stats[\"num_outputs\"]\n", + " tot_pooling_filters = pooling_filters.shape[1]\n", + " patch_edge_size = np.int32(np.sqrt(num_inputs))\n", + " filter_idx_list = np.arange(num_pooling_filters, dtype=np.int32)\n", + " assert num_pooling_filters <= num_outputs, (\n", + " \"num_pooling_filters must be less than or equal to bf_stats['num_outputs']\")\n", + " cmap = bgr_colormap()#plt.get_cmap('bwr')\n", + " cNorm = matplotlib.colors.SymLogNorm(linthresh=0.03, linscale=0.01, vmin=-1.0, vmax=1.0)\n", + " scalarMap = matplotlib.cm.ScalarMappable(norm=cNorm, cmap=cmap)\n", + " num_plots_y = np.int32(np.ceil(np.sqrt(num_pooling_filters)))\n", + " num_plots_x = np.int32(np.ceil(np.sqrt(num_pooling_filters)))+1 # +cbar col\n", + " gs_widths = [1 for _ in range(num_plots_x-1)]+[0.3]\n", + " gs = gridspec.GridSpec(num_plots_y, num_plots_x, width_ratios=gs_widths)\n", + " if figsize is None:\n", + " fig = plt.figure()\n", + " figsize = (fig.get_figwidth(), fig.get_figheight())\n", + " else:\n", + " fig = plt.figure(figsize=figsize)\n", + " filter_total = 0\n", + " for plot_id in np.ndindex((num_plots_y, num_plots_x-1)):\n", + " (y_id, x_id) = plot_id\n", + " ax = fig.add_subplot(gs[plot_id])\n", + " if (filter_total < num_pooling_filters and x_id != num_plots_x-1):\n", + " ax = clear_axis(ax, spines=\"k\")\n", + " filter_idx = filter_idx_list[filter_total]\n", + " example_filter = pooling_filters[:, filter_idx]\n", + " top_indices = np.argsort(np.abs(example_filter))[::-1] #descending\n", + " filter_norm = np.max(np.abs(example_filter))\n", + " SFs = np.asarray([np.sqrt(fcent[0]**2 + fcent[1]**2)\n", + " for fcent in bf_stats[\"fourier_centers\"]], dtype=np.float32)\n", + " # Plot weakest of the top connected filters first because of occlusion\n", + " for bf_idx in top_indices[:num_connected_weights][::-1]:\n", + " connection_strength = example_filter[bf_idx]/filter_norm\n", + " color_val = scalarMap.to_rgba(connection_strength)\n", + " center = bf_stats[\"gauss_centers\"][bf_idx]\n", + " evals, evecs = bf_stats[\"gauss_orientations\"][bf_idx]\n", + " orientations = bf_stats[\"fourier_centers\"][bf_idx]\n", + " angle = np.rad2deg(np.pi/2 + np.arctan2(*orientations))\n", + " alpha = 0.5#todo:spatial_freq for filled ellipses?\n", + " ellipse = plot_ellipse(ax, center, evals, angle, color_val, alpha=alpha, lines=lines)\n", + " ax.set_xlim(0, patch_edge_size-1)\n", + " ax.set_ylim(patch_edge_size-1, 0)\n", + " filter_total += 1\n", + " else:\n", + " ax = clear_axis(ax, spines=\"none\")\n", + " ax.set_aspect(\"equal\")\n", + " scalarMap._A = []\n", + " ax = clear_axis(fig.add_subplot(gs[0, -1]))\n", + " cbar = fig.colorbar(scalarMap, ax=ax, ticks=[-1, 0, 1])\n", + " cbar.ax.set_yticklabels([\"-1\", \"0\", \"1\"])\n", + " for label in cbar.ax.yaxis.get_ticklabels():\n", + " label.set_weight(\"bold\")\n", + " label.set_fontsize(14)\n", + " plt.show()\n", + " return fig" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig = plot_pooling_summaries(\n", + " bf_stats,\n", + " pool_weights[:, :, kernel_pos, kernel_pos].T,\n", + " num_pooling_filters=outputs,\n", + " num_connected_weights=40,\n", + " lines=True,\n", + " figsize=(18,18))\n", + "\n", + "fig.savefig(\n", + " f'{model.params.disp_dir}/pooling_lines.png',\n", + " transparent=False,\n", + " bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "P = pool_weights[:, :, kernel_pos, kernel_pos] # [inputs, outputs]\n", + "p_norm = np.linalg.norm(P, ord=2, axis=0)\n", + "affinity = np.dot(P.T, P) # cosyne similarity of neurons in embedded space\n", + "for i in range(affinity.shape[0]):\n", + " for j in range(affinity.shape[1]):\n", + " affinity[i, j] = affinity[i, j] / (p_norm[i] * p_norm[j])\n", + "affinity = affinity.T # [inputs, inputs]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig = plot_pooling_centers(\n", + " bf_stats,\n", + " affinity,\n", + " num_pooling_filters=outputs,\n", + " num_connected_weights=128, \n", + " spot_size=30,\n", + " figsize=(5, 5))\n", + "\n", + "fig.savefig(\n", + " f'{model.params.disp_dir}/affinity_spots.png',\n", + " transparent=False,\n", + " bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig = plot_pooling_summaries(\n", + " bf_stats,\n", + " affinity,\n", + " num_pooling_filters=outputs,\n", + " num_connected_weights=15,\n", + " lines=True,\n", + " figsize=(10, 10))\n", + "\n", + "fig.savefig(\n", + " f'{model.params.disp_dir}/affinity_lines.png',\n", + " transparent=False,\n", + " bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "params.standardize_data = False\n", + "train_loader, val_loader, test_loader, data_stats, data_mean_std = dataset_utils.load_dataset(params)\n", + "train_mean_image = 0#data_mean_std['dataset_mean_image'].to(model.params.device)\n", + "train_std_image = 1#data_mean_std['dataset_std_image'].to(model.params.device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "example_batch = next(iter(train_loader))[0].to(model.params.device)\n", + "example_batch = model[0].preprocess_data(example_batch)\n", + "#example_batch *= train_std_image\n", + "#example_batch += train_mean_image\n", + "batch_min = example_batch.min().item()\n", + "batch_max = example_batch.max().item()\n", + "\n", + "example_image = example_batch[0, ...]\n", + "print(\n", + " f'example image min = {example_image.min().item()}'+\n", + " f'\\nexample image mean = {example_image.mean().item()}'+\n", + " f'\\nexample image max = {example_image.max().item()}'+\n", + " f'\\nexample image std = {example_image.std().item()}')\n", + "preproc_image, example_image_mean, example_image_std = dp.standardize(example_image[None, ...], samplewise=True)\n", + "print(\n", + " f'preproc image min = {preproc_image.min().item()}'+\n", + " f'\\npreproc image mean = {preproc_image.mean().item()}'+\n", + " f'\\npreproc image max = {preproc_image.max().item()}'+\n", + " f'\\npreproc image std = {preproc_image.std().item()}')\n", + "\n", + "plot_example_image = ((example_image * train_std_image) + train_mean_image).cpu().numpy().transpose(1,2,0)\n", + "plot_preproc_image = ((preproc_image - preproc_image.min())/(preproc_image.max() - preproc_image.min()))[0,...].cpu().numpy().transpose(1,2,0)\n", + "\n", + "fig, axs = plot.subplots(nrows=1, ncols=2)\n", + "ax = pf.clear_axis(axs[0])\n", + "ax.imshow(plot_example_image, vmin=0, vmax=1)\n", + "ax.format(title='Original')\n", + "ax = pf.clear_axis(axs[1])\n", + "ax.format(title='Preprocessed')\n", + "ax.imshow(plot_preproc_image, vmin=0, vmax=1)\n", + "plot.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "alpha_1 = model.lca_1.get_encodings(preproc_image)\n", + "beta_1 = model.pool_1.get_encodings(alpha_1)\n", + "alpha_2 = model.lca_2.get_encodings(beta_1)\n", + "beta_2 = model(preproc_image)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class lca_2_recon_params(LcaParams):\n", + " def set_params(self):\n", + " super(lca_2_recon_params, self).set_params()\n", + " self.model_type = 'lca'\n", + " self.model_name = 'lca_2_recon'\n", + " self.version = '0'\n", + " self.layer_types = ['fc']\n", + " self.standardize_data = False\n", + " self.rescale_data_to_one = False\n", + " self.center_dataset = False\n", + " self.batch_size = 1\n", + " self.dt = 0.001\n", + " self.tau = 0.2\n", + " self.num_steps = 75\n", + " self.rectify_a = True\n", + " self.thresh_type = 'hard'\n", + " self.compute_helper_params()\n", + " \n", + "params = lca_2_recon_params()\n", + "params.set_params()\n", + "params.layer_channels = list(model.pool_2.weight.shape)\n", + "params.sparse_mult = 0.0#model.lca_2.params.sparse_mult\n", + "params.data_shape = list(beta_2.shape)\n", + "params.epoch_size = 1\n", + "params.num_pixels = np.prod(params.data_shape)\n", + "\n", + "lca_2_recon_model = loaders.load_model(params.model_type)\n", + "lca_2_recon_model.setup(params)\n", + "lca_2_recon_model.to(params.device)\n", + "lca_2_recon_model.eval()\n", + "with torch.no_grad():\n", + " lca_2_recon_model.weight = nn.Parameter(model.pool_2.weight.T)\n", + "a2h_b2 = lca_2_recon_model(beta_2)[:, :, None, None]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "alpha_2_bin = alpha_2.detach().cpu().numpy()>0\n", + "alpha_2_hat_bin = a2h_b2.detach().cpu().numpy()>0\n", + "\n", + "alpha_2_nnz = np.count_nonzero(alpha_2_bin)\n", + "alpha_2_hat_nnz = np.count_nonzero(alpha_2_hat_bin)\n", + "\n", + "hamming_dist = np.abs(np.sum(alpha_2_bin * alpha_2_hat_bin))\n", + "\n", + "num_alpha_2 = np.prod(list(a2h_b2.shape))\n", + "alpha_2_edge = int(np.sqrt(num_alpha_2))\n", + "\n", + "plot_alpha_2_hat = ((a2h_b2 - a2h_b2.min())/(a2h_b2.max() - a2h_b2.min())).reshape(alpha_2_edge, alpha_2_edge).detach().cpu().numpy()\n", + "\n", + "plot_alpha_2 = ((alpha_2 - alpha_2.min()) / (alpha_2.max() - alpha_2.min())).reshape(alpha_2_edge, alpha_2_edge).detach().cpu().numpy()\n", + "\n", + "alpha_diff = np.abs(plot_alpha_2 - plot_alpha_2_hat)\n", + "\n", + "fig, axs = plot.subplots(nrows=1, ncols=3)\n", + "\n", + "ax = pf.clear_axis(axs[0])\n", + "ax.imshow(alpha_diff, vmin=0, vmax=1)\n", + "ax.format(title='alpha 2 differences')\n", + "\n", + "ax = pf.clear_axis(axs[1])\n", + "ax.imshow(plot_alpha_2_hat, vmin=0, vmax=1)\n", + "ax.format(title='alpha 2 hat')\n", + "\n", + "ax = pf.clear_axis(axs[2])\n", + "m = ax.imshow(plot_alpha_2, vmin=0, vmax=1)\n", + "ax.format(title='alpha 2')\n", + "\n", + "ax.colorbar(m, ax=ax)\n", + "axs.format(suptitle=f'active index overlap = {hamming_dist}; alpha 2 nnz = {alpha_2_nnz}; alpha 2 hat nnz = {alpha_2_hat_nnz}')\n", + "plot.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with torch.no_grad():\n", + " b1h_a2h_b2 = F.conv_transpose2d(\n", + " input=a2h_b2,\n", + " weight=model.lca_2.weight,\n", + " bias=None,\n", + " stride=model.lca_2.params.stride,\n", + " padding=model.lca_2.params.padding)\n", + " \n", + " b1h_a2 = F.conv_transpose2d(\n", + " input=alpha_2,\n", + " weight=model.lca_2.weight,\n", + " bias=None,\n", + " stride=model.lca_2.params.stride,\n", + " padding=model.lca_2.params.padding)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "squared_error = torch.pow(beta_1 - b1h_a2h_b2, 2.)\n", + "l2_dist = 0.5 * torch.mean(squared_error).detach().cpu().numpy()\n", + "\n", + "num_beta_1 = np.prod(list(b1h_a2h_b2.shape))\n", + "beta_1_edge = int(np.floor(np.sqrt(num_beta_1)))\n", + "beta_1_resh = int(beta_1_edge**2)\n", + "\n", + "plot_beta_1_hat = ((b1h_a2h_b2 - b1h_a2h_b2.min())/(b1h_a2h_b2.max() - b1h_a2h_b2.min())).view(-1)[:beta_1_resh].reshape(beta_1_edge, beta_1_edge).detach().cpu().numpy()\n", + "\n", + "plot_beta_1 = ((beta_1 - beta_1.min()) / (beta_1.max() - beta_1.min())).view(-1)[:beta_1_resh].reshape(beta_1_edge, beta_1_edge).detach().cpu().numpy()\n", + "\n", + "beta_diff = np.abs(plot_beta_1 - plot_beta_1_hat)\n", + "\n", + "fig, axs = plot.subplots(nrows=1, ncols=3)\n", + "\n", + "ax = pf.clear_axis(axs[0])\n", + "ax.imshow(beta_diff, vmin=0, vmax=1)\n", + "ax.format(title='beta 1 differences')\n", + "\n", + "ax = pf.clear_axis(axs[1])\n", + "ax.imshow(plot_beta_1_hat, vmin=0, vmax=1)\n", + "ax.format(title='beta 1 hat')\n", + "\n", + "ax = pf.clear_axis(axs[2])\n", + "m = ax.imshow(plot_beta_1, vmin=0, vmax=1)\n", + "ax.format(title='beta 1')\n", + "\n", + "ax.colorbar(m, ax=ax)\n", + "axs.format(suptitle=f'l2 distance = {l2_dist:0.5f}')\n", + "plot.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from DeepSparseCoding.modules.lca_module import LcaModule\n", + "from DeepSparseCoding.models.base_model import BaseModel\n", + "from DeepSparseCoding.utils.run_utils import compute_deconv_output_shape\n", + "import DeepSparseCoding.modules.losses as losses\n", + "\n", + "class TransposedLcaModule(LcaModule):\n", + " def setup_module(self, params):\n", + " super(TransposedLcaModule, self).setup_module(params)\n", + " if self.params.layer_types[0] == 'conv':\n", + " assert (self.params.data_shape[-1] % self.params.stride == 0), (\n", + " f'Stride = {self.params.stride} must divide evenly into input edge size = {self.params.data_shape[-1]}')\n", + " self.w_shape = [\n", + " self.params.layer_channels[1],\n", + " self.params.layer_channels[0],\n", + " self.params.kernel_size,\n", + " self.params.kernel_size\n", + " ]\n", + " output_height = compute_deconv_output_shape(\n", + " self.params.data_shape[1],\n", + " self.params.kernel_size,\n", + " self.params.stride,\n", + " self.params.padding,\n", + " output_padding=self.params.output_padding,\n", + " dilation=1)\n", + " output_width = compute_deconv_output_shape(\n", + " self.params.data_shape[2],\n", + " self.params.kernel_size,\n", + " self.params.stride,\n", + " self.params.padding,\n", + " output_padding=self.params.output_padding,\n", + " dilation=1)\n", + " self.layer_output_shape = [self.params.layer_channels[1], output_height, output_width]\n", + " w_init = torch.randn(self.w_shape)\n", + " w_init_normed = dp.l2_normalize_weights(w_init, eps=self.params.eps)\n", + " self.weight = nn.Parameter(w_init_normed, requires_grad=True)\n", + "\n", + " def compute_excitatory_current(self, input_tensor, a_in):\n", + " if self.params.layer_types[0] == 'fc':\n", + " excitatory_current = torch.matmul(input_tensor, self.weight.T)\n", + " else:\n", + " recon = self.get_recon_from_latents(a_in)\n", + " recon_error = input_tensor - recon\n", + " error_injection = F.conv_transpose2d(\n", + " input=recon_error,\n", + " weight=self.weight,\n", + " bias=None,\n", + " stride=self.params.stride,\n", + " padding=self.params.padding,\n", + " output_padding=self.params.output_padding,\n", + " dilation=1\n", + " )\n", + " excitatory_current = error_injection + a_in\n", + " return excitatory_current\n", + "\n", + " def get_recon_from_latents(self, a_in):\n", + " if self.params.layer_types[0] == 'fc':\n", + " recon = torch.matmul(a_in, self.weight)\n", + " else:\n", + " recon = F.conv2d(\n", + " input=a_in,\n", + " weight=self.weight,\n", + " bias=None,\n", + " stride=self.params.stride,\n", + " padding=self.params.padding,\n", + " dilation=1\n", + " )\n", + " return recon\n", + "\n", + "class TransposedLcaModel(BaseModel, TransposedLcaModule):\n", + " def setup(self, params, logger=None):\n", + " super(TransposedLcaModel, self).setup(params, logger)\n", + " self.setup_module(params)\n", + " self.setup_optimizer()\n", + " if params.checkpoint_boot_log != '':\n", + " checkpoint = self.get_checkpoint_from_log(params.checkpoint_boot_log)\n", + " self.module.load_state_dict(checkpoint['model_state_dict'])\n", + " self.module.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n", + "\n", + " def get_total_loss(self, input_tuple):\n", + " input_tensor, input_labels = input_tuple\n", + " latents = self.get_encodings(input_tensor)\n", + " recon = self.get_recon_from_latents(latents)\n", + " recon_loss = losses.half_squared_l2(input_tensor, recon)\n", + " sparse_loss = self.params.sparse_mult * losses.l1_norm(latents)\n", + " total_loss = recon_loss + sparse_loss\n", + " return total_loss\n", + "\n", + " def generate_update_dict(self, input_data, input_labels=None, batch_step=0, update_dict=None):\n", + " if update_dict is None:\n", + " update_dict = super(TransposedLcaModel, self).generate_update_dict(input_data, input_labels, batch_step)\n", + " stat_dict = dict()\n", + " latents = self.get_encodings(input_data)\n", + " recon = self.get_recon_from_latents(latents)\n", + " recon_loss = losses.half_squared_l2(input_data, recon).item()\n", + " sparse_loss = self.params.sparse_mult * losses.l1_norm(latents).item()\n", + " stat_dict['weight_lr'] = self.scheduler.get_lr()[0]\n", + " stat_dict['loss_recon'] = recon_loss\n", + " stat_dict['loss_sparse'] = sparse_loss\n", + " stat_dict['loss_total'] = recon_loss + sparse_loss\n", + " stat_dict['input_max_mean_min'] = [\n", + " input_data.max().item(), input_data.mean().item(), input_data.min().item()]\n", + " stat_dict['recon_max_mean_min'] = [\n", + " recon.max().item(), recon.mean().item(), recon.min().item()]\n", + " def count_nonzero(array, dim):\n", + " # TODO: github issue 23907 requests torch.count_nonzero, integrated in torch 1.7\n", + " return torch.sum(array !=0, dim=dim, dtype=torch.float)\n", + " latent_dims = tuple([i for i in range(len(latents.shape))])\n", + " latent_nnz = count_nonzero(latents, dim=latent_dims).item()\n", + " stat_dict['fraction_active_all_latents'] = latent_nnz / latents.numel()\n", + " if self.params.layer_types[0] == 'conv':\n", + " latent_map_dims = latent_dims[2:]\n", + " latent_map_size = np.prod(list(latents.shape[2:]))\n", + " latent_channel_nnz = count_nonzero(latents, dim=latent_map_dims)/latent_map_size\n", + " latent_channel_mean_nnz = torch.mean(latent_channel_nnz).item()\n", + " stat_dict['fraction_active_latents_per_channel'] = latent_channel_mean_nnz\n", + " num_channels = latents.shape[1]\n", + " latent_patch_mean_nnz = torch.mean(count_nonzero(latents, dim=1)/num_channels).item()\n", + " stat_dict['fraction_active_latents_per_patch'] = latent_patch_mean_nnz\n", + " update_dict.update(stat_dict)\n", + " return update_dict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class lca_1_recon_params(LcaParams):\n", + " def set_params(self):\n", + " super(lca_1_recon_params, self).set_params()\n", + " self.model_type = 'lca'\n", + " self.model_name = 'lca_1_recon'\n", + " self.version = '0'\n", + " self.layer_types = ['conv']\n", + " self.standardize_data = False\n", + " self.rescale_data_to_one = False\n", + " self.center_dataset = False\n", + " self.batch_size = 1\n", + " self.dt = 0.001\n", + " self.tau = 0.2\n", + " self.num_steps = 75\n", + " self.rectify_a = True\n", + " self.thresh_type = 'hard'\n", + " self.compute_helper_params()\n", + " \n", + "params = lca_1_recon_params()\n", + "params.set_params()\n", + "params.layer_channels = model.pool_1.params.layer_channels[::-1]\n", + "params.kernel_size = model.pool_1.params.pool_ksize\n", + "params.stride = model.pool_1.params.pool_stride\n", + "params.padding = 0\n", + "params.sparse_mult = 0.00#model.lca_1.params.sparse_mult\n", + "params.data_shape = list(b1h_a2h_b2.shape[1:])\n", + "params.epoch_size = 1\n", + "params.output_padding = 1\n", + "params.num_pixels = np.prod(params.data_shape)\n", + "\n", + "lca_1_recon_model = TransposedLcaModel()\n", + "lca_1_recon_model.setup(params)\n", + "lca_1_recon_model.to(params.device)\n", + "lca_1_recon_model.eval()\n", + "with torch.no_grad():\n", + " lca_1_recon_model.weight = nn.Parameter(model.pool_1.weight)\n", + "a1h_b1h_a2h_b2 = lca_1_recon_model(b1h_a2h_b2)\n", + "a1h_b1h_a2 = lca_1_recon_model(b1h_a2)\n", + "a1h_b1 = lca_1_recon_model(beta_1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "alpha_1_bin = alpha_1.detach().cpu().numpy()>0\n", + "alpha_1_hat_bin = a1h_b1h_a2h_b2.detach().cpu().numpy()>0\n", + "hamming_dist = np.abs(np.sum(alpha_1_bin * alpha_1_hat_bin))\n", + "alpha_1_nnz = np.count_nonzero(alpha_1_bin)\n", + "alpha_1_hat_nnz = np.count_nonzero(alpha_1_hat_bin)\n", + "\n", + "num_alpha_1 = np.prod(list(a1h_b1h_a2h_b2.shape))\n", + "alpha_1_edge = int(np.floor(np.sqrt(num_alpha_1)))\n", + "alpha_1_resh = int(alpha_1_edge**2)\n", + "\n", + "plot_alpha_1_hat = ((a1h_b1h_a2h_b2 - a1h_b1h_a2h_b2.min())/(a1h_b1h_a2h_b2.max() - a1h_b1h_a2h_b2.min())).view(-1)[:alpha_1_resh].reshape(alpha_1_edge, alpha_1_edge).detach().cpu().numpy()\n", + "\n", + "plot_alpha_1 = ((alpha_1 - alpha_1.min()) / (alpha_1.max() - alpha_1.min())).view(-1)[:alpha_1_resh].reshape(alpha_1_edge, alpha_1_edge).detach().cpu().numpy()\n", + "\n", + "alpha_diff = np.abs(plot_alpha_1 - plot_alpha_1_hat)\n", + "\n", + "fig, axs = plot.subplots(nrows=1, ncols=3)\n", + "\n", + "ax = pf.clear_axis(axs[0])\n", + "ax.imshow(alpha_diff, vmin=0, vmax=1)\n", + "ax.format(title='alpha 1 differences')\n", + "\n", + "ax = pf.clear_axis(axs[1])\n", + "ax.imshow(plot_alpha_1_hat, vmin=0, vmax=1)\n", + "ax.format(title='alpha 1 hat')\n", + "\n", + "ax = pf.clear_axis(axs[2])\n", + "m = ax.imshow(plot_alpha_1, vmin=0, vmax=1)\n", + "ax.format(title='alpha 1')\n", + "\n", + "ax.colorbar(m, ax=ax)\n", + "axs.format(suptitle=f'active index overlap = {hamming_dist}; alpha 1 nnz = {alpha_1_nnz}; alpha 1 hat nnz = {alpha_1_hat_nnz}')\n", + "plot.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with torch.no_grad():\n", + " recon_from_alpha_1 = F.conv_transpose2d(\n", + " input=alpha_1,\n", + " weight=model.lca_1.weight,\n", + " bias=None,\n", + " stride=model.lca_1.params.stride,\n", + " padding=model.lca_1.params.padding)\n", + " \n", + " recon_from_beta_1 = F.conv_transpose2d(\n", + " input=a1h_b1,\n", + " weight=model.lca_1.weight,\n", + " bias=None,\n", + " stride=model.lca_1.params.stride,\n", + " padding=model.lca_1.params.padding)\n", + " \n", + " recon_from_alpha_2 = F.conv_transpose2d(\n", + " input=a1h_b1h_a2,\n", + " weight=model.lca_1.weight,\n", + " bias=None,\n", + " stride=model.lca_1.params.stride,\n", + " padding=model.lca_1.params.padding)\n", + " \n", + " recon_from_beta_2 = F.conv_transpose2d(\n", + " input=a1h_b1h_a2h_b2,\n", + " weight=model.lca_1.weight,\n", + " bias=None,\n", + " stride=model.lca_1.params.stride,\n", + " padding=model.lca_1.params.padding)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_func(x):\n", + " x *= example_image_std\n", + " x += example_image_mean\n", + " x = ((x - x.min()) / (x.max() - x.min())).squeeze().cpu().numpy().transpose(1,2,0)\n", + " return(x)\n", + "\n", + "alpha_2_nnz = torch.sum(a2h_b2 !=0,\n", + " dim=tuple([i for i in range(len(a2h_b2.shape))]),\n", + " dtype=torch.float)/a2h_b2.numel()\n", + "alpha_1_nnz = torch.sum(a1h_b1h_a2h_b2 !=0,\n", + " dim=tuple([i for i in range(len(a1h_b1h_a2h_b2.shape))]),\n", + " dtype=torch.float)/a1h_b1h_a2h_b2.numel()\n", + "print(\n", + " f'beta2 shape = {beta_2.shape}' + \n", + " f'\\nalpha2^ nnz = {alpha_2_nnz}'+\n", + " f'\\nalpha2^ shape = {a2h_b2.shape}'+\n", + " f'\\nbeta1^ shape = {b1h_a2h_b2.shape}'\n", + " f'\\nalpha1^ nnz = {alpha_1_nnz}'+\n", + " f'\\nalpha1^ shape = {a1h_b1h_a2h_b2.shape}'+\n", + " f'\\nimage^ shape = {recon_from_beta_2.shape}'\n", + ")\n", + "print(\n", + " f'recon min = {recon_from_beta_2.min().item()}'+\n", + " f'\\nrecon mean = {recon_from_beta_2.mean().item()}'+\n", + " f'\\nrecon max = {recon_from_beta_2.max().item()}'+\n", + " f'\\nrecon std = {recon_from_beta_2.std().item()}'\n", + ")\n", + "\n", + "fig, axs = plot.subplots(nrows=2, ncols=3)\n", + "\n", + "ax = pf.clear_axis(axs[0,0])\n", + "ax.imshow(plot_preproc_image, vmin=0, vmax=1)\n", + "ax.format(title='original')\n", + "\n", + "ax = pf.clear_axis(axs[0,1])\n", + "ax.imshow(plot_func(recon_from_alpha_1), vmin=0, vmax=1)\n", + "ax.format(title='recon from alpha 1')\n", + "\n", + "ax = pf.clear_axis(axs[0,2])\n", + "ax.imshow(plot_func(recon_from_beta_1), vmin=0, vmax=1)\n", + "ax.format(title='recon from beta 1')\n", + "\n", + "ax = pf.clear_axis(axs[1,0])\n", + "ax.imshow(plot_func(recon_from_alpha_2), vmin=0, vmax=1)\n", + "ax.format(title='recon from alpha 2')\n", + "\n", + "ax = pf.clear_axis(axs[1,1])\n", + "ax.imshow(plot_func(recon_from_beta_2), vmin=0, vmax=1)\n", + "ax.format(title='recon from beta 2')\n", + "\n", + "ax = pf.clear_axis(axs[1,2])\n", + "plot.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.7" + }, + "varInspector": { + "cols": { + "lenName": 16, + "lenType": 16, + "lenVar": 40 + }, + "kernels_config": { + "python": { + "delete_cmd_postfix": "", + "delete_cmd_prefix": "del ", + "library": "var_list.py", + "varRefreshCmd": "print(var_dic_list())" + }, + "r": { + "delete_cmd_postfix": ") ", + "delete_cmd_prefix": "rm(", + "library": "var_list.r", + "varRefreshCmd": "cat(var_dic_list()) " + } + }, + "types_to_exclude": [ + "module", + "function", + "builtin_function_or_method", + "instance", + "_Feature" + ], + "window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/visualize_model_weights.ipynb b/notebooks/visualize_model_weights.ipynb index 84678ccb..a6bb3460 100644 --- a/notebooks/visualize_model_weights.ipynb +++ b/notebooks/visualize_model_weights.ipynb @@ -29,16 +29,12 @@ "outputs": [], "source": [ "workspace_dir = os.path.expanduser(\"~\")+\"/Work/\"\n", - "model_name = 'lca_dsprites'\n", - "num_epochs = 100\n", - "sparse_mult = 0.05\n", - "model_name += '_{}_{}'.format(sparse_mult, num_epochs)\n", + "model_name = 'conv_lca_mnist'\n", + "#num_epochs = 200\n", + "#sparse_mult = 0.25\n", + "#model_name += '_{}_{}'.format(sparse_mult, num_epochs)\n", "log_file = workspace_dir+'/Torch_projects/{}/logfiles/{}_v0.log'.format(model_name, model_name)\n", "logger = Logger(log_file, overwrite=False)\n", - "\n", - "target_index = 1\n", - "\n", - "logger = Logger(log_files[target_index], overwrite=False)\n", "log_text = logger.load_file()\n", "params = logger.read_params(log_text)[-1]" ] @@ -94,7 +90,7 @@ " paired_pics = [paired_pics[i, :, :, :] for i in range(paired_pics.shape[0])]\n", " print(np.array(paired_pics).shape)\n", " visualize_util.grid_save_images(paired_pics, os.path.join('', \"reconstructions.jpg\"))\n", - "else:\n", + "elif model.params.model_type.lower() in ['mlp', 'ensemble']:\n", " test_results = run_utils.test_epoch(0, model, test_loader, log_to_file=False)\n", " print(test_results)" ] @@ -111,8 +107,14 @@ "else:\n", " weights = list(model.parameters())[0].data.cpu().numpy()\n", "\n", - "num_neurons, num_pixels = weights.shape\n", - "weights = np.reshape(weights, [num_neurons, int(np.sqrt(num_pixels)), int(np.sqrt(num_pixels))])" + "if weights.ndim == 4:\n", + " num_neurons, num_channels, num_h, num_w = weights.shape\n", + " num_pixels = num_channels * num_h * num_w\n", + "elif weights.ndim == 2:\n", + " num_neurons, num_pixels = weights.shape\n", + " weights = np.reshape(weights, [num_neurons, 1, int(np.sqrt(num_pixels)), int(np.sqrt(num_pixels))])\n", + "else:\n", + " assert False, (f'weights.ndim == {weights.ndim} must be 2 or 4')\n" ] }, { @@ -141,7 +143,11 @@ "def pad_matrix_to_image(matrix, pad_size=0, pad_value=0, normalize=False):\n", " if normalize:\n", " matrix = normalize_data_with_max(matrix)[0]\n", - " num_weights, img_h, img_w = matrix.shape\n", + " num_weights, img_c, img_h, img_w = matrix.shape\n", + " if img_c == 1:\n", + " matrix = matrix.squeeze()\n", + " else:\n", + " assert False, (f'Multiple color channels are not currently supported') # TODO\n", " num_extra_images = int(np.ceil(np.sqrt(num_weights))**2 - num_weights)\n", " if num_extra_images > 0:\n", " matrix = np.concatenate(\n", @@ -181,7 +187,8 @@ "\n", "tfpf.plot_image(pad_matrix_to_image(weights), vmin=None, vmax=None, title=\"\", save_filename=model.params.disp_dir+\"/weights_plot_image.png\")\n", "tfpf.plot_weights(weights, save_filename=model.params.disp_dir+\"/weights_plot_weights.png\")\n", - "tfpf.plot_data_tiled(weights[..., None], save_filename=model.params.disp_dir+\"/weights_plot_data_tiled.png\")" + "tfpf.plot_data_tiled(np.transpose(weights, (0, 2, 3, 1)),\n", + " save_filename=model.params.disp_dir+\"/weights_plot_data_tiled.png\")" ] }, { diff --git a/params/base_params.py b/params/base_params.py index c82fe486..056dd3c9 100644 --- a/params/base_params.py +++ b/params/base_params.py @@ -9,10 +9,15 @@ class BaseParams(object): """ all models batch_size [int] number of images in a training batch + center_dataset [bool] if True, subtract the mean dataset image from all datapoints + checkpoint_boot_log [str] path to a training log file for booting from checkpoint + if set, all specified model params must mach those in the log file #TODO: meaningful errors if not data_dir [str] location of dataset folders device [str] which device to run on dtype [torch dtype] dtype for network variables eps [float] small value to avoid division by zero + fast_mnist [bool] if True, use the fastMNIST dataset, + which loads faster but does not allow for torchvision transforms like flip and rotate lib_root_dir [str] system location of this library directory log_to_file [bool] if set, log to file, else log to stderr model_name [str] name for model (can be anything) @@ -34,6 +39,10 @@ class BaseParams(object): standardize_data [bool] if set, z-score data to have mean=0 and standard deviation=1 using numpy operators train_logs_per_epoch [int or None] how often to send updates to the logfile workspace_dir [str] system directory that is the parent to the primary repository directory + num_validation [int] number of images to reserve for the validation set (only works with some datasets) + + ensemble + allow_parent_grads [bool] if True, allow loss gradients to propagate through all members of the ensemble mlp activation_functions [list of str] strings correspond to activation functions for layers. @@ -44,6 +53,11 @@ class BaseParams(object): layer_types [list of str] weight connectivity type, either "conv" or "fc" len must be equal to the len of layer_channels - 1 layer_channels [list of int] number of outputs per layer, including the input layer + kernel_sizes [list of ints] number of pixels on the edge of a square kernel, only used if layer_types is "conv" + strides [list of ints] number of pixels for the convolutional stride, assumes equal horizontal and vertical strides and is only used if layer_types is "conv" + max_pool [list of bools] if True, the network includes a max pooling op after the conv/fc op and before the dropout op + pool_ksizes [list of ints] number of pixels on the edge of a square max pooling kernel + pool_strides [list of ints] number of pixels in pooling stride, assumes equal and horizontal strides lca dt [float] discrete global time constant for neuron dynamics @@ -52,7 +66,7 @@ class BaseParams(object): num_steps [int] number of lca inference steps to take rectify_a [bool] if set, rectify the layer 1 neuron activity sparse_mult [float] multiplyer placed in front of the sparsity loss term - tau [float] LCA time constant + tau [float] LCA time constant; larger values result in smaller step sizes (i.e. slower convergence) lca update rule (step_size) is multiplied by dt/tau thresh_type [str] specifying LCA threshold function; can be "hard" or "soft" @@ -66,8 +80,10 @@ def __init__(self): self.compute_helper_params() def set_params(self): - self.standardize_data = False + self.center_dataset = False + self.checkpoint_boot_log = '' self.rescale_data_to_one = False + self.standardize_data = False self.model_type = None self.log_to_file = True self.train_logs_per_epoch = None @@ -75,6 +91,7 @@ def set_params(self): self.shuffle_data = True self.eps = 1e-12 self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.num_validation = 0 self.rand_seed = 123456789 self.rand_state = np.random.RandomState(self.rand_seed) self.workspace_dir = os.path.join(os.path.expanduser('~'), 'Work') diff --git a/params/lca_cifar10_params.py b/params/lca_cifar10_params.py new file mode 100644 index 00000000..40390315 --- /dev/null +++ b/params/lca_cifar10_params.py @@ -0,0 +1,47 @@ +import types + +from DeepSparseCoding.params.base_params import BaseParams + + +class params(BaseParams): + def set_params(self): + super(params, self).set_params() + self.model_type = 'lca' + self.model_name = 'lca_cifar10' + self.version = '0' + self.dataset = 'cifar10' + self.layer_types = ['conv'] + self.num_validation = 10000 + self.standardize_data = True + self.rescale_data_to_one = False + self.center_dataset = False + self.batch_size = 25 + self.num_epochs = 500 + self.train_logs_per_epoch = 6 + self.renormalize_weights = True + self.layer_channels = [3, 128] + self.kernel_size = 8 + self.stride = 2 + self.padding = 0 + self.weight_decay = 0.0 + self.weight_lr = 0.001 + self.optimizer = types.SimpleNamespace() + self.optimizer.name = 'sgd' + self.optimizer.lr_annealing_milestone_frac = [0.8] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.8 + self.dt = 0.001 + self.tau = 0.1#0.2 + self.num_steps = 37#75 + self.rectify_a = True + self.thresh_type = 'hard' + self.sparse_mult = 0.35#0.30 + self.compute_helper_params() + + def compute_helper_params(self): + super(params, self).compute_helper_params() + self.optimizer.milestones = [frac * self.num_epochs + for frac in self.optimizer.lr_annealing_milestone_frac] + self.step_size = self.dt / self.tau + self.num_pixels = 3072 + self.in_channels = self.layer_channels[0] + self.out_channels = self.layer_channels[1] diff --git a/params/lca_dsprites_params.py b/params/lca_dsprites_params.py index d498b076..f0090e40 100644 --- a/params/lca_dsprites_params.py +++ b/params/lca_dsprites_params.py @@ -12,6 +12,7 @@ def set_params(self): super(params, self).set_params() self.model_type = 'lca' self.model_name = 'lca_dsprites' + self.layer_types = ['fc'] self.version = '0' self.dataset = 'dsprites' self.standardize_data = False @@ -26,13 +27,13 @@ def set_params(self): self.optimizer.lr_annealing_milestone_frac = [0.7] # fraction of num_epochs self.optimizer.lr_decay_rate = 0.5 self.renormalize_weights = True + self.layer_channels = [1, int(self.num_pixels*1.5)] self.dt = 0.001 self.tau = 0.03 self.num_steps = 75 self.rectify_a = False self.thresh_type = 'soft' self.sparse_mult = 0.25 - self.num_latent = int(self.num_pixels*1.5) self.compute_helper_params() def compute_helper_params(self): diff --git a/params/lca_mlp_cifar10_params.py b/params/lca_mlp_cifar10_params.py new file mode 100644 index 00000000..475fdee2 --- /dev/null +++ b/params/lca_mlp_cifar10_params.py @@ -0,0 +1,93 @@ +import os +import types +import numpy as np +import torch + +from DeepSparseCoding.params.base_params import BaseParams +from DeepSparseCoding.params.lca_mnist_params import params as LcaParams +from DeepSparseCoding.params.mlp_mnist_params import params as MlpParams +from DeepSparseCoding.utils.run_utils import compute_conv_output_shape + + +class shared_params(object): + def __init__(self): + self.model_type = 'ensemble' + self.model_name = 'lca_mlp_cifar10' + self.version = '0' + self.dataset = 'cifar10' + self.standardize_data = True + self.batch_size = 25 + self.num_epochs = 10#00 + self.train_logs_per_epoch = 4 + self.allow_parent_grads = True + + +class lca_params(LcaParams): + def set_params(self): + super(lca_params, self).set_params() + for key, value in shared_params().__dict__.items(): setattr(self, key, value) + self.model_type = 'lca' + self.layer_name = 'lca' + self.layer_types = ['conv'] + self.weight_decay = 0.0 + self.weight_lr = 0.001 + self.renormalize_weights = True + self.layer_channels = [3, 512] + self.kernel_size = 8 + self.stride = 2 + self.padding = 0 + self.optimizer = types.SimpleNamespace() + self.optimizer.name = 'sgd' + self.optimizer.lr_annealing_milestone_frac = [0.8] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.8 + self.dt = 0.001 + self.tau = 0.2 + self.num_steps = 75 + self.rectify_a = True + self.thresh_type = 'hard' + self.sparse_mult = 0.30 + self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/conv_lca_cifar10/logfiles/lca_cifar10_v0.log' + self.compute_helper_params() + + +class mlp_params(MlpParams): + def set_params(self): + super(mlp_params, self).set_params() + for key, value in shared_params().__dict__.items(): setattr(self, key, value) + self.model_type = 'mlp' + self.layer_name = 'classifier' + self.weight_lr = 2e-3 + self.weight_decay = 1e-6 + self.layer_types = ['fc'] + self.layer_channels = [None, 10] + self.activation_functions = ['identity'] + self.dropout_rate = [0.0] # probability of value being set to zero + self.optimizer = types.SimpleNamespace() + self.optimizer.name = 'sgd' + self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.8 + self.compute_helper_params() + + +class params(BaseParams): + def set_params(self): + super(params, self).set_params() + lca_params_inst = lca_params() + mlp_params_inst = mlp_params() + lca_output_height = compute_conv_output_shape( + 32, # TODO: infer this? currently hardcoded CIFAR10 size + lca_params_inst.kernel_size, + lca_params_inst.stride, + lca_params_inst.padding, + dilation=1) + lca_output_width = compute_conv_output_shape( + 32, + lca_params_inst.kernel_size, + lca_params_inst.stride, + lca_params_inst.padding, + dilation=1) + lca_output_shape = [lca_params_inst.layer_channels[1], lca_output_height, lca_output_width] + mlp_params_inst.layer_channels[0] = np.prod(lca_output_shape) + self.ensemble_params = [lca_params_inst, mlp_params_inst] + for key, value in shared_params().__dict__.items(): + setattr(self, key, value) diff --git a/params/lca_mlp_mnist_params.py b/params/lca_mlp_mnist_params.py index 212de71f..6f230aac 100644 --- a/params/lca_mlp_mnist_params.py +++ b/params/lca_mlp_mnist_params.py @@ -14,11 +14,13 @@ def __init__(self): self.model_name = 'lca_768_mlp_mnist' self.version = '0' self.dataset = 'mnist' + self.fast_mnist = True self.standardize_data = False self.num_pixels = 28*28*1 self.batch_size = 100 self.num_epochs = 1200 self.train_logs_per_epoch = 4 + self.allow_parent_grads = False class lca_params(LcaParams): @@ -27,6 +29,7 @@ def set_params(self): for key, value in shared_params().__dict__.items(): setattr(self, key, value) self.model_type = 'lca' + self.layer_name = 'lca' self.weight_decay = 0.0 self.weight_lr = 0.1 self.optimizer = types.SimpleNamespace() @@ -34,14 +37,14 @@ def set_params(self): self.optimizer.lr_annealing_milestone_frac = [0.7] # fraction of num_epochs self.optimizer.lr_decay_rate = 0.5 self.renormalize_weights = True + self.layer_channels = [1, 768] self.dt = 0.001 self.tau = 0.03 self.num_steps = 75 self.rectify_a = True self.thresh_type = 'soft' self.sparse_mult = 0.25 - self.num_latent = 768#self.num_pixels*4 - #self.allow_parent_grads = False # TODO: enable this param + self.checkpoint_boot_log = '' self.compute_helper_params() @@ -51,10 +54,11 @@ def set_params(self): for key, value in shared_params().__dict__.items(): setattr(self, key, value) self.model_type = 'mlp' + self.layer_name = 'classifier' self.weight_lr = 1e-4 self.weight_decay = 0.0 self.layer_types = ['fc'] - self.layer_channels = [768, 10]#[self.num_pixels*4, 10] + self.layer_channels = [768, 10] self.activation_functions = ['identity'] self.dropout_rate = [0.0] # probability of value being set to zero self.optimizer = types.SimpleNamespace() diff --git a/params/lca_mnist_params.py b/params/lca_mnist_params.py index dd9cdf34..6274ef54 100644 --- a/params/lca_mnist_params.py +++ b/params/lca_mnist_params.py @@ -1,38 +1,54 @@ -import os import types -import numpy as np -import torch - from DeepSparseCoding.params.base_params import BaseParams +CONV = True + + class params(BaseParams): def set_params(self): super(params, self).set_params() self.model_type = 'lca' - self.model_name = 'lca_768_mnist' self.version = '0' self.dataset = 'mnist' + self.fast_mnist = True self.standardize_data = False self.num_pixels = 784 - self.batch_size = 100 - self.num_epochs = 1000 - self.weight_decay = 0. - self.weight_lr = 0.1 - self.train_logs_per_epoch = 6 - self.optimizer = types.SimpleNamespace() - self.optimizer.name = 'sgd' - self.optimizer.lr_annealing_milestone_frac = [0.7] # fraction of num_epochs - self.optimizer.lr_decay_rate = 0.5 - self.renormalize_weights = True self.dt = 0.001 self.tau = 0.03 self.num_steps = 75 self.rectify_a = True self.thresh_type = 'soft' self.sparse_mult = 0.25 - self.num_latent = 768#self.num_pixels*4 + self.renormalize_weights = True + self.optimizer = types.SimpleNamespace() + self.optimizer.name = 'sgd' + self.num_epochs = 1000 + self.weight_decay = 0.0 + self.train_logs_per_epoch = 6 + if CONV: + self.layer_types = ['conv'] + self.model_name = 'conv_lca_mnist' + self.rescale_data_to_one = True + self.batch_size = 50 + self.weight_lr = 0.001 + self.optimizer.lr_annealing_milestone_frac = [0.8] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.8 + self.layer_channels = [1, 128] + self.kernel_size = 8 + self.stride = 2 + self.padding = 0 + else: + self.layer_types = ['fc'] + self.model_type = 'lca' + self.model_name = 'lca_768_mnist' + self.rescale_data_to_one = False + self.batch_size = 100 + self.weight_lr = 0.1 + self.optimizer.lr_annealing_milestone_frac = [0.7] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.5 + self.layer_channels = [1, 768] #self.num_pixels * 4 self.compute_helper_params() def compute_helper_params(self): diff --git a/params/mlp_cifar10_params.py b/params/mlp_cifar10_params.py new file mode 100644 index 00000000..24aa8f29 --- /dev/null +++ b/params/mlp_cifar10_params.py @@ -0,0 +1,43 @@ +import os +import types + +import numpy as np +import torch + +from DeepSparseCoding.params.base_params import BaseParams + + +class params(BaseParams): + def set_params(self): + super(params, self).set_params() + self.model_type = 'mlp' + self.model_name = 'mlp_cifar10' + self.version = '0' + self.dataset = 'cifar10' + self.standardize_data = True + self.rescale_data_to_one = False + self.center_data = False + self.num_validation = 1000 + self.batch_size = 50 + self.num_epochs = 500 + self.weight_decay = 3e-6 + self.weight_lr = 2e-3 + self.layer_types = ['conv', 'fc'] + self.layer_channels = [3, 512, 10] + self.kernel_sizes = [8, None] + self.strides = [2, None] + self.activation_functions = ['lrelu', 'identity'] + self.dropout_rate = [0.5, 0.0] # probability of value being set to zero + self.max_pool = [True, False] + self.pool_ksizes = [5, None] + self.pool_strides = [4, None] + self.train_logs_per_epoch = 4 + self.optimizer = types.SimpleNamespace() + self.optimizer.name = 'adam' + self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.1 + + def compute_helper_params(self): + super(params, self).compute_helper_params() + self.optimizer.milestones = [frac * self.num_epochs + for frac in self.optimizer.lr_annealing_milestone_frac] diff --git a/params/mlp_mnist_params.py b/params/mlp_mnist_params.py index 901ab60e..dcc7e195 100644 --- a/params/mlp_mnist_params.py +++ b/params/mlp_mnist_params.py @@ -14,17 +14,19 @@ def set_params(self): self.model_name = 'mlp_768_mnist' self.version = '0' self.dataset = 'mnist' + self.fast_mnist = True self.standardize_data = False self.rescale_data_to_one = False - self.num_pixels = 28*28*1 self.batch_size = 50 self.num_epochs = 300 self.weight_lr = 5e-4 self.weight_decay = 2e-6 self.layer_types = ['fc', 'fc'] + self.num_pixels = 28*28*1 self.layer_channels = [self.num_pixels, 768, 10] self.activation_functions = ['lrelu', 'identity'] self.dropout_rate = [0.5, 0.0] # probability of value being set to zero + self.max_pool = [False, False] self.train_logs_per_epoch = 4 self.optimizer = types.SimpleNamespace() self.optimizer.name = 'adam' diff --git a/params/smt_cifar10_params.py b/params/smt_cifar10_params.py new file mode 100644 index 00000000..940d4fcf --- /dev/null +++ b/params/smt_cifar10_params.py @@ -0,0 +1,259 @@ +import os +import types + +import numpy as np +import torch + +from DeepSparseCoding.params.base_params import BaseParams +from DeepSparseCoding.params.lca_cifar10_params import params as LcaParams +from DeepSparseCoding.params.mlp_mnist_params import params as MlpParams +from DeepSparseCoding.utils.run_utils import compute_conv_output_shape + + +class shared_params(object): + def __init__(self): + self.model_type = 'ensemble' + self.model_name = 'test_smt_cifar10' + #self.version = 'lplpm' + self.version = '2lp' + self.dataset = 'cifar10' + self.standardize_data = True + self.rescale_data_to_one = False + self.center_dataset = False + self.batch_size = 30 + self.num_epochs = 200 + self.train_logs_per_epoch = 4 + self.allow_parent_grads = False + + +class lca_1_params(LcaParams): + def set_params(self): + super(lca_1_params, self).set_params() + for key, value in shared_params().__dict__.items(): setattr(self, key, value) + self.model_type = 'lca' + self.layer_name = 'lca_1' + self.layer_types = ['conv'] + self.weight_decay = 0.0 + #self.weight_lr = 1e-3 + self.weight_lr = 0.0 # For next layer training + self.renormalize_weights = True + #self.layer_channels = [3, 128] + self.layer_channels = [3, 256] + self.kernel_size = 8 + #self.stride = 2 + self.stride = 1 + self.padding = 0 + self.optimizer = types.SimpleNamespace() + self.optimizer.name = 'sgd' + self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.8 + self.dt = 0.001 + #self.tau = 0.2#0.10 + self.tau = 0.25 + #self.num_steps = 75#37 + self.num_steps = 75 + self.rectify_a = True + self.thresh_type = 'hard' + #self.sparse_mult = 0.35 + self.sparse_mult = 0.28 + #self.checkpoint_boot_log = '' + #self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/smt_cifar10/logfiles/smt_cifar10_vlplp.log' + self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/smt_cifar10/logfiles/smt_cifar10_v2l.log' + self.compute_helper_params() + + +class pooling_1_params(BaseParams): + def set_params(self): + super(pooling_1_params, self).set_params() + for key, value in shared_params().__dict__.items(): + setattr(self, key, value) + self.model_type = 'pooling' + self.layer_name = 'pool_1' + self.layer_types = ['conv'] + self.weight_lr = 1e-3 + #self.weight_lr = 0.0 # For next layer training + #self.layer_channels = [128, 32] + self.layer_channels = [256, 32] + #self.pool_ksize = 2 + self.pool_ksize = 4 + self.pool_stride = 2 + self.renormalize_weights = True + self.optimizer = types.SimpleNamespace() + self.optimizer.name = 'sgd' + self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.8 + self.checkpoint_boot_log = '' + #self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/smt_cifar10/logfiles/smt_cifar10_vlplp.log' + #self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/smt_cifar10/logfiles/smt_cifar10_v2lp.log' + self.compute_helper_params() + + def compute_helper_params(self): + super(pooling_1_params, self).compute_helper_params() + self.optimizer.milestones = [frac * self.num_epochs + for frac in self.optimizer.lr_annealing_milestone_frac] + + +class lca_2_params(LcaParams): + def set_params(self): + super(lca_2_params, self).set_params() + for key, value in shared_params().__dict__.items(): setattr(self, key, value) + for key, value in lca_1_params().__dict__.items(): setattr(self, key, value) + self.layer_name = 'lca_2' + self.weight_lr = 1e-3 + #self.weight_lr = 0.0 # For next layer training + #self.layer_channels = [32, 256] + self.layer_channels = [32, 512] + #self.kernel_size = 6 + self.kernel_size = 8 + self.stride = 1 + self.padding = 0 + self.sparse_mult = 0.15 + self.tau = 0.20 + self.checkpoint_boot_log = '' + #self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/smt_cifar10/logfiles/smt_cifar10_vlplp.log' + self.compute_helper_params() + +class pooling_2_params(BaseParams): + def set_params(self): + super(pooling_2_params, self).set_params() + for key, value in shared_params().__dict__.items(): setattr(self, key, value) + for key, value in pooling_1_params().__dict__.items(): setattr(self, key, value) + self.layer_name = 'pool_2' + self.weight_lr = 1e-3 + #self.weight_lr = 0.0 # For next layer training + #self.layer_types = ['fc'] + self.layer_types = ['conv'] + #self.layer_channels = [None, 64] + self.layer_channels = [512, 150] + self.pool_ksize = 4 + self.pool_stride = 1 + self.optimizer = types.SimpleNamespace() + self.optimizer.name = 'sgd' + self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.8 + self.checkpoint_boot_log = '' + #self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/smt_cifar10/logfiles/smt_cifar10_vlplp.log' + self.compute_helper_params() + + def compute_helper_params(self): + super(pooling_2_params, self).compute_helper_params() + self.optimizer.milestones = [frac * self.num_epochs + for frac in self.optimizer.lr_annealing_milestone_frac] + + +class mlp_params(MlpParams): + def set_params(self): + super(mlp_params, self).set_params() + for key, value in shared_params().__dict__.items(): setattr(self, key, value) + self.model_type = 'mlp' + self.layer_name = 'classifier' + self.weight_lr = 1e-2 + self.weight_decay = 1e-6 + self.layer_types = ['fc'] + #self.layer_channels = [64, 10] + self.layer_channels = [150, 10] + self.activation_functions = ['identity'] + self.dropout_rate = [0.0] # probability of value being set to zero + self.optimizer = types.SimpleNamespace() + self.optimizer.name = 'sgd' + self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.8 + self.compute_helper_params() + + +class params(BaseParams): + def set_params(self): + super(params, self).set_params() + lca_1_params_inst = lca_1_params() + pooling_1_params_inst = pooling_1_params() + lca_2_params_inst = lca_2_params() + pooling_2_params_inst = pooling_2_params() + mlp_params_inst = mlp_params() + data_shape = [3, 32, 32] + lca_1_output_height = compute_conv_output_shape( + data_shape[1], + lca_1_params_inst.kernel_size, + lca_1_params_inst.stride, + lca_1_params_inst.padding, + dilation=1) + lca_1_output_width = compute_conv_output_shape( + data_shape[2], + lca_1_params_inst.kernel_size, + lca_1_params_inst.stride, + lca_1_params_inst.padding, + dilation=1) + lca_1_shape = [ + lca_1_params_inst.layer_channels[-1], + lca_1_output_height, + lca_1_output_width + ] + pooling_1_output_height = compute_conv_output_shape( + lca_1_output_height, + pooling_1_params_inst.pool_ksize, + pooling_1_params_inst.pool_stride, + padding=0, + dilation=1) + pooling_1_output_width = compute_conv_output_shape( + lca_1_output_width, + pooling_1_params_inst.pool_ksize, + pooling_1_params_inst.pool_stride, + padding=0, + dilation=1) + pooling_1_shape = [ + pooling_1_params_inst.layer_channels[-1], + pooling_1_output_height, + pooling_1_output_width + ] + lca_2_params_inst.data_shape = [ + int(pooling_1_params_inst.layer_channels[-1]), + int(pooling_1_output_height), + int(pooling_1_output_width)] + lca_2_output_height = compute_conv_output_shape( + pooling_1_output_height, + lca_2_params_inst.kernel_size, + lca_2_params_inst.stride, + lca_2_params_inst.padding, + dilation=1) + lca_2_output_width = compute_conv_output_shape( + pooling_1_output_width, + lca_2_params_inst.kernel_size, + lca_2_params_inst.stride, + lca_2_params_inst.padding, + dilation=1) + lca_2_shape = [ + lca_2_params_inst.layer_channels[-1], + lca_2_output_height, + lca_2_output_width + ] + lca_2_flat_dim = int(np.prod(lca_2_shape)) + pooling_2_params_inst.layer_channels[0] = lca_2_flat_dim + pooling_2_output_height = compute_conv_output_shape( + lca_2_output_height, + pooling_2_params_inst.pool_ksize, + pooling_2_params_inst.pool_stride, + padding=0, + dilation=1) + pooling_2_output_width = compute_conv_output_shape( + lca_2_output_width, + pooling_2_params_inst.pool_ksize, + pooling_2_params_inst.pool_stride, + padding=0, + dilation=1) + pooling_2_shape = [ + pooling_2_params_inst.layer_channels[-1], + pooling_2_output_height, + pooling_2_output_width + ] + l1_overcompleteness = np.prod(lca_1_shape) / np.prod(data_shape) + p1_overcompleteness = np.prod(pooling_1_shape) / np.prod(lca_1_shape) + l2_overcompleteness = np.prod(lca_2_shape) / np.prod(pooling_1_shape) + p2_overcompleteness = np.prod(pooling_2_shape) / np.prod(lca_2_shape) + self.ensemble_params = [ + lca_1_params_inst, + pooling_1_params_inst, + #lca_2_params_inst, + #pooling_2_params_inst, + #mlp_params_inst + ] + for key, value in shared_params().__dict__.items(): + setattr(self, key, value) diff --git a/params/test_params.py b/params/test_params.py index 01a687b5..80a25b6b 100644 --- a/params/test_params.py +++ b/params/test_params.py @@ -1,13 +1,14 @@ import os import sys import types +from os.path import dirname as up + +ROOT_DIR = up(up(up(os.path.realpath(__file__)))) +if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np import torch -ROOT_DIR = os.path.dirname(os.getcwd()) -if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) - from DeepSparseCoding.params.base_params import BaseParams from DeepSparseCoding.params.lca_mnist_params import params as LcaParams from DeepSparseCoding.params.mlp_mnist_params import params as MlpParams @@ -35,6 +36,7 @@ def __init__(self): self.num_test_images = 0 self.standardize_data = False self.rescale_data_to_one = False + self.allow_parent_grads = False self.num_epochs = 3 self.train_logs_per_epoch = 1 @@ -42,18 +44,17 @@ def __init__(self): class base_params(BaseParams): def set_params(self): super(base_params, self).set_params() - for key, value in shared_params().__dict__.items(): - setattr(self, key, value) + for key, value in shared_params().__dict__.items(): setattr(self, key, value) class lca_params(BaseParams): def set_params(self): super(lca_params, self).set_params() - for key, value in shared_params().__dict__.items(): - setattr(self, key, value) + for key, value in shared_params().__dict__.items(): setattr(self, key, value) self.model_type = 'lca' self.weight_decay = 0.0 self.weight_lr = 0.1 + self.layer_types = ['fc'] self.optimizer = types.SimpleNamespace() self.optimizer.name = 'sgd' self.optimizer.lr_annealing_milestone_frac = [0.7] # fraction of num_epochs @@ -65,17 +66,49 @@ def set_params(self): self.rectify_a = True self.thresh_type = 'soft' self.sparse_mult = 0.25 - self.num_latent = 128 + self.layer_channels = [64, 128] self.optimizer.milestones = [frac * self.num_epochs for frac in self.optimizer.lr_annealing_milestone_frac] self.step_size = self.dt / self.tau +# TODO: Add ability to test multiple param values +#class conv_lca_params(lca_params): +# def set_params(self): +# super(conv_lca_params, self).set_params() +# self.layer_types = ['conv'] +# self.kernel_size = 8 +# self.stride = 2 +# self.padding = 0 +# self.optimizer.milestones = [frac * self.num_epochs +# for frac in self.optimizer.lr_annealing_milestone_frac] +# self.step_size = self.dt / self.tau +# self.out_channels = self.layer_channels +# self.in_channels = 1 + + +class pooling_params(BaseParams): + def set_params(self): + super(pooling_params, self).set_params() + for key, value in shared_params().__dict__.items(): setattr(self, key, value) + self.model_type = 'pooling' + self.layer_name = 'test_pool_1' + self.weight_lr = 1e-3 + self.layer_types = ['conv'] + self.layer_channels = [128, 32] + self.pool_ksize = 2 + self.pool_stride = 2 # non-overlapping + self.optimizer = types.SimpleNamespace() + self.optimizer.name = 'sgd' + self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs + self.optimizer.lr_decay_rate = 0.8 + self.optimizer.milestones = [frac * self.num_epochs + for frac in self.optimizer.lr_annealing_milestone_frac] + class mlp_params(BaseParams): def set_params(self): super(mlp_params, self).set_params() - for key, value in shared_params().__dict__.items(): - setattr(self, key, value) + for key, value in shared_params().__dict__.items(): setattr(self, key, value) self.model_type = 'mlp' self.weight_lr = 1e-4 self.weight_decay = 0.0 @@ -83,6 +116,7 @@ def set_params(self): self.layer_channels = [128, 10] self.activation_functions = ['identity'] self.dropout_rate = [0.0] # probability of value being set to zero + self.max_pool = [False] self.optimizer = types.SimpleNamespace() self.optimizer.name = 'adam' self.optimizer.lr_annealing_milestone_frac = [0.8] # fraction of num_epochs @@ -94,6 +128,9 @@ def set_params(self): class ensemble_params(BaseParams): def set_params(self): super(ensemble_params, self).set_params() - self.ensemble_params = [lca_params(), mlp_params()] - for key, value in shared_params().__dict__.items(): - setattr(self, key, value) + layer1_params = lca_params() + layer1_params.layer_name = 'layer1' + layer2_params = mlp_params() + layer2_params.layer_name = 'layer2' + self.ensemble_params = [layer1_params, layer2_params] + for key, value in shared_params().__dict__.items(): setattr(self, key, value) diff --git a/requirements.txt b/requirements.txt index f7e76024..02e1795b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,9 +11,4 @@ Pillow>=5.3.0 scikit-image>=0.14.1 scikit-learn>=0.20.0 scipy>=1.1.0 -seaborn>=0.9.0 -tensorflow-gpu==1.15.2 -tensorflow-estimator==1.15.1 -tensorboard==1.15 -tensorflow-probability==0.8.0 -tensorflow-compression +seaborn>=0.9.0 \ No newline at end of file diff --git a/tests/test_data_processing.py b/tests/test_data_processing.py index a615b4d0..7a3f8af6 100644 --- a/tests/test_data_processing.py +++ b/tests/test_data_processing.py @@ -1,15 +1,14 @@ import os import sys import unittest +from os.path import dirname as up +ROOT_DIR = up(up(up(os.path.realpath(__file__)))) +if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import torch import numpy as np - -ROOT_DIR = os.path.dirname(os.getcwd()) -if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) - import DeepSparseCoding.utils.data_processing as dp @@ -21,9 +20,9 @@ def test_reshape_data(self): function call: reshape_data(data, flatten=None, out_shape=None): 24 possible conditions: data: [np.ndarray] data of shape: - n is num_examples, i is num_rows, j is num_cols, k is num_channels, l is num_examples = i*j*k - (l) - single data point of shape l, assumes 1 color channel - (n, l) - n data points, each of shape l (flattened) + n is num_examples, i is num_channels, j is num_rows, k is num_cols + (i) - single data point of shape i + (n, i) - n data points, each of shape i (flattened) (i, j, k) - single datapoint of of shape (i, j, k) (n, i, j, k) - n data points, each of shape (i,j,k) flatten: True, False, None @@ -44,8 +43,8 @@ def test_reshape_data(self): input_array_list = [ np.zeros((num_elements)), # assumed num_examples == 1 np.zeros((num_examples, num_elements)), - np.zeros((num_rows, num_cols, num_channels)), # assumed num_examples == 1 - np.zeros((num_examples, num_rows, num_cols, num_channels))] + np.zeros((num_channels, num_rows, num_cols)), # assumed num_examples == 1 + np.zeros((num_examples, num_channels, num_rows, num_cols))] for input_array in input_array_list: input_shape = input_array.shape input_ndim = input_array.ndim @@ -54,7 +53,7 @@ def test_reshape_data(self): out_shape_list = [ None, (num_elements,), - (num_rows, num_cols, num_channels)] + (num_channels, num_rows, num_cols)] if(num_channels == 1): out_shape_list.append((num_rows, num_cols)) else: @@ -62,7 +61,7 @@ def test_reshape_data(self): out_shape_list = [ None, (num_examples, num_elements), - (num_examples, num_rows, num_cols, num_channels)] + (num_examples, num_channels, num_rows, num_cols)] if(num_channels == 1): out_shape_list.append((num_examples, num_rows, num_cols)) for out_shape in out_shape_list: @@ -83,7 +82,7 @@ def test_reshape_data(self): reshaped_array = reshape_outputs[0].numpy() err_msg += f'\nreshaped_array.shape={reshaped_array.shape}' self.assertEqual(reshape_outputs[1], input_shape, err_msg) # orig_shape - (resh_num_examples, resh_num_rows, resh_num_cols, resh_num_channels) = reshape_outputs[2:] + (resh_num_examples, resh_num_channels, resh_num_rows, resh_num_cols) = reshape_outputs[2:] err_msg += (f'\nfunction_shape_outputs={reshape_outputs[2:]}') if(out_shape is None): if(flatten is None): @@ -105,26 +104,26 @@ def test_reshape_data(self): expected_out_shape, err_msg) self.assertEqual( - resh_num_rows*resh_num_cols*resh_num_channels, + resh_num_channels * resh_num_rows * resh_num_cols, expected_out_shape[1], err_msg) elif(flatten == False): - expected_out_shape = (num_examples, num_rows, num_cols, num_channels) + expected_out_shape = (num_examples, num_channels, num_rows, num_cols) err_msg += f'\nexpected_out_shape={expected_out_shape}' self.assertEqual( reshaped_array.shape, expected_out_shape, err_msg) self.assertEqual( - resh_num_rows, + resh_num_channels, expected_out_shape[1], err_msg) self.assertEqual( - resh_num_cols, + resh_num_rows, expected_out_shape[2], err_msg) self.assertEqual( - resh_num_channels, + resh_num_cols, expected_out_shape[3], err_msg) else: @@ -135,22 +134,12 @@ def test_reshape_data(self): self.assertEqual(reshaped_array.shape, expected_out_shape, err_msg) self.assertEqual(resh_num_examples, None, err_msg) - - def test_flatten_feature_map(self): - unflat_shape = [8, 4, 4, 3] - flat_shape = [8, 4*4*3] - shapes = [unflat_shape, flat_shape] - for shape in shapes: - test_map = torch.zeros(shape) - flat_map = dp.flatten_feature_map(test_map).numpy() - self.assertEqual(list(flat_map.shape), flat_shape) - def test_standardize(self): num_tolerance_decimals = 5 unflat_shape = [8, 4, 4, 3] flat_shape = [8, 4*4*3] shape_options = [unflat_shape, flat_shape] - eps_options = [1e-6, None] + eps_options = [1e-8, None] samplewise_options = [True, False] for shape in shape_options: for eps_val in eps_options: @@ -237,3 +226,40 @@ def test_label_conversion(self): np.testing.assert_equal(func_dense, dense_labels) func_one_hot = dp.dense_to_one_hot(torch.tensor(dense_labels), num_classes).numpy() np.testing.assert_equal(func_one_hot, one_hot_labels) + + def test_atleastkd(self): + x = np.random.standard_normal([2, 3, 4]) + ks = [0, 3, 8, 10] + for k in ks: + new_x = dp.atleast_kd(torch.tensor(x), k).numpy() + test_nd = np.maximum(k, x.ndim) + np.testing.assert_equal(new_x.ndim, test_nd) + + def test_l2_weight_norm(self): + w_fc = np.random.standard_normal([38, 24]) + w_conv = np.random.standard_normal([38, 24, 8, 8]) + for w in [w_fc, w_conv]: + w_norm = dp.get_weights_l2_norm(torch.tensor(w), eps=1e-12).numpy() + normed_w = dp.l2_normalize_weights(torch.tensor(w), eps=1e-12).numpy() + normed_w_norm = dp.get_weights_l2_norm(torch.tensor(normed_w), eps=1e-12).numpy() + np.testing.assert_allclose(normed_w_norm, 1.0, rtol=1e-10) + np.testing.assert_allclose(w / w_norm, normed_w, rtol=1e-10) + + def test_patches(self): + err = 1e-6 + rand_mean = 0; rand_var = 1 + num_im = 10; im_edge = 512; im_chan = 1; patch_edge = 16 + num_patches = int(num_im * (im_edge / patch_edge)**2) + rand_seed = 1234 + rand_state = np.random.RandomState(rand_seed) + data = np.stack([rand_state.normal(rand_mean, rand_var, size=[im_chan, im_edge, im_edge]) + for _ in range(num_im)]) + data_shape = list(data.shape) + patch_shape = [im_chan, patch_edge, patch_edge] + datapoint = torch.tensor(data[0, ...]) + datapoint_patches = dp.single_image_to_patches(datapoint, patch_shape) + datapoint_recon = dp.patches_to_single_image(datapoint_patches, data_shape[1:]) + np.testing.assert_allclose(datapoint.numpy(), datapoint_recon.numpy(), rtol=err) + patches = dp.images_to_patches(torch.tensor(data), patch_shape) + data_recon = dp.patches_to_images(patches, data_shape[1:]) + np.testing.assert_allclose(data, data_recon.numpy(), rtol=err) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index d1c0cde4..c2ac7ea2 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -2,13 +2,14 @@ import sys import unittest import types +from os.path import dirname as up + +ROOT_DIR = up(up(up(os.path.realpath(__file__)))) +if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np from torchvision import datasets -ROOT_DIR = os.path.dirname(os.getcwd()) -if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) - import DeepSparseCoding.utils.dataset_utils as dataset_utils class TestDatasets(unittest.TestCase): @@ -33,7 +34,7 @@ def test_mnist(self): params.dataset = 'mnist' params.shuffle_data = True params.batch_size = 10000 - train_loader, val_loader, test_loader, data_params = dataset_utils.load_dataset(params) + train_loader, val_loader, test_loader, data_params = dataset_utils.load_dataset(params)[:4] for key, value in data_params.items(): setattr(params, key, value) assert len(train_loader.dataset) == params.epoch_size @@ -60,10 +61,15 @@ def test_synthetic(self): params.dist_type = dist_type params.num_classes = num_classes params.rand_state = rand_state - train_loader, val_loader, test_loader, data_params = dataset_utils.load_dataset(params) + train_loader, val_loader, test_loader, data_params = dataset_utils.load_dataset(params)[:4] for key, value in data_params.items(): setattr(params, key, value) assert len(train_loader.dataset) == epoch_size for batch_idx, (data, target) in enumerate(train_loader): - assert data.numpy().shape == (params.batch_size, params.data_edge_size, params.data_edge_size, 1) + expected_size = ( + params.batch_size, + 1, + params.data_edge_size, + params.data_edge_size) + assert data.numpy().shape == expected_size assert batch_idx + 1 == epoch_size // params.batch_size diff --git a/tests/test_foolbox.py b/tests/test_foolbox.py index 02dc807d..24ca9774 100644 --- a/tests/test_foolbox.py +++ b/tests/test_foolbox.py @@ -1,16 +1,17 @@ import os import sys import unittest +from os.path import dirname as up -#import numpy as np -import eagerpy as ep -from foolbox import PyTorchModel, accuracy, samples -import foolbox.attacks as fa - -ROOT_DIR = os.path.dirname(os.getcwd()) +ROOT_DIR = up(up(up(os.path.realpath(__file__)))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) -import DeepSparseCoding.utils.loaders as loaders +#import numpy as np +#import eagerpy as ep +#from foolbox import PyTorchModel, accuracy, samples +#import foolbox.attacks as fa + +#import DeepSparseCoding.utils.loaders as loaders #import DeepSparseCoding.utils.dataset_utils as datasets #import DeepSparseCoding.utils.run_utils as run_utils @@ -28,7 +29,7 @@ # 'steps':3}} # max perturbation it can reach is 0.5 # attack = fa.LinfPGD(**attack_params['linfPGD']) # epsilons = [0.3] # allowed perturbation size -# params['ensemble'] = loaders.load_params(self.test_params_file, key='ensemble_params') +# params['ensemble'] = loaders.load_params_file(self.test_params_file, key='ensemble_params') # params['ensemble'].train_logs_per_epoch = None # params['ensemble'].shuffle_data = False # train_loader, val_loader, test_loader, data_params = datasets.load_dataset(params['ensemble']) @@ -42,4 +43,4 @@ # fmodel = PyTorchModel(model.eval(), bounds=(0, 1)) # model_output = fmodel.forward() # adv_model_outputs, adv_images, success = attack(fmodel, train_data_batch, train_target_batch, epsilons=epsilons) -# \ No newline at end of file +# diff --git a/tests/test_models.py b/tests/test_models.py index a5e16ccd..7fdd172a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,12 +1,13 @@ import os import sys import unittest +from os.path import dirname as up -import numpy as np - -ROOT_DIR = os.path.dirname(os.getcwd()) +ROOT_DIR = up(up(up(os.path.realpath(__file__)))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) +import numpy as np + import DeepSparseCoding.utils.loaders as loaders import DeepSparseCoding.utils.dataset_utils as datasets import DeepSparseCoding.utils.run_utils as run_utils @@ -18,37 +19,54 @@ def setUp(self): self.model_list = loaders.get_model_list(self.dsc_dir) self.test_params_file = os.path.join(*[self.dsc_dir, 'params', 'test_params.py']) + ### TODO - define endpoint function for checkpoint loading & test independently + ### TODO - add ability to test multiple options (e.g. 'conv' and 'fc') from test params def test_model_loading(self): for model_type in self.model_list: - model_type = ''.join(model_type.split('_')[:-1]) # remove '_model' at the end + model_type = '_'.join(model_type.split('_')[:-1]) # remove '_model' at the end model = loaders.load_model(model_type) - params = loaders.load_params(self.test_params_file, key=model_type+'_params') - train_loader, val_loader, test_loader, data_params = datasets.load_dataset(params) + params = loaders.load_params_file(self.test_params_file, key=model_type+'_params') + train_loader, val_loader, test_loader, data_params = datasets.load_dataset(params)[:4] for key, value in data_params.items(): setattr(params, key, value) model.setup(params) - ### TODO - more basic test to compute gradients per model### + + ### TODO - more basic test to compute gradients per model #def test_gradients(self): # for model_type in self.model_list: # model_type = ''.join(model_type.split('_')[:-1]) # remove '_model' at the end # model = loaders.load_model(model_type) + ### TODO - test for gradient blocking + #def test_get_module_encodings(self): + # """ + # Test for gradient blocking in the get_module_encodings function + + # construct test model1 & model2 + # construct test ensemble model = model1 -> model2 + # get encoding & grads for allow_grads={True, False} + # False: compare grads for model1 alone vs model1 in ensemble + # True: ensure that grad is different from model1 alone + # * Should also manually compute grads to compare? + # """ + # # test should utilize run_utils.get_module_encodings() + + def test_lca_ensemble_gradients(self): + ## Load models params = {} models = {} - params['lca'] = loaders.load_params(self.test_params_file, key='lca_params') + params['lca'] = loaders.load_params_file(self.test_params_file, key='lca_params') params['lca'].train_logs_per_epoch = None params['lca'].shuffle_data = False - train_loader, val_loader, test_loader, data_params = datasets.load_dataset(params['lca']) - for key, value in data_params.items(): - setattr(params['lca'], key, value) + train_loader, val_loader, test_loader, data_params = datasets.load_dataset(params['lca'])[:4] + for key, value in data_params.items(): setattr(params['lca'], key, value) models['lca'] = loaders.load_model(params['lca'].model_type) models['lca'].setup(params['lca']) models['lca'].to(params['lca'].device) - params['ensemble'] = loaders.load_params(self.test_params_file, key='ensemble_params') - for key, value in data_params.items(): - setattr(params['ensemble'], key, value) + params['ensemble'] = loaders.load_params_file(self.test_params_file, key='ensemble_params') + for key, value in data_params.items(): setattr(params['ensemble'], key, value) err_msg = f'\ndata_shape={params["ensemble"].data_shape}' err_msg += f'\nnum_pixels={params["ensemble"].num_pixels}' err_msg += f'\nbatch_size={params["ensemble"].batch_size}' @@ -56,9 +74,11 @@ def test_lca_ensemble_gradients(self): models['ensemble'] = loaders.load_model(params['ensemble'].model_type) models['ensemble'].setup(params['ensemble']) models['ensemble'].to(params['ensemble'].device) + ## Overwrite weight initialization so that they have the same weights ensemble_state_dict = models['ensemble'].state_dict() - ensemble_state_dict['lca.w'] = models['lca'].w.clone() + ensemble_state_dict['layer1.weight'] = models['lca'].weight.clone() models['ensemble'].load_state_dict(ensemble_state_dict) + ## Load data data, target = next(iter(train_loader)) train_data_batch = models['lca'].preprocess_data(data.to(params['lca'].device)) train_target_batch = target.to(params['lca'].device) @@ -66,27 +86,37 @@ def test_lca_ensemble_gradients(self): for submodel in models['ensemble']: submodel.optimizer.zero_grad() inputs = [train_data_batch] # only the first model acts on input + ## Verify feedforward encodings + lca_encoding = models['lca'](inputs[0]).cpu().detach().numpy() + ensemble_encoding = models['ensemble'][0].get_encodings(inputs[0]).cpu().detach().numpy() + assert np.all(lca_encoding == ensemble_encoding), (err_msg+'\n' + +f'Forward encodings for lca and ensemble[0] should be equal, but are not') + ## Verify LCA loss + lca_loss = models['lca'].get_total_loss((inputs[0], train_target_batch)) + ensemble_losses = [models['ensemble'].get_total_loss((inputs[0], train_target_batch), 0)] + lca_loss_val = lca_loss.cpu().detach().numpy() + ensemble_loss_val = ensemble_losses[0].cpu().detach().numpy() + assert lca_loss_val == ensemble_loss_val, (err_msg+'\n' + +f'Losses should be equal, but are lca={lca_loss_val} and ensemble={ensemble_loss_val}') + ## Compute remaining ensemble outputs for submodel in models['ensemble']: inputs.append(submodel.get_encodings(inputs[-1]).detach()) - lca_loss = models['lca'].get_total_loss((train_data_batch, train_target_batch)) - ensemble_losses = [models['ensemble'].get_total_loss((inputs[0], train_target_batch), 0)] ensemble_losses.append(models['ensemble'].get_total_loss((inputs[1], train_target_batch), 1)) + ## Verify lca grad & ensemble grad are equal lca_loss.backward() ensemble_losses[0].backward() ensemble_losses[1].backward() - lca_loss_val = lca_loss.cpu().detach().numpy() - lca_w_grad = models['lca'].w.grad.cpu().numpy() - ensemble_loss_val = ensemble_losses[0].cpu().detach().numpy() - ensemble_w_grad = models['ensemble'][0].w.grad.cpu().numpy() - assert lca_loss_val == ensemble_loss_val, (err_msg+'\n' - +'Losses should be equal, but are lca={lca_loss_val} and ensemble={ensemble_loss_val}') - assert np.all(lca_w_grad == ensemble_w_grad), (err_msg+'\nGrads should be equal, but are not.') - lca_pre_train_w = models['lca'].w.cpu().detach().numpy().copy() - ensemble_pre_train_w = models['ensemble'][0].w.cpu().detach().numpy().copy() + lca_w_grad = models['lca'].weight.grad.cpu().numpy() + ensemble_w_grad = models['ensemble'][0].weight.grad.cpu().numpy() + assert np.all(lca_w_grad == ensemble_w_grad), (err_msg+'\n' + +f'Grads should be equal, but are not.') + ## Verify weight updates are equal + lca_pre_train_w = models['lca'].weight.cpu().detach().numpy().copy() + ensemble_pre_train_w = models['ensemble'][0].weight.cpu().detach().numpy().copy() run_utils.train_epoch(1, models['lca'], train_loader) run_utils.train_epoch(1, models['ensemble'], train_loader) - lca_w = models['lca'].w.cpu().detach().numpy().copy() - ensemble_w = models['ensemble'][0].w.cpu().detach().numpy().copy() + lca_w = models['lca'].weight.cpu().detach().numpy().copy() + ensemble_w = models['ensemble'][0].weight.cpu().detach().numpy().copy() assert np.all(lca_pre_train_w == ensemble_pre_train_w), (err_msg+'\n' +"lca & ensemble weights are not equal before one epoch of training") assert not np.all(lca_pre_train_w == lca_w), (err_msg+'\n' diff --git a/tests/test_param_loading.py b/tests/test_param_loading.py index 721b6a59..7d488a73 100644 --- a/tests/test_param_loading.py +++ b/tests/test_param_loading.py @@ -1,7 +1,8 @@ import os import sys +from os.path import dirname as up -ROOT_DIR = os.path.dirname(os.getcwd()) +ROOT_DIR = up(up(up(os.path.realpath(__file__)))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import DeepSparseCoding.utils.loaders as loaders @@ -12,4 +13,4 @@ def test_param_loading(): for params_name in params_list: if 'test_' not in params_name: params_file = os.path.join(*[dsc_dir, 'params', params_name+'.py']) - params = loaders.load_params(params_file, key='params') + params = loaders.load_params_file(params_file, key='params') diff --git a/tf1x/additional_requirements.txt b/tf1x/additional_requirements.txt new file mode 100644 index 00000000..8e5be756 --- /dev/null +++ b/tf1x/additional_requirements.txt @@ -0,0 +1,5 @@ +tensorflow-gpu==1.15.2 +tensorflow-estimator==1.15.1 +tensorboard==1.15 +tensorflow-probability==0.8.0 +tensorflow-compression \ No newline at end of file diff --git a/tf1x/analysis/iso_response_analysis.py b/tf1x/analysis/iso_response_analysis.py index 754dd044..21cbd701 100644 --- a/tf1x/analysis/iso_response_analysis.py +++ b/tf1x/analysis/iso_response_analysis.py @@ -219,29 +219,32 @@ def __init__(self): cont_analysis['min_angle'] = 15 cont_analysis['batch_size'] = 100 cont_analysis['vh_image_scale'] = 31.773287 # Mean of the l2 norm of the training set - cont_analysis['comparison_method'] = 'closest' # rand or closest - + cont_analysis['comparison_method'] = 'rand' # rand or closest + cont_analysis['measure_upper_right'] = False + cont_analysis['bounds'] = ((-1, 1), (-1, 1)) + cont_analysis['target_act'] = 0.5 cont_analysis['num_neurons'] = 100 # How many neurons to plot cont_analysis['num_comparisons'] = 300 # How many planes to construct (None is all of them) cont_analysis['x_range'] = [-2.0, 2.0] cont_analysis['y_range'] = [-2.0, 2.0] cont_analysis['num_images'] = int(30**2) - cont_analysis['params_list'] = [lca_512_vh_params()] + cont_analysis['params_list'] = [lca_1024_vh_params(), lca_2560_vh_params()] #cont_analysis['params_list'] = [lca_768_vh_params()] #cont_analysis['params_list'] = [lca_1024_vh_params()] #cont_analysis['params_list'] = [lca_2560_vh_params()] #cont_analysis['iso_save_name'] = "iso_curvature_xrange1.3_yrange-2.2_" #cont_analysis['iso_save_name'] = "iso_curvature_ryan_" - cont_analysis['iso_save_name'] = "rescaled_closecomp_" + cont_analysis['iso_save_name'] = "newfits_rescaled_randomcomp_" #cont_analysis['iso_save_name'] = '' - np.savez(save_root+'iso_params_'+cont_analysis['iso_save_name']+params.save_info+".npz", - data=cont_analysis) analyzer_list = [load_analyzer(params) for params in cont_analysis['params_list']] for analyzer, params in zip(analyzer_list, cont_analysis['params_list']): + save_root=analyzer.analysis_out_dir+'savefiles/' + np.savez(save_root+'iso_params_'+cont_analysis['iso_save_name']+params.save_info+".npz", + data=cont_analysis) print(analyzer.analysis_params.display_name) print("Computing the iso-response vectors...") cont_analysis['target_neuron_ids'] = iso_data.get_rand_target_neuron_ids( @@ -297,7 +300,6 @@ def __init__(self): datapoints, get_dsc_activations_cell, activation_function_kwargs) - save_root=analyzer.analysis_out_dir+'savefiles/' if use_rand_orth_vects: np.savez(save_root+'iso_rand_activations_'+cont_analysis['iso_save_name']+params.save_info+'.npz', data=activations) @@ -310,10 +312,11 @@ def __init__(self): data=contour_dataset) cont_analysis['comparison_neuron_ids'] = analyzer.comparison_neuron_ids cont_analysis['contour_dataset'] = contour_dataset + cont_analysis['activations'] = activations curvatures, fits = hist_funcs.iso_response_curvature_poly_fits( cont_analysis['activations'], target_act=cont_analysis['target_act'], - measure_upper_right=False + bounds=cont_analysis['bounds'] ) cont_analysis['curvatures'] = np.stack(np.stack(curvatures, axis=0), axis=0) np.savez(save_root+'group_iso_vectors_'+cont_analysis['iso_save_name']+params.save_info+'.npz', diff --git a/tf1x/analyze_model.py b/tf1x/analyze_model.py index c8db1e28..a30975aa 100644 --- a/tf1x/analyze_model.py +++ b/tf1x/analyze_model.py @@ -1,13 +1,14 @@ import os import sys import argparse +from os.path import dirname as up + +ROOT_DIR = up(up(up(os.path.realpath(__file__)))) +if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np import tensorflow as tf -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) -if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) - from DeepSparseCoding.tf1x.utils.logger import Logger import DeepSparseCoding.tf1x.utils.data_processing as dp import DeepSparseCoding.tf1x.data.data_selector as ds diff --git a/tf1x/params/lca_conv_params.py b/tf1x/params/lca_conv_params.py index cedd0d96..55f492d2 100644 --- a/tf1x/params/lca_conv_params.py +++ b/tf1x/params/lca_conv_params.py @@ -57,19 +57,24 @@ def set_data_params(self, data_type): self.data_type = data_type if data_type.lower() == "mnist": self.model_name += "_mnist" + self.log_int = 200 self.rescale_data = True self.center_data = False self.whiten_data = False self.lpf_data = False # only for ZCA + self.num_val = 0 + self.batch_size = 50 self.lpf_cutoff = 0.7 - self.num_neurons = 768 + self.num_neurons = 128 + self.num_steps = 75 self.stride_y = 2 self.stride_x = 2 self.patch_size_y = 8 # weight receptive field self.patch_size_x = 8 for schedule_idx in range(len(self.schedule)): - self.schedule[schedule_idx]["sparse_mult"] = 0.21 - self.schedule[schedule_idx]["weight_lr"] = [0.1] + self.schedule[schedule_idx]["num_batches"] = int(6e5) + self.schedule[schedule_idx]["sparse_mult"] = 0.25 + self.schedule[schedule_idx]["weight_lr"] = [0.001] elif data_type.lower() == "vanhateren": self.model_name += "_vh" diff --git a/tf1x/tests/analysis/atas_test.py b/tf1x/tests/analysis/atas_test.py index af071a2b..061a12aa 100644 --- a/tf1x/tests/analysis/atas_test.py +++ b/tf1x/tests/analysis/atas_test.py @@ -1,7 +1,8 @@ import os import sys +from os.path import dirname as up -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) +ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__)))))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np diff --git a/tf1x/tests/data/data_selector_test.py b/tf1x/tests/data/data_selector_test.py index 43864e26..76ec089b 100644 --- a/tf1x/tests/data/data_selector_test.py +++ b/tf1x/tests/data/data_selector_test.py @@ -1,7 +1,8 @@ import os import sys +from os.path import dirname as up -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) +ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__)))))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np diff --git a/tf1x/tests/models/build_test.py b/tf1x/tests/models/build_test.py index c2acca84..70ffb516 100644 --- a/tf1x/tests/models/build_test.py +++ b/tf1x/tests/models/build_test.py @@ -1,8 +1,9 @@ import copy import os import sys +from os.path import dirname as up -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) +ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__)))))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np diff --git a/tf1x/tests/models/comb_test.py b/tf1x/tests/models/comb_test.py index f54396f4..6d10e833 100644 --- a/tf1x/tests/models/comb_test.py +++ b/tf1x/tests/models/comb_test.py @@ -1,8 +1,9 @@ import copy import os import sys +from os.path import dirname as up -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) +ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__)))))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np diff --git a/tf1x/tests/models/run_test.py b/tf1x/tests/models/run_test.py index 20f379be..7342d63a 100644 --- a/tf1x/tests/models/run_test.py +++ b/tf1x/tests/models/run_test.py @@ -1,8 +1,9 @@ import copy import os import sys +from os.path import dirname as up -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) +ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__)))))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np diff --git a/tf1x/tests/utils/checkpoint_test.py b/tf1x/tests/utils/checkpoint_test.py index 4833d59e..bd7c5617 100644 --- a/tf1x/tests/utils/checkpoint_test.py +++ b/tf1x/tests/utils/checkpoint_test.py @@ -1,7 +1,8 @@ import os import sys +from os.path import dirname as up -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) +ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__)))))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np diff --git a/tf1x/tests/utils/contrast_normalize_test.py b/tf1x/tests/utils/contrast_normalize_test.py index 3ec003fe..32712405 100644 --- a/tf1x/tests/utils/contrast_normalize_test.py +++ b/tf1x/tests/utils/contrast_normalize_test.py @@ -1,7 +1,8 @@ import os import sys +from os.path import dirname as up -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) +ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__)))))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np diff --git a/tf1x/tests/utils/patches_test.py b/tf1x/tests/utils/patches_test.py index 4639a162..ad0e7c78 100644 --- a/tf1x/tests/utils/patches_test.py +++ b/tf1x/tests/utils/patches_test.py @@ -1,7 +1,8 @@ import os import sys +from os.path import dirname as up -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) +ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__)))))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np diff --git a/tf1x/tests/utils/reshape_data_test.py b/tf1x/tests/utils/reshape_data_test.py index ff6922ad..18b76eed 100644 --- a/tf1x/tests/utils/reshape_data_test.py +++ b/tf1x/tests/utils/reshape_data_test.py @@ -1,7 +1,8 @@ import os import sys +from os.path import dirname as up -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) +ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__)))))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np diff --git a/tf1x/tests/utils/standardize_data_test.py b/tf1x/tests/utils/standardize_data_test.py index 524d3f18..e918bd1a 100644 --- a/tf1x/tests/utils/standardize_data_test.py +++ b/tf1x/tests/utils/standardize_data_test.py @@ -1,7 +1,8 @@ import os import sys +from os.path import dirname as up -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) +ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__)))))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np diff --git a/tf1x/train_model.py b/tf1x/train_model.py index a11ff070..bc2cd08c 100644 --- a/tf1x/train_model.py +++ b/tf1x/train_model.py @@ -2,15 +2,16 @@ import sys import time as ti import argparse +from os.path import dirname as up + +ROOT_DIR = up(up(up(os.path.realpath(__file__)))) +if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import matplotlib matplotlib.use("Agg") import numpy as np import tensorflow as tf -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) -if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) - import DeepSparseCoding.tf1x.params.param_picker as pp import DeepSparseCoding.tf1x.models.model_picker as mp import DeepSparseCoding.tf1x.data.data_selector as ds diff --git a/tf1x/utils/data_processing.py b/tf1x/utils/data_processing.py index 1fd62606..999a8dcf 100644 --- a/tf1x/utils/data_processing.py +++ b/tf1x/utils/data_processing.py @@ -88,7 +88,7 @@ def reshape_data(data, flatten=None, out_shape=None): if flatten == True: data = np.reshape(data, (num_examples, num_rows*num_cols*num_channels)) else: - assert False, ("Data must have 1, 2, 3, or 4 dimensions.") + assert False, (f'Data must have 1, 2, 3, or 4 dimensions, not {orig_ndim}') else: num_examples = None; num_rows=None; num_cols=None; num_channels=None data = np.reshape(data, out_shape) @@ -284,7 +284,7 @@ def generate_grating(patch_edge_size, location, diameter, orientation, frequency """ vals = np.linspace(-np.pi, np.pi, patch_edge_size) X, Y = np.meshgrid(vals, vals) - Xr = np.cos(orientation)*X + -np.sin(orientation)*Y # countercloclwise + Xr = np.cos(orientation)*X + -np.sin(orientation)*Y # counterclockwise Yr = np.sin(orientation)*X + np.cos(orientation)*Y stim = contrast*np.sin(Yr*frequency+phase) if diameter > 0: # Generate mask @@ -958,13 +958,13 @@ def pca_reduction(data, num_pcs=-1): data_mean = data.mean(axis=(1))[:,None] data -= data_mean Cov = np.cov(data.T) # Covariace matrix - U, S, V = np.linalg.svd(Cov) # SVD decomposition + U, S, VT = np.linalg.svd(Cov) # SVD decomposition diagS = np.diag(S) if num_pcs <= 0: n = num_rows else: n = num_pcs - data_reduc = np.dot(data, np.dot(np.dot(U[:, :n], diagS[:n, :n]), V[:n, :])) + data_reduc = np.dot(data, np.dot(np.dot(U[:, :n], diagS[:n, :n]), VT[:n, :])) return data_reduc def compute_power_spectrum(data): diff --git a/tf1x/utils/jov_funcs.py b/tf1x/utils/jov_funcs.py index be46539e..5dd8f78e 100644 --- a/tf1x/utils/jov_funcs.py +++ b/tf1x/utils/jov_funcs.py @@ -119,10 +119,6 @@ def plot_iso_contour_set(analysis_dict, curvatures, num_levels, num_x, num_y, sh width=arrow_width, head_width=arrow_head_width, head_length=arrow_head_length, fc='k', ec='k', linestyle='-', linewidth=arrow_linewidth) tenth_range_shift = ((max(analysis_dict['x_range']) - min(analysis_dict['x_range']))/10) # For shifting labels - #text_handle = curve_axes[-1].text( - # target_vector_x+(tenth_range_shift*phi_k_text_x_offset), - # target_vector_y+(tenth_range_shift*phi_k_text_y_offset), - # r'$\Phi_{k}$', horizontalalignment='center', verticalalignment='center') # plot comparison neuron arrow & label proj_comparison = analysis_dict['contour_dataset']['proj_comparison_vect'][neuron_index][orth_index] comparison_vector_x = proj_comparison[0].item() @@ -130,10 +126,6 @@ def plot_iso_contour_set(analysis_dict, curvatures, num_levels, num_x, num_y, sh curve_axes[-1].arrow(0, 0, comparison_vector_x, comparison_vector_y, width=arrow_width, head_width=arrow_head_width, head_length=arrow_head_length, fc='k', ec='k', linestyle='-', linewidth=arrow_linewidth) - #text_handle = curve_axes[-1].text( - # comparison_vector_x+(tenth_range_shift*phi_j_text_x_offset), - # comparison_vector_y+(tenth_range_shift*phi_j_text_y_offset), - # r'$\Phi_{j}$', horizontalalignment='center', verticalalignment='center') # Plot orthogonal vector Nu proj_orth = analysis_dict['contour_dataset']['proj_orth_vect'][neuron_index][orth_index] orth_vector_x = proj_orth[0].item() @@ -141,10 +133,6 @@ def plot_iso_contour_set(analysis_dict, curvatures, num_levels, num_x, num_y, sh curve_axes[-1].arrow(0, 0, orth_vector_x, orth_vector_y, width=arrow_width, head_width=arrow_head_width, head_length=arrow_head_length, fc='k', ec='k', linestyle='-', linewidth=arrow_linewidth) - #text_handle = curve_axes[-1].text( - # orth_vector_x+(tenth_range_shift*nu_text_x_offset), - # orth_vector_y+(tenth_range_shift*nu_text_y_offset), - # r'$\nu$', horizontalalignment='center', verticalalignment='center') # Plot axes curve_axes[-1].set_aspect('equal') curve_axes[-1].plot(analysis_dict['x_range'], [0,0], color='k', linewidth=arrow_linewidth/2) @@ -153,14 +141,11 @@ def plot_iso_contour_set(analysis_dict, curvatures, num_levels, num_x, num_y, sh k_idx = analysis_dict["target_neuron_ids"][neuron_index] curv_val = curvatures[y_id, x_id] curve_axes[-1].set_title( - #f'k={k_idx}; j={j_idx}\nC={curv_val:.3f}', f'C={curv_val:.3f}', fontsize=rcParams['axes.titlesize']/2, pad=2, horizontalalignment='center' ) - #curve_axes[-1].text(x=-0.08, y=1.75, s=f'C={curvatures[y_id, x_id]:.3f}', - # horizontalalignment='right', verticalalignment='center', fontsize=6) if y_id==0: curve_axes[-1].set_ylabel(str(neuron_index), visible=True) # Add colorbar @@ -180,6 +165,123 @@ def plot_iso_contour_set(analysis_dict, curvatures, num_levels, num_x, num_y, sh return fig, contour_handles +def plot_curvature_histograms(activity, contour_pts, contour_angle, view_elevation, contour_text_loc, hist_list, + label_list, color_list, mesh_color, bin_centers, title, xlabel, curve_lims, + scatter, log=True, text_width=200, width_ratio=1.0, dpi=100): + gs0_wspace = 0.5 + hspace_hist = 0.7 + wspace_hist = 0.08 + iso_response_line_thickness = 2 + respone_attenuation_line_thickness = 2 + num_y_plots = 2 if log else 1 + num_x_plots = 1 + fig = plt.figure(figsize=set_size(text_width, width_ratio, [num_y_plots, num_x_plots]), dpi=dpi) + gs_base = gridspec.GridSpec(num_y_plots, num_x_plots, wspace=gs0_wspace) + if log: + curve_ax = fig.add_subplot(gs_base[0], projection='3d') + curve_ax.minorticks_off() + x_mesh, y_mesh = np.meshgrid(*contour_pts) + curve_ax.set_zlim(0, 1) + curve_ax.set_xlim3d(5, 200) + curve_ax.grid(False) + curve_ax.set_xticklabels([]) + curve_ax.set_yticklabels([]) + curve_ax.set_zticklabels([]) + curve_ax.zaxis.set_rotate_label(False) + if scatter: + curve_ax.scatter(x_mesh, y_mesh, activity, color=mesh_color, s=0.01) + else: + curve_ax.plot_wireframe(x_mesh, y_mesh, activity, rcount=100, ccount=100, color=mesh_color, zorder=1, + linestyles='dotted', linewidths=0.3, alpha=1.0) + # Plane vector visualizations + v = Arrow3D([-200/3., -200/3.], [200/2., 200/2.+200/16.], + [0, 0.0], mutation_scale=10, + lw=0.5, arrowstyle='-|>', color='red', linestyle='dashed') + curve_ax.add_artist(v) + curve_ax.text(-300/3., 280/3.0, 0.0, r'$\nu$', color='red') + phi_k = Arrow3D([-200/3., 0.], [200/2., 200/2.], + [0, 0.0], mutation_scale=10, + lw=1, arrowstyle='-|>', color='red', linestyle = 'dashed') + curve_ax.add_artist(phi_k) + curve_ax.text(-175/3., 250/3.0, 0.0, r'${\phi}_{k}$', color='red') + # Iso-response curve + loc0, loc1, loc2 = contour_text_loc[0] + curve_ax.text(loc0, loc1, loc2, 'Iso-\nresponse', color='black', weight='bold', zorder=10) + lines = np.array([0.2, 0.203, 0.197]) - 0.1 + for i in lines: + curve_ax.contour3D(x_mesh, y_mesh, activity, [i], colors='black', linewidths=2, zorder=2) + # Response attenuation curve + loc0, loc1, loc2 = contour_text_loc[1] + curve_ax.text(loc0, loc1, loc2, 'Response\nAttenuation', color='black', weight='bold', zorder=10) + att_line_offset = 165 + x, y = contour_pts + curve_ax.plot(np.zeros_like(x)+att_line_offset, y, activity[:, att_line_offset], + color='black', lw=2, zorder=2) + # Activity label + #loc0, loc1, loc2 = contour_text_loc[2] + #curve_ax.text(loc0, loc1, loc2, 'Activity', color='black', weight='bold', zorder=10, zdir='z') + # Additional settings + curve_ax.view_init(view_elevation, contour_angle) + scaling = np.array([getattr(curve_ax, 'get_{}lim'.format(dim))() for dim in 'xyz']) + curve_ax.auto_scale_xyz(*[[np.min(scaling), np.max(scaling)]]*3) # square aspect + curve_ax._axis3don = False + gs_base_idx = 1 if log else 0 + # Histogram plots + num_hist_y_plots = 2 + num_hist_x_plots = 2 + gs_hist = gridspec.GridSpecFromSubplotSpec(num_hist_y_plots, num_hist_x_plots, gs_base[gs_base_idx], + hspace=hspace_hist, wspace=wspace_hist) + orig_ax = fig.add_subplot(gs_hist[0,0]) + axes = [] + for sub_plt_y in range(0, num_hist_y_plots): + axes.append([]) + for sub_plt_x in range(0, num_hist_x_plots): + if (sub_plt_x, sub_plt_y) == (0,0): + axes[sub_plt_y].append(orig_ax) + else: + axes[sub_plt_y].append(fig.add_subplot(gs_hist[sub_plt_y, sub_plt_x], sharey=orig_ax)) + all_x_lists = zip(hist_list, label_list, color_list, bin_centers, title) + for axis_x, (curvature_hist, sub_label, sub_color, sub_bins, sub_title) in enumerate(all_x_lists): + sub_bins = np.squeeze(sub_bins) + all_y_lists = zip(curvature_hist, sub_label, sub_color, xlabel) + for axis_y, (dataset_hist, axis_labels, axis_colors, sub_xlabel) in enumerate(all_y_lists): + axes[axis_y][axis_x].spines['top'].set_visible(False) + axes[axis_y][axis_x].spines['right'].set_visible(False) + axes[axis_y][axis_x].set_xticks(sub_bins, minor=True) + axes[axis_y][axis_x].set_xticks(sub_bins[::int(len(sub_bins)/4)], minor=False) + axes[axis_y][axis_x].xaxis.set_major_formatter(plticker.FormatStrFormatter('%0.3f')) + for neuron_hist, label, color in zip(dataset_hist, axis_labels, axis_colors): + neuron_hist = np.squeeze(neuron_hist) + if log: + axes[axis_y][axis_x].semilogy(sub_bins, neuron_hist, color=color, linestyle='-', + drawstyle='steps-mid', label=label) + axes[axis_y][axis_x].yaxis.set_major_formatter(matplotlib.ticker.LogFormatterSciNotation()) + else: + axes[axis_y][axis_x].plot(sub_bins, neuron_hist, color=color, linestyle='-', drawstyle='steps-mid', label=label) + axes[axis_y][axis_x].axvline(0.0, color='black', linestyle='dashed', linewidth=1) + if axis_y == 0: + axes[axis_y][axis_x].set_title(sub_title) + axes[axis_y][axis_x].set_xlabel(sub_xlabel) + if axis_x == 0: + if log: + axes[axis_y][axis_x].set_ylabel('Relative\nLog Frequency') + else: + axes[axis_y][axis_x].set_ylabel('Relative\nFrequency') + ax_handles, ax_labels = axes[axis_y][axis_x].get_legend_handles_labels() + legend = axes[axis_y][axis_x].legend(handles=ax_handles, labels=ax_labels, loc='upper right', + ncol=3, borderaxespad=0., borderpad=0., handlelength=0., columnspacing=-0.5, + labelspacing=0., bbox_to_anchor=(0.95, 0.95)) + legend.get_frame().set_linewidth(0.0) + for text, color in zip(legend.get_texts(), axis_colors): + text.set_color(color) + for item in legend.legendHandles: + item.set_visible(False) + if axis_x == 1: + axes[axis_y][axis_x].tick_params(axis='y', labelleft=False) + plt.show() + return fig + + def plot_group_iso_contours(analyzer_list, neuron_indices, orth_indices, num_levels, x_range, y_range, show_contours=True, curvature=None, text_width=200, width_fraction=1.0, dpi=100): arrow_width = 0.0 @@ -421,140 +523,142 @@ def compute_curvature_hists(analyzer_list, num_bins): rand_hist, _ = np.histogram(flat_rand_curvatures, attn_bins, density=False) analyzer.attn_rand_hist = rand_hist / len(flat_rand_curvatures) -def plot_curvature_histograms(activity, contour_pts, contour_angle, contour_text_loc, hist_list, label_list, - color_list, mesh_color, bin_centers, title, xlabel, curve_lims, log=True, - text_width=200, width_ratio=1.0, dpi=100): - gs0_wspace = 0.5 - hspace_hist = 0.7 - wspace_hist = 0.08 - view_elevation = 30 - iso_response_line_thickness = 2 - respone_attenuation_line_thickness = 2 - num_y_plots = 2 - num_x_plots = 1 - fig = plt.figure(figsize=set_size(text_width, width_ratio, [num_y_plots, num_x_plots]), dpi=dpi) - gs_base = gridspec.GridSpec(num_y_plots, num_x_plots, wspace=gs0_wspace) - curve_ax = fig.add_subplot(gs_base[0], projection='3d') - x_mesh, y_mesh = np.meshgrid(*contour_pts) - curve_ax.set_zlim(0, 1) - curve_ax.set_xlim3d(5, 200) - curve_ax.grid(b=False, zorder=0) - x_ticks = curve_ax.get_xticks().tolist() - x_ticks = np.round(np.linspace(curve_lims['x'][0], curve_lims['x'][1], - len(x_ticks)), 1).astype(str) - a_x = [' ']*len(x_ticks) - a_x[1] = x_ticks[1] - a_x[-1] = x_ticks[-1] - curve_ax.set_xticklabels(a_x) - y_ticks = curve_ax.get_yticks().tolist() - y_ticks = np.round(np.linspace(curve_lims['y'][0], curve_lims['y'][1], - len(y_ticks)), 1).astype(str) - a_y = [' ']*len(y_ticks) - a_y[1] = y_ticks[1] - a_y[-1] = y_ticks[-1] - curve_ax.set_yticklabels(a_y) - curve_ax.set_zticklabels([]) - curve_ax.zaxis.set_rotate_label(False) - curve_ax.set_zlabel('Normalized\nActivity', rotation=95, labelpad=-12., position=(-10., 0.)) - #curve_ax.scatter(x_mesh, y_mesh, activity, color=mesh_color, s=0.01) - curve_ax.plot_wireframe(x_mesh, y_mesh, activity, rcount=100, ccount=100, color=mesh_color, zorder=1, - linestyles='dotted', linewidths=0.5, alpha=1.0) - # Plane vector visualizations - v = Arrow3D([-200/3., -200/3.], [200/2., 200/2.+200/16.], - [0, 0.0], mutation_scale=10, - lw=1, arrowstyle='-|>', color='red', linestyle='dashed') - curve_ax.add_artist(v) - curve_ax.text(-300/3., 280/3.0, 0.0, r'$\nu$', color='red') - phi_k = Arrow3D([-200/3., 0.], [200/2., 200/2.], - [0, 0.0], mutation_scale=10, - lw=1, arrowstyle='-|>', color='red', linestyle = 'dashed') - curve_ax.add_artist(phi_k) - curve_ax.text(-175/3., 250/3.0, 0.0, r'${\phi}_{k}$', color='red') - # Iso-response curve - loc0, loc1, loc2 = contour_text_loc[0] - curve_ax.text(loc0, loc1, loc2, 'Iso-\nresponse', color='black', weight='bold', zorder=10) - iso_line_offset = 165 - x, y = contour_pts - curve_ax.plot(np.zeros_like(x)+iso_line_offset, y, activity[:, iso_line_offset], - color='black', lw=2, zorder=2) - # Response attenuation curve - loc0, loc1, loc2 = contour_text_loc[1] - curve_ax.text(loc0, loc1, loc2, 'Response\nAttenuation', color='black', weight='bold', zorder=10) - lines = np.array([0.2, 0.203, 0.197]) - 0.1 - for i in lines: - curve_ax.contour3D(x_mesh, y_mesh, activity, [i], colors='black', linewidths=2, zorder=2) - # Additional settings - curve_ax.view_init(view_elevation, contour_angle) - scaling = np.array([getattr(curve_ax, 'get_{}lim'.format(dim))() for dim in 'xyz']) - curve_ax.auto_scale_xyz(*[[np.min(scaling), np.max(scaling)]]*3) # make sure it has a square aspect - num_hist_y_plots = 2 - num_hist_x_plots = 2 - gs_hist = gridspec.GridSpecFromSubplotSpec(num_hist_y_plots, num_hist_x_plots, gs_base[1], - hspace=hspace_hist, wspace=wspace_hist) - orig_ax = fig.add_subplot(gs_hist[0,0]) - axes = [] - for sub_plt_y in range(0, num_hist_y_plots): - axes.append([]) - for sub_plt_x in range(0, num_hist_x_plots): - if (sub_plt_x, sub_plt_y) == (0,0): - axes[sub_plt_y].append(orig_ax) - else: - axes[sub_plt_y].append(fig.add_subplot(gs_hist[sub_plt_y, sub_plt_x], sharey=orig_ax)) - #[curvature type] [iso/att] - #[dataset type] [comp/rand] - #[target neuron id] - all_x_lists = zip(hist_list, label_list, color_list, bin_centers, title) - for axis_x, (curvature_hist, sub_label, sub_color, sub_bins, sub_title) in enumerate(all_x_lists): - sub_bins = np.squeeze(sub_bins) - #max_hist_val = 0.001 - #min_hist_val = 100 - all_y_lists = zip(curvature_hist, sub_label, sub_color, xlabel) - for axis_y, (dataset_hist, axis_labels, axis_colors, sub_xlabel) in enumerate(all_y_lists): - axes[axis_y][axis_x].spines['top'].set_visible(False) - axes[axis_y][axis_x].spines['right'].set_visible(False) - axes[axis_y][axis_x].set_xticks(sub_bins, minor=True) - axes[axis_y][axis_x].set_xticks(sub_bins[::int(len(sub_bins)/4)], minor=False) - axes[axis_y][axis_x].xaxis.set_major_formatter(plticker.FormatStrFormatter('%0.3f')) - for neuron_hist, label, color in zip(dataset_hist, axis_labels, axis_colors): - neuron_hist = np.squeeze(neuron_hist) - plot_hist = [] - - if log: - axes[axis_y][axis_x].semilogy(sub_bins, neuron_hist, color=color, linestyle='-', drawstyle='steps-mid', label=label) - #axes[axis_y][axis_x].set_yscale('log') - axes[axis_y][axis_x].yaxis.set_major_formatter( - ticker.FuncFormatter( - lambda y,pos: ('{{:.{:1d}f}}'.format(int(np.maximum(-np.log10(y),0)))).format(y) - ) - ) - else: - axes[axis_y][axis_x].plot(sub_bins, neuron_hist, color=color, linestyle='-', drawstyle='steps-mid', label=label) - #if np.max(hist) > max_hist_val: - # max_hist_val = np.max(hist) - #if np.min(hist) < min_hist_val: - # min_hist_val = np.min(hist) - axes[axis_y][axis_x].axvline(0.0, color='black', linestyle='dashed', linewidth=1) - if axis_y == 0: - axes[axis_y][axis_x].set_title(sub_title) - axes[axis_y][axis_x].set_xlabel(sub_xlabel) - if axis_x == 0: - if log: - axes[axis_y][axis_x].set_ylabel('Relative\nLog Frequency') - else: - axes[axis_y][axis_x].set_ylabel('Relative\nFrequency') - ax_handles, ax_labels = axes[axis_y][axis_x].get_legend_handles_labels() - legend = axes[axis_y][axis_x].legend(handles=ax_handles, labels=ax_labels, loc='upper right', - ncol=3, borderaxespad=0., borderpad=0., handlelength=0., columnspacing=-0.5, - labelspacing=0., bbox_to_anchor=(0.95, 0.95)) - legend.get_frame().set_linewidth(0.0) - for text, color in zip(legend.get_texts(), axis_colors): - text.set_color(color) - for item in legend.legendHandles: - item.set_visible(False) - if axis_x == 1: - axes[axis_y][axis_x].tick_params(axis='y', labelleft=False) - plt.show() - return fig + +#def plot_curvature_histograms(activity, contour_pts, contour_angle, contour_text_loc, hist_list, label_list, +# color_list, mesh_color, bin_centers, title, xlabel, curve_lims, log=True, +# text_width=200, width_ratio=1.0, dpi=100): +# gs0_wspace = 0.5 +# hspace_hist = 0.7 +# wspace_hist = 0.08 +# view_elevation = 30 +# iso_response_line_thickness = 2 +# respone_attenuation_line_thickness = 2 +# num_y_plots = 2 +# num_x_plots = 1 +# fig = plt.figure(figsize=set_size(text_width, width_ratio, [num_y_plots, num_x_plots]), dpi=dpi) +# gs_base = gridspec.GridSpec(num_y_plots, num_x_plots, wspace=gs0_wspace) +# curve_ax = fig.add_subplot(gs_base[0], projection='3d') +# x_mesh, y_mesh = np.meshgrid(*contour_pts) +# curve_ax.set_zlim(0, 1) +# curve_ax.set_xlim3d(5, 200) +# curve_ax.grid(b=False, zorder=0) +# x_ticks = curve_ax.get_xticks().tolist() +# x_ticks = np.round(np.linspace(curve_lims['x'][0], curve_lims['x'][1], +# len(x_ticks)), 1).astype(str) +# a_x = [' ']*len(x_ticks) +# a_x[1] = x_ticks[1] +# a_x[-1] = x_ticks[-1] +# curve_ax.set_xticklabels(a_x) +# y_ticks = curve_ax.get_yticks().tolist() +# y_ticks = np.round(np.linspace(curve_lims['y'][0], curve_lims['y'][1], +# len(y_ticks)), 1).astype(str) +# a_y = [' ']*len(y_ticks) +# a_y[1] = y_ticks[1] +# a_y[-1] = y_ticks[-1] +# curve_ax.set_yticklabels(a_y) +# curve_ax.set_zticklabels([]) +# curve_ax.zaxis.set_rotate_label(False) +# curve_ax.set_zlabel('Normalized\nActivity', rotation=95, labelpad=-12., position=(-10., 0.)) +# #curve_ax.scatter(x_mesh, y_mesh, activity, color=mesh_color, s=0.01) +# curve_ax.plot_wireframe(x_mesh, y_mesh, activity, rcount=100, ccount=100, color=mesh_color, zorder=1, +# linestyles='dotted', linewidths=0.5, alpha=1.0) +# # Plane vector visualizations +# v = Arrow3D([-200/3., -200/3.], [200/2., 200/2.+200/16.], +# [0, 0.0], mutation_scale=10, +# lw=1, arrowstyle='-|>', color='red', linestyle='dashed') +# curve_ax.add_artist(v) +# curve_ax.text(-300/3., 280/3.0, 0.0, r'$\nu$', color='red') +# phi_k = Arrow3D([-200/3., 0.], [200/2., 200/2.], +# [0, 0.0], mutation_scale=10, +# lw=1, arrowstyle='-|>', color='red', linestyle = 'dashed') +# curve_ax.add_artist(phi_k) +# curve_ax.text(-175/3., 250/3.0, 0.0, r'${\phi}_{k}$', color='red') +# # Iso-response curve +# loc0, loc1, loc2 = contour_text_loc[0] +# curve_ax.text(loc0, loc1, loc2, 'Iso-\nresponse', color='black', weight='bold', zorder=10) +# iso_line_offset = 165 +# x, y = contour_pts +# curve_ax.plot(np.zeros_like(x)+iso_line_offset, y, activity[:, iso_line_offset], +# color='black', lw=2, zorder=2) +# # Response attenuation curve +# loc0, loc1, loc2 = contour_text_loc[1] +# curve_ax.text(loc0, loc1, loc2, 'Response\nAttenuation', color='black', weight='bold', zorder=10) +# lines = np.array([0.2, 0.203, 0.197]) - 0.1 +# for i in lines: +# curve_ax.contour3D(x_mesh, y_mesh, activity, [i], colors='black', linewidths=2, zorder=2) +# # Additional settings +# curve_ax.view_init(view_elevation, contour_angle) +# scaling = np.array([getattr(curve_ax, 'get_{}lim'.format(dim))() for dim in 'xyz']) +# curve_ax.auto_scale_xyz(*[[np.min(scaling), np.max(scaling)]]*3) # make sure it has a square aspect +# num_hist_y_plots = 2 +# num_hist_x_plots = 2 +# gs_hist = gridspec.GridSpecFromSubplotSpec(num_hist_y_plots, num_hist_x_plots, gs_base[1], +# hspace=hspace_hist, wspace=wspace_hist) +# orig_ax = fig.add_subplot(gs_hist[0,0]) +# axes = [] +# for sub_plt_y in range(0, num_hist_y_plots): +# axes.append([]) +# for sub_plt_x in range(0, num_hist_x_plots): +# if (sub_plt_x, sub_plt_y) == (0,0): +# axes[sub_plt_y].append(orig_ax) +# else: +# axes[sub_plt_y].append(fig.add_subplot(gs_hist[sub_plt_y, sub_plt_x], sharey=orig_ax)) +# #[curvature type] [iso/att] +# #[dataset type] [comp/rand] +# #[target neuron id] +# all_x_lists = zip(hist_list, label_list, color_list, bin_centers, title) +# for axis_x, (curvature_hist, sub_label, sub_color, sub_bins, sub_title) in enumerate(all_x_lists): +# sub_bins = np.squeeze(sub_bins) +# #max_hist_val = 0.001 +# #min_hist_val = 100 +# all_y_lists = zip(curvature_hist, sub_label, sub_color, xlabel) +# for axis_y, (dataset_hist, axis_labels, axis_colors, sub_xlabel) in enumerate(all_y_lists): +# axes[axis_y][axis_x].spines['top'].set_visible(False) +# axes[axis_y][axis_x].spines['right'].set_visible(False) +# axes[axis_y][axis_x].set_xticks(sub_bins, minor=True) +# axes[axis_y][axis_x].set_xticks(sub_bins[::int(len(sub_bins)/4)], minor=False) +# axes[axis_y][axis_x].xaxis.set_major_formatter(plticker.FormatStrFormatter('%0.3f')) +# for neuron_hist, label, color in zip(dataset_hist, axis_labels, axis_colors): +# neuron_hist = np.squeeze(neuron_hist) +# plot_hist = [] +# +# if log: +# axes[axis_y][axis_x].semilogy(sub_bins, neuron_hist, color=color, linestyle='-', drawstyle='steps-mid', label=label) +# #axes[axis_y][axis_x].set_yscale('log') +# axes[axis_y][axis_x].yaxis.set_major_formatter( +# ticker.FuncFormatter( +# lambda y,pos: ('{{:.{:1d}f}}'.format(int(np.maximum(-np.log10(y),0)))).format(y) +# ) +# ) +# else: +# axes[axis_y][axis_x].plot(sub_bins, neuron_hist, color=color, linestyle='-', drawstyle='steps-mid', label=label) +# #if np.max(hist) > max_hist_val: +# # max_hist_val = np.max(hist) +# #if np.min(hist) < min_hist_val: +# # min_hist_val = np.min(hist) +# axes[axis_y][axis_x].axvline(0.0, color='black', linestyle='dashed', linewidth=1) +# if axis_y == 0: +# axes[axis_y][axis_x].set_title(sub_title) +# axes[axis_y][axis_x].set_xlabel(sub_xlabel) +# if axis_x == 0: +# if log: +# axes[axis_y][axis_x].set_ylabel('Relative\nLog Frequency') +# else: +# axes[axis_y][axis_x].set_ylabel('Relative\nFrequency') +# ax_handles, ax_labels = axes[axis_y][axis_x].get_legend_handles_labels() +# legend = axes[axis_y][axis_x].legend(handles=ax_handles, labels=ax_labels, loc='upper right', +# ncol=3, borderaxespad=0., borderpad=0., handlelength=0., columnspacing=-0.5, +# labelspacing=0., bbox_to_anchor=(0.95, 0.95)) +# legend.get_frame().set_linewidth(0.0) +# for text, color in zip(legend.get_texts(), axis_colors): +# text.set_color(color) +# for item in legend.legendHandles: +# item.set_visible(False) +# if axis_x == 1: +# axes[axis_y][axis_x].tick_params(axis='y', labelleft=False) +# plt.show() +# return fig + def plot_contrast_orientation_tuning(bf_indices, contrasts, orientations, activations, figsize=(32,32)): ''' diff --git a/tf1x/utils/logger.py b/tf1x/utils/logger.py index c527e488..edfb404e 100644 --- a/tf1x/utils/logger.py +++ b/tf1x/utils/logger.py @@ -90,7 +90,13 @@ def read_js(self, tokens, text): assert len(tokens) == 2, ("Input variable tokens must be a list of length 2") matches = re.findall(re.escape(tokens[0])+"([\s\S]*?)"+re.escape(tokens[1]), text) if len(matches) > 1: - js_matches = [js.loads(match) for match in matches] + js_matches = [] + for match_idx, match in enumerate(matches): + try: + js_matches.append(js.loads(match)) + except: + print(f'ERROR: JSON load failed on match index {match_idx}') + import IPython; IPython.embed(); raise SystemExit else: js_matches = [js.loads(matches[0])] return js_matches diff --git a/tf1x/vis/JOV_Euler_Attacks.ipynb b/tf1x/vis/JOV_Euler_Attacks.ipynb index efef9419..60a7d42c 100644 --- a/tf1x/vis/JOV_Euler_Attacks.ipynb +++ b/tf1x/vis/JOV_Euler_Attacks.ipynb @@ -32,11 +32,13 @@ "outputs": [], "source": [ "import autograd.numpy as np\n", + "import matplotlib as mpl\n", "import matplotlib.pyplot as plt\n", "\n", "root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd())))\n", "if root_path not in sys.path: sys.path.append(root_path)\n", "\n", + "import DeepSparseCoding.tf1x.utils.jov_funcs as jov\n", "import DeepSparseCoding.tf1x.analysis.analysis_picker as ap\n", "from DeepSparseCoding.tf1x.data.dataset import Dataset\n", "import DeepSparseCoding.tf1x.utils.data_processing as dp" @@ -65,7 +67,8 @@ " import schematic_utils\n", "except ImportError:\n", " import sys\n", - " sys.path.append(\"../schematic_figure/\")\n", + " usr = os.path.expanduser('~')\n", + " sys.path.append(usr+'/Work/DeepSparseCoding/tf1x/')\n", " import schematic_utils" ] }, @@ -75,9 +78,30 @@ "metadata": {}, "outputs": [], "source": [ - "figsize = (16, 16)\n", + "text_width = 540.60236 #pt 416.83269 #pt = 14.65cm\n", + "text_width_cm = 18.9973 # 14.705\n", + "fontsize = 10\n", + "dpi = 300\n", + "file_extensions = ['.pdf']#, '.eps', '.png']\n", + "#figsize = (16, 16)\n", + "num_y_plots = 3\n", + "num_x_plots = 1\n", + "width_ratio = 1.0\n", + "figsize = jov.set_size(text_width, width_ratio, [num_y_plots, num_x_plots])\n", "fontsize = 20\n", - "dpi = 200" + "font_settings = {\n", + " \"text.usetex\": True,\n", + " \"font.family\": 'serif',\n", + " \"font.serif\": 'Computer Modern Roman',\n", + " \"axes.labelsize\": fontsize,\n", + " \"axes.titlesize\": fontsize,\n", + " \"figure.titlesize\": fontsize+2,\n", + " \"font.size\": fontsize,\n", + " \"legend.fontsize\": fontsize,\n", + " \"xtick.labelsize\": fontsize-2,\n", + " \"ytick.labelsize\": fontsize-2,\n", + "}\n", + "mpl.rcParams.update(font_settings)" ] }, { @@ -364,7 +388,8 @@ "metadata": {}, "outputs": [], "source": [ - "f = plt.figure(figsize=(2*figsize[0],figsize[1]), dpi=dpi)\n", + "figsize = (2*16,16)#(figsize[0], figsize[1])\n", + "f = plt.figure(figsize=figsize, dpi=dpi)\n", "fig_shape = (1, 4)\n", "\n", "mlp_ax = plt.subplot2grid(fig_shape, loc=(0, 0), colspan=1, fig=f)\n", @@ -404,11 +429,57 @@ "metadata": {}, "outputs": [], "source": [ - "for ext in [\".png\", \".eps\"]:\n", - " save_name = (analyzer.analysis_out_dir+\"/vis/contours_and_gradients_schematic\"\n", + "for ext in file_extensions:#[\".png\", \".eps\"]:\n", + " save_name = (analyzer.analysis_out_dir+\"/vis/contours_and_gradients_schematic_new\"\n", " +\"_\"+analyzer.analysis_params.save_info+ext)\n", " f.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.01, dpi=dpi)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig_size_inches = f.get_size_inches()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig_size_inches" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "f.set_size_inches(fig_size_inches/2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for ext in file_extensions:#[\".png\", \".eps\"]:\n", + " save_name = (analyzer.analysis_out_dir+\"/vis/contours_and_gradients_schematic_small\"\n", + " +\"_\"+analyzer.analysis_params.save_info+ext)\n", + " f.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.01, dpi=dpi)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/tf1x/vis/JOV_figs.ipynb b/tf1x/vis/JOV_figs.ipynb index 7c92c4ca..76111b37 100644 --- a/tf1x/vis/JOV_figs.ipynb +++ b/tf1x/vis/JOV_figs.ipynb @@ -85,10 +85,16 @@ "metadata": {}, "outputs": [], "source": [ - "text_width = 416.83269 #pt = 14.65cm\n", - "text_width_cm = 14.705\n", - "fontsize = 12\n", - "dpi = 1200" + "\"\"\"\n", + "textwidth in pt: 540.60236pt\n", + "textwidth in cm: 18.9973cm\n", + "textwidth in in: 7.48178in\n", + "\"\"\"\n", + "text_width = 540.60236 #pt 416.83269 #pt = 14.65cm\n", + "text_width_cm = 18.9973 # 14.705\n", + "fontsize = 10\n", + "dpi = 300\n", + "file_extensions = ['.pdf']#, '.eps', '.png']" ] }, { @@ -99,10 +105,11 @@ "source": [ "font_settings = {\n", " \"text.usetex\": True,\n", - " \"font.family\": \"serif\",\n", + " \"font.family\": 'serif',\n", + " \"font.serif\": 'Computer Modern Roman',\n", " \"axes.labelsize\": fontsize,\n", " \"axes.titlesize\": fontsize,\n", - " \"figure.titlesize\": fontsize,\n", + " \"figure.titlesize\": fontsize+2,\n", " \"font.size\": fontsize,\n", " \"legend.fontsize\": fontsize,\n", " \"xtick.labelsize\": fontsize-2,\n", @@ -305,18 +312,8 @@ "for analyzer, save_name in zip(analyzer_list, save_names):\n", " analyzer.iso_params = np.load(analyzer.analysis_out_dir+'savefiles/iso_params_'+save_name\n", " +analyzer.analysis_params.save_info+'.npz', allow_pickle=True)['data'].item()\n", - " #min_angle = analyzer.iso_params['min_angle']\n", - " #batch_size = analyzer.iso_params['batch_size']\n", - " #vh_image_scale = analyzer.iso_params['vh_image_scale']\n", - " #comparison_method = analyzer.iso_params['comparison_method']\n", - " #num_neurons = analyzer.iso_params['num_neurons']\n", - " #analyzer.num_comparison_vectors = analyzer.iso_params['num_comparisons']\n", " x_range = analyzer.iso_params['x_range']\n", " y_range = analyzer.iso_params['y_range']\n", - " #num_images = analyzer.iso_params['num_images']\n", - " #params_list = analyzer.iso_params['params_list']\n", - " #iso_save_name = analyzer.iso_params['iso_save_name']\n", - " #target_neuron_ids = analyzer.iso_params['target_neuron_ids']\n", "\n", " iso_vectors = np.load(analyzer.analysis_out_dir+'savefiles/iso_vectors_'+save_name\n", " +analyzer.analysis_params.save_info+'.npz', allow_pickle=True)['data'].item()\n", @@ -340,6 +337,28 @@ " analyzer = add_analyzer_keys(analyzer)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "target_act = 0.5 # target activity spot between min & max value of normalized activity (btwn 0 and 1)\n", + "lca_activations = analyzer_list[-1].comp_activations\n", + "curvatures, fits, contours = ha.iso_response_curvature_poly_fits(\n", + " lca_activations,\n", + " target_act=target_act\n", + ")\n", + "max_comp_indices = []\n", + "max_vals = []\n", + "for target_neuron_id in range(len(curvatures)):\n", + " max_idx = np.argmax(curvatures[target_neuron_id])\n", + " max_comp_indices.append(max_idx)\n", + " max_vals.append(curvatures[target_neuron_id][max_idx])\n", + "max_target_id = np.argmax(max_vals)\n", + "max_comparison_id = max_comp_indices[max_target_id]" + ] + }, { "cell_type": "code", "execution_count": null, @@ -363,27 +382,29 @@ "min_comparison_id = target_min_idx[min_target_id]\n", "\n", "# 8(.039), 17(.028) 23(.037), 25(.036), 41(.033), 48(.035), 49(0.039)\n", - "neuron_indices = [0, 0, 0, min_target_id]\n", - "orth_indices = [0, 0, 0, min_comparison_id]\n", - "target_act = 0.5 # target activity spot between min & max value of normalized activity (btwn 0 and 1)\n", + "neuron_indices = [0, 0, 0, max_target_id]#min_target_id]\n", + "orth_indices = [0, 0, 0, max_comparison_id]#min_comparison_id]\n", "num_plots_y = 2\n", "num_plots_x = 2\n", "width_fraction = 1.0\n", "show_contours = True\n", "\n", "lca_activations = analyzer_list[-1].comp_activations[neuron_indices[-1], orth_indices[-1], ...][None, None, ...]\n", - "curvatures, fits = ha.iso_response_curvature_poly_fits(\n", + "curvatures, fits, contours = ha.iso_response_curvature_poly_fits(\n", " lca_activations,\n", - " target_act=target_act,\n", - " measure_upper_right=False\n", + " target_act=target_act\n", ")\n", "curvature = [None, None, None, curvatures[0][0]]\n", "\n", + "#for analyzer in analyzer_list:\n", + "# analyzer.comp_activations = analyzer.comp_activations - analyzer.comp_activations.min()\n", + "# analyzer.comp_activations = analyzer.comp_activations / analyzer.comp_activations.max()\n", + "\n", "contour_fig, contour_handles = nc.plot_group_iso_contours(analyzer_list, neuron_indices, orth_indices,\n", " num_levels, x_range, y_range, show_contours, curvature, text_width, width_fraction, dpi)\n", "\n", "for analyzer, neuron_index, orth_index, save_suffix in zip(analyzer_list, neuron_indices, orth_indices, save_names):\n", - " for ext in [\".eps\"]:#[\".png\", \".eps\"]:\n", + " for ext in file_extensions:\n", " neuron_str = str(analyzer.target_neuron_ids[neuron_index])\n", " orth_str = str(analyzer.comparison_neuron_ids[neuron_index][orth_index])\n", " save_name = analyzer.analysis_out_dir+\"/vis/iso_contour_comparison_\"\n", @@ -432,11 +453,11 @@ " num_y,\n", " show_contours,\n", " text_width,\n", - " width_fraction,\n", + " 1.00,\n", " dpi\n", ")\n", "\n", - "for ext in [\".eps\"]:#[\".png\", \".eps\"]:\n", + "for ext in file_extensions:\n", " save_name = analyzer.analysis_out_dir+\"/vis/scaled_iso_contours_set_\"\n", " if not show_contours:\n", " save_name += \"continuous_\"\n", @@ -550,154 +571,6 @@ "full_xlabel = [\"Curvature (Comparison)\", \"Curvature (Random)\"]" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def plot_curvature_histograms(activity, contour_pts, contour_angle, view_elevation, contour_text_loc, hist_list,\n", - " label_list, color_list, mesh_color, bin_centers, title, xlabel, curve_lims,\n", - " scatter, log=True, text_width=200, width_ratio=1.0, dpi=100):\n", - " gs0_wspace = 0.5\n", - " hspace_hist = 0.7\n", - " wspace_hist = 0.08\n", - " iso_response_line_thickness = 2\n", - " respone_attenuation_line_thickness = 2\n", - " num_y_plots = 2\n", - " num_x_plots = 1\n", - " fig = plt.figure(figsize=nc.set_size(text_width, width_ratio, [num_y_plots, num_x_plots]), dpi=dpi)\n", - " gs_base = gridspec.GridSpec(num_y_plots, num_x_plots, wspace=gs0_wspace)\n", - " \n", - " curve_ax = fig.add_subplot(gs_base[0], projection='3d')\n", - " curve_ax.minorticks_off()\n", - " x_mesh, y_mesh = np.meshgrid(*contour_pts)\n", - " curve_ax.set_zlim(0, 1)\n", - " curve_ax.set_xlim3d(5, 200)\n", - " curve_ax.grid(False)\n", - " #x_ticks = curve_ax.get_xticks().tolist()\n", - " #x_ticks = np.round(np.linspace(curve_lims['x'][0], curve_lims['x'][1],\n", - " # len(x_ticks)), 1).astype(str)\n", - " #a_x = [' ']*len(x_ticks)\n", - " #a_x[1] = x_ticks[1]\n", - " #a_x[-1] = x_ticks[-1]\n", - " curve_ax.set_xticklabels([])#a_x)\n", - " #y_ticks = curve_ax.get_yticks().tolist()\n", - " #y_ticks = np.round(np.linspace(curve_lims['y'][0], curve_lims['y'][1],\n", - " # len(y_ticks)), 1).astype(str)\n", - " #a_y = [' ']*len(y_ticks)\n", - " #a_y[1] = y_ticks[1]\n", - " #a_y[-1] = y_ticks[-1]\n", - " curve_ax.set_yticklabels([])#a_y)\n", - " curve_ax.set_zticklabels([])\n", - " curve_ax.zaxis.set_rotate_label(False)\n", - " #curve_ax.set_zlabel('Activity', rotation=95, labelpad=-15., position=(-10., 0.))\n", - " if scatter:\n", - " curve_ax.scatter(x_mesh, y_mesh, activity, color=mesh_color, s=0.01)\n", - " else:\n", - " curve_ax.plot_wireframe(x_mesh, y_mesh, activity, rcount=100, ccount=100, color=mesh_color, zorder=1,\n", - " linestyles='dotted', linewidths=0.3, alpha=1.0)\n", - " \n", - " # Plane vector visualizations\n", - " v = nc.Arrow3D([-200/3., -200/3.], [200/2., 200/2.+200/16.], \n", - " [0, 0.0], mutation_scale=10, \n", - " lw=0.5, arrowstyle='-|>', color='red', linestyle='dashed')\n", - " curve_ax.add_artist(v)\n", - " curve_ax.text(-300/3., 280/3.0, 0.0, r'$\\nu$', color='red')\n", - " phi_k = nc.Arrow3D([-200/3., 0.], [200/2., 200/2.], \n", - " [0, 0.0], mutation_scale=10, \n", - " lw=1, arrowstyle='-|>', color='red', linestyle = 'dashed')\n", - " curve_ax.add_artist(phi_k)\n", - " curve_ax.text(-175/3., 250/3.0, 0.0, r'${\\phi}_{k}$', color='red')\n", - " \n", - " # Iso-response curve\n", - " loc0, loc1, loc2 = contour_text_loc[0]\n", - " curve_ax.text(loc0, loc1, loc2, 'Iso-\\nresponse', color='black', weight='bold', zorder=10)\n", - " lines = np.array([0.2, 0.203, 0.197]) - 0.1\n", - " for i in lines:\n", - " curve_ax.contour3D(x_mesh, y_mesh, activity, [i], colors='black', linewidths=2, zorder=2)\n", - " \n", - " # Response attenuation curve\n", - " loc0, loc1, loc2 = contour_text_loc[1]\n", - " curve_ax.text(loc0, loc1, loc2, 'Response\\nAttenuation', color='black', weight='bold', zorder=10)\n", - " att_line_offset = 165\n", - " x, y = contour_pts\n", - " curve_ax.plot(np.zeros_like(x)+att_line_offset, y, activity[:, att_line_offset],\n", - " color='black', lw=2, zorder=2)\n", - " \n", - " # Activity label\n", - " #loc0, loc1, loc2 = contour_text_loc[2]\n", - " #curve_ax.text(loc0, loc1, loc2, 'Activity', color='black', weight='bold', zorder=10, zdir='z')\n", - " \n", - " # Additional settings\n", - " curve_ax.view_init(view_elevation, contour_angle)\n", - " scaling = np.array([getattr(curve_ax, 'get_{}lim'.format(dim))() for dim in 'xyz'])\n", - " curve_ax.auto_scale_xyz(*[[np.min(scaling), np.max(scaling)]]*3) # square aspect\n", - " #curve_ax.get_xaxis().set_visible(False)\n", - " #curve_ax.get_yaxis().set_visible(False)\n", - " curve_ax._axis3don = False\n", - " #y_aspect = 2\n", - " #scale_x = [np.min(scaling), np.max(scaling)]\n", - " #scale_y = [y_aspect*np.min(scaling), y_aspect*np.max(scaling)]\n", - " #scale_z = [np.min(scaling), np.max(scaling)]\n", - " #curve_ax.auto_scale_xyz(scale_x, scale_y, scale_z)\n", - " \n", - " # Histogram plots\n", - " num_hist_y_plots = 2\n", - " num_hist_x_plots = 2\n", - " gs_hist = gridspec.GridSpecFromSubplotSpec(num_hist_y_plots, num_hist_x_plots, gs_base[1],\n", - " hspace=hspace_hist, wspace=wspace_hist)\n", - " orig_ax = fig.add_subplot(gs_hist[0,0])\n", - " axes = []\n", - " for sub_plt_y in range(0, num_hist_y_plots):\n", - " axes.append([])\n", - " for sub_plt_x in range(0, num_hist_x_plots):\n", - " if (sub_plt_x, sub_plt_y) == (0,0):\n", - " axes[sub_plt_y].append(orig_ax)\n", - " else:\n", - " axes[sub_plt_y].append(fig.add_subplot(gs_hist[sub_plt_y, sub_plt_x], sharey=orig_ax))\n", - " all_x_lists = zip(hist_list, label_list, color_list, bin_centers, title)\n", - " for axis_x, (curvature_hist, sub_label, sub_color, sub_bins, sub_title) in enumerate(all_x_lists):\n", - " sub_bins = np.squeeze(sub_bins)\n", - " all_y_lists = zip(curvature_hist, sub_label, sub_color, xlabel)\n", - " for axis_y, (dataset_hist, axis_labels, axis_colors, sub_xlabel) in enumerate(all_y_lists):\n", - " axes[axis_y][axis_x].spines['top'].set_visible(False)\n", - " axes[axis_y][axis_x].spines['right'].set_visible(False)\n", - " axes[axis_y][axis_x].set_xticks(sub_bins, minor=True)\n", - " axes[axis_y][axis_x].set_xticks(sub_bins[::int(len(sub_bins)/4)], minor=False)\n", - " axes[axis_y][axis_x].xaxis.set_major_formatter(plticker.FormatStrFormatter('%0.3f'))\n", - " for neuron_hist, label, color in zip(dataset_hist, axis_labels, axis_colors):\n", - " neuron_hist = np.squeeze(neuron_hist)\n", - " if log:\n", - " axes[axis_y][axis_x].semilogy(sub_bins, neuron_hist, color=color, linestyle='-',\n", - " drawstyle='steps-mid', label=label)\n", - " axes[axis_y][axis_x].yaxis.set_major_formatter(matplotlib.ticker.LogFormatterSciNotation())\n", - " else:\n", - " axes[axis_y][axis_x].plot(sub_bins, neuron_hist, color=color, linestyle='-', drawstyle='steps-mid', label=label)\n", - " axes[axis_y][axis_x].axvline(0.0, color='black', linestyle='dashed', linewidth=1)\n", - " if axis_y == 0:\n", - " axes[axis_y][axis_x].set_title(sub_title)\n", - " axes[axis_y][axis_x].set_xlabel(sub_xlabel)\n", - " if axis_x == 0:\n", - " if log:\n", - " axes[axis_y][axis_x].set_ylabel('Relative\\nLog Frequency')\n", - " else:\n", - " axes[axis_y][axis_x].set_ylabel('Relative\\nFrequency')\n", - " ax_handles, ax_labels = axes[axis_y][axis_x].get_legend_handles_labels()\n", - " legend = axes[axis_y][axis_x].legend(handles=ax_handles, labels=ax_labels, loc='upper right',\n", - " ncol=3, borderaxespad=0., borderpad=0., handlelength=0., columnspacing=-0.5,\n", - " labelspacing=0., bbox_to_anchor=(0.95, 0.95))\n", - " legend.get_frame().set_linewidth(0.0)\n", - " for text, color in zip(legend.get_texts(), axis_colors):\n", - " text.set_color(color)\n", - " for item in legend.legendHandles:\n", - " item.set_visible(False)\n", - " if axis_x == 1:\n", - " axes[axis_y][axis_x].tick_params(axis='y', labelleft=False)\n", - " plt.show()\n", - " return fig" - ] - }, { "cell_type": "code", "execution_count": null, @@ -718,12 +591,12 @@ "activity_loc = [-27, 150, 1.5]\n", "contour_text_loc = [iso_resp_loc, resp_att_loc, activity_loc]\n", "\n", - "curvature_log_fig = plot_curvature_histograms(contour_activity, contour_pts, contour_angle, view_elevation, \n", + "curvature_log_fig = nc.plot_curvature_histograms(contour_activity, contour_pts, contour_angle, view_elevation, \n", " contour_text_loc, full_hist_list, full_label_list, full_color_list, mesh_color, full_bin_centers,\n", - " full_title, full_xlabel, curve_lims, scatter, log=True, text_width=text_width, width_ratio=1.0, dpi=dpi)\n", + " full_title, full_xlabel, curve_lims, scatter, log=True, text_width=text_width, width_ratio=0.75, dpi=dpi)\n", "\n", "for analyzer in analyzer_list:\n", - " for ext in [\".pdf\"]:\n", + " for ext in file_extensions:\n", " save_name = (analyzer.analysis_out_dir+\"/vis/\"+iso_save_name+\"curvatures_and_histograms_logy\"\n", " +\"_\"+analyzer.analysis_params.save_info+ext)\n", " curvature_log_fig.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.05, dpi=dpi)" @@ -737,12 +610,12 @@ }, "outputs": [], "source": [ - "curvature_lin_fig = plot_curvature_histograms(contour_activity, contour_pts, contour_angle, view_elevation,\n", + "curvature_lin_fig = nc.plot_curvature_histograms(contour_activity, contour_pts, contour_angle, view_elevation,\n", " contour_text_loc, full_hist_list, full_label_list, full_color_list, mesh_color, full_bin_centers,\n", - " full_title, full_xlabel, curve_lims, scatter, log=False, text_width=text_width, width_ratio=1.0, dpi=dpi)\n", + " full_title, full_xlabel, curve_lims, scatter, log=False, text_width=text_width, width_ratio=0.75, dpi=dpi)\n", "\n", "for analyzer in analyzer_list:\n", - " for ext in [\".eps\"]:\n", + " for ext in file_extensions:\n", " save_name = (analyzer.analysis_out_dir+\"/vis/\"+iso_save_name+\"curvatures_and_histograms_liny\"\n", " +\"_\"+analyzer.analysis_params.save_info+ext)\n", " curvature_lin_fig.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.05, dpi=dpi)" @@ -820,10 +693,10 @@ "density = False\n", "\n", "circ_var_fig = nc.plot_circ_variance_histogram(analyzer_list, circ_var_list, color_list, label_list, num_bins,\n", - " density, width_ratios, height_ratios, text_width=text_width, width_ratio=1.0, dpi=dpi)\n", + " density, width_ratios, height_ratios, text_width=text_width, width_ratio=0.75, dpi=dpi)\n", "\n", "for analyzer in analyzer_list:\n", - " for ext in [\".png\", \".eps\"]:\n", + " for ext in file_extensions:\n", " save_name = (analyzer.analysis_out_dir+\"/vis/circular_variance_combo\"\n", " +\"_\"+analyzer.analysis_params.save_info+ext)\n", " circ_var_fig.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.01, dpi=dpi)" @@ -838,7 +711,7 @@ "spatial_frequencies = np.stack([np.array(analyzer.bf_stats[\"spatial_frequencies\"]) for analyzer in analyzer_list], axis=0)\n", "circular_variances = np.stack([variance for variance in circ_var_list], axis=0)\n", "\n", - "cv_vs_sf_fig = plt.figure(figsize=nc.set_size(text_width), dpi=dpi)\n", + "cv_vs_sf_fig = plt.figure(figsize=nc.set_size(text_width, fraction=0.75), dpi=dpi)\n", "ax = cv_vs_sf_fig.add_subplot()\n", "for analyzer_idx in range(len(analyzer_list)):\n", " ax.scatter(spatial_frequencies[analyzer_idx, :], circular_variances[analyzer_idx, :],\n", @@ -858,7 +731,7 @@ "plt.show()\n", "\n", "for analyzer in analyzer_list:\n", - " for ext in [\".eps\"]:\n", + " for ext in file_extensions:\n", " save_name = (analyzer.analysis_out_dir+\"/vis/spatial_freq_vs_circular_variance\"\n", " +\"_\"+analyzer.analysis_params.save_info+ext)\n", " cv_vs_sf_fig.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.01, dpi=dpi)" @@ -871,7 +744,7 @@ "outputs": [], "source": [ "params_list = [lca_512_vh_params(), lca_1024_vh_params(), lca_2560_vh_params()]\n", - "display_names = [\"512 Neurons\", \"768 Neurons\", \"1024 Neurons\"]#, \"2560 Neurons\"]\n", + "display_names = [\"512 Neurons\", \"1024 Neurons\", \"2560 Neurons\"]\n", "for params, display_name in zip(params_list, display_names):\n", " params.display_name = display_name\n", " params.model_dir = (os.path.expanduser(\"~\")+\"/Work/Projects/\"+params.model_name)\n", @@ -925,10 +798,10 @@ "density = True\n", "\n", "oc_vs_cv_fig = nc.plot_circ_variance_histogram(analyzer_list, circ_var_list, color_list, label_list, num_bins,\n", - " density, width_ratios, height_ratios, text_width=text_width, width_ratio=1.0, dpi=dpi)\n", + " density, width_ratios, height_ratios, text_width=text_width, width_ratio=0.75, dpi=dpi)\n", "\n", "for analyzer in analyzer_list:\n", - " for ext in [\".png\", \".eps\"]:\n", + " for ext in file_extensions:\n", " save_name = (analyzer.analysis_out_dir+\"/vis/overcompleteness_vs_circular_variance\"\n", " +\"_\"+analyzer.analysis_params.save_info+ext)\n", " oc_vs_cv_fig.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.05, dpi=dpi)" @@ -944,7 +817,9 @@ { "cell_type": "code", "execution_count": null, - "metadata": {}, + "metadata": { + "scrolled": false + }, "outputs": [], "source": [ "params_list = [lca_512_vh_params(), lca_768_vh_params(), lca_2560_vh_params()]\n", @@ -970,6 +845,156 @@ " analyzer_list.append(analyzer)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def closest_val_in_array(num, arr):\n", + " curr = arr[0]\n", + " for val in arr:\n", + " if abs(num - val) < abs(num - curr):\n", + " curr = val\n", + " curr_idx = np.argwhere(np.array(arr) == curr).item()\n", + " return arr[curr_idx]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "code_folding": [], + "scrolled": false + }, + "outputs": [], + "source": [ + "num_interesting_vals = [\n", + " np.array([analyzer.nat_selectivity['num_interesting_img_nl'],\n", + " analyzer.nat_selectivity['num_interesting_img_l']])\n", + " for analyzer in analyzer_list]\n", + "\n", + "num_interesting_medians = np.stack(\n", + " [np.array([np.median(np.array(analyzer.nat_selectivity['num_interesting_img_nl'])),\n", + " np.median(np.array(analyzer.nat_selectivity['num_interesting_img_l']))])\n", + " for analyzer in analyzer_list], axis=0)\n", + "\n", + "num_interesting_means = np.stack(\n", + " [np.array([analyzer.nat_selectivity['num_interesting_img_nl_mean'],\n", + " analyzer.nat_selectivity['num_interesting_img_l_mean']])\n", + " for analyzer in analyzer_list], axis=0)\n", + "\n", + "num_interesting_stds = np.stack(\n", + " [np.array([analyzer.nat_selectivity['num_interesting_img_nl_std'],\n", + " analyzer.nat_selectivity['num_interesting_img_l_std']])\n", + " for analyzer in analyzer_list], axis=0)\n", + "\n", + "array = [\n", + " [1, 2, 3],\n", + " [4, 5, 6],\n", + "]\n", + "\n", + "scale = 1\n", + "rc_kwargs = {\n", + " 'fontsize':scale*matplotlib.rcParams['font.size'],\n", + " 'fontfamily':scale*matplotlib.rcParams['font.family'],\n", + " 'legend.fontsize': scale*matplotlib.rcParams['font.size'],\n", + " 'text.labelsize': scale*matplotlib.rcParams['font.size']\n", + "}\n", + "figsize = nc.set_size(text_width, fraction=1.00)\n", + "with plot.rc.context(**rc_kwargs):\n", + " interesting_imgs_fig, axs = plot.subplots(array, sharey=False, sharex=False, aspect=3.0, figsize=figsize)\n", + " for ovc_idx, overcompleteness in enumerate(num_interesting_vals):\n", + " ax = axs[ovc_idx]\n", + " df = pd.DataFrame(\n", + " overcompleteness.T,\n", + " columns=pd.Index(['Sparse Coding', 'Linear'])#, name='xlabel')\n", + " )\n", + " box_parts = ax.boxplot(\n", + " df,\n", + " notch=True,\n", + " fill=False,\n", + " whis=(5, 95),\n", + " marker='*',\n", + " markersize=1.0,\n", + " lw=1.2\n", + " )\n", + " colors = ['md_red', 'md_green']\n", + " for pc_idx, box in enumerate(box_parts['boxes']):\n", + " box.set_color(color_vals[colors[pc_idx]])\n", + " ax.format(\n", + " ylocator=50,\n", + " ylim=[0, np.max([np.max(val) for val in num_interesting_vals])],\n", + " title=analyzer_list[ovc_idx].nat_selectivity['oc_label'],\n", + " ylabel='Average number of\\nintersting images',\n", + " xtickminor=False,\n", + " xgrid=False\n", + " )\n", + "\n", + " for idx, analyzer in enumerate(analyzer_list):\n", + " ax = axs[idx+3]\n", + " angle_min = 0.0\n", + " angle_max = 90.0\n", + " nbins=20\n", + " bins = np.linspace(angle_min, angle_max, nbins)\n", + " lin_data = [mean for mean in analyzer.nat_selectivity['lin_means'] if mean>0]\n", + " non_lin_data = [mean for mean in analyzer.nat_selectivity['lca_means'] if mean>0]\n", + " hist_list = []\n", + " color_list = [color_vals['md_green'], color_vals['md_red']]\n", + " label_list = ['Linear Autoencoder', 'Sparse Coding']\n", + " handles = []\n", + " hist_max_list = []\n", + " for angles, label, color in zip([lin_data, non_lin_data], label_list, color_list):\n", + " # density means the y vals are probability density function at the bin, normalized such that the integral over the range is 1.\n", + " hist, bin_edges = np.histogram(np.array(angles).flatten(), bins, density=False)\n", + " hist_max_list.append(hist.max())\n", + " hist_list.append(hist)\n", + " bin_left, bin_right = bin_edges[:-1], bin_edges[1:]\n", + " bin_centers = bin_left + (bin_right - bin_left)/2\n", + " handles.append(ax.plot(bin_centers, hist, linestyle='-', drawstyle='steps-mid', color=color, label=label))\n", + " oc = analyzer.nat_selectivity['oc_label']\n", + " ax.spines['top'].set_visible(False)\n", + " ax.spines['right'].set_visible(False)\n", + " ax.set_xticks(bin_left, minor=True)\n", + " ax.set_xticks(bin_left[::2], minor=False)\n", + " ax.xaxis.set_major_formatter(plticker.FormatStrFormatter('%0.0f'))\n", + " ax.set_xticks([angle_min, angle_max//2, angle_max])\n", + " mid_val = max(hist_max_list)//2\n", + " max_val = int(max(hist_max_list))\n", + " #interval_list = list(range(0, mid_val+51, 50))\n", + " #new_mid = closest_val_in_array(mid_val, interval_list)\n", + " interval_list = list(range(0, max_val+51, 50))\n", + " new_max = closest_val_in_array(max_val, interval_list)\n", + " new_mid = new_max//2\n", + " ax.set_ylim([0, new_max+0.1*new_max])\n", + " ax.set_yticks([0, new_mid, new_max])\n", + " #axs[-1].legend(handles, ncol=1, frameon=False, loc='ur', bbox_to_anchor=[1, 1.02])\n", + " hist_ax_idx = 3\n", + " axs[hist_ax_idx].format(ylabel='Total number of\\ninteresting images')\n", + " axs[hist_ax_idx:].format(\n", + " suptitle='Sparse Coding Increases Neuron Selectivity for Natural Signals',\n", + " xlabel='Mean image-to-weight angle',\n", + " xlim=[0, 90],\n", + " ygrid=False\n", + " )\n", + "plot.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "for analyzer in analyzer_list:\n", + " for ext in file_extensions:\n", + " save_name = (analyzer.analysis_out_dir+'/vis/natural_img_selectivity_box_'\n", + " +analyzer.analysis_params.save_info+ext)\n", + " interesting_imgs_fig.savefig(save_name, transparent=False, pad_inches=0.005, dpi=dpi)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -984,6 +1009,16 @@ "color_list = [color_vals['md_green'], color_vals['md_red']]\n", "label_list = ['Linear Autoencoder', 'Sparse Coding']\n", "\n", + "num_interesting_vals = [\n", + " np.array([analyzer.nat_selectivity['num_interesting_img_nl'],\n", + " analyzer.nat_selectivity['num_interesting_img_l']])\n", + " for analyzer in analyzer_list]\n", + "\n", + "num_interesting_medians = np.stack(\n", + " [np.array([analyzer.nat_selectivity['num_interesting_img_nl_mean'],\n", + " analyzer.nat_selectivity['num_interesting_img_l_mean']])\n", + " for analyzer in analyzer_list], axis=0)\n", + "\n", "num_interesting_means = np.stack(\n", " [np.array([analyzer.nat_selectivity['num_interesting_img_nl_mean'],\n", " analyzer.nat_selectivity['num_interesting_img_l_mean']])\n", @@ -1010,10 +1045,11 @@ " 'fontsize':scale*matplotlib.rcParams['font.size'],\n", " 'fontfamily':scale*matplotlib.rcParams['font.family'],\n", " 'legend.fontsize': scale*matplotlib.rcParams['font.size'],\n", - " 'text.labelsize': scale*matplotlib.rcParams['font.size']-2\n", + " 'text.labelsize': scale*matplotlib.rcParams['font.size']\n", "}\n", + "figsize = nc.set_size(text_width, fraction=1.00)\n", "with plot.rc.context(**rc_kwargs):\n", - " interesting_imgs_fig, axs = plot.subplots(array, sharey=False, aspect=3.0, width=0.4*text_width_cm)\n", + " interesting_imgs_fig, axs = plot.subplots(array, sharey=False, aspect=3.0, figsize=figsize)#, width=0.4*text_width_cm)\n", " ax = axs[0]\n", " obj = ax.bar(\n", " df,\n", @@ -1033,7 +1069,7 @@ " xlocator=1,\n", " xminorlocator=0.5,\n", " ytickminor=False,\n", - " ylim=[0, np.max(num_interesting_means)+np.max(num_interesting_stds)],\n", + " #ylim=[0, np.max(num_interesting_means)+np.max(num_interesting_stds)],\n", " #suptitle='Average number of intersting images'\n", " ylabel='Average number of\\nintersting images',\n", " xgrid=False\n", @@ -1062,7 +1098,8 @@ " ax.set_xticks(bin_left[::2], minor=False)\n", " ax.xaxis.set_major_formatter(plticker.FormatStrFormatter('%0.0f'))\n", " ax.set_xticks([angle_min, angle_max//2, angle_max])\n", - " ax.set_ylim([0.0, max(hist_max_list)+0.1*max(hist_max_list)])\n", + " ax.set_ylim([0, max(hist_max_list)+0.1*max(hist_max_list)])\n", + " ax.set_yticks([0, max(hist_max_list)//2, int(max(hist_max_list))])\n", " ax.format(title=f'{oc}\\n')#, ygrid=False)\n", " #ax.grid(b=False, which='both', axis='both')\n", " axs[1].format(ylabel='Total number of\\ninteresting images')\n", @@ -1084,10 +1121,10 @@ "outputs": [], "source": [ "for analyzer in analyzer_list:\n", - " for ext in ['.png', '.pdf', '.eps']:\n", - " save_name = (analyzer.analysis_out_dir+'/vis/natural_img_selectivity_'\n", + " for ext in file_extensions:\n", + " save_name = (analyzer.analysis_out_dir+'/vis/natural_img_selectivity_bar_'\n", " +analyzer.analysis_params.save_info+ext)\n", - " interesting_imgs_fig.savefig(save_name, transparent=False, pad_inches=0.005, dpi=interesting_imgs_fig.dpi)" + " interesting_imgs_fig.savefig(save_name, transparent=False, pad_inches=0.005, dpi=dpi)" ] }, { @@ -1480,7 +1517,7 @@ " COLORS, inner_group_names, outer_group_names, titles, text_width, width_ratio=1.0, dpi=dpi)\n", "\n", "#for analyzer in analyzer_list:\n", - "for ext in [\".png\", \".eps\"]:\n", + "for ext in file_extensions:\n", " save_name = (output_dir+'/adv_mse_comparison_boxplots'+ext)\n", " adv_fig.savefig(save_name, transparent=False, bbox_inches='tight', pad_inches=0.05, dpi=dpi)" ] @@ -1785,7 +1822,7 @@ "outputs": [], "source": [ "#for analyzer in analyzer_list:\n", - "for ext in [\".png\", \".eps\"]:\n", + "for ext in file_extensions:\n", " save_name = (output_dir+'/adv_mse_comparison_example_images'+ext)\n", " adv_img_fig.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.05, dpi=dpi)" ] @@ -1839,9 +1876,8 @@ "\n", "conf_fig = plot_average_conf_step(files, names)\n", "#for analyzer in analyzer_list:\n", - "for ext in [\".png\", \".eps\"]:\n", - " save_name = (output_dir+'adv_mse_comparison_example_images'\n", - " +\"_\"+analyzer.analysis_params.save_info+ext)\n", + "for ext in file_extensions:\n", + " save_name = (output_dir+'adv_mse_comparison_example_images'+ext)\n", " conf_fig.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.01, dpi=dpi)" ] }, @@ -2082,66 +2118,75 @@ "metadata": {}, "outputs": [], "source": [ - "def plot_average_mse_step(analysis_files, colors, title, model_names, bar_width, hatches, figsize, dpi):\n", + "def plot_average_mse_step(analysis_files, recons, confs, colors, title, model_names, bar_width, hatches, figsize, dpi):\n", " fig = plt.figure(figsize=figsize, dpi=dpi)\n", - " gs0 = gridspec.GridSpec(1, 2, wspace=0.2, width_ratios = [2, 1])#, hspace=0.3)\n", - " left_gs = gridspec.GridSpecFromSubplotSpec(1, 2, gs0[0], wspace=1.3)\n", - " right_gs = gridspec.GridSpecFromSubplotSpec(1, 1, gs0[1], wspace=0.9)\n", - " group_data = []\n", - " group_means = []\n", - " handles = []\n", + " num_conditions = len(analysis_files)\n", + " gs_top = gridspec.GridSpec(num_conditions, num_conditions)\n", " axes = []\n", - " for x_ax_idx, key in enumerate(['input_adv_mses', 'adversarial_outputs']):\n", - " axes.append(fig.add_subplot(left_gs[x_ax_idx]))\n", - " for file_idx, (file, name) in enumerate(zip(analysis_files, model_names)):\n", - " analysis = np.load(file, allow_pickle=True)[\"data\"].item()\n", - " adv_conf = 100*np.max(np.squeeze(analysis['adversarial_outputs']), axis=-1)\n", - " if x_ax_idx == 0:\n", - " axes[-1].set_ylabel('Adversarial\\nConfidence')\n", - " axes[-1].axhline(90.0, color='black', linestyle='dashed', linewidth=1) \n", - " axes[-1].set_ylim([0, 100.1])\n", - " mean_vals = np.mean(adv_conf, axis=-1)[1:]\n", - " std_vals = np.std(adv_conf, axis=-1)[1:]\n", - " else:\n", - " adv_mse = np.squeeze(analysis['input_adv_mses'])\n", - " axes[-1].set_ylabel('Adversarial Mean\\nSquared Distance')\n", - " axes[-1].yaxis.set_major_formatter(ticker.FormatStrFormatter('%.3f'))\n", - " thresh_indices = np.argwhere(np.mean(adv_conf, axis=-1)>90)\n", - " first_adv_cross = np.min(thresh_indices[thresh_indices>2]) # first couple are original label\n", - " axes[-1].axvline(first_adv_cross, color=colors[file_idx][0], linestyle='dashed', linewidth=1)\n", - " mean_vals = np.mean(adv_mse, axis=-1)[1:]\n", - " std_vals = np.std(adv_mse, axis=-1)[1:]\n", - " group_data.append(adv_mse[first_adv_cross, :])\n", - " group_means.append(mean_vals[first_adv_cross])\n", - " max_val = 0.020#np.max(mean_vals)+std_vals[np.argmax(mean_vals)]\n", - " axes[-1].set_ylim([0, max_val])\n", - " axes[-1].plot(range(len(mean_vals)), mean_vals, label=name,\n", - " lw=2, color=colors[file_idx][0], zorder=1)\n", - " axes[-1].fill_between(range(len(mean_vals)), mean_vals + std_vals , mean_vals - std_vals,\n", - " edgecolor=colors[file_idx][1], alpha=1.0, zorder=0, facecolor=\"none\", hatch=hatches[file_idx],\n", - " rasterized=False)\n", - " axes[-1].set_xlabel('Attack Step')\n", + " for condition, (condition_analysis_files, recon, conf) in enumerate(zip(analysis_files, recons, confs)):\n", + " #gs0 = gridspec.GridSpec(1, 2, wspace=0.2, width_ratios = [2, 1])#, hspace=0.3)\n", + " gs0 = gridspec.GridSpecFromSubplotSpec(1, 2, gs_top[condition, :],\n", + " wspace=0.2, width_ratios = [2, 1])#, hspace=0.3)\n", + " left_gs = gridspec.GridSpecFromSubplotSpec(1, 2, gs0[0], wspace=1.3)\n", + " right_gs = gridspec.GridSpecFromSubplotSpec(1, 1, gs0[1], wspace=0.9)\n", + " group_data = []\n", + " group_means = []\n", + " handles = []\n", + " for x_ax_idx, key in enumerate(['input_adv_mses', 'adversarial_outputs']):\n", + " axes.append(fig.add_subplot(left_gs[x_ax_idx]))\n", + " for file_idx, (file, name) in enumerate(zip(condition_analysis_files, model_names)):\n", + " analysis = np.load(file, allow_pickle=True)[\"data\"].item()\n", + " adv_conf = 100*np.max(np.squeeze(analysis['adversarial_outputs']), axis=-1)\n", + " if x_ax_idx == 0:\n", + " if condition == 0:\n", + " axes[-1].set_ylabel('Adversarial\\nConfidence')\n", + " axes[-1].axhline(90.0, color='black', linestyle='dashed', linewidth=1) \n", + " axes[-1].set_ylim([0, 100.1])\n", + " mean_vals = np.mean(adv_conf, axis=-1)[1:]\n", + " std_vals = np.std(adv_conf, axis=-1)[1:]\n", + " else:\n", + " if condition == 0:\n", + " axes[-1].set_ylabel('Adversarial Mean\\nSquared Distance')\n", + " adv_mse = np.squeeze(analysis['input_adv_mses'])\n", + " axes[-1].yaxis.set_major_formatter(ticker.FormatStrFormatter('%.3f'))\n", + " thresh_indices = np.argwhere(np.mean(adv_conf, axis=-1)>90)\n", + " first_adv_cross = np.min(thresh_indices[thresh_indices>2]) # first couple are original label\n", + " axes[-1].axvline(first_adv_cross, color=colors[file_idx][0], linestyle='dashed', linewidth=1)\n", + " mean_vals = np.mean(adv_mse, axis=-1)[1:]\n", + " std_vals = np.std(adv_mse, axis=-1)[1:]\n", + " group_data.append(adv_mse[first_adv_cross, :])\n", + " group_means.append(mean_vals[first_adv_cross])\n", + " max_val = 0.03#np.max(mean_vals)+std_vals[np.argmax(mean_vals)]\n", + " axes[-1].set_ylim([0, max_val])\n", + " axes[-1].plot(range(len(mean_vals)), mean_vals, label=name,\n", + " lw=2, color=colors[file_idx][0], zorder=1)\n", + " axes[-1].fill_between(range(len(mean_vals)), mean_vals + std_vals , mean_vals - std_vals,\n", + " edgecolor=colors[file_idx][1], alpha=1.0, zorder=0, facecolor=\"none\",\n", + " hatch=hatches[file_idx], rasterized=False)\n", + " if condition == num_conditions-1:\n", + " axes[-1].set_xlabel('Attack Step')\n", + " axes[-1].grid(False)\n", + " axes.append(fig.add_subplot(right_gs[0]))\n", + " x_pos = np.arange(2) + 2 * bar_width\n", + " linewidth = 1\n", + " medianprops = dict(linestyle='--', linewidth=linewidth, color='k')\n", + " meanprops = dict(linestyle='-', linewidth=linewidth, color='k')\n", + " float_colors = [[52/255, 152/255, 219/255], [231/255, 76/255, 60/255]] # blue, red\n", + " axes[-1].set_title(f'c={recon}, '+r'$\\kappa$'+f'={conf}')\n", + " for data, means, pos, color, name in zip(group_data, group_means, x_pos, float_colors, model_names):\n", + " boxprops = dict(linestyle='-', linewidth=linewidth, color=color)\n", + " whiskerprops = boxprops\n", + " capprops = boxprops\n", + " handles.append(axes[-1].boxplot(data, sym='', positions=[pos],\n", + " whis=(5, 95), widths=bar_width, meanline=True, showmeans=True, boxprops=boxprops,\n", + " whiskerprops=whiskerprops, capprops=capprops, medianprops=medianprops,\n", + " meanprops=meanprops\n", + " ))\n", + " axes[-1].set_ylim([0, max_val])\n", + " axes[-1].set_yticklabels('')\n", + " axes[-1].get_xaxis().set_ticks([])\n", " axes[-1].grid(False)\n", - " axes.append(fig.add_subplot(right_gs[0]))\n", - " x_pos = np.arange(2) + 2 * bar_width\n", - " linewidth = 1\n", - " medianprops = dict(linestyle='--', linewidth=linewidth, color='k')\n", - " meanprops = dict(linestyle='-', linewidth=linewidth, color='k')\n", - " colors = [[52/255, 152/255, 219/255], [231/255, 76/255, 60/255]] # blue, red\n", - " for data, means, pos, color, name in zip(group_data, group_means, x_pos, colors, model_names):\n", - " boxprops = dict(linestyle='-', linewidth=linewidth, color=color)\n", - " whiskerprops = boxprops\n", - " capprops = boxprops\n", - " handles.append(axes[-1].boxplot(data, sym='', positions=[pos],\n", - " whis=(5, 95), widths=bar_width, meanline=True, showmeans=True, boxprops=boxprops,\n", - " whiskerprops=whiskerprops, capprops=capprops, medianprops=medianprops,\n", - " meanprops=meanprops\n", - " ))\n", - " axes[-1].set_ylim([0, max_val])\n", - " axes[-1].set_yticklabels('')\n", - " axes[-1].get_xaxis().set_ticks([])\n", - " axes[-1].grid(False)\n", - " axes[-1].text(pos, 0.002, name, horizontalalignment='center', verticalalignment='center')\n", + " axes[-1].text(pos, 0.0025, name, horizontalalignment='center', verticalalignment='center')\n", " fig.subplots_adjust(top=0.8)\n", " fig.suptitle(title, y=0.98)\n", " return fig, axes" @@ -2156,35 +2201,44 @@ "outputs": [], "source": [ "colors = [[color_vals['md_blue'], color_vals['lt_blue']], [color_vals['md_red'], color_vals['lt_red']]]\n", - "model_names = ['w/o\\nLCA', 'w/\\nLCA']\n", + "model_names = ['w/o LCA', 'w/ LCA']\n", "hatches = ['///', '\\\\\\\\\\\\']\n", "\n", - "k_file_path = analysis_dir+'savefiles/class_adversary_analysis_test_temp_kurakin_targeted.npz'\n", - "k_img_path = analysis_dir+'savefiles/class_adversary_images_analysis_test_temp_kurakin_targeted.npz'\n", - "k_mlp_files = [projects_dir + model_name + k_file_path for model_name in [mnist_mlp_768_2layer]]\n", - "k_lca_files = [projects_dir + model_name + k_file_path for model_name in [mnist_lca_768_2layer]]\n", - "k_files = k_mlp_files + k_lca_files\n", - "\n", - "run_number = '5'\n", - "c_file_path = (analysis_dir+'savefiles/class_adversary_analysis_test_temp'\n", - " +str(run_number)+'_carlini_targeted.npz')\n", - "c_img_path = (analysis_dir+'savefiles/class_adversary_images_analysis_test_temp'\n", - " +str(run_number)+'_carlini_targeted.npz')\n", - "c_mlp_files = [projects_dir + model_name + c_file_path for model_name in [mnist_mlp_768_2layer]]\n", - "c_lca_files = [projects_dir + model_name + c_file_path for model_name in [mnist_lca_768_2layer]]\n", - "c_files = c_mlp_files + c_lca_files\n", - "\n", - "figsize = nc.set_size(text_width, fraction=1.0, subplot=[2, 3])\n", "#carlini_title = 'Networks with an LCA layer require larger\\nperturbations for equal confidence with the Carlini attack'\n", "#carlini_title = 'Networks with an LCA layer are more robust than without'\n", - "carlini_title = ''#Networks with an LCA layer are more robust than without'\n", - "fig, ax = plot_average_mse_step(c_files, colors, carlini_title,\n", + "carlini_title = ''\n", + "\n", + "all_recons = []\n", + "all_confs = []\n", + "all_files = []\n", + "for recon in ['0.5', '1.0']:\n", + " for conf in ['0.0', '10.0']:\n", + " if conf == '10.0':\n", + " extra_str = '_'\n", + " temp = '1.00'\n", + " else:\n", + " extra_str = ''\n", + " temp = '1.0'\n", + " c_file_path = (f'{analysis_dir}savefiles/class_adversary_analysis_test'+\n", + " f'{extra_str}temp{temp}_conf{conf}_recon{recon}_carlini_targeted.npz')\n", + " c_mlp_files = [projects_dir + model_name + c_file_path for model_name in [mnist_mlp_768_2layer]]\n", + " temp = '0.65'\n", + " c_file_path = (f'{analysis_dir}savefiles/class_adversary_analysis_test'+\n", + " f'{extra_str}temp{temp}_conf{conf}_recon{recon}_carlini_targeted.npz')\n", + " c_lca_files = [projects_dir + model_name + c_file_path for model_name in [mnist_lca_768_2layer]]\n", + " c_files = c_mlp_files + c_lca_files\n", + " all_recons.append(recon)\n", + " all_confs.append(conf)\n", + " all_files.append(c_files)\n", + "\n", + "figsize = nc.set_size(text_width, fraction=1.0, subplot=[2*2, 3])\n", + "fig, ax = plot_average_mse_step(all_files, all_recons, all_confs, colors, carlini_title,\n", " model_names, bar_width, hatches, figsize, dpi)\n", "\n", - "out_list = [projects_dir + model_name + '/analysis/0.0/vis/kurakin_carlini_mse_vs_iteration_temp' + str(run_number)\n", + "out_list = [projects_dir + model_name + '/analysis/0.0/vis/carlini_mse_vs_iteration_k0.0-10.0_conditions'\n", " for model_name in [mnist_lca_768_2layer, mnist_mlp_768_2layer]]\n", "for out_name in out_list:\n", - " for ext in [\".png\", \".eps\"]:\n", + " for ext in file_extensions:\n", " save_name = out_name+ext\n", " fig.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.01, dpi=dpi)" ] @@ -2334,7 +2388,7 @@ "\n", "out_list += [path + lista + \"/analysis/0.0/vis/lista_adv_transferability\"]\n", "for out_name in out_list:\n", - " for ext in [\".png\", \".eps\"]:\n", + " for ext in file_extensions:\n", " save_name = out_name+ext\n", " fig.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.01, dpi=dpi)" ] @@ -2367,7 +2421,7 @@ " image_idx = category_idx + start_idx\n", " orig_ax = ax_orig_list[category_idx]\n", " if category_idx == 0:\n", - " orig_ax.set_title(\"Unperturbed\", y=orig_y_adj, fontsize=fontsize)\n", + " orig_ax.set_title(\"Unperturbed\", y=orig_y_adj)#, fontsize=fontsize)\n", " orig_img = crop(np.squeeze(mlp_grp[0][0][image_idx, ...]), crop_ammount)\n", " orig_im_handle = show_image_with_label(orig_ax, orig_img, orig_labels[0][0][image_idx], cmap=cmap)\n", " for model_idx, gs_sub in enumerate([gs_sub0_list[category_idx], gs_sub1_list[category_idx]]):\n", @@ -2388,27 +2442,27 @@ " vmax = 1.0\n", " if j == 0: # top left image\n", " if model_idx == 0:\n", - " current_ax.set_ylabel(r\"$s^{*}_{T}$\", fontsize=fontsize)\n", + " current_ax.set_ylabel(r\"$s^{*}_{T}$\")#, fontsize=fontsize)\n", " current_target_label = target_labels[j][model_idx][image_idx]\n", " if category_idx == 0: # top category only\n", " x_loc = group_name_loc[0]\n", " y_loc = group_name_loc[1]\n", - " text_handle = current_ax.text(x_loc, y_loc, group_names[j+model_idx], fontsize=fontsize,\n", + " text_handle = current_ax.text(x_loc, y_loc, group_names[j+model_idx],#, fontsize=fontsize,\n", " horizontalalignment='left', verticalalignment='bottom')\n", " else: # i == 1\n", " vmin = np.round(diff_vmin, 2)\n", " vmax = np.round(diff_vmax, 2)\n", " if j == 0 and model_idx == 0:\n", - " current_ax.set_ylabel(r\"$s-s^{*}_{T}$\", fontsize=fontsize)\n", + " current_ax.set_ylabel(r\"$s-s^{*}_{T}$\")#, fontsize=fontsize)\n", " if j == 0 and category_idx == num_categories-1: # bottom left\n", - " current_ax.set_xlabel(\"w/o\\nLCA\", fontsize=fontsize)\n", + " current_ax.set_xlabel(\"w/o\\nLCA\")#, fontsize=fontsize)\n", " elif j == 1 and category_idx == num_categories-1: # bottom right\n", - " current_ax.set_xlabel(\"w/\\nLCA\", fontsize=fontsize)\n", + " current_ax.set_xlabel(\"w/\\nLCA\")#, fontsize=fontsize)\n", " im_handle = show_image_with_label(current_ax, current_image, current_target_label, vmin=vmin, vmax=vmax, cmap=cmap)\n", " if j == 1:\n", - " pf.add_colorbar_to_im(im_handle, aspect=10, ax=current_ax, ticks=[vmin, vmax], labelsize=fontsize/2)\n", + " pf.add_colorbar_to_im(im_handle, aspect=10, ax=current_ax, ticks=[vmin, vmax])#, labelsize=fontsize/2)\n", "\n", - "def plot_adv_images_with_figsize(image_groups, labels, mnist_start_idx, cifar_start_idx, figsize, fontsize, dpi=100):\n", + "def plot_adv_images_with_figsize(image_groups, labels, mnist_start_idx, cifar_start_idx, figsize, dpi=100):\n", " mnist_mlp_grp, mnist_lca_grp = image_groups[0]\n", " cifar_mlp_grp, cifar_lca_grp = image_groups[1]\n", " mnist_orig_labels, mnist_target_labels, mnist_img_labels = labels[0]\n", @@ -2420,22 +2474,23 @@ " sub_wspace = 0.2\n", " orig_y_adj = 1.10\n", " img_label_loc = [-8.0, -8.0] # [x, y]\n", - " fig2 = plt.figure(figsize=[figsize[0]/2, figsize[1]], dpi=dpi)\n", - " gs0 = plt.GridSpec(2, 1, figure=fig2, hspace=0.3)\n", + " fig = plt.figure(figsize=[figsize[0]/2, figsize[1]], dpi=dpi)\n", + " gs0 = plt.GridSpec(2, 1, figure=fig, hspace=0.3)\n", " \n", " num_categories=3\n", " \n", " gs_mnist = gridspec.GridSpecFromSubplotSpec(num_categories, 6, gs0[0], hspace=hspace, wspace=wspace)\n", - " make_grid_subplots_with_fontsize(fig2, gs_mnist, mnist_mlp_grp, mnist_lca_grp, mnist_orig_labels,\n", + " make_grid_subplots_with_fontsize(fig, gs_mnist, mnist_mlp_grp, mnist_lca_grp, mnist_orig_labels,\n", " mnist_target_labels, mnist_img_labels, img_label_loc, orig_y_adj, mnist_start_idx, num_categories,\n", - " hspace=sub_hspace, wspace=sub_wspace, cmap=\"Greys\", fontsize=fontsize)\n", + " hspace=sub_hspace, wspace=sub_wspace, cmap=\"Greys\")\n", " \n", " gs_cifar = gridspec.GridSpecFromSubplotSpec(num_categories, 6, gs0[1], hspace=hspace, wspace=wspace)\n", - " make_grid_subplots_with_fontsize(fig2, gs_cifar, cifar_mlp_grp, cifar_lca_grp, cifar_orig_labels,\n", + " make_grid_subplots_with_fontsize(fig, gs_cifar, cifar_mlp_grp, cifar_lca_grp, cifar_orig_labels,\n", " cifar_target_labels, cifar_img_labels, img_label_loc, orig_y_adj, 0, num_categories,\n", - " hspace=sub_hspace, wspace=sub_wspace, cmap=\"Greys_r\", fontsize=fontsize)\n", + " hspace=sub_hspace, wspace=sub_wspace, cmap=\"Greys_r\")\n", " \n", - " plt.show()" + " plt.show()\n", + " return fig " ] }, { @@ -2446,9 +2501,17 @@ }, "outputs": [], "source": [ - "full_adv_img_fig = plot_adv_images_with_figsize(image_groups, labels, mnist_start_idx=44, cifar_start_idx=0,\n", - " figsize=(16, 16), fontsize=20, dpi=dpi)" + "figsize = nc.set_size(text_width, fraction=1.0, subplot=[16, 16])\n", + "full_adv_img_fig = plot_adv_images_with_figsize(image_groups, label_groups, mnist_start_idx=44, cifar_start_idx=0,\n", + " figsize=figsize, dpi=dpi)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/tf1x/vis/tsne_analysis.py b/tf1x/vis/tsne_analysis.py index f0f4430e..b60d7220 100644 --- a/tf1x/vis/tsne_analysis.py +++ b/tf1x/vis/tsne_analysis.py @@ -1,5 +1,9 @@ import os import sys +from os.path import dirname as up + +ROOT_DIR = up(up(up(up(os.path.realpath(__file__))))) +if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np import pickle @@ -7,9 +11,6 @@ from tensorflow.contrib.tensorboard.plugins import projector from scipy.misc import imsave -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) -if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) - import DeepSparseCoding.tf1x.data.data_selector as ds import DeepSparseCoding.tf1x.analysis.analysis_picker as ap import DeepSparseCoding.tf1x.utils.data_processing as dp diff --git a/tf1x/vis/vis_class_adversarial.py b/tf1x/vis/vis_class_adversarial.py index 7bbd72a7..ac41ce23 100644 --- a/tf1x/vis/vis_class_adversarial.py +++ b/tf1x/vis/vis_class_adversarial.py @@ -2,6 +2,11 @@ matplotlib.use('Agg') import os import sys +from os.path import dirname as up + +ROOT_DIR = up(up(up(up(os.path.realpath(__file__))))) +if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) + import numpy as np import matplotlib @@ -13,9 +18,6 @@ import pandas as pd import pdb -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) -if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) - import DeepSparseCoding.tf1x.data.data_selector as ds import DeepSparseCoding.tf1x.utils.data_processing as dp import DeepSparseCoding.tf1x.utils.plot_functions as pf diff --git a/tf1x/vis/vis_conv_lca.py b/tf1x/vis/vis_conv_lca.py index 567dea57..08fbda10 100644 --- a/tf1x/vis/vis_conv_lca.py +++ b/tf1x/vis/vis_conv_lca.py @@ -1,6 +1,11 @@ # In[1]: import os +from os.path import dirname as up + +ROOT_DIR = up(up(up(up(os.path.realpath(__file__))))) +if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) + import numpy as np import matplotlib matplotlib.use('Agg') @@ -11,9 +16,6 @@ import tensorflow as tf import pdb -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) -if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) - import DeepSparseCoding.tf1x.data.data_selector as ds import DeepSparseCoding.tf1x.utils.data_processing as dp import DeepSparseCoding.tf1x.utils.plot_functions as pf diff --git a/tf1x/vis/vis_corrupt.py b/tf1x/vis/vis_corrupt.py index 8b6b3df4..dee47ad2 100644 --- a/tf1x/vis/vis_corrupt.py +++ b/tf1x/vis/vis_corrupt.py @@ -2,6 +2,10 @@ matplotlib.use('Agg') import os import sys +from os.path import dirname as up + +ROOT_DIR = up(up(up(up(os.path.realpath(__file__))))) +if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np import matplotlib @@ -13,9 +17,6 @@ import pandas as pd import pickle -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) -if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) - import DeepSparseCoding.tf1x.data.data_selector as ds import DeepSparseCoding.tf1x.utils.data_processing as dp import DeepSparseCoding.tf1x.utils.plot_functions as pf diff --git a/tf1x/vis/vis_recon_adversarial.py b/tf1x/vis/vis_recon_adversarial.py index ab364e0e..ae926b89 100644 --- a/tf1x/vis/vis_recon_adversarial.py +++ b/tf1x/vis/vis_recon_adversarial.py @@ -3,6 +3,10 @@ import os import sys import pdb +from os.path import dirname as up + +ROOT_DIR = up(up(up(up(os.path.realpath(__file__))))) +if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import numpy as np import matplotlib @@ -11,9 +15,6 @@ import matplotlib.gridspec as gridspec from skimage.measure import compare_psnr -ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd())) -if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) - from DeepSparseCoding.tf1x.data.dataset import Dataset import DeepSparseCoding.tf1x.data.data_selector as ds import DeepSparseCoding.tf1x.utils.data_processing as dp diff --git a/train_model.py b/train_model.py index db2b2778..1bad1364 100644 --- a/train_model.py +++ b/train_model.py @@ -2,10 +2,13 @@ import sys import argparse import time as ti +from os.path import dirname as up -ROOT_DIR = os.path.dirname(os.getcwd()) +ROOT_DIR = up(up(os.path.realpath(__file__))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) +import torch + import DeepSparseCoding.utils.loaders as loaders import DeepSparseCoding.utils.run_utils as run_utils import DeepSparseCoding.utils.dataset_utils as dataset_utils @@ -19,31 +22,33 @@ t0 = ti.time() # Load params -params = loaders.load_params(param_file) +params = loaders.load_params_file(param_file) # Load data -train_loader, val_loader, test_loader, data_stats = dataset_utils.load_dataset(params) -for key, value in data_stats.items(): - setattr(params, key, value) +train_loader, val_loader, test_loader, data_stats = dataset_utils.load_dataset(params)[:4] +for key, value in data_stats.items(): setattr(params, key, value) # Load model model = loaders.load_model(params.model_type) model.setup(params) model.to(params.device) +model.log_architecture_details() # Train model for epoch in range(1, model.params.num_epochs+1): run_utils.train_epoch(epoch, model, train_loader) - if(model.params.model_type.lower() in ['mlp', 'ensemble']): - run_utils.test_epoch(epoch, model, test_loader) - model.log_info(f'Completed epoch {epoch}/{model.params.num_epochs}') + # TODO: Ensemble models might not actually have a classification objective / need validation + #if(model.params.model_type.lower() in ['mlp', 'ensemble']): # TODO: use to validation set here; test at the end of training + # run_utils.test_epoch(epoch, model, test_loader) + model.logger.log_string(f'Completed epoch {epoch}/{model.params.num_epochs}') print(f'Completed epoch {epoch}/{model.params.num_epochs}') +# Final outputs t1 = ti.time() tot_time=float(t1-t0) tot_images = model.params.num_epochs*len(train_loader.dataset) out_str = f'Training on {tot_images} images is complete. Total time was {tot_time} seconds.\n' -model.log_info(out_str) +model.logger.log_string(out_str) print('Training Complete\n') model.write_checkpoint() diff --git a/utils/data_processing.py b/utils/data_processing.py index fff2fddc..16482dfa 100644 --- a/utils/data_processing.py +++ b/utils/data_processing.py @@ -4,34 +4,36 @@ def reshape_data(data, flatten=None, out_shape=None): """ - Helper function to reshape input data for processing and return data shape - Inputs: + Reshape input data for processing and return data shape + + Keyword arguments: data: [tensor] data of shape: - n is num_examples, i is num_rows, j is num_cols, k is num_channels, l is num_examples = i*j*k + n is num_examples, i is num_channels (a.k.a. vector length for fattened inputs), j is num_rows, k is num_cols if out_shape is not specified, it is assumed that i == j - (l) - single data point of shape l, assumes 1 color channel - (n, l) - n data points, each of shape l (flattened) - (i, j, k) - single datapoint of of shape (i,j, k) - (n, i, j, k) - n data points, each of shape (i,j,k) + (i) - single data point of shape i (flattened; assumed j = k = 1) + (n, i) - n data points, each of shape i (flattened; assumed j = k = 1) + (i, j, k) - single datapoint of of shape (i, j, k) + (n, i, j, k) - n data points, each of shape (i, j, k) flatten: [bool or None] specify the shape of the output If out_shape is not None, this arg has no effect If None, do not reshape data, but add num_examples dimension if necessary - If True, return ravelled data of shape (num_examples, num_elements) - If False, return unravelled data of shape (num_examples, sqrt(l), sqrt(l), 1) - where l is the number of elements (dimensionality) of the datapoints + If True, return ravelled data of shape (num_examples, num_elements), where num_elements=i*j*k + If False, return unravelled data of shape (num_examples, 1, sqrt(i), sqrt(i)) + where i is the number of elements (size) of the data points and 1 is the assumed number of channels If data is flat and flatten==True, or !flat and flatten==False, then None condition will apply out_shape: [list or tuple] containing the desired output shape This will overwrite flatten, and return the input reshaped according to out_shape + Outputs: tuple containing: data: [tensor] data with new shape - (num_examples, num_rows, num_cols, num_channels) if flatten==False - (num_examples, num_elements) if flatten==True + (num_examples, num_channels, num_rows, num_cols) if flatten==False + (num_examples, num_elements) if flatten==True, where num_elements = num_channels*num_rows*num_cols orig_shape: [tuple of int32] original shape of the input data num_examples: [int32] number of data examples or None if out_shape is specified + num_channels: [int32] number of data channels or None if out_shape is specified num_rows: [int32] number of data rows or None if out_shape is specified num_cols: [int32] number of data cols or None if out_shape is specified - num_channels: [int32] number of data channels or None if out_shape is specified """ orig_shape = data.shape orig_ndim = data.ndim @@ -48,7 +50,7 @@ def reshape_data(data, flatten=None, out_shape=None): elif flatten == True: num_rows = num_elements num_cols = 1 - data = torch.reshape(data, (num_examples, num_rows*num_cols*num_channels)) + data = torch.reshape(data, (num_examples, num_channels * num_rows * num_cols)) else: # flatten == False sqrt_num_elements = np.sqrt(num_elements) assert np.floor(sqrt_num_elements) == np.ceil(sqrt_num_elements), ( @@ -57,7 +59,7 @@ def reshape_data(data, flatten=None, out_shape=None): +' and data_shape='+str(orig_shape)) num_rows = int(sqrt_num_elements) num_cols = num_rows - data = torch.reshape(data, (num_examples, num_rows, num_cols, num_channels)) + data = torch.reshape(data, (num_examples, num_channels, num_rows, num_cols)) elif orig_ndim == 2: # already flattened (num_examples, num_elements) = data.shape if flatten is None or flatten == True: # don't reshape data @@ -71,34 +73,35 @@ def reshape_data(data, flatten=None, out_shape=None): num_rows = int(sqrt_num_elements) num_cols = num_rows num_channels = 1 - data = torch.reshape(data, (num_examples, num_rows, num_cols, num_channels)) + data = torch.reshape(data, (num_examples, num_channels, num_rows, num_cols)) else: assert False, ('flatten argument must be True, False, or None') elif orig_ndim == 3: # single data point num_examples = 1 - num_rows, num_cols, num_channels = data.shape + num_channels, num_rows, num_cols = data.shape if flatten == True: - data = torch.reshape(data, (num_examples, num_rows * num_cols * num_channels)) + data = torch.reshape(data, (num_examples, num_channels * num_rows * num_cols)) elif flatten is None or flatten == False: # already not flat - data = data[None, ...] + data = data[None, ...] # add singleton num_examples dimension else: assert False, ('flatten argument must be True, False, or None') - elif orig_ndim == 4: # not flat - num_examples, num_rows, num_cols, num_channels = data.shape + elif orig_ndim == 4: # multiple data points, not flat + num_examples, num_channels, num_rows, num_cols = data.shape if flatten == True: - data = torch.reshape(data, (num_examples, num_rows*num_cols*num_channels)) + data = torch.reshape(data, (num_examples, num_channels * num_rows * num_cols)) else: assert False, ('Data must have 1, 2, 3, or 4 dimensions.') else: num_examples = None; num_rows=None; num_cols=None; num_channels=None data = torch.reshape(data, out_shape) - return (data, orig_shape, num_examples, num_rows, num_cols, num_channels) + return (data, orig_shape, num_examples, num_channels, num_rows, num_cols) def check_all_same_shape(tensor_list): """ Verify that all tensors in the tensor list have the same shape - Args: + + Keyword arguments: tensor_list: list of tensors to be checked Returns: raises error if the tensors are not the same shape @@ -107,53 +110,134 @@ def check_all_same_shape(tensor_list): for index, tensor in enumerate(tensor_list): if tensor.shape != first_shape: raise ValueError( - 'Tensor entry %g in input list has shape %g, but should have shape %g'%( - index, tensor.shape, first_shape)) + f'Tensor entry {index} in input list has shape {tensor.shape}, but should have shape {first_shape}') -def flatten_feature_map(feature_map): +def get_std_from_dataloader(loader, dataset_mean): """ - Flatten input tensor from [batch, y, x, f] to [batch, y*x*f] - Args: - feature_map: tensor with shape [batch, y, x, f] - Returns: - reshaped_map: tensor with shape [batch, y*x*f] - """ - map_shape = feature_map.shape - if(len(map_shape) == 4): - (batch, y, x, f) = map_shape - prev_input_features = int(y * x * f) - resh_map = torch.reshape(feature_map, [-1, prev_input_features]) - elif(len(map_shape) == 2): - resh_map = feature_map - else: - raise ValueError('Input feature_map has incorrect ndims') - return resh_map + TODO: Calculate the standard deviation from all entries in a pytorch data loader + Keyword arguments: + loader: [pytorch DataLoader] containing the full dataset. + This function assumes there is always a target label, i.e. loader.next() returns (data, target) + dataset_mean: [torch tensor] of the same shape as a single dataset sample -def standardize(data, eps=None, samplewise=True): + Outputs: + dataset_std: [torch tensor] of the same shape as a single dataset sample + """ + #dataset_std = torch.zeros(next(iter(loader)).shape[1:]) + #for data, target in loader: + # std_sum_squares += (data - dataset_mean)**2 + #dataset_std = torch.sqrt(std_sum_squares / len(loader.dataset)) + #return dataset_std + raise NotImplementedError + + +def get_std_from_dataloader(loader, dataset_mean=None): + """ + Calculate the standard deviation from the mean for all entries in a pytorch data loader + + Keyword arguments: + loader: [pytorch DataLoader] containing the full dataset. + This function assumes there is always a target label, i.e. loader.next() returns (data, target) + + Outputs: + dataset_std: [torch tensor] of the same shape as a single dataset sample + """ + if dataset_mean is None: + dataset_mean = get_mean_from_dataloader(loader) + dataset_std = torch.zeros(next(iter(loader))[0].shape[1:]) # don't include batch dimension + num_batches = 0 + for data, target in loader: + dataset_std += torch.std(data - dataset_mean, dim=0, keepdim=False) + num_batches += 1 + return dataset_std / num_batches + + +def get_mean_from_dataloader(loader): + """ + Calculate the mean datapoint from all entries in a pytorch data loader + + Keyword arguments: + loader: [pytorch DataLoader] containing the full dataset. + This function assumes there is always a target label, i.e. loader.next() returns (data, target) + + Outputs: + dataset_mean: [torch tensor] of the same shape as a single dataset sample + """ + dataset_mean = torch.zeros(next(iter(loader))[0].shape[1:]) # don't include batch dimension + num_batches = 0 + for data, target in loader: + dataset_mean += data.mean(dim=0, keepdim=False) + num_batches += 1 + return dataset_mean / num_batches + + +def center(data, samplewise=False, batch_size=100): + """ + Center image dataset to have zero mean + + Keyword arguments: + data: [tensor] unnormalized data + samplewise: [bool] if True, center each sample individually; if False, compute mean over entire batch + + Outputs: + data: [tensor] centered data + """ + data, orig_shape = reshape_data(data, flatten=True)[:2] + if(samplewise): # center each input sample individually + data_axis = tuple(range(data.ndim)[1:]) + data_mean = torch.mean(data, dim=data_axis, keepdim=True) + else: # center the entire population + data_mean = torch.mean(data, dim=0) + data = data - data_mean + if(data.shape != orig_shape): + data = reshape_data(data, out_shape=orig_shape)[0] + return data, data_mean + + +def standardize(data, eps=None, samplewise=False, batch_size=100, sample_mean=None, sample_std=None): """ Standardize each image data to have zero mean and unit standard-deviation (z-score) - Uses population standard deviation data.sum() / N, where N = data.shape[0]. - Inputs: + + This function uses population standard deviation data.sum() / N, where N = data.shape[0]. + + Keyword arguments: data: [tensor] unnormalized data eps: [float] if the std(data) is less than eps, then divide by eps instead of std(data) + defaults to 1/sqrt(data_dim) where data_dim is the total size of a data vector samplewise: [bool] if True, standardize each sample individually; akin to contrast-normalization if False, compute mean and std over entire batch + sample_mean: [tensor] to be used as the dataset mean instead of calculating it, + it should be the same shape as a single data element + sample_std: [tensor] to be used as the dataset mean instead of calculating it, + it should be the same shape as a single data element + Outputs: data: [tensor] normalized data """ if(eps is None): - eps = 1.0 / np.sqrt(data[0,...].numel()) - data, orig_shape = reshape_data(data, flatten=True)[:2] # Adds channel dimension if it's missing + eps = 1.0 / data[0,...].numel() + data, orig_shape = reshape_data(data, flatten=True)[:2] num_examples = data.shape[0] - if(samplewise): # standardize the entire population - data_axis = tuple(range(data.ndim)[1:]) # standardize each example individually - data_mean = torch.mean(data, dim=data_axis, keepdim=True) - data_true_std = torch.std(data, unbiased=False, dim=data_axis, keepdim=True) - else: # standardize each input sample individually - data_mean = torch.mean(data) - data_true_std = torch.std(data, unbiased=False) + if(samplewise): # standardize each input sample individually + if sample_mean is None: + data_mean = torch.mean(data, dim=1, keepdim=True) # [num_examples, 1] + else: + data_mean = sample_mean.mean().repeat(num_examples, 1) + if sample_std is None: + data_true_std = torch.std(data - data_mean, unbiased=False, dim=1, keepdim=True) + else: + data_true_std = sample_std.mean().repeat(num_examples, 1) + else: # standardize the entire population + if sample_mean is None: + data_mean = torch.mean(data, dim=0, keepdim=True) # [1, sample_dim] + else: + data_mean = sample_mean.view(1, -1) + if sample_std is None: + data_true_std = torch.std(data - data_mean, dim=0, unbiased=False, keepdim=True) + else: + data_true_std = sample_std.view(1, -1) data_std = torch.where(data_true_std >= eps, data_true_std, eps*torch.ones_like(data_true_std)) data = (data - data_mean) / data_std if(data.shape != orig_shape): @@ -164,10 +248,12 @@ def standardize(data, eps=None, samplewise=True): def rescale_data_to_one(data, eps=None, samplewise=True): """ Rescale input data to be between 0 and 1 - Inputs: + + Keyword arguments: data: [tensor] unnormalized data eps: [float] if the std(data) is less than eps, then divide by eps instead of std(data) samplewise: [bool] if True, compute it per-sample, otherwise normalize entire batch + Outputs: data: [tensor] centered data of shape (n, i, j, k) or (n, l) """ @@ -175,9 +261,9 @@ def rescale_data_to_one(data, eps=None, samplewise=True): eps = 1.0 / np.sqrt(data[0,...].numel()) if(samplewise): data_min = torch.min(data.view(-1, np.prod(data.shape[1:])), - axis=1, keepdims=False)[0].view(-1, *[1]*(data.ndim-1)) + axis=1, keepdim=False)[0].view(-1, *[1]*(data.ndim-1)) data_max = torch.max(data.view(-1, np.prod(data.shape[1:])), - axis=1, keepdims=False)[0].view(-1, *[1]*(data.ndim-1)) + axis=1, keepdim=False)[0].view(-1, *[1]*(data.ndim-1)) else: data_min = torch.min(data) data_max = torch.max(data) @@ -186,11 +272,14 @@ def rescale_data_to_one(data, eps=None, samplewise=True): data = (data - data_min) / data_range return data, data_min, data_max + def one_hot_to_dense(one_hot_labels): """ - converts a matrix of one-hot labels to a list of dense labels - Inputs: + Convert a matrix of one-hot labels to a list of dense labels + + Keyword arguments: one_hot_labels: one-hot torch tensor of shape [num_labels, num_classes] + Outputs: dense_labels: 1D torch tensor array of labels The integer value indicates the class and 0 is assumed to be a class. @@ -202,13 +291,247 @@ def one_hot_to_dense(one_hot_labels): dense_labels[label_id] = torch.nonzero(one_hot_labels[label_id, :] == 1) return dense_labels + def dense_to_one_hot(labels_dense, num_classes): """ - converts a (np.ndarray) vector of dense labels to a (np.ndarray) matrix of one-hot labels - e.g. [0, 1, 1, 3] -> [00, 01, 01, 11] + Converts a (np.ndarray) vector of dense labels to a (np.ndarray) matrix of one-hot labels. E.g. [0, 1, 1, 3] -> [00, 01, 01, 11] + + Keyword arguments: + labels_dense: dense torch tensor of shape [num_classes], where each entry is an integer indicating the class label + num-classes: The total number of classes in the dataset + + Outputs: + one_hot_labels: one-hot torch tensor of shape [num_labels, num_classes] """ num_labels = labels_dense.shape[0] index_offset = torch.arange(end=num_labels, dtype=torch.int32) * num_classes labels_one_hot = torch.zeros((num_labels, num_classes)) labels_one_hot.view(-1)[index_offset + labels_dense.view(-1)] = 1 return labels_one_hot + + +def atleast_kd(x, k): + """ + Return x reshaped to append singleton dimensions such that x.ndim is at least k + + Keyword arguments: + x [Tensor or numpy ndarray] + k [int] minimum number of dimensions + + Outputs: + x [same as input x] reshaped input to have at least k dimensions + """ + shape = x.shape + (1,) * (k - x.ndim) + return x.reshape(shape) + + +def get_weights_l2_norm(w, eps=1e-12): + """ + Return l2 norm of weight matrix + + Keyword arguments: + w [Tensor] assumed to have shape [outC, inC] or [outC, inC, kernH, kernW] + norm is calculated over vectorized version of inC in the first case or inC*kernH*kernW in the second + eps [float] minimum value to prevent division by zero + + Outputs: + norm [Tensor] norm of each of the outC weight vectors + """ + if w.ndim == 2: # fully-connected, [outputs, inputs] + norms = torch.norm(w, dim=1, keepdim=True) + elif w.ndim == 4: # convolutional, [out_channels, in_channels, kernel_height, kernel_width] + norms = torch.norm(w.flatten(start_dim=1), dim=-1, keepdim=True) + else: + assert False, (f'input w must have ndim = 2 or 4, not {w.ndim}') + if(torch.max(norms) <= eps): #TODO: raise proper warnings + print(f'Warning: input gradient is less than or equal to {eps}') + norms = torch.max(norms, eps*torch.ones_like(norms)) # prevent div by 0 # TODO: Change to torch.maximum when it is stable + norms = atleast_kd(norms, w.ndim) + return norms + + +def l2_normalize_weights(w, eps=1e-12): + """ + l2 normalize weight matrix + + Keyword arguments: + w [Tensor] assumed to have shape [inC, outC] or [outC, inC, kernH, kernW] + norm is calculated over vectorized version of inC in the first case or inC*kernH*kernW in the second + eps [float] minimum value to prevent division by zero + + Outputs: + w [Tensor] same type and shape as input w, but with unitary l2 norm when computed over all input dimensions + """ + norms = get_weights_l2_norm(w, eps) + return w / norms + + +def single_image_to_patches(image, patch_shape): + """ + Extract patches from a single image + + Keyword arguments: + image [torch tensor] of shape [im_chan, im_height, im_width] + patch_shape [tuple or list] containing the output shape + [patch_chan, patch_height, patch_width] + patch_chan must be the same as im_chan + + It is recommended, though not required, that the patch height and width divide evenly into + the image height and width, respectively. + + Outputs: + patches [torch tensor] of patches of shape [num_patches]+list(patch_shape) + """ + try: + im_chan, im_height, im_width = image.shape + patch_chan, patch_height, patch_width = patch_shape + except Exception as e: + raise ValueError( + f'This function requires that: ' + +f'1) The input variable "image" must have shape [im_chan, im_height, im_width], and is {image.shape}' + +f'and 2) the input variable "patch_shape" must have shape [patch_chan, patch_height, patch_width], and is {patch_shape}.' + ) from e + num_row_patches = np.floor(im_height / patch_height) + num_col_patches = np.floor(im_width / patch_width) + num_patches = int(num_row_patches * num_col_patches) + patches = torch.zeros((num_patches, patch_chan, patch_height, patch_width)) + row_id = 0 + col_id = 0 + for patch_idx in range(num_patches): + row_end = row_id + patch_height + col_end = col_id + patch_width + try: + patches[patch_idx, ...] = image[:, row_id:row_end, col_id:col_end] + except Exception as e: + raise ValueError('This function requires that im_chan equal patch_chan.') from e + row_id += patch_height + if row_id >= im_height: + row_id = 0 + col_id += patch_width + if col_id >= im_width: + col_id = 0 + return patches + + +def patches_to_single_image(patches, image_shape): + """ + Convert patches input into a single ouput + + Keyword arguments: + patches [torch tensor] of shape [num_patches, patch_chan, patch_height, patch_width] + image_shape [list or tuple] of length 2 containing the image shape [im_chan, im_height, im_width] + + im_chan is assumed to equal patch_chan + + Outputs: + image [torch tensor] of shape [im_chan, im_height, im_width] + """ + try: + num_patches, patch_chan, patch_height, patch_width = patches.shape + im_chan, im_height, im_width = image_shape + except Exception as e: + raise ValueError( + f'This funciton requires that input patches has shape' + f' [num_patches, patch_chan, patch_height, patch_width] and is {patches.shape}' + f' and input image_shape is a list or tuple of integers of length 3 containing [im_chan, im_height, im_width] and is {image_shape}' + ) from e + image = torch.zeros((im_chan, im_height, im_width)) + row_id = 0 + col_id = 0 + for patch_idx in range(num_patches): + row_end = row_id + patch_height + col_end = col_id + patch_width + image[:, row_id:row_end, col_id:col_end] = patches[patch_idx, ...] + row_id += patch_height + if row_id >= im_height: + row_id = 0 + col_id += patch_width + if col_id >= im_width: + col_id = 0 + return image + +def images_to_patches(images, patch_shape): + """ + Extract evenly distributed non-overlapping patches from an image dataset + + Keyword arguments: + images [torch tensor] of shape [num_images, im_chan, im_height, im_width] or [im_chan, im_height, im_width] for a single image + patch_shape [tuple or list] containing the output shape + [patch_chan, patch_height, patch_width] + patch_chan must be the same as im_chan + + It is recommended, though not required, that the patch height and width divide evenly into the image height and width, respectively. + + Outputs: + patches [np.ndarray] of patches of shape [num_patches]+list(patch_shape) + """ + if images.ndim == 3: # single image + return single_image_to_patches(images, patch_shape) + num_im, im_chan, im_height, im_width = images.shape + patch_chan, patch_height, patch_width = patch_shape + num_row_patches = np.floor(im_height / patch_height) + num_col_patches = np.floor(im_width / patch_width) + num_patches_per_im = int(num_row_patches * num_col_patches) + tot_num_patches = int(num_patches_per_im * num_im) + patches = torch.zeros([tot_num_patches, ]+list(patch_shape)) + patch_id = 0 + for im_id in range(num_im): + image = images[im_id, ...] + image_patches = single_image_to_patches(image, patch_shape) + patch_end = patch_id + num_patches_per_im + patches[patch_id:patch_end, ...] = image_patches + patch_id += num_patches_per_im + return patches + +def patches_to_images(patches, image_shape): + """ + Recombine patches tensor into a dataset of images + + Keyword arguments: + patches [torch tensor] holding square patch data of shape [num_patches, patch_chan, patch_height, patch_width] + image_shape [list or tuple] containing the image dataset shape [im_chan, im_height, im_width] + + It is assumed that im_chan equals patch_chan + + Outputs: + images [torch tensor] holding the recombined image dataset + """ + tot_num_patches, patch_chan, patch_height, patch_width = patches.shape + im_chan, im_height, im_width = image_shape + num_row_patches = np.floor(im_height / patch_height) + num_col_patches = np.floor(im_width / patch_width) + num_patches_per_im = int(num_row_patches * num_col_patches) + num_im = tot_num_patches // num_patches_per_im + images = torch.zeros([num_im]+image_shape) + patch_id = 0 + for im_id in range(num_im): + patch_end = patch_id + num_patches_per_im + patch_batch = patches[patch_id:patch_end, ...] + images[im_id, ...] = patches_to_single_image(patch_batch, image_shape) + patch_id += num_patches_per_im + return images + + +def covariance(tensor): + """ + Returns the covariance matrix of the input tensor + + Keyword arguments: + tensor [torch tensor] of shape [num_batch, num_channels] or [num_batch, num_channels, elements_h, elements_w] + if tensor.ndim is 2 then the covariance is computed for each element over the batch dimension + if tensor.ndim is 4 then the covariance is computed over spatial dimensions for each channel and each batch instance + + Outputs: + covariance matrix [torch tensor] of shape [num_channels, num_channels] + """ + if tensor.ndim == 2: # [num_batch, num_channels] + centered_tensor = tensor - tensor.mean(dim=0, keepdim=True) # subtract mean vector + covariance = torch.mm(centered_tensor.T, centered_tensor) # sum over batch + covariance = covariance / (centered_tensor.shape[0]-1) + elif tensor.ndim == 4: # [num_batch, num_channels, elements_h, elements_w] + num_batch, num_channels, elements_h, elements_w = tensor.shape + flat_map = tensor.view(num_batch, num_channels, elements_h * elements_w) + cent_flat_map = flat_map - flat_map.mean(dim=2, keepdim=True) # subtract mean vector + covariance = torch.bmm(cent_flat_map, torch.transpose(cent_flat_map, 1, 2)) # sum over space + covariance = covariance.mean(dim=0, keepdim=False) # avg cov over batch + return covariance diff --git a/utils/dataset_utils.py b/utils/dataset_utils.py index cd46541b..8c17df11 100644 --- a/utils/dataset_utils.py +++ b/utils/dataset_utils.py @@ -1,17 +1,54 @@ import os import sys +from os.path import dirname as up +ROOT_DIR = up(up(up(os.path.realpath(__file__)))) +if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) + +from PIL import Image import numpy as np import torch -from torchvision import datasets, transforms - -ROOT_DIR = os.path.dirname(os.getcwd()) -if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) +from torchvision import transforms +import torchvision.datasets import DeepSparseCoding.utils.data_processing as dp import DeepSparseCoding.datasets.synthetic as synthetic +class FastMNIST(torchvision.datasets.MNIST): + """ + The torchvision MNIST dataset has additional overhead that slows it down. + This loads the entire dataset onto the specified device at init, resulting in a considerable speedup + """ + def __init__(self, *args, **kwargs): + device = kwargs.pop('device', 'cpu') + super().__init__(*args, **kwargs) + # Scale data to [0,1] + self.data = self.data.unsqueeze(-1).float().div(255) + self.data = self.data.permute(0, 3, 1, 2) # channels first + if self.transform is not None: + # doing this so that it is consistent with all other datasets + # to return a PIL Image + for data_idx in range(self.data.shape[0]): + self.data[data_idx, ...] = self.transform( + Image.fromarray( + self.data[data_idx, ...].numpy().squeeze(), mode='L'))[None, ...] + if self.target_transform is not None: + self.targets = [self.target_transform(int(target)) for target in self.targets] + # Put both data and targets on GPU in advance + self.data, self.targets = self.data.to(device), self.targets.to(device) + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: (image, target) where target is index of the target class. + """ + img, target = self.data[index], self.targets[index] + return img, target + + class CustomTensorDataset(torch.utils.data.Dataset): def __init__(self, data_tensor): self.data_tensor = data_tensor @@ -24,29 +61,98 @@ def __len__(self): def load_dataset(params): - new_params = {} + data_stats = {} # dataset statistics + extra_outputs = {} # depending on parameters may include dataset_mean, dataset_std, etc if(params.dataset.lower() == 'mnist'): preprocessing_pipeline = [ transforms.ToTensor(), - transforms.Lambda(lambda x: x.permute(1, 2, 0)) # channels last - ] + ] if params.standardize_data: preprocessing_pipeline.append( transforms.Lambda(lambda x: dp.standardize(x, eps=params.eps)[0])) if params.rescale_data_to_one: preprocessing_pipeline.append( transforms.Lambda(lambda x: dp.rescale_data_to_one(x, eps=params.eps, samplewise=True)[0])) - train_loader = torch.utils.data.DataLoader( - datasets.MNIST(root=params.data_dir, train=True, download=True, - transform=transforms.Compose(preprocessing_pipeline)), - batch_size=params.batch_size, shuffle=params.shuffle_data, - num_workers=0, pin_memory=False) - val_loader = None - test_loader = torch.utils.data.DataLoader( - datasets.MNIST(root=params.data_dir, train=False, download=True, - transform=transforms.Compose(preprocessing_pipeline)), - batch_size=params.batch_size, shuffle=params.shuffle_data, - num_workers=0, pin_memory=False) + kwargs = { + 'root':params.data_dir, + 'download':False, + 'transform':transforms.Compose(preprocessing_pipeline) + } + if(hasattr(params, 'fast_mnist') and params.fast_mnist): + kwargs['device'] = params.device + kwargs['train'] = True + train_loader = torch.utils.data.DataLoader( + FastMNIST(**kwargs), batch_size=params.batch_size, + shuffle=params.shuffle_data, num_workers=0, pin_memory=False) + kwargs['train'] = False + val_loader = None + test_loader = torch.utils.data.DataLoader( + FastMNIST(**kwargs), batch_size=params.batch_size, + shuffle=params.shuffle_data, num_workers=0, pin_memory=False) + else: + kwargs['train'] = True + train_loader = torch.utils.data.DataLoader( + torchvision.datasets.MNIST(**kwargs), batch_size=params.batch_size, + shuffle=params.shuffle_data, num_workers=0, pin_memory=True) + kwargs['train'] = False + val_loader = None + test_loader = torch.utils.data.DataLoader( + torchvision.datasets.MNIST(**kwargs), batch_size=params.batch_size, + shuffle=params.shuffle_data, num_workers=0, pin_memory=True) + + elif(params.dataset.lower() == 'cifar10'): + preprocessing_pipeline = [ + transforms.ToTensor(), + ] + kwargs = { + 'root': os.path.join(*[params.data_dir, 'cifar10']), + 'download': False, + 'train': True, + 'transform': transforms.Compose(preprocessing_pipeline) + } + if params.center_dataset: + dataset = torchvision.datasets.CIFAR10(**kwargs) + data_loader = torch.utils.data.DataLoader(dataset, batch_size=params.batch_size, + shuffle=False, num_workers=0, pin_memory=True) + dataset_mean_image = dp.get_mean_from_dataloader(data_loader) + preprocessing_pipeline.append( + transforms.Lambda(lambda x: x - dataset_mean_image)) + extra_outputs['dataset_mean_image'] = dataset_mean_image + if params.standardize_data: + #dataset = torchvision.datasets.CIFAR10(**kwargs) + #data_loader = torch.utils.data.DataLoader(dataset, batch_size=params.batch_size, + # shuffle=False, num_workers=0, pin_memory=True) + #dataset_mean_image = dp.get_mean_from_dataloader(data_loader) + #extra_outputs['dataset_mean_image'] = dataset_mean_image + #dataset_std_image = dp.get_std_from_dataloader(data_loader, dataset_mean_image) + #extra_outputs['dataset_std_image'] = dataset_std_image + preprocessing_pipeline.append( + transforms.Lambda( + lambda x: dp.standardize(x, + eps=params.eps, + samplewise=True,#False, + batch_size=params.batch_size)[0] + #sample_mean=dataset_mean_image, + #sample_std=dataset_std_image)[0] + ) + ) + if params.rescale_data_to_one: + preprocessing_pipeline.append( + transforms.Lambda(lambda x: dp.rescale_data_to_one(x, eps=params.eps, samplewise=True)[0])) + kwargs['transform'] = transforms.Compose(preprocessing_pipeline) + kwargs['train'] = True + dataset = torchvision.datasets.CIFAR10(**kwargs) + kwargs['train'] = False + testset = torchvision.datasets.CIFAR10(**kwargs) + num_train = len(dataset) - params.num_validation + trainset, valset = torch.utils.data.random_split(dataset, [num_train, params.num_validation]) + train_loader = torch.utils.data.DataLoader(trainset, batch_size=params.batch_size, + shuffle=params.shuffle_data, num_workers=0, pin_memory=True) + val_loader = torch.utils.data.DataLoader(valset, batch_size=params.batch_size, + shuffle=False, num_workers=0, pin_memory=True) + test_loader = torch.utils.data.DataLoader(testset, batch_size=params.batch_size, + shuffle=False, num_workers=0, pin_memory=True) + elif(params.dataset.lower() == 'dsprites'): root = os.path.join(*[params.data_dir]) dsprites_file = os.path.join(*[root, 'dsprites/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz']) @@ -67,10 +173,11 @@ def load_dataset(params): pin_memory=False) val_loader = None test_loader = None + elif(params.dataset.lower() == 'synthetic'): - preprocessing_pipeline = [transforms.ToTensor(), - transforms.Lambda(lambda x: x.permute(1, 2, 0)) # channels last - ] + preprocessing_pipeline = [ + transforms.ToTensor(), + ] train_loader = torch.utils.data.DataLoader( synthetic.SyntheticImages(params.epoch_size, params.data_edge_size, params.dist_type, params.rand_state, params.num_classes, @@ -79,20 +186,20 @@ def load_dataset(params): num_workers=0, pin_memory=False) val_loader = None test_loader = None - new_params["num_pixels"] = params.data_edge_size**2 + else: assert False, (f'Supported datasets are ["mnist", "dsprites", "synthetic"], not {dataset_name}') - new_params = {} - new_params['epoch_size'] = len(train_loader.dataset) + data_stats['epoch_size'] = len(train_loader.dataset) if(not hasattr(params, 'num_val_images')): if val_loader is None: - new_params['num_val_images'] = 0 + data_stats['num_val_images'] = 0 else: - new_params['num_val_images'] = len(val_loader.dataset) + data_stats['num_val_images'] = len(val_loader.dataset) if(not hasattr(params, 'num_test_images')): if test_loader is None: - new_params['num_test_images'] = 0 + data_stats['num_test_images'] = 0 else: - new_params['num_test_images'] = len(test_loader.dataset) - new_params['data_shape'] = list(next(iter(train_loader))[0].shape)[1:] - return (train_loader, val_loader, test_loader, new_params) + data_stats['num_test_images'] = len(test_loader.dataset) + data_stats['data_shape'] = list(next(iter(train_loader))[0].shape)[1:] + data_stats['num_pixels'] = np.prod(data_stats['data_shape']) + return (train_loader, val_loader, test_loader, data_stats, extra_outputs) diff --git a/utils/file_utils.py b/utils/file_utils.py index 6dd4b7f8..2ca41dc9 100644 --- a/utils/file_utils.py +++ b/utils/file_utils.py @@ -3,6 +3,7 @@ import types import os from copy import deepcopy +from collections import OrderedDict import importlib import numpy as np @@ -27,6 +28,16 @@ def js_dumpstring(self, obj): """Dump json string with special CustomEncoder""" return js.dumps(obj, sort_keys=True, indent=2, cls=CustomEncoder) + def log_string(self, string): + """Log input string""" + now = time.localtime(time.time()) + time_str = time.strftime('%m/%d/%y %H:%M:%S', now) + out_str = '\n' + time_str + ' -- ' + str(string) + if(self.log_to_file): + self.file_obj.write(out_str) + else: + print(out_str) + def log_trainable_variables(self, name_list): """ Use logging to write names of trainable variables in model @@ -34,7 +45,7 @@ def log_trainable_variables(self, name_list): name_list: list containing variable names """ js_str = self.js_dumpstring(name_list) - self.log_info(''+js_str+'') + self.log_string(''+js_str+'') def log_params(self, params): """ @@ -45,7 +56,7 @@ def log_params(self, params): out_params = deepcopy(params) if('ensemble_params' in out_params.keys()): for sub_idx, sub_params in enumerate(out_params['ensemble_params']): - sub_params.set_params() + #sub_params.set_params() for key, value in sub_params.__dict__.items(): if(key != 'rand_state'): new_dict_key = f'{sub_idx}_{key}' @@ -54,17 +65,17 @@ def log_params(self, params): if('rand_state' in out_params.keys()): del out_params['rand_state'] js_str = self.js_dumpstring(out_params) - self.log_info(''+js_str+'') + self.log_string(''+js_str+'') - def log_info(self, string): - """Log input string""" - now = time.localtime(time.time()) - time_str = time.strftime('%m/%d/%y %H:%M:%S', now) - out_str = '\n' + time_str + ' -- ' + str(string) - if(self.log_to_file): - self.file_obj.write(out_str) - else: - print(out_str) + def log_stats(self, stat_dict): + """Log dictionary of training / testing statistics""" + js_str = self.js_dumpstring(stat_dict) + self.log_string(''+js_str+'') + + def log_info(self, info_dict): + """Log input dictionary in tags""" + js_str = self.js_dumpstring(info_dict) + self.log_string(''+js_str+'') def load_file(self, filename=None): """ @@ -168,6 +179,18 @@ def read_stats(self, text): stats[key] = [js_match[key]] return stats + def read_architecture(self, text): + """ + Generate dictionary of lists that contain stats from log text + Outpus: + stats: [dict] containing run statistics + Inputs: + text: [str] containing text to parse, can be obtained by calling load_file() + """ + tokens = ['', ''] + js_match = self.read_js(tokens, text) + return js_match + def __del__(self): if(self.log_to_file and hasattr(self, 'file_obj')): self.file_obj.close() @@ -199,3 +222,92 @@ def python_module_from_file(py_module_name, file_name): py_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(py_module) return py_module + +def summary_string(model, input_size, batch_size=2, device=torch.device('cuda:0'), dtype=torch.FloatTensor): + """ + Returns a string that summarizees the model architecture, including the number of parameters + and layer output sizes + + Code is modified from: + https://github.com/sksq96/pytorch-summary + + Keyword arguments: + model [torch module, module subclass, or EnsembleModel] model to summarize + input_size [tuple or list of tuples] must not include the batch dimension; if it is a list + of tuples then the architecture will be computed for each option + batch_size [positive int] how many images to feed into the model. + The default of 2 will ensure that batch norm works. + devie [torch.device] which device to run the test on + dtype [torch.dtype] for the artificially generated inputs + """ + def register_hook(module): + def hook(module, input, output): + class_name = str(module.__class__).split('.')[-1].split("'")[0] + module_idx = len(summary) + m_key = '%s-%i' % (class_name, module_idx + 1) + summary[m_key] = OrderedDict() + summary[m_key]['input_shape'] = list(input[0].size()) + summary[m_key]['input_shape'][0] = batch_size + if isinstance(output, (list, tuple)): + summary[m_key]['output_shape'] = [ + [-1] + list(o.size())[1:] for o in output + ] + else: + summary[m_key]['output_shape'] = list(output.size()) + summary[m_key]['output_shape'][0] = batch_size + params = 0 + if hasattr(module, 'weight') and hasattr(module.weight, 'size'): + params += torch.prod(torch.LongTensor(list(module.weight.size()))) + summary[m_key]['trainable'] = module.weight.requires_grad + if hasattr(module, 'bias') and hasattr(module.bias, 'size'): + params += torch.prod(torch.LongTensor(list(module.bias.size()))) + summary[m_key]['nb_params'] = params + summary[m_key]['gpu_mem'] = round(torch.cuda.memory_allocated(0)/1024**3, 1) + if len(list(module.children())) == 0: # only apply hooks at child modules to avoid applying them twice + hooks.append(module.register_forward_hook(hook)) + x = torch.rand(batch_size, *input_size).type(dtype).to(device=device) + summary = OrderedDict() # used within hook function to store properties + hooks = [] # used within hook function to store resgistered hooks + model.apply(register_hook) # recursively apply register_hook function to model and all children + model(x) # make a forward pass + for h in hooks: + h.remove() # remove the hooks so they are not used at run time + summary_str = '----------------------------------------------------------------\n' + line_new = '{:>20} {:>25} {:>15}'.format('Layer (type)', 'Output Shape', 'Param #') + summary_str += line_new + '\n' + summary_str += '================================================================\n' + total_params = 0 + total_output = 0 + trainable_params = 0 + for layer in summary: + line_new = '{:>20} {:>25} {:>15}'.format( + layer, + str(summary[layer]['output_shape']), + '{0:,}'.format(summary[layer]['nb_params']), + ) # input_shape, output_shape, trainable, nb_params + total_params += summary[layer]['nb_params'] + total_output += np.prod(summary[layer]['output_shape']) + if 'trainable' in summary[layer]: + if summary[layer]['trainable'] == True: + trainable_params += summary[layer]['nb_params'] + summary_str += line_new + '\n' + # assume 4 bytes/number (float on cuda). + total_input_size = abs(np.prod(input_size)) * batch_size * 4. / (1024 ** 2.) + total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients + total_params_size = abs(total_params * 4. / (1024 ** 2.)) + total_size = total_params_size + total_output_size + total_input_size + summary_str += '================================================================\n' + summary_str += f'Total params: {total_params}\n' + summary_str += f'Trainable params: {trainable_params}\n' + param_diff = total_params - trainable_params + summary_str += f'Non-trainable params: {param_diff}\n' + summary_str += '----------------------------------------------------------------\n' + summary_str += f'Input size (MB): {total_input_size:0.2f}\n' + summary_str += f'Forward/backward pass size (MB): {total_output_size:0.2f}\n' + summary_str += f'Params size (MB): {total_params_size:0.2f}\n' + summary_str += f'Estimated total size (MB): {total_size:0.2f}\n' + ## TODO: Update pytorch for this to work + #device_memory = torch.cuda.memory_summary(device, abbreviated=True) + #summary_str += f'Device memory allocated with batch of inputs (GB): {device_memory}\n' + summary_str += '----------------------------------------------------------------\n' + return summary_str, (total_params, trainable_params) diff --git a/utils/loaders.py b/utils/loaders.py index 9b73a8fb..8b395e2a 100644 --- a/utils/loaders.py +++ b/utils/loaders.py @@ -1,11 +1,13 @@ import os import sys +from os.path import dirname as up -ROOT_DIR = os.path.dirname(os.getcwd()) +ROOT_DIR = up(up(up(os.path.realpath(__file__)))) if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR) import DeepSparseCoding.utils.file_utils as file_utils + def get_dir_list(target_dir, target_string): dir_list = [filename.split('.')[0] for filename in os.listdir(target_dir) @@ -42,13 +44,16 @@ def load_model_class(model_type): elif(model_type.lower() == 'lca'): py_module_name = 'LcaModel' file_name = os.path.join(*[dsc_dir, 'models', 'lca_model.py']) + elif(model_type.lower() == 'pooling'): + py_module_name = 'PoolingModel' + file_name = os.path.join(*[dsc_dir, 'models', 'pooling_model.py']) elif(model_type.lower() == 'ensemble'): py_module_name = 'EnsembleModel' file_name = os.path.join(*[dsc_dir, 'models', 'ensemble_model.py']) else: accepted_names = [''.join(name.split('_')[:-1]) for name in get_module_list(dsc_dir)] assert False, ( - 'Acceptible model_types are %s, not %s'%(','.join(accepted_names), model_type)) + 'Acceptible model_types are %s, not %s'%('; '.join(accepted_names), model_type)) py_module = file_utils.python_module_from_file(py_module_name, file_name) py_module_class = getattr(py_module, py_module_name) return py_module_class @@ -66,19 +71,29 @@ def load_module(module_type): elif(module_type.lower() == 'lca'): py_module_name = 'LcaModule' file_name = os.path.join(*[dsc_dir, 'modules', 'lca_module.py']) + elif(module_type.lower() == 'pooling'): + py_module_name = 'PoolingModule' + file_name = os.path.join(*[dsc_dir, 'modules', 'pooling_module.py']) elif(module_type.lower() == 'ensemble'): py_module_name = 'EnsembleModule' file_name = os.path.join(*[dsc_dir, 'modules', 'ensemble_module.py']) else: accepted_names = [''.join(name.split('_')[:-1]) for name in get_module_list(dsc_dir)] assert False, ( - 'Acceptible model_types are %s, not %s'%(','.join(accepted_names), module_type)) + 'Acceptible model_types are %s, not %s'%('; '.join(accepted_names), module_type)) py_module = file_utils.python_module_from_file(py_module_name, file_name) py_module_class = getattr(py_module, py_module_name) return py_module_class() -def load_params(file_name, key='params'): +def load_params_from_log(log_file): + logger = file_utils.Logger(log_file, overwrite=False) + log_text = logger.load_file() + params = logger.read_params(log_text)[-1] + return params + + +def load_params_file(file_name, key='params'): params_module = file_utils.python_module_from_file(key, file_name) params = getattr(params_module, key)() return params diff --git a/utils/plot_functions.py b/utils/plot_functions.py index 150ab1eb..8374db1b 100644 --- a/utils/plot_functions.py +++ b/utils/plot_functions.py @@ -94,3 +94,29 @@ def plot_stats(data, x_key, x_label=None, y_keys=None, y_labels=None, start_inde return None plot.show() return fig + +def pad_images(images, pad_values=1): + """ + Convert an array of images into a single tiled image with padded border + + Keyword arguments: + images: [np.ndarray] of shape [num_samples, im_height, im_width, im_chan] + pad_values: [int] specifying what value will be used for padding + + Outputs: + padded_images: [np.ndarray] padded version of input + """ + n = int(np.ceil(np.sqrt(images.shape[0]))) + padding = (((0, n ** 2 - images.shape[0]), + (1, 1), (1, 1)) # add some space between filters + + ((0, 0),) * (images.ndim - 3)) # don't pad last dimension (if there is one) + padded_images = np.pad(images, padding, mode="constant", + constant_values=pad_values) + # tile the filters into an image + padded_images = padded_images.reshape(( + (n, n) + padded_images.shape[1:])).transpose(( + (0, 2, 1, 3) + tuple(range(4, padded_images.ndim + 1)))) + padded_images = padded_images.reshape((n * padded_images.shape[1], + n * padded_images.shape[3]) + padded_images.shape[4:]) + return padded_images + diff --git a/utils/run_utils.py b/utils/run_utils.py index 085fb6ca..20a1f7ea 100644 --- a/utils/run_utils.py +++ b/utils/run_utils.py @@ -1,5 +1,25 @@ +import numpy as np import torch +import DeepSparseCoding.utils.data_processing as dp + + +def compute_conv_output_shape(in_length, kernel_length, stride, padding=0, dilation=1): + out_shape = ((in_length + 2 * padding - dilation * (kernel_length - 1) - 1) / stride) + 1 + return np.floor(out_shape).astype(int) + + +def compute_deconv_output_shape(in_length, kernel_length, stride, padding=0, output_padding=0, dilation=1): + out_shape = (in_length - 1) * stride - 2 * padding + dilation * (kernel_length - 1) + output_padding + 1 + return np.floor(out_shape).astype(int) + + +def get_module_encodings(module, data, allow_grads=False): + if allow_grads: + return module.get_encodings(data) + else: + return module.get_encodings(data).detach() + def train_single_model(model, loss): model.optimizer.zero_grad() # clear gradietns of all optimized variables @@ -7,7 +27,7 @@ def train_single_model(model, loss): model.optimizer.step() if(hasattr(model.params, 'renormalize_weights') and model.params.renormalize_weights): with torch.no_grad(): # tell autograd to not record this operation - model.w.div_(torch.norm(model.w, dim=0, keepdim=True)) + model.weight.div_(dp.get_weights_l2_norm(model.weight)) def train_epoch(epoch, model, loader): @@ -18,33 +38,34 @@ def train_epoch(epoch, model, loader): for batch_idx, (data, target) in enumerate(loader): data, target = data.to(model.params.device), target.to(model.params.device) inputs = [] - if(model.params.model_type.lower() == 'ensemble'): # TODO: Move this to train_model + if(model.params.model_type.lower() == 'ensemble'): inputs.append(model[0].preprocess_data(data)) # First model preprocesses the input for submodule_idx, submodule in enumerate(model): loss = model.get_total_loss((inputs[-1], target), submodule_idx) train_single_model(submodule, loss) - # TODO: include optional parameter to allow gradients to propagate through the entire ensemble. - inputs.append(submodule.get_encodings(inputs[-1]).detach()) # must detach to prevent gradient leaking + encodings = get_module_encodings(submodule, inputs[-1], + model.params.allow_parent_grads) + inputs.append(encodings) else: inputs.append(model.preprocess_data(data)) loss = model.get_total_loss((inputs[-1], target)) train_single_model(model, loss) if model.params.train_logs_per_epoch is not None: if(batch_idx % int(num_batches/model.params.train_logs_per_epoch) == 0.): - batch_step = epoch * model.params.batches_per_epoch + batch_idx + batch_step = int((epoch - 1) * model.params.batches_per_epoch + batch_idx) model.print_update( input_data=inputs[0], input_labels=target, batch_step=batch_step) if(model.params.model_type.lower() == 'ensemble'): for submodule in model: - submodule.scheduler.step(epoch) + submodule.scheduler.step() else: - model.scheduler.step(epoch) + model.scheduler.step() def test_single_model(model, data, target, epoch): output = model(data) #test_loss = torch.nn.functional.nll_loss(output, target, reduction='sum').item() - test_loss = torch.nn.CorssEntropyLoss()(output, target) + test_loss = torch.nn.CrossEntropyLoss()(output, target) pred = output.max(1, keepdim=True)[1] correct = pred.eq(target.view_as(pred)).sum().item() return (test_loss, correct) @@ -76,13 +97,12 @@ def test_epoch(epoch, model, loader, log_to_file=True): test_accuracy = 100. * correct / len(loader.dataset) stat_dict = { 'test_epoch':epoch, - 'test_loss':test_loss, + 'test_loss':test_loss.item(), 'test_correct':correct, 'test_total':len(loader.dataset), 'test_accuracy':test_accuracy} if log_to_file: - js_str = model.js_dumpstring(stat_dict) - model.log_info(''+js_str+'') + model.logger.log_stats(stat_dict) else: return stat_dict