diff --git a/adversarial_analysis.py b/adversarial_analysis.py
deleted file mode 100644
index 716d6d5f..00000000
--- a/adversarial_analysis.py
+++ /dev/null
@@ -1,121 +0,0 @@
-import os
-import sys
-
-ROOT_DIR = os.path.dirname(os.getcwd())
-if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
-
-import numpy as np
-import proplot as plot
-import torch
-
-from DeepSparseCoding.utils.file_utils import Logger
-import DeepSparseCoding.utils.loaders as loaders
-import DeepSparseCoding.utils.run_utils as run_utils
-import DeepSparseCoding.utils.dataset_utils as dataset_utils
-import DeepSparseCoding.utils.run_utils as ru
-import DeepSparseCoding.utils.plot_functions as pf
-
-import eagerpy as ep
-from foolbox import PyTorchModel, accuracy, samples
-import foolbox.attacks as fa
-
-
-log_files = [
- os.path.join(*[ROOT_DIR, 'Torch_projects', 'mlp_768_mnist', 'logfiles', 'mlp_768_mnist_v0.log']),
- os.path.join(*[ROOT_DIR, 'Torch_projects', 'lca_768_mlp_mnist', 'logfiles', 'lca_768_mlp_mnist_v0.log'])
- ]
-
-cp_latest_filenames = [
- os.path.join(*[ROOT_DIR,'Torch_projects', 'mlp_768_mnist', 'checkpoints', 'mlp_768_mnist_latest_checkpoint_v0.pt']),
- os.path.join(*[ROOT_DIR, 'Torch_projects', 'lca_768_mlp_mnist', 'checkpoints', 'lca_768_mlp_mnist_latest_checkpoint_v0.pt'])
- ]
-
-attack_params = {
- 'linfPGD': {
- 'abs_stepsize':0.01,
- 'steps':5000
- }
-}
-
-attacks = [
- #fa.FGSM(),
- fa.LinfPGD(**attack_params['linfPGD']),
- #fa.LinfBasicIterativeAttack(),
- #fa.LinfAdditiveUniformNoiseAttack(),
- #fa.LinfDeepFoolAttack(),
-]
-
-epsilons = [ # allowed perturbation size
- 0.0,
- 0.05,
- 0.1,
- 0.15,
- 0.2,
- 0.25,
- 0.3,
- 0.35,
- #0.4,
- 0.5,
- #0.8,
- 1.0
-]
-
-num_models = len(log_files)
-for model_index in range(num_models):
- logger = Logger(log_files[model_index], overwrite=False)
- log_text = logger.load_file()
- params = logger.read_params(log_text)[-1]
- params.cp_latest_filename = cp_latest_filenames[model_index]
- train_loader, val_loader, test_loader, data_params = dataset_utils.load_dataset(params)
- for key, value in data_params.items():
- setattr(params, key, value)
- model = loaders.load_model(params.model_type)
- model.setup(params, logger)
- model.params.analysis_out_dir = os.path.join(
- *[model.params.model_out_dir, 'analysis', model.params.version])
- model.params.analysis_save_dir = os.path.join(model.params.analysis_out_dir, 'savefiles')
- if not os.path.exists(model.params.analysis_save_dir):
- os.makedirs(model.params.analysis_save_dir)
- model.to(params.device)
- model.load_checkpoint()
- fmodel = PyTorchModel(model.eval(), bounds=(0, 1))
- print('\n', '~' * 79)
- num_batches = len(test_loader.dataset) // model.params.batch_size
- attack_success = np.zeros(
- (len(attacks), len(epsilons), num_batches, model.params.batch_size), dtype=np.bool)
- for batch_index, (data, target) in enumerate(test_loader):
- data = model.preprocess_data(data.to(model.params.device))
- target = target.to(model.params.device)
- images, labels = ep.astensors(*(data, target))
- del data; del target
- print(f'Model type: {model.params.model_type} [{model_index+1} out of {len(log_files)}]')
- print(f'Batch {batch_index+1} out of {num_batches}')
- print(f'accuracy {accuracy(fmodel, images, labels)}')
- for attack_index, attack in enumerate(attacks):
- advs, inputs, success = attack(fmodel, images, labels, epsilons=epsilons)
- assert success.shape == (len(epsilons), len(images))
- success_ = success.numpy()
- assert success_.dtype == np.bool
- attack_success[attack_index, :, batch_index, :] = success_
- print('\n', attack)
- print(' ', 1.0 - success_.mean(axis=-1).round(2))
- np.savez('tmp_perturbations.npz', data=advs[0].numpy())
- np.savez('tmp_images.npz', data=images.numpy())
- np.savez('tmp_inputs.npz', data=inputs[0].numpy())
- import IPython; IPython.embed(); raise SystemExit
- robust_accuracy = 1.0 - attack_success[:, :, batch_index, :].max(axis=0).mean(axis=-1)
- print('\n', '-' * 79, '\n')
- print('worst case (best attack per-sample)')
- print(' ', robust_accuracy.round(2))
- print('-' * 79)
- attack_success = attack_success.reshape(
- (len(attacks), len(epsilons), num_batches*model.params.batch_size))
- attack_types = [str(type(attack)).split('.')[-1][:-2] for attack in attacks]
- output_filename = os.path.join(model.params.analysis_save_dir,
- f'linf_adversarial_analysis.npz')
- out_dict = {
- 'adversarial_analysis':attack_success,
- 'attack_types':attack_types,
- 'epsilons':epsilons,
- 'attack_params':attack_params}
- np.savez(output_filename, data=out_dict)
diff --git a/datasets/synthetic.py b/datasets/synthetic.py
index db82e635..3af48bc6 100644
--- a/datasets/synthetic.py
+++ b/datasets/synthetic.py
@@ -1,5 +1,9 @@
import os
import sys
+from os.path import dirname as up
+
+ROOT_DIR = up(up(up(os.path.realpath(__file__))))
+if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
import numpy as np
from scipy.stats import norm
@@ -7,9 +11,6 @@
import torch
import torchvision
-ROOT_DIR = os.path.dirname(os.getcwd())
-if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
-
import DeepSparseCoding.utils.data_processing as dp
class SyntheticImages(torchvision.datasets.vision.VisionDataset):
diff --git a/models/base_model.py b/models/base_model.py
index c1e383fa..904cb7ea 100644
--- a/models/base_model.py
+++ b/models/base_model.py
@@ -1,16 +1,20 @@
import os
+import subprocess
+import pprint
import numpy as np
import torch
+from DeepSparseCoding.utils.file_utils import summary_string
from DeepSparseCoding.utils.file_utils import Logger
+from DeepSparseCoding.utils.run_utils import compute_conv_output_shape
+import DeepSparseCoding.utils.loaders as loaders
class BaseModel(object):
def setup(self, params, logger=None):
"""
Setup required model components
- #TODO: log system info, including git commit hash
"""
self.load_params(params)
self.check_params()
@@ -18,6 +22,7 @@ def setup(self, params, logger=None):
if logger is None:
self.init_logging()
self.log_params()
+ self.logger.log_info(self.get_env_details())
else:
self.logger = logger
@@ -92,24 +97,106 @@ def log_params(self, params=None):
dump_obj = self.params.__dict__
self.logger.log_params(dump_obj)
- def log_info(self, string):
- """Log input string"""
- self.logger.log_info(string)
+ def get_train_stats(self, batch_step=None):
+ """
+ Get default statistics about current training run
+
+ Keyword arguments:
+ batch_step: [int] current batch iteration. The default assumes that training has finished.
+ """
+ if batch_step is None:
+ batch_step = self.params.num_batches
+ epoch = batch_step / self.params.batches_per_epoch
+ stat_dict = {
+ 'epoch':int(epoch),
+ 'batch_step':batch_step,
+ 'train_progress':np.round(batch_step/self.params.num_batches, 3),
+ }
+ return stat_dict
- def write_checkpoint(self):
- """Write checkpoints"""
- torch.save(self.state_dict(), self.params.cp_latest_filename)
- self.log_info('Full model saved in file %s'%self.params.cp_latest_filename)
+ def get_env_details(self):
+ env = {}
+ for k in ['SYSTEMROOT', 'PATH']:
+ v = os.environ.get(k)
+ if v is not None:
+ env[k] = v
+ commit_cmd = ['git', 'rev-parse', 'HEAD']
+ commit = subprocess.Popen(commit_cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
+ commit = commit.strip().decode('ascii')
+ branch_cmd = ['git', 'rev-parse', '--abbrev-ref', 'HEAD']
+ branch = subprocess.Popen(branch_cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
+ branch = branch.strip().decode('ascii')
+ system_details = os.uname()
+ out_dict = {
+ 'current_branch':branch,
+ 'current_commit_hash':commit,
+ 'sysname':system_details.sysname,
+ 'release':system_details.release,
+ 'machine':system_details.machine
+ }
+ if torch.cuda.is_available():
+ out_dict['gpu_device'] = torch.cuda.get_device_name(0)
+ return out_dict
- def load_checkpoint(self, cp_file=None):
+ def log_architecture_details(self):
+ """
+ Log model architecture with computed output sizes and number of parameters for each layer
+ """
+ architecture_string = '\n'+summary_string(
+ self,
+ input_size=tuple(self.params.data_shape),
+ batch_size=self.params.batch_size,
+ device=self.params.device,
+ dtype=torch.FloatTensor
+ )[0]
+ architecture_string += '\n'
+ self.logger.log_string(architecture_string)
+
+ def write_checkpoint(self, batch_step=None):
+ """
+ Write checkpoints
+
+ Keyword arguments:
+ batch_step: [int] current batch iteration. The default assumes that training has finished.
+ """
+ output_dict = {}
+ if(self.params.model_type.lower() == 'ensemble'):
+ for module in self:
+ module_name = module.params.submodule_name
+ output_dict[module_name+'_module_state_dict'] = module.state_dict()
+ output_dict[module_name+'_optimizer_state_dict'] = module.optimizer.state_dict()
+ else:
+ output_dict['model_state_dict'] = self.state_dict()
+ module_state_dict_name = 'optimizer_state_dict'
+ output_dict[module_state_dict_name] = self.optimizer.state_dict(),
+ ## TODO: Save scheduler state dict as well
+ training_stats = self.get_train_stats(batch_step)
+ output_dict.update(training_stats)
+ torch.save(output_dict, self.params.cp_latest_filename)
+ self.logger.log_string('Full model saved in file %s'%self.params.cp_latest_filename)
+
+ def get_checkpoint_from_log(self, logfile):
+ model_params = loaders.load_params_from_log(logfile)
+ checkpoint = torch.load(model_params.cp_latest_filename)
+ return checkpoint
+
+ def load_checkpoint(self, cp_file=None, load_optimizer=False):
"""
Load checkpoint
- Inputs:
- model_dir: String specifying the path to the checkpoint
+ Keyword arguments:
+ model_dir: [str] specifying the path to the checkpoint
"""
if cp_file is None:
cp_file = self.params.cp_latest_filename
- return self.load_state_dict(torch.load(cp_file))
+ checkpoint = torch.load(cp_file)
+ if load_optimizer:
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+ self.load_state_dict(checkpoint['model_state_dict'])
+ _ = checkpoint.pop('optimizer_state_dict', None)
+ _ = checkpoint.pop('model_state_dict', None)
+ training_status = pprint.pformat(checkpoint, compact=True)#, sort_dicts=True #TODO: Python 3.8 adds the sort_dicts parameter
+ out_str = f'Loaded checkpoint from {cp_file} with the following stats:\n{training_status}'
+ return out_str
def get_optimizer(self, optimizer_params, trainable_variables):
optimizer_name = optimizer_params.optimizer.name
@@ -129,8 +216,8 @@ def get_optimizer(self, optimizer_params, trainable_variables):
def setup_optimizer(self):
self.optimizer = self.get_optimizer(
- optimizer_params=self.params,
- trainable_variables=self.parameters())
+ optimizer_params=self.params,
+ trainable_variables=self.parameters())
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
self.optimizer,
milestones=self.params.optimizer.milestones,
@@ -143,21 +230,18 @@ def print_update(self, input_data, input_labels=None, batch_step=0):
input_data: data object containing the current image batch
input_labels: data object containing the current label batch
batch_step: current batch number within the schedule
- NOTE: For the analysis code to parse update statistics, the self.js_dumpstring() call
- must receive a dict object. Additionally, the self.js_dumpstring() output must be
- logged with tags.
- For example: logging.info(''+self.js_dumpstring(output_dictionary)+'')
+ NOTE: For the analysis code to parse update statistics,
+ the logger.log_stats() function must be used
"""
update_dict = self.generate_update_dict(input_data, input_labels, batch_step)
- js_str = self.js_dumpstring(update_dict)
- self.log_info(''+js_str+'')
+ self.logger.log_stats(update_dict)
def generate_update_dict(self, input_data, input_labels=None, batch_step=0, update_dict=None):
"""
Generates a dictionary to be logged in the print_update function
"""
if update_dict is None:
- update_dict = dict()
+ update_dict = self.get_train_stats(batch_step)
for param_name, param_var in self.named_parameters():
grad = param_var.grad
update_dict[param_name+'_grad_max_mean_min'] = [
diff --git a/models/ensemble_model.py b/models/ensemble_model.py
index 60bbb343..d1b334fd 100644
--- a/models/ensemble_model.py
+++ b/models/ensemble_model.py
@@ -1,3 +1,5 @@
+import pprint
+
import torch
import DeepSparseCoding.utils.loaders as loaders
@@ -15,28 +17,96 @@ def setup(self, params, logger=None):
self.setup_optimizer()
def setup_module(self, params):
- for subparams in params.ensemble_params:
+ layer_names = [] # TODO: Make this submodule_name=model_type+layer_name is unique, not layer_name is unique
+ for sub_index, subparams in enumerate(params.ensemble_params):
+ layer_names.append(subparams.layer_name)
+ assert len(set(layer_names)) == len(layer_names), (
+ 'The "layer_name" parameter must be unique for each module in the ensemble.')
+ subparams.submodule_name = subparams.model_type + '_' + subparams.layer_name
subparams.epoch_size = params.epoch_size
subparams.batches_per_epoch = params.batches_per_epoch
subparams.num_batches = params.num_batches
#subparams.num_val_images = params.num_val_images
#subparams.num_test_images = params.num_test_images
- subparams.data_shape = params.data_shape
+ if not hasattr(subparams, 'data_shape'): # TODO: This is a workaround for a dependency on data_shape in lca module
+ subparams.data_shape = params.data_shape
super(EnsembleModel, self).setup_ensemble_module(params)
self.submodel_classes = []
- for submodel_params in self.params.ensemble_params:
- self.submodel_classes.append(loaders.load_model_class(submodel_params.model_type))
+ for ensemble_index, subparams in enumerate(self.params.ensemble_params):
+ submodule_class = loaders.load_model_class(subparams.model_type)
+ self.submodel_classes.append(submodule_class)
+ if subparams.checkpoint_boot_log != '':
+ checkpoint = self.get_checkpoint_from_log(subparams.checkpoint_boot_log)
+ submodule = self.__getitem__(ensemble_index)
+ module_state_dict_name = subparams.submodule_name+'_module_state_dict'
+ if module_state_dict_name in checkpoint.keys(): # It was already in an ensemble
+ submodule.load_state_dict(checkpoint[module_state_dict_name])
+ else: # it was trained on its own
+ if 'model_state_dict' in checkpoint.keys():
+ submodule.load_state_dict(checkpoint['model_state_dict'])
+ else:
+ assert False, (
+ f'subparams {subparams} has checkpoint_boot_log set to '
+ +f'{subparams.checkpoint_boot_log}, but that log does not have the '
+ +f'appropriate key. The key "{module_state_dict_name}" must be in '
+ +f'checkpoint.keys() = {checkpoint.keys}')
def setup_optimizer(self):
for module in self:
module.optimizer = self.get_optimizer(
optimizer_params=module.params,
trainable_variables=module.parameters())
+ if module.params.checkpoint_boot_log != '':
+ checkpoint = self.get_checkpoint_from_log(module.params.checkpoint_boot_log)
+ module_state_dict_name = module.params.submodule_name+'_optimizer_state_dict'
+ if module_state_dict_name in checkpoint.keys(): # It was already in an ensemble
+ module.optimizer.load_state_dict(checkpoint[module_state_dict_name])
+ else: # it was trained on its own
+ module.optimizer.load_state_dict(checkpoint['optimizer_state_dict'][0]) #TODO: For some reason this is a tuple of size 1 containing the dictionary. It should just be the dictionary
+ for group in module.optimizer.param_groups: # overwrite learning rates
+ group['lr'] = module.params.weight_lr
+ group['initial_lr'] = module.params.weight_lr
+ ## TODO: load scheduler state dict with checkpoint, set last_epoch correctly
+ ## https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html
module.scheduler = torch.optim.lr_scheduler.MultiStepLR(
module.optimizer,
milestones=module.params.optimizer.milestones,
gamma=module.params.optimizer.lr_decay_rate)
+ def load_checkpoint(self, cp_file=None, load_optimizer=False):
+ """
+ Load checkpoint
+ Keyword arguments:
+ model_dir: [str] specifying the path to the checkpoint
+ """
+ if cp_file is None:
+ cp_file = self.params.cp_latest_filename
+ checkpoint = torch.load(cp_file)
+ for module in self:
+ module_state_dict_name = module.params.submodule_name+'_module_state_dict'
+ if module_state_dict_name in checkpoint.keys(): # It was already in an ensemble
+ module.load_state_dict(checkpoint[module_state_dict_name])
+ _ = checkpoint.pop(module_state_dict_name, None)
+ else: # it was trained on its own
+ module.load_state_dict(checkpoint['model_state_dict'])
+ _ = checkpoint.pop('optimizer_state_dict', None)
+ if load_optimizer:
+ module_state_dict_name = module.params.submodule_name+'_optimizer_state_dict'
+ if module_state_dict_name in checkpoint.keys(): # It was already in an ensemble
+ module.optimizer.load_state_dict(checkpoint[module_state_dict_name])
+ _ = checkpoint.pop(module_state_dict_name, None)
+ else:
+ module.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+ _ = checkpoint.pop('optimizer_state_dict', None)
+ for group in module.optimizer.param_groups: # overwrite learning rates
+ group['lr'] = module.params.weight_lr
+ group['initial_lr'] = module.params.weight_lr
+ ## TODO: Load scheduler state dict as well
+ _ = checkpoint.pop('model_state_dict', None)
+ training_status = pprint.pformat(checkpoint, compact=True)#, sort_dicts=True #TODO: Python 3.8 adds the sort_dicts parameter
+ out_str = f'Loaded checkpoint from {cp_file} with the following stats:\n{training_status}'
+ return out_str
+
def preprocess_data(self, data):
"""
We assume that only the first submodel will be preprocessing the input data
@@ -59,7 +129,7 @@ def generate_update_dict(self, input_data, input_labels=None, batch_step=0):
input_labels, batch_step, update_dict=dict())
for key, value in submodel_update_dict.items():
if key not in ['epoch', 'batch_step']:
- key = submodule.params.model_type+'_'+key
+ key = submodule.params.submodule_name + '_' + key
update_dict[key] = value
x = submodule.get_encodings(x)
return update_dict
diff --git a/models/lca_model.py b/models/lca_model.py
index ec13c014..ad78e3a8 100644
--- a/models/lca_model.py
+++ b/models/lca_model.py
@@ -11,6 +11,10 @@ def setup(self, params, logger=None):
super(LcaModel, self).setup(params, logger)
self.setup_module(params)
self.setup_optimizer()
+ if params.checkpoint_boot_log != '':
+ checkpoint = self.get_checkpoint_from_log(params.checkpoint_boot_log)
+ self.module.load_state_dict(checkpoint['model_state_dict'])
+ self.module.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
def get_total_loss(self, input_tuple):
input_tensor, input_labels = input_tuple
@@ -24,16 +28,12 @@ def get_total_loss(self, input_tuple):
def generate_update_dict(self, input_data, input_labels=None, batch_step=0, update_dict=None):
if update_dict is None:
update_dict = super(LcaModel, self).generate_update_dict(input_data, input_labels, batch_step)
- epoch = batch_step / self.params.batches_per_epoch
- stat_dict = {
- 'epoch':int(epoch),
- 'batch_step':batch_step,
- 'train_progress':np.round(batch_step/self.params.num_batches, 3),
- 'weight_lr':self.scheduler.get_lr()[0]}
+ stat_dict = dict()
latents = self.get_encodings(input_data)
recon = self.get_recon_from_latents(latents)
recon_loss = losses.half_squared_l2(input_data, recon).item()
sparse_loss = self.params.sparse_mult * losses.l1_norm(latents).item()
+ stat_dict['weight_lr'] = self.scheduler.get_last_lr()[0]
stat_dict['loss_recon'] = recon_loss
stat_dict['loss_sparse'] = sparse_loss
stat_dict['loss_total'] = recon_loss + sparse_loss
@@ -41,7 +41,20 @@ def generate_update_dict(self, input_data, input_labels=None, batch_step=0, upda
input_data.max().item(), input_data.mean().item(), input_data.min().item()]
stat_dict['recon_max_mean_min'] = [
recon.max().item(), recon.mean().item(), recon.min().item()]
- latent_nnz = torch.sum(latents != 0).item() # TODO: github issue 23907 requests torch.count_nonzero
- stat_dict['latents_fraction_active'] = latent_nnz / latents.numel()
+ def count_nonzero(array, dim):
+ # TODO: github issue 23907 requests torch.count_nonzero, integrated in torch 1.7
+ return torch.sum(array !=0, dim=dim, dtype=torch.float)
+ latent_dims = tuple([i for i in range(len(latents.shape))])
+ latent_nnz = count_nonzero(latents, dim=latent_dims).item()
+ stat_dict['fraction_active_all_latents'] = latent_nnz / latents.numel()
+ if self.params.layer_types[0] == 'conv':
+ latent_map_dims = latent_dims[2:]
+ latent_map_size = np.prod(list(latents.shape[2:]))
+ latent_channel_nnz = count_nonzero(latents, dim=latent_map_dims)/latent_map_size
+ latent_channel_mean_nnz = torch.mean(latent_channel_nnz).item()
+ stat_dict['fraction_active_latents_per_channel'] = latent_channel_mean_nnz
+ num_channels = latents.shape[1]
+ latent_patch_mean_nnz = torch.mean(count_nonzero(latents, dim=1)/num_channels).item()
+ stat_dict['fraction_active_latents_per_patch'] = latent_patch_mean_nnz
update_dict.update(stat_dict)
return update_dict
diff --git a/models/mlp_model.py b/models/mlp_model.py
index bf755d12..28d0f261 100644
--- a/models/mlp_model.py
+++ b/models/mlp_model.py
@@ -10,6 +10,10 @@ def setup(self, params, logger=None):
super(MlpModel, self).setup(params, logger)
self.setup_module(params)
self.setup_optimizer()
+ if params.checkpoint_boot_log != '':
+ checkpoint = self.get_checkpoint_from_log(params.checkpoint_boot_log)
+ self.module.load_state_dict(checkpoint['model_state_dict'])
+ self.module.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
def get_total_loss(self, input_tuple):
input_tensor, input_label = input_tuple
@@ -20,16 +24,12 @@ def get_total_loss(self, input_tuple):
def generate_update_dict(self, input_data, input_labels=None, batch_step=0, update_dict=None):
if update_dict is None:
update_dict = super(MlpModel, self).generate_update_dict(input_data, input_labels, batch_step)
- epoch = batch_step / self.params.batches_per_epoch
- stat_dict = {
- 'epoch':int(epoch),
- 'batch_step':batch_step,
- 'train_progress':np.round(batch_step/self.params.num_batches, 3)}
+ stat_dict = dict()
pred = self.forward(input_data)
- #total_loss = F.nll_loss(pred, input_labels)
total_loss = self.loss_fn(pred, input_labels)
pred = pred.max(1, keepdim=True)[1]
correct = pred.eq(input_labels.view_as(pred)).sum().item()
+ stat_dict['weight_lr'] = self.scheduler.get_last_lr()[0] # one LR for all parameters
stat_dict['loss'] = total_loss.item()
stat_dict['train_accuracy'] = 100. * correct / self.params.batch_size
update_dict.update(stat_dict)
diff --git a/models/pooling_model.py b/models/pooling_model.py
new file mode 100644
index 00000000..3f5caf1a
--- /dev/null
+++ b/models/pooling_model.py
@@ -0,0 +1,46 @@
+import torch
+
+import DeepSparseCoding.modules.losses as losses
+
+from DeepSparseCoding.models.base_model import BaseModel
+from DeepSparseCoding.modules.pooling_module import PoolingModule
+
+class PoolingModel(BaseModel, PoolingModule):
+ """
+ TODO: rename pool_ksize and pool_stride to just kernel_size and stride
+ """
+ def setup(self, params, logger=None):
+ self.setup_module(params)
+ self.setup_optimizer()
+ if params.checkpoint_boot_log != '':
+ checkpoint = self.get_checkpoint_from_log(params.checkpoint_boot_log)
+ self.module.load_state_dict(checkpoint['model_state_dict'])
+ self.module.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+
+ def get_total_loss(self, input_tuple):
+ def loss_fn(model_output):
+ output_loss = losses.trace_covariance(model_output)
+ w_stride = self.params.pool_stride
+ weight_loss = losses.weight_orthogonality(self.weight, stride=w_stride, padding=0)
+ return output_loss + weight_loss
+ input_tensor, input_label = input_tuple
+ layer_output = self.forward(input_tensor)
+ self.loss_fn = loss_fn
+ return self.loss_fn(layer_output)
+
+ def generate_update_dict(self, input_data, input_labels=None, batch_step=0, update_dict=None):
+ if update_dict is None:
+ update_dict = super(PoolinModel, self).generate_update_dict(input_data, input_labels, batch_step)
+ stat_dict = dict()
+ rep = self.forward(input_data)
+ def count_nonzero(array, dim):
+ # TODO: github issue 23907 requests torch.count_nonzero, integrated in torch 1.7
+ return torch.sum(array !=0, dim=dim, dtype=torch.float)
+ rep_dims = tuple([i for i in range(len(rep.shape))])
+ rep_nnz = count_nonzero(rep, dim=rep_dims).item()
+ stat_dict['fraction_active_all_latents'] = rep_nnz / rep.numel()
+ total_loss = self.loss_fn(rep)
+ stat_dict['weight_lr'] = self.scheduler.get_last_lr()[0]
+ stat_dict['loss'] = total_loss.item()
+ update_dict.update(stat_dict)
+ return update_dict
diff --git a/modules/activations.py b/modules/activations.py
index 98446536..56a502a2 100644
--- a/modules/activations.py
+++ b/modules/activations.py
@@ -1,17 +1,5 @@
import torch
import torch.nn as nn
-import torch.nn.functional as F
-
-def activation_picker(activation_function):
- if activation_function == 'identity':
- return lambda x: x
- if activation_function == 'relu':
- return F.relu
- if activation_function == 'lrelu' or activation_function == 'leaky_relu':
- return F.leaky_relu
- if activation_function == 'lca_threshold':
- return lca_threshold
- assert False, (f'Activation function {activation_function} is not supported.')
def lca_threshold(u_in, thresh_type, rectify, sparse_threshold):
u_zeros = torch.zeros_like(u_in)
@@ -40,3 +28,12 @@ def lca_threshold(u_in, thresh_type, rectify, sparse_threshold):
else:
assert False, (f'Parameter thresh_type must be "soft" or "hard", not {thresh_type}')
return a_out
+
+def activation_picker(activation_function):
+ if activation_function == 'identity':
+ return nn.Identity()
+ if activation_function == 'relu':
+ return nn.ReLU()
+ if activation_function == 'lrelu' or activation_function == 'leaky_relu':
+ return nn.LeakyReLU()
+ assert False, (f'Activation function {activation_function} is not supported.')
diff --git a/modules/ensemble_module.py b/modules/ensemble_module.py
index 9e7b7932..27e60f77 100644
--- a/modules/ensemble_module.py
+++ b/modules/ensemble_module.py
@@ -4,21 +4,19 @@
class EnsembleModule(nn.Sequential):
- def __init__(self): # do not do Sequential's init
- super(nn.Sequential, self).__init__()
-
def setup_ensemble_module(self, params):
self.params = params
for subparams in params.ensemble_params:
submodule = loaders.load_module(subparams.model_type)
submodule.setup_module(subparams)
- self.add_module(subparams.model_type, submodule)
+ self.add_module(subparams.layer_name, submodule)
def forward(self, x):
- self.layer_list = [x]
for module in self:
- self.layer_list.append(module.get_encodings(self.layer_list[-1])) # latent encodings
- return self.layer_list[-1]
+ if module.params.layer_types[0] == 'fc':
+ x = x.view(x.size(0), -1) #flat
+ x = module(x)
+ return x
def get_encodings(self, x):
return self.forward(x)
diff --git a/modules/lca_module.py b/modules/lca_module.py
index 9444a5d6..b8446aae 100644
--- a/modules/lca_module.py
+++ b/modules/lca_module.py
@@ -3,62 +3,132 @@
import torch.nn.functional as F
from DeepSparseCoding.modules.activations import lca_threshold
+from DeepSparseCoding.utils.run_utils import compute_conv_output_shape
+import DeepSparseCoding.utils.data_processing as dp
class LcaModule(nn.Module):
+ """
+ Keyword arguments:
+ params: [dict] with keys:
+ data_shape [list of int] of shape [elements, channels, height, width]; Assumes h = w (i.e. square inputs)
+ The remaining keys are only used layer_types[0] is "conv":
+ kernel_size: [int] edge size of the square convolving kernel
+ stride: [int] vertical and horizontal stride of the convolution
+ padding: [int] zero-padding added to both sides of the input
+ """
def setup_module(self, params):
self.params = params
- self.w = nn.Parameter(
- F.normalize(
- torch.randn(self.params.num_pixels, self.params.num_latent),
- p=2, dim=0),
- requires_grad=True)
+ if self.params.layer_types[0] == 'fc':
+ self.w_shape = self.params.layer_channels[::-1] #[outputs, inputs]
+ self.layer_output_shape = [self.params.layer_channels[-1]]
+ else:
+ assert (self.params.data_shape[-1] % self.params.stride == 0), (
+ f'Stride = {self.params.stride} must divide evenly into input edge size = {self.params.data_shape[-1]}')
+ self.w_shape = [
+ self.params.layer_channels[1],
+ self.params.layer_channels[0],
+ self.params.kernel_size,
+ self.params.kernel_size
+ ]
+ output_height = compute_conv_output_shape(
+ self.params.data_shape[1],
+ self.params.kernel_size,
+ self.params.stride,
+ self.params.padding,
+ dilation=1)
+ output_width = compute_conv_output_shape(
+ self.params.data_shape[2],
+ self.params.kernel_size,
+ self.params.stride,
+ self.params.padding,
+ dilation=1)
+ self.layer_output_shape = [self.params.layer_channels[1], output_height, output_width]
+ w_init = torch.randn(self.w_shape)
+ w_init_normed = dp.l2_normalize_weights(w_init, eps=self.params.eps)
+ self.weight = nn.Parameter(w_init_normed, requires_grad=True)
def preprocess_data(self, input_tensor):
- input_tensor = input_tensor.view(-1, self.params.num_pixels)
+ if self.params.layer_types[0] == 'fc':
+ input_tensor = input_tensor.view(self.params.batch_size, -1)
return input_tensor
- def compute_excitatory_current(self, input_tensor):
- return torch.matmul(input_tensor, self.w)
+ def compute_excitatory_current(self, input_tensor, a_in):
+ if self.params.layer_types[0] == 'fc':
+ excitatory_current = torch.mm(input_tensor, self.weight.T)
+ else:
+ recon = self.get_recon_from_latents(a_in)
+ recon_error = input_tensor - recon
+ error_injection = F.conv2d(
+ input=recon_error,
+ weight=self.weight,
+ bias=None,
+ stride=self.params.stride,
+ padding=self.params.padding
+ )
+ excitatory_current = error_injection + a_in
+ return excitatory_current
def compute_inhibitory_connectivity(self):
- lca_g = torch.matmul(torch.transpose(self.w, dim0=0, dim1=1),
- self.w) - torch.eye(self.params.num_latent,
+ identity = torch.eye(self.params.layer_channels[1],
requires_grad=True, device=self.params.device)
- return lca_g
+ if self.params.layer_types[0] == 'fc':
+ inhibitory_connectivity = torch.mm(self.weight, self.weight.T) - identity
+ else:
+ conv_kernels = self.weight.view(1, -1)
+ inhibitory_connectivity = torch.mm(conv_kernels, conv_kernels.T) - identity
+ return inhibitory_connectivity
def threshold_units(self, u_in):
a_out = lca_threshold(u_in, self.params.thresh_type, self.params.rectify_a,
self.params.sparse_mult)
return a_out
- def step_inference(self, u_in, a_in, b, g, step):
- lca_explain_away = torch.matmul(a_in, g)
- du = b - lca_explain_away - u_in
+ def step_inference(self, u_in, a_in, excitatory_current, inhibitory_connectivity, step):
+ if self.params.layer_types[0] == 'fc':
+ lca_explain_away = torch.mm(a_in, inhibitory_connectivity)
+ else:
+ lca_explain_away = 0 # already computed in excitatory_current
+ du = excitatory_current - lca_explain_away - u_in
u_out = u_in + self.params.step_size * du
return u_out, lca_explain_away
def infer_coefficients(self, input_tensor):
- lca_b = self.compute_excitatory_current(input_tensor)
- lca_g = self.compute_inhibitory_connectivity()
- u_list = [torch.zeros([input_tensor.shape[0], self.params.num_latent],
- device=self.params.device)]
+ output_shape = [input_tensor.shape[0]] + self.layer_output_shape
+ u_list = [torch.zeros(output_shape, device=self.params.device)]
a_list = [self.threshold_units(u_list[0])]
- # TODO: look into redoing this with a register_buffer that gets updated? look up simple RNN code...
+ excitatory_current = self.compute_excitatory_current(input_tensor, a_list[-1])
+ inhibitory_connectivity = self.compute_inhibitory_connectivity()
for step in range(self.params.num_steps-1):
- u = self.step_inference(u_list[step], a_list[step], lca_b, lca_g, step)[0]
+ u = self.step_inference(
+ u_list[step],
+ a_list[step],
+ excitatory_current,
+ inhibitory_connectivity,
+ step
+ )[0]
u_list.append(u)
a_list.append(self.threshold_units(u))
+ if self.params.layer_types[0] == 'conv':
+ excitatory_current = self.compute_excitatory_current(input_tensor, a_list[-1])
return (u_list, a_list)
- def get_recon_from_latents(self, latents):
- return torch.matmul(latents, torch.transpose(self.w, dim0=0, dim1=1))
+ def get_recon_from_latents(self, a_in):
+ if self.params.layer_types[0] == 'fc':
+ recon = torch.mm(a_in, self.weight)
+ else:
+ recon = F.conv_transpose2d(
+ input=a_in,
+ weight=self.weight,
+ bias=None,
+ stride=self.params.stride,
+ padding=self.params.padding
+ )
+ return recon
- def get_encodings(self, input_tensor):
+ def forward(self, input_tensor):
u_list, a_list = self.infer_coefficients(input_tensor)
return a_list[-1]
- def forward(self, input_tensor):
- latents = self.get_encodings(input_tensor)
- reconstruction = self.get_recon_from_latents(latents)
- return reconstruction
+ def get_encodings(self, input_tensor):
+ return self.forward(input_tensor)
diff --git a/modules/losses.py b/modules/losses.py
index d9ea61b7..4176a10e 100644
--- a/modules/losses.py
+++ b/modules/losses.py
@@ -1,15 +1,24 @@
+import numpy as np
import torch
import DeepSparseCoding.utils.data_processing as dp
+#def l2_flatness(z1, z2, z3, weight):
+# """
+# Minimized when a straight line can be drawn through [z1, z2, z3].
+# Extended from equations 8 and 12 in
+# Chen, Paiton, Olshausen (2018) - The Sparse Manifold Transform
+# """
+# z_mat =
+
def half_squared_l2(x1, x2):
"""
Computes the standard reconstruction loss. It will average over batch dimensions.
- Args:
+ Keyword arguments:
x1: Tensor with original input image
x2: Tensor with reconstructed image for comparison
- Returns:
+ Outputs:
recon_loss: Tensor representing the squared l2 distance between the inputs, averaged over batch
"""
dp.check_all_same_shape([x1, x2])
@@ -22,15 +31,15 @@ def half_squared_l2(x1, x2):
def half_weight_norm_squared(weight_list):
"""
Computes a loss that encourages each weight in the list of weights to have unit l2 norm.
- Args:
+ Keyword arguments:
weight_list: List of torch variables
- Returns:
- w_norm_loss: 0.5 * sum of (1 - l2_norm(w))^2 for each w in weight_list
+ Outputs:
+ w_norm_loss: 0.5 * sum of (1 - l2_norm(weight))^2 for each weight in weight_list
"""
w_norm_list = []
- for w in weight_list:
- reduc_dim = list(range(1, len(w.shape)))
- w_norm = torch.sum(torch.pow(1 - torch.sqrt(torch.sum(tf.pow(w, 2.), axis=reduc_dim)), 2.))
+ for weight in weight_list:
+ reduc_dim = list(range(1, len(weight.shape)))
+ w_norm = torch.sum(torch.pow(1 - torch.sqrt(torch.sum(tf.pow(weight, 2.), axis=reduc_dim)), 2.))
w_norm_list.append(w_norm)
norm_loss = 0.5 * torch.sum(w_norm_list)
return norm_loss
@@ -39,12 +48,12 @@ def half_weight_norm_squared(weight_list):
def weight_decay(weight_list):
"""
Computes typical weight decay loss
- Args:
+ Keyword arguments:
weight_list: List of torch variables
- Returns:
- decay_loss: 0.5 * sum of w^2 for each w in weight_list
+ Outputs:
+ decay_loss: 0.5 * sum of weight^2 for each weight in weight_list
"""
- decay_loss = 0.5 * torch.sum([torch.sum(torch.pow(w, 2.)) for w in weight_list])
+ decay_loss = 0.5 * torch.sum([torch.sum(torch.pow(weight, 2.)) for weight in weight_list])
return decay_loss
@@ -52,11 +61,66 @@ def l1_norm(latents):
"""
Computes the L1 norm of for a batch of input vector
This is the sparsity loss for a Laplacian prior
- Args:
+ Keyword arguments:
latents: torch tensor of any shape, but where first index is always batch
- Returns:
+ Outputs:
sparse_loss: sum of abs of latents, averaged over the batch
"""
reduc_dim = list(range(1, len(latents.shape)))
sparse_loss = torch.mean(torch.sum(torch.abs(latents), dim=reduc_dim, keepdim=False))
return sparse_loss
+
+
+def trace_covariance(latents):
+ """
+ Returns loss that is the trace of the covariance matrix of the latents
+
+ Keyword arguments:
+ latents: torch tensor of shape [num_batch, num_latents] or [num_batch, num_channels, latents_h, latents_w]
+ Outputs:
+ loss
+ """
+ covariance = dp.covariance(latents) # [num_channels, num_channels]
+ if latents.ndim == 4:
+ num_batch, num_channels, latents_h, latents_w = latents.shape
+ covariance = covariance / (latents_h * latents_w - 1.0)
+ trace = torch.trace(covariance)
+ target = torch.trace(torch.eye(covariance.size(0), device=trace.device)) # should = trace.size[0]
+ return torch.norm(trace - target, p='fro')
+
+
+def weight_orthogonality(weight, stride=1, padding=0):
+ """
+ Returns l2 loss that is minimized when the weight are orthogonal
+
+ Keyword arguments:
+ weight [torch tensor] layer weight, either fully connected or 2d convolutional
+ stride [int] layer stride for convolutional layers
+ padding [int] layer padding for convolutional layers
+
+ Outputs:
+ loss
+
+ Note:
+ Convolutional orthogonalization loss is based on
+ Orthogonal Convolutional Neural Networks
+ https://arxiv.org/abs/1911.12207
+ https://github.com/samaonline/Orthogonal-Convolutional-Neural-Networks
+ """
+ w_shape = weight.shape
+ if weight.ndim == 2: # fully-connected, [inputs, outputs]
+ loss = torch.norm(torch.mm(weight.T, weight) - torch.eye(w_shape[1], device=weight.device))
+ elif weight.ndim == 4: # convolutional, [output_channels, input_channels, height, width]
+ out_channels, in_channels, in_height, in_width = w_shape
+ output = torch.conv2d(weight, weight, stride=stride, padding=padding)
+ out_height = output.shape[-2]
+ out_width = output.shape[-1]
+ target = torch.zeros((out_channels, out_channels, out_height, out_width),
+ device=weight.device)
+ center_h = int(np.floor(out_height / 2))
+ center_w = int(np.floor(out_width / 2))
+ target[:, :, center_h, center_w] = torch.eye(out_channels, device=weight.device)
+ loss = torch.norm(output - target, p='fro')
+ else:
+ assert False, (f'weight ndim must be 2 or 4, not {weight.ndim}')
+ return loss
diff --git a/modules/mlp_module.py b/modules/mlp_module.py
index 4877d8eb..cebda2c3 100644
--- a/modules/mlp_module.py
+++ b/modules/mlp_module.py
@@ -1,7 +1,11 @@
+from collections import OrderedDict
+
+import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from DeepSparseCoding.modules.activations import activation_picker
+from DeepSparseCoding.utils.run_utils import compute_conv_output_shape
class MlpModule(nn.Module):
@@ -9,28 +13,102 @@ def setup_module(self, params):
self.params = params
self.act_funcs = [activation_picker(act_func_str)
for act_func_str in self.params.activation_functions]
+ self.layer_output_shapes = [self.params.data_shape] # [channels, height, width]
self.layers = []
+ self.pooling = []
self.dropout = []
for layer_index, layer_type in enumerate(self.params.layer_types):
if layer_type == 'fc':
+ if(layer_index > 0 and self.params.layer_types[layer_index-1] == 'conv'):
+ in_features = np.prod(self.layer_output_shapes[-1]).astype(np.int)
+ else:
+ in_features = self.params.layer_channels[layer_index]
layer = nn.Linear(
- in_features = self.params.layer_channels[layer_index],
- out_features = self.params.layer_channels[layer_index+1],
- bias = True)
+ in_features=in_features,
+ out_features=self.params.layer_channels[layer_index + 1],
+ bias=True)
self.register_parameter('fc'+str(layer_index)+'_w', layer.weight)
self.register_parameter('fc'+str(layer_index)+'_b', layer.bias)
self.layers.append(layer)
+ self.layer_output_shapes.append(self.params.layer_channels[layer_index + 1])
+ elif layer_type == 'conv':
+ layer = nn.Conv2d(
+ in_channels=self.params.layer_channels[layer_index],
+ out_channels=self.params.layer_channels[layer_index + 1],
+ kernel_size=self.params.kernel_sizes[layer_index],
+ stride=self.params.strides[layer_index],
+ padding=0,
+ dilation=1,
+ bias=True)
+ self.register_parameter('conv'+str(layer_index)+'_w', layer.weight)
+ self.register_parameter('conv'+str(layer_index)+'_b', layer.bias)
+ self.layers.append(layer)
+ output_channels = self.params.layer_channels[layer_index + 1]
+ output_height = compute_conv_output_shape(
+ self.layer_output_shapes[-1][1],
+ self.params.kernel_sizes[layer_index],
+ self.params.strides[layer_index],
+ padding=0,
+ dilation=1)
+ output_width = compute_conv_output_shape(
+ self.layer_output_shapes[-1][2],
+ self.params.kernel_sizes[layer_index],
+ self.params.strides[layer_index],
+ padding=0,
+ dilation=1)
+ self.layer_output_shapes.append([output_channels, output_height, output_width])
+ else:
+ assert False, ('layer_type parameter must be "fc" or "conv", not %g'%(layer_type))
+ if(self.params.max_pool[layer_index] and layer_type == 'conv'):
+ self.pooling.append(nn.MaxPool2d(
+ kernel_size=self.params.pool_ksizes[layer_index],
+ stride=self.params.pool_strides[layer_index],
+ padding=0,
+ dilation=1))
+ output_channels = self.params.layer_channels[layer_index + 1]
+ output_height = compute_conv_output_shape(
+ self.layer_output_shapes[-1][1],
+ self.params.pool_ksizes[layer_index],
+ self.params.pool_strides[layer_index],
+ padding=0,
+ dilation=1)
+ output_width = compute_conv_output_shape(
+ self.layer_output_shapes[-1][2],
+ self.params.pool_ksizes[layer_index],
+ self.params.pool_strides[layer_index],
+ padding=0,
+ dilation=1)
+ self.layer_output_shapes.append([output_channels, output_height, output_width])
else:
- assert False, ('layer_type parameter must be "fc", not %g'%(layer_type))
+ self.pooling.append(nn.Identity()) # do nothing
self.dropout.append(nn.Dropout(p=self.params.dropout_rate[layer_index]))
+ conv_module_dict = OrderedDict()
+ fc_module_dict = OrderedDict()
+ layer_zip = zip(self.params.layer_types, self.layers, self.act_funcs, self.pooling,
+ self.dropout)
+ for layer_idx, full_layer in enumerate(layer_zip):
+ for component_idx, layer_component in enumerate(full_layer[1:]):
+ component_id = f'{layer_idx:02}-{component_idx:02}'
+ if full_layer[0] == 'fc':
+ fc_module_dict[full_layer[0] + component_id] = layer_component
+ else:
+ conv_module_dict[full_layer[0] + component_id] = layer_component
+ self.conv_sequential = lambda x: x # identity by default
+ self.fc_sequential = lambda x: x # identity by default
+ if len(conv_module_dict) > 0:
+ self.conv_sequential = nn.Sequential(conv_module_dict)
+ if len(fc_module_dict) > 0:
+ self.fc_sequential = nn.Sequential(fc_module_dict)
def preprocess_data(self, input_tensor):
- input_tensor = input_tensor.view(-1, self.params.layer_channels[0])
+ if self.params.layer_types[0] == 'fc':
+ input_tensor = input_tensor.view(input_tensor.size(0), -1) #flat
return input_tensor
def forward(self, x):
- for dropout, act_func, layer in zip(self.dropout, self.act_funcs, self.layers):
- x = dropout(act_func(layer(x)))
+ x = self.conv_sequential(x)
+ x = x.view(x.size(0), -1) #flat
+ x = self.fc_sequential(x)
return x
def get_encodings(self, input_tensor):
diff --git a/modules/pooling_module.py b/modules/pooling_module.py
new file mode 100644
index 00000000..80f2bca2
--- /dev/null
+++ b/modules/pooling_module.py
@@ -0,0 +1,37 @@
+import torch
+import torch.nn as nn
+
+
+class PoolingModule(nn.Module):
+ def setup_module(self, params):
+ params.weight_decay = 0 # used by base model; pooling layer never has weight decay
+ self.params = params
+ if self.params.layer_types[0] == 'fc':
+ self.layer = nn.Linear(
+ in_features=self.params.layer_channels[0],
+ out_features=self.params.layer_channels[1],
+ bias=False)
+ self.weight = self.layer.weight # [outputs, inputs]
+
+ elif self.params.layer_types[0] == 'conv':
+ self.layer = nn.Conv2d(
+ in_channels=self.params.layer_channels[0],
+ out_channels=self.params.layer_channels[1],
+ kernel_size=self.params.pool_ksize,
+ stride=self.params.pool_stride,
+ padding=0,
+ dilation=1,
+ bias=False)
+ nn.init.orthogonal_(self.layer.weight) # initialize to orthogonal matrix
+ self.weight = self.layer.weight
+
+ else:
+ assert False, ('layer_types[0] parameter must be "fc", "conv", not %g'%(layer_types[0]))
+
+ def forward(self, x):
+ if self.params.layer_types[0] == 'fc':
+ x = x.view(x.shape[0], -1) # flat
+ return self.layer(x)
+
+ def get_encodings(self, input_tensor):
+ return self.forward(input_tensor)
diff --git a/notebooks/monitor_training.ipynb b/notebooks/monitor_training.ipynb
index 2c3ecad2..3796865e 100644
--- a/notebooks/monitor_training.ipynb
+++ b/notebooks/monitor_training.ipynb
@@ -23,7 +23,8 @@
"outputs": [],
"source": [
"workspace_dir = os.path.expanduser('~')+'/Work/'\n",
- "log_file = workspace_dir+'/Torch_projects/lca_768_mlp_mnist/logfiles/lca_768_mlp_mnist_v0.log'\n",
+ "model_name = 'conv_lca_mnist'\n",
+ "log_file = workspace_dir+'/Torch_projects/{}/logfiles/{}_v0.log'.format(model_name, model_name)\n",
"logger = Logger(log_file, overwrite=False)\n",
"\n",
"log_text = logger.load_file()\n",
@@ -41,9 +42,9 @@
"outputs": [],
"source": [
"x_key = 'epoch'\n",
- "y_keys = ['lca_loss_recon', 'lca_loss_sparse', 'lca_loss_total', 'mlp_loss', 'mlp_train_accuracy']\n",
- "y_labels = ['Recon loss', 'Sparse loss', 'Total LCA loss', 'Total MLP loss', 'MLP train accuracy']\n",
- "stats_fig = pf.plot_stats(model_stats, x_key=x_key, y_keys=y_keys, y_labels=y_labels, start_index=0)"
+ "#y_keys = ['lca_loss_recon', 'lca_loss_sparse', 'lca_loss_total', 'mlp_loss', 'mlp_train_accuracy']\n",
+ "#y_labels = ['Recon loss', 'Sparse loss', 'Total LCA loss', 'Total MLP loss', 'MLP train accuracy']\n",
+ "stats_fig = pf.plot_stats(model_stats, x_key=x_key)#, y_keys=y_keys, y_labels=y_labels, start_index=0)"
]
},
{
@@ -70,7 +71,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.6.7"
+ "version": "3.6.8"
},
"varInspector": {
"cols": {
diff --git a/notebooks/smt_analysis.ipynb b/notebooks/smt_analysis.ipynb
new file mode 100644
index 00000000..c79282e9
--- /dev/null
+++ b/notebooks/smt_analysis.ipynb
@@ -0,0 +1,1470 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%matplotlib inline"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import sys\n",
+ "\n",
+ "ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd()))\n",
+ "if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)\n",
+ "\n",
+ "import scipy\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.nn.functional as F\n",
+ "\n",
+ "import proplot as plot\n",
+ "import matplotlib\n",
+ "import matplotlib.pyplot as plt\n",
+ "import matplotlib.gridspec as gridspec\n",
+ "from matplotlib.colors import LinearSegmentedColormap\n",
+ "\n",
+ "from DeepSparseCoding.utils.file_utils import Logger\n",
+ "import DeepSparseCoding.utils.run_utils as run_utils\n",
+ "import DeepSparseCoding.utils.dataset_utils as dataset_utils\n",
+ "import DeepSparseCoding.utils.loaders as loaders\n",
+ "import DeepSparseCoding.utils.plot_functions as pf\n",
+ "import DeepSparseCoding.utils.data_processing as dp\n",
+ "from DeepSparseCoding.params.lca_cifar10_params import params as LcaParams"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "workspace_dir = '/mnt/qb/bethge/dpaiton/'\n",
+ "model_name = 'smt_cifar10'\n",
+ "model_version = 'lplp'\n",
+ "log_file = workspace_dir + os.path.join(*['Projects', model_name, 'logfiles', f'{model_name}_v{model_version}.log'])\n",
+ "logger = Logger(log_file, overwrite=False)\n",
+ "log_text = logger.load_file()\n",
+ "params = logger.read_params(log_text)[-1]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "model_stats = logger.read_stats(log_text)\n",
+ "x_key = \"epoch\"\n",
+ "y_keys = [key for key in list(model_stats.keys()) if 'test_' not in key]\n",
+ "stats_fig = pf.plot_stats(model_stats, x_key, y_keys=y_keys)\n",
+ "\n",
+ "if 'test_epoch' in list(model_stats.keys()):\n",
+ " x_key = \"test_epoch\"\n",
+ " y_keys = [key for key in list(model_stats.keys()) if 'test_' in key]\n",
+ " test_stats_fig = pf.plot_stats(model_stats, x_key, y_keys=y_keys)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model = loaders.load_model(params.model_type)\n",
+ "model.setup(params, logger)\n",
+ "model.to(params.device)\n",
+ "model_state_str = model.load_checkpoint()\n",
+ "print(model_state_str)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "lca_weights = model.lca_1.weight.detach().cpu().numpy()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "def normalize_data_with_max(data):\n",
+ " \"\"\"\n",
+ " Normalize data by dividing by abs(max(data))\n",
+ " If abs(max(data)) is zero, then output is zero\n",
+ " Inputs:\n",
+ " data: [np.ndarray] data to be normalized\n",
+ " Outputs:\n",
+ " norm_data: [np.ndarray] normalized data\n",
+ " data_max: [float] max that was divided out\n",
+ " \"\"\"\n",
+ " data_max = np.max(np.abs(data), axis=(1,2), keepdims=True)\n",
+ " norm_data = np.divide(data, data_max, out=np.zeros_like(data), where=data_max!=0)\n",
+ " return norm_data, data_max\n",
+ "\n",
+ "def pad_matrix_to_image(matrix, pad_size=0, pad_value=0, normalize=False):\n",
+ " if normalize:\n",
+ " #matrix = normalize_data_with_max(matrix)[0]\n",
+ " matrix = dp.rescale_data_to_one(torch.from_numpy(matrix), eps=1e-10, samplewise=True)[0].numpy()\n",
+ " num_weights, img_c, img_h, img_w = matrix.shape\n",
+ " #if img_c == 1:\n",
+ " # matrix = matrix.squeeze()\n",
+ " #else:\n",
+ " # # TODO: separate channels, pad each individual one, then recombine.\n",
+ " # assert False, (f'Multiple color channels are not currently supported') \n",
+ " num_extra_images = int(np.ceil(np.sqrt(num_weights))**2 - num_weights)\n",
+ " matrices = []\n",
+ " for channel_idx in range(img_c):\n",
+ " channel_matrix = matrix[:, channel_idx, ...].copy()\n",
+ " if num_extra_images > 0:\n",
+ " channel_matrix = np.concatenate(\n",
+ " [channel_matrix, np.zeros((num_extra_images, img_h, img_w))], axis=0)\n",
+ " channel_matrix = np.pad(channel_matrix,\n",
+ " pad_width=((0,0), (num_pad_pix, num_pad_pix), (num_pad_pix, num_pad_pix)),\n",
+ " mode='constant', constant_values=pad_value)\n",
+ " padded_img_h, padded_img_w = channel_matrix.shape[1:]\n",
+ " num_edge_tiles = int(np.sqrt(channel_matrix.shape[0]))\n",
+ " tiles = channel_matrix.reshape(num_edge_tiles, num_edge_tiles, padded_img_h, padded_img_w)\n",
+ " tiles = tiles.swapaxes(1, 2)\n",
+ " matrices.append(tiles.reshape(num_edge_tiles * padded_img_h, num_edge_tiles * padded_img_w))\n",
+ " padded_matrix = np.stack(matrices, axis=0) # channel dim first\n",
+ " return padded_matrix\n",
+ " \n",
+ "def plot_matrix(matrix, title='', cmap=None):\n",
+ " fig, ax = plot.subplots(figsize=(10,10))\n",
+ " ax = pf.clear_axis(ax)\n",
+ " ax.imshow(matrix, cmap=cmap)#, vmin=0.0, vmax=1.0)#, cmap='greys_r')\n",
+ " ax.format(title=title)\n",
+ " plot.show()\n",
+ " return fig\n",
+ "\n",
+ "pad_value = 0.5\n",
+ "num_pad_pix = 1\n",
+ "padded_matrix = pad_matrix_to_image(lca_weights, num_pad_pix, pad_value, normalize=True)\n",
+ "fig = plot_matrix(np.transpose(padded_matrix, axes=[1, 2, 0]), title=f'{model.params.model_name} weights')\n",
+ "fig.savefig(\n",
+ " f'{model.params.disp_dir}/weights_plot_matrix.png',\n",
+ " transparent=False,\n",
+ " bbox_inches='tight')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def rgb_to_gray(rgb):\n",
+ " num, chan, height, width = rgb.shape\n",
+ " gray = np.zeros((num, 1, height, width))\n",
+ " for neuron_idx in range(num):\n",
+ " gray[neuron_idx, ...] = 0.2125 * rgb[neuron_idx, 0, ...]\n",
+ " gray[neuron_idx, ...] += 0.7154 * rgb[neuron_idx, 1, ...]\n",
+ " gray[neuron_idx, ...] += 0.0721 * rgb[neuron_idx, 2, ...]\n",
+ " return gray\n",
+ "\n",
+ "gray_lca_weights = rgb_to_gray(lca_weights)\n",
+ "pad_value = 0.5\n",
+ "num_pad_pix = 1\n",
+ "padded_matrix = pad_matrix_to_image(gray_lca_weights, num_pad_pix, pad_value, normalize=True)\n",
+ "fig = plot_matrix(np.squeeze(np.transpose(padded_matrix, axes=[1, 2, 0])), title=f'{model.params.model_name} weights', cmap='grays_r')\n",
+ "fig.savefig(\n",
+ " f'{model.params.disp_dir}/weights_grayscale_plot_matrix.png',\n",
+ " transparent=False,\n",
+ " bbox_inches='tight')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def generate_gaussian(shape, mean, cov):\n",
+ " \"\"\"\n",
+ " Generate a Gaussian PDF from given mean & cov\n",
+ " Inputs:\n",
+ " shape: [tuple] specifying (num_rows, num_cols)\n",
+ " mean: [np.ndarray] of shape (2,) specifying the 2-D Gaussian center\n",
+ " cov: [np.ndarray] of shape (2,2) specifying the 2-D Gaussian covariance matrix\n",
+ " Outputs:\n",
+ " tuple containing (Gaussian PDF, grid_points used to generate PDF)\n",
+ " grid_points are specified as a tuple of (y,x) points\n",
+ " \"\"\"\n",
+ " (y_size, x_size) = shape\n",
+ " y = np.linspace(0, y_size, np.int32(np.floor(y_size)))\n",
+ " x = np.linspace(0, x_size, np.int32(np.floor(x_size)))\n",
+ " y, x = np.meshgrid(y, x)\n",
+ " pos = np.empty(x.shape + (2,)) #x.shape == y.shape\n",
+ " pos[:, :, 0] = y; pos[:, :, 1] = x\n",
+ " gauss = scipy.stats.multivariate_normal(mean, cov)\n",
+ " return (gauss.pdf(pos), (y,x))\n",
+ "\n",
+ "\n",
+ "def gaussian_fit(pyx):\n",
+ " \"\"\"\n",
+ " Compute the expected mean & covariance matrix for a 2-D gaussian fit of input distribution\n",
+ " Inputs:\n",
+ " pyx: [np.ndarray] of shape [num_rows, num_cols] that indicates the probability function to fit\n",
+ " Outputs:\n",
+ " mean: [np.ndarray] of shape (2,) specifying the 2-D Gaussian center\n",
+ " cov: [np.ndarray] of shape (2,2) specifying the 2-D Gaussian covariance matrix\n",
+ " \"\"\"\n",
+ " assert pyx.ndim == 2, (\n",
+ " \"Input must have 2 dimensions specifying [num_rows, num_cols]\")\n",
+ " mean = np.zeros((1,2), dtype=np.float32) # [mu_y, mu_x]\n",
+ " for idx in np.ndindex(pyx.shape): # [y, x] ticks columns (x) first, then rows (y)\n",
+ " mean += np.asarray([pyx[idx]*idx[0], pyx[idx]*idx[1]])[None,:]\n",
+ " cov = np.zeros((2,2), dtype=np.float32)\n",
+ " for idx in np.ndindex(pyx.shape): # ticks columns first, then rows\n",
+ " cov += np.dot((idx-mean).T, (idx-mean))*pyx[idx] # typically an outer-product\n",
+ " return (np.squeeze(mean), cov)\n",
+ "\n",
+ "\n",
+ "def get_gauss_fit(prob_map, num_attempts=1, perc_mean=0.33):\n",
+ " \"\"\"\n",
+ " Returns a gaussian fit for a given probability map\n",
+ " Fitting is done via robust regression, where a fit is\n",
+ " continuously refined by deleting outliers num_attempts times\n",
+ " Inputs:\n",
+ " prob_map: 2-D probability map to be fit\n",
+ " num_attempts: Number of times to fit & remove outliers\n",
+ " perc_mean: All probability values below perc_mean*mean(gauss_fit) will be\n",
+ " considered outliers for repeated attempts\n",
+ " Outputs:\n",
+ " gauss_fit: [np.ndarray] specifying the 2-D Gaussian PDF\n",
+ " grid: [tuple] containing (y,x) points with which the Gaussian PDF can be plotted\n",
+ " gauss_mean: [np.ndarray] of shape (2,) specifying the 2-D Gaussian center\n",
+ " gauss_cov: [np.ndarray] of shape (2,2) specifying the 2-D Gaussian covariance matrix\n",
+ " \"\"\"\n",
+ " assert prob_map.ndim==2, (\n",
+ " \"get_gauss_fit: Input prob_map must have 2 dimension specifying [num_rows, num_cols\")\n",
+ " if num_attempts < 1:\n",
+ " num_attempts = 1\n",
+ " orig_prob_map = prob_map.copy()\n",
+ " gauss_success = False\n",
+ " while not gauss_success:\n",
+ " prob_map = orig_prob_map.copy()\n",
+ " try:\n",
+ " for i in range(num_attempts):\n",
+ " map_min = np.min(prob_map)\n",
+ " prob_map -= map_min\n",
+ " map_sum = np.sum(prob_map)\n",
+ " if map_sum != 1.0:\n",
+ " prob_map /= map_sum\n",
+ " gauss_mean, gauss_cov = gaussian_fit(prob_map)\n",
+ " gauss_fit, grid = generate_gaussian(prob_map.shape, gauss_mean, gauss_cov)\n",
+ " gauss_fit = (gauss_fit * map_sum) + map_min\n",
+ " if i < num_attempts-1:\n",
+ " gauss_mask = gauss_fit.copy().T\n",
+ " mask_slice = np.where(gauss_mask0)] = 1\n",
+ " prob_map *= gauss_mask\n",
+ " gauss_success = True\n",
+ " except np.linalg.LinAlgError: # Usually means cov matrix is singular\n",
+ " print(\"get_gauss_fit: Failed to fit Gaussian at attempt \",i,\", trying again.\"+\n",
+ " \"\\n To avoid this try decreasing perc_mean.\")\n",
+ " num_attempts = i-1\n",
+ " if num_attempts <= 0:\n",
+ " assert False, (\"get_gauss_fit: np.linalg.LinAlgError - Unable to fit gaussian.\")\n",
+ " return (gauss_fit, grid, gauss_mean, gauss_cov)\n",
+ "\n",
+ "\n",
+ "def hilbert_amplitude(weights, padding=None):\n",
+ " \"\"\"\n",
+ " Compute Hilbert amplitude envelope of weight matrix\n",
+ " Inputs:\n",
+ " weights: [np.ndarray] of shape [num_inputs, num_outputs]\n",
+ " num_inputs must have an even square root\n",
+ " padding: [int] specifying how much 0-padding to use for FFT\n",
+ " default is the closest power of 2 of sqrt(num_inputs)\n",
+ " Outputs:\n",
+ " env: [np.ndarray] of shape [num_outputs, num_inputs]\n",
+ " Hilbert envelope\n",
+ " bff_filt: [np.ndarray] of shape [num_outputs, padded_num_inputs]\n",
+ " Filtered Fourier transform of basis function\n",
+ " hil_filt: [np.ndarray] of shape [num_outputs, sqrt(num_inputs), sqrt(num_inputs)]\n",
+ " Hilbert filter to be applied in Fourier space\n",
+ " bffs: [np.ndarray] of shape [num_outputs, padded_num_inputs, padded_num_inputs]\n",
+ " Fourier transform of input weights\n",
+ " \"\"\"\n",
+ " cart2pol = lambda x,y: (np.arctan2(y,x), np.hypot(x, y))\n",
+ " num_inputs, num_outputs = weights.shape\n",
+ " assert np.sqrt(num_inputs) == np.floor(np.sqrt(num_inputs)), (\n",
+ " \"weights.shape[0] must have an even square root.\")\n",
+ " patch_edge_size = int(np.sqrt(num_inputs))\n",
+ " if padding is None or padding <= patch_edge_size:\n",
+ " # Amount of zero padding for fft2 (closest power of 2)\n",
+ " N = np.int(2**(np.ceil(np.log2(patch_edge_size))))\n",
+ " else:\n",
+ " N = np.int(padding)\n",
+ " # Analytic signal envelope for weights\n",
+ " # (Hilbet transform of each basis function)\n",
+ " env = np.zeros((num_outputs, num_inputs), dtype=complex)\n",
+ " # Fourier transform of weights\n",
+ " bffs = np.zeros((num_outputs, N, N), dtype=complex)\n",
+ " # Filtered Fourier transform of weights\n",
+ " bff_filt = np.zeros((num_outputs, N**2), dtype=complex)\n",
+ " # Hilbert filters\n",
+ " hil_filt = np.zeros((num_outputs, N, N))\n",
+ " # Grid for creating filter\n",
+ " f = (2/N) * np.pi * np.arange(-N/2.0, N/2.0)\n",
+ " (fx, fy) = np.meshgrid(f, f)\n",
+ " (theta, r) = cart2pol(fx, fy)\n",
+ " for neuron_idx in range(num_outputs):\n",
+ " # Grab single basis function, reshape to a square image\n",
+ " bf = weights[:, neuron_idx].reshape(patch_edge_size, patch_edge_size)\n",
+ " # Convert basis function into DC-centered Fourier domain\n",
+ " bff = np.fft.fftshift(np.fft.fft2(bf-np.mean(bf), [N, N]))\n",
+ " bffs[neuron_idx, ...] = bff\n",
+ " # Find indices of the peak amplitude\n",
+ " max_ys = np.abs(bff).argmax(axis=0) # Returns row index for each col\n",
+ " max_x = np.argmax(np.abs(bff).max(axis=0))\n",
+ " # Convert peak amplitude location into angle in freq domain\n",
+ " fx_ang = f[max_x]\n",
+ " fy_ang = f[max_ys[max_x]]\n",
+ " theta_max = np.arctan2(fy_ang, fx_ang)\n",
+ " # Define the half-plane with respect to the maximum\n",
+ " ang_diff = np.abs(theta-theta_max)\n",
+ " idx = (ang_diff>np.pi).nonzero()\n",
+ " ang_diff[idx] = 2.0 * np.pi - ang_diff[idx]\n",
+ " hil_filt[neuron_idx, ...] = (ang_diff < np.pi/2.0).astype(int)\n",
+ " # Create analytic signal from the inverse FT of the half-plane filtered bf\n",
+ " abf = np.fft.ifft2(np.fft.fftshift(hil_filt[neuron_idx, ...]*bff))\n",
+ " env[neuron_idx, ...] = abf[0:patch_edge_size, 0:patch_edge_size].reshape(num_inputs)\n",
+ " bff_filt[neuron_idx, ...] = (hil_filt[neuron_idx, ...]*bff).reshape(N**2)\n",
+ " return (env, bff_filt, hil_filt, bffs)\n",
+ "\n",
+ "\n",
+ "def get_dictionary_stats(weights, padding=None, num_gauss_fits=20, gauss_thresh=0.2):\n",
+ " \"\"\"\n",
+ " Compute summary statistics on dictionary elements using Hilbert amplitude envelope\n",
+ " Inputs:\n",
+ " weights: [np.ndarray] of shape [num_inputs, num_outputs]\n",
+ " padding: [int] total image size to pad out to in the FFT computation\n",
+ " num_gauss_fits: [int] total number of attempts to make when fitting the BFs\n",
+ " gauss_thresh: All probability values below gauss_thresh*mean(gauss_fit) will be\n",
+ " considered outliers for repeated fits\n",
+ " Outputs:\n",
+ " The function output is a dictionary containing the keys for each type of analysis\n",
+ " Each key dereferences a list of len num_outputs (i.e. one entry for each weight vector)\n",
+ " The keys and their list entries are as follows:\n",
+ " basis_functions: [np.ndarray] of shape [patch_edge_size, patch_edge_size]\n",
+ " envelopes: [np.ndarray] of shape [N, N], where N is the amount of padding\n",
+ " for the hilbert_amplitude function\n",
+ " envelope_centers: [tuples of ints] indicating the (y, x) position of the\n",
+ " center of the Hilbert envelope\n",
+ " gauss_fits: [list of np.ndarrays] containing (gaussian_fit, grid) where gaussian_fit\n",
+ " is returned from get_gauss_fit and specifies the 2D Gaussian PDF fit to the Hilbert\n",
+ " envelope and grid is a tuple containing (y,x) points with which the Gaussian PDF\n",
+ " can be plotted\n",
+ " gauss_centers: [list of ints] containing the (y,x) position of the center of\n",
+ " the Gaussian fit\n",
+ " gauss_orientations: [list of np.ndarrays] containing the (eigenvalues, eigenvectors) of\n",
+ " the covariance matrix for the Gaussian fit of the Hilbert amplitude envelope. They are\n",
+ " both sorted according to the highest to lowest Eigenvalue.\n",
+ " fourier_centers: [list of ints] containing the (y,x) position of the center (max) of\n",
+ " the Fourier amplitude map\n",
+ " num_inputs: [int] dim[0] of input weights\n",
+ " num_outputs: [int] dim[1] of input weights\n",
+ " patch_edge_size: [int] int(floor(sqrt(num_inputs)))\n",
+ " areas: [list of floats] area of enclosed ellipse\n",
+ " spatial_frequncies: [list of floats] dominant spatial frequency for basis function\n",
+ " \"\"\"\n",
+ " envelope, bff_filt, hil_filter, bffs = hilbert_amplitude(weights, padding)\n",
+ " num_inputs, num_outputs = weights.shape\n",
+ " patch_edge_size = np.int(np.floor(np.sqrt(num_inputs)))\n",
+ " basis_funcs = [None]*num_outputs\n",
+ " envelopes = [None]*num_outputs\n",
+ " gauss_fits = [None]*num_outputs\n",
+ " gauss_centers = [None]*num_outputs\n",
+ " diameters = [None]*num_outputs\n",
+ " gauss_orientations = [None]*num_outputs\n",
+ " envelope_centers = [None]*num_outputs\n",
+ " fourier_centers = [None]*num_outputs\n",
+ " ellipse_orientations = [None]*num_outputs\n",
+ " fourier_maps = [None]*num_outputs\n",
+ " spatial_frequencies = [None]*num_outputs\n",
+ " areas = [None]*num_outputs\n",
+ " phases = [None]*num_outputs\n",
+ " for bf_idx in range(num_outputs):\n",
+ " # Reformatted individual basis function\n",
+ " basis_funcs[bf_idx] = weights.T[bf_idx,...].reshape((patch_edge_size, patch_edge_size))\n",
+ " # Reformatted individual envelope filter\n",
+ " envelopes[bf_idx] = np.abs(envelope[bf_idx,...]).reshape((patch_edge_size, patch_edge_size))\n",
+ " # Basis function center\n",
+ " max_ys = envelopes[bf_idx].argmax(axis=0) # Returns row index for each col\n",
+ " max_x = np.argmax(envelopes[bf_idx].max(axis=0))\n",
+ " y_cen = max_ys[max_x]\n",
+ " x_cen = max_x\n",
+ " envelope_centers[bf_idx] = (y_cen, x_cen)\n",
+ " # Gaussian fit to Hilbet amplitude envelope\n",
+ " gauss_fit, grid, gauss_mean, gauss_cov = get_gauss_fit(envelopes[bf_idx],\n",
+ " num_gauss_fits, gauss_thresh)\n",
+ " gauss_fits[bf_idx] = (gauss_fit, grid)\n",
+ " gauss_centers[bf_idx] = gauss_mean\n",
+ " evals, evecs = np.linalg.eigh(gauss_cov)\n",
+ " sort_indices = np.argsort(evals)[::-1]\n",
+ " gauss_orientations[bf_idx] = (evals[sort_indices], evecs[:,sort_indices])\n",
+ " width, height = evals[sort_indices] # Width & height are relative to orientation\n",
+ " diameters[bf_idx] = np.sqrt(width**2+height**2)\n",
+ " # Fourier function center, spatial frequency, orientation\n",
+ " fourier_map = np.sqrt(np.real(bffs[bf_idx, ...])**2+np.imag(bffs[bf_idx, ...])**2)\n",
+ " fourier_maps[bf_idx] = fourier_map\n",
+ " N = fourier_map.shape[0]\n",
+ " center_freq = int(np.floor(N/2))\n",
+ " fourier_map[center_freq, center_freq] = 0 # remove DC component\n",
+ " max_fys = fourier_map.argmax(axis=0)\n",
+ " max_fx = np.argmax(fourier_map.max(axis=0))\n",
+ " fy_cen = (max_fys[max_fx] - (N/2)) * (patch_edge_size/N)\n",
+ " fx_cen = (max_fx - (N/2)) * (patch_edge_size/N)\n",
+ " fourier_centers[bf_idx] = [fy_cen, fx_cen]\n",
+ " # NOTE: we flip fourier_centers because fx_cen is the peak of the x frequency,\n",
+ " # which would be a y coordinate\n",
+ " ellipse_orientations[bf_idx] = np.arctan2(*fourier_centers[bf_idx][::-1])\n",
+ " spatial_frequencies[bf_idx] = np.sqrt(fy_cen**2 + fx_cen**2)\n",
+ " areas[bf_idx] = np.pi * np.prod(evals)\n",
+ " phases[bf_idx] = np.angle(bffs[bf_idx])[y_cen, x_cen]\n",
+ " output = {\"basis_functions\":basis_funcs, \"envelopes\":envelopes, \"gauss_fits\":gauss_fits,\n",
+ " \"gauss_centers\":gauss_centers, \"gauss_orientations\":gauss_orientations, \"areas\":areas,\n",
+ " \"fourier_centers\":fourier_centers, \"fourier_maps\":fourier_maps, \"num_inputs\":num_inputs,\n",
+ " \"spatial_frequencies\":spatial_frequencies, \"envelope_centers\":envelope_centers,\n",
+ " \"num_outputs\":num_outputs, \"patch_edge_size\":patch_edge_size, \"phases\":phases,\n",
+ " \"ellipse_orientations\":ellipse_orientations, \"diameters\":diameters}\n",
+ " return output"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "bf_stats = get_dictionary_stats(\n",
+ " gray_lca_weights.reshape(gray_lca_weights.shape[0], -1).T,\n",
+ " padding=32,\n",
+ " num_gauss_fits=20,\n",
+ " gauss_thresh=0.2)\n",
+ "\n",
+ "np.savez(\n",
+ " model.params.save_dir+'bf_summary_stats.npz',\n",
+ " data={'bf_stats':bf_stats})"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def clear_axis(ax, spines=\"none\"):\n",
+ " for ax_loc in [\"top\", \"bottom\", \"left\", \"right\"]:\n",
+ " ax.spines[ax_loc].set_color(spines)\n",
+ " ax.set_yticklabels([])\n",
+ " ax.set_xticklabels([])\n",
+ " ax.get_xaxis().set_visible(False)\n",
+ " ax.get_yaxis().set_visible(False)\n",
+ " ax.tick_params(axis=\"both\", bottom=False, top=False, left=False, right=False)\n",
+ " return ax\n",
+ "\n",
+ "def plot_ellipse(axis, center, shape, angle, color_val=\"auto\", alpha=1.0, lines=False,\n",
+ " fill_ellipse=False):\n",
+ " \"\"\"\n",
+ " Add an ellipse to given axis\n",
+ " Inputs:\n",
+ " axis [matplotlib.axes._subplots.AxesSubplot] axis on which ellipse should be drawn\n",
+ " center [tuple or list] specifying [y, x] center coordinates\n",
+ " shape [tuple or list] specifying [width, height] shape of ellipse\n",
+ " angle [float] specifying angle of ellipse\n",
+ " color_val [matplotlib color spec] specifying the color of the edge & face of the ellipse\n",
+ " alpha [float] specifying the transparency of the ellipse\n",
+ " lines [bool] if true, output will be a line, where the secondary axis of the ellipse\n",
+ " is collapsed\n",
+ " fill_ellipse [bool] if true and lines is false then a filled ellipse will be plotted\n",
+ " Outputs:\n",
+ " ellipse [matplotlib.patches.ellipse] ellipse object\n",
+ " \"\"\"\n",
+ " if fill_ellipse:\n",
+ " face_color_val = \"none\" if color_val==\"auto\" else color_val\n",
+ " else:\n",
+ " face_color_val = \"none\"\n",
+ " y_cen, x_cen = center\n",
+ " width, height = shape\n",
+ " if lines:\n",
+ " min_length = 0.1\n",
+ " if width < height:\n",
+ " width = min_length\n",
+ " elif width > height:\n",
+ " height = min_length\n",
+ " ellipse = matplotlib.patches.Ellipse(xy=[x_cen, y_cen], width=width,\n",
+ " height=height, angle=angle, edgecolor=color_val, facecolor=face_color_val,\n",
+ " alpha=alpha, fill=True)\n",
+ " axis.add_artist(ellipse)\n",
+ " ellipse.set_clip_box(axis.bbox)\n",
+ " return ellipse\n",
+ "\n",
+ "def plot_ellipse_summaries(bf_stats, num_bf=-1, lines=False, rand_bf=False):\n",
+ " \"\"\"\n",
+ " Plot basis functions with summary ellipses drawn over them\n",
+ " Inputs:\n",
+ " bf_stats [dict] output of dp.get_dictionary_stats()\n",
+ " num_bf [int] number of basis functions to plot (<=0 is all; >total is all)\n",
+ " lines [bool] If true, will plot lines instead of ellipses\n",
+ " rand_bf [bool] If true, will choose a random set of basis functions\n",
+ " \"\"\"\n",
+ " tot_num_bf = len(bf_stats[\"basis_functions\"])\n",
+ " if num_bf <= 0 or num_bf > tot_num_bf:\n",
+ " num_bf = tot_num_bf\n",
+ " SFs = np.asarray([np.sqrt(fcent[0]**2 + fcent[1]**2)\n",
+ " for fcent in bf_stats[\"fourier_centers\"]], dtype=np.float32)\n",
+ " sf_sort_indices = np.argsort(SFs)\n",
+ " if rand_bf:\n",
+ " bf_range = np.random.choice([i for i in range(tot_num_bf)], num_bf, replace=False)\n",
+ " num_plots_y = int(np.ceil(np.sqrt(num_bf)))\n",
+ " num_plots_x = int(np.ceil(np.sqrt(num_bf)))\n",
+ " gs = gridspec.GridSpec(num_plots_y, num_plots_x)\n",
+ " fig = plt.figure(figsize=(17,17))\n",
+ " filter_idx = 0\n",
+ " for plot_id in np.ndindex((num_plots_y, num_plots_x)):\n",
+ " ax = clear_axis(fig.add_subplot(gs[plot_id]))\n",
+ " if filter_idx < tot_num_bf and filter_idx < num_bf:\n",
+ " if rand_bf:\n",
+ " bf_idx = bf_range[filter_idx]\n",
+ " else:\n",
+ " bf_idx = filter_idx\n",
+ " bf = bf_stats[\"basis_functions\"][bf_idx]\n",
+ " ax.imshow(bf, interpolation=\"Nearest\", cmap=\"grays_r\")\n",
+ " ax.set_title(str(bf_idx), fontsize=\"8\")\n",
+ " center = bf_stats[\"gauss_centers\"][bf_idx]\n",
+ " evals, evecs = bf_stats[\"gauss_orientations\"][bf_idx]\n",
+ " orientations = bf_stats[\"fourier_centers\"][bf_idx]\n",
+ " angle = np.rad2deg(np.pi/2 + np.arctan2(*orientations))\n",
+ " alpha = 1.0\n",
+ " ellipse = plot_ellipse(ax, center, evals, angle, color_val=\"b\", alpha=alpha, lines=lines)\n",
+ " filter_idx += 1\n",
+ " ax.set_aspect(\"equal\")\n",
+ " plt.show()\n",
+ " return fig"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig = plot_ellipse_summaries(bf_stats, lines=False)\n",
+ "\n",
+ "fig.savefig(\n",
+ " f'{model.params.disp_dir}/basis_function_fits.png',\n",
+ " transparent=False,\n",
+ " bbox_inches='tight')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def bgr_colormap():\n",
+ " \"\"\"\n",
+ " In cdict, the first column is interpolated between 0.0 & 1.0 - this indicates the value to be plotted\n",
+ " the second column specifies how interpolation should be done from below\n",
+ " the third column specifies how interpolation should be done from above\n",
+ " if the second column does not equal the third, then there will be a break in the colors\n",
+ " \"\"\"\n",
+ " darkness = 0.85 #0 is black, 1 is white\n",
+ " cdict = {\n",
+ " 'red': ((0.0, 0.0, 0.0),\n",
+ " (0.5, darkness, darkness),\n",
+ " (1.0, 1.0, 1.0)),\n",
+ " 'green': ((0.0, 0.0, 0.0),\n",
+ " (0.5, darkness, darkness),\n",
+ " (1.0, 0.0, 0.0)),\n",
+ " 'blue': ((0.0, 1.0, 1.0),\n",
+ " (0.5, darkness, darkness),\n",
+ " (1.0, 0.0, 0.0))\n",
+ " }\n",
+ " return LinearSegmentedColormap(\"bgr\", cdict)\n",
+ "\n",
+ "def plot_pooling_centers(bf_stats, pooling_filters, num_pooling_filters, num_connected_weights,\n",
+ " spot_size=10, figsize=None):\n",
+ " \"\"\"\n",
+ " Plot 2nd layer (fully-connected) weights in terms of spatial/frequency centers of\n",
+ " 1st layer weights\n",
+ " Inputs:\n",
+ " bf_stats [dict] Output of dp.get_dictionary_stats() which was run on the 1st layer weights\n",
+ " pooling_filters [np.ndarray] 2nd layer weights\n",
+ " should be shape [num_1st_layer_neurons, num_2nd_layer_neurons]\n",
+ " num_pooling_filters [int] How many 2nd layer neurons to plot\n",
+ " figsize [tuple] Containing the (width, height) of the figure, in inches\n",
+ " spot_size [int] How big to make the points\n",
+ " \"\"\"\n",
+ " num_filters_y = int(np.ceil(np.sqrt(num_pooling_filters)))\n",
+ " num_filters_x = int(np.ceil(np.sqrt(num_pooling_filters)))\n",
+ " tot_pooling_filters = pooling_filters.shape[1]\n",
+ " #filter_indices = np.random.choice(tot_pooling_filters, num_pooling_filters, replace=False)\n",
+ " filter_indices = np.arange(tot_pooling_filters, dtype=np.int32)\n",
+ " cmap = plt.get_cmap(bgr_colormap())# Could also use \"nipy_spectral\", \"coolwarm\", \"bwr\"\n",
+ " cNorm = matplotlib.colors.SymLogNorm(linthresh=0.03, linscale=0.01, vmin=-1.0, vmax=1.0)\n",
+ " scalarMap = matplotlib.cm.ScalarMappable(norm=cNorm, cmap=cmap)\n",
+ " x_p_cent = [x for (y,x) in bf_stats[\"gauss_centers\"]]# Get raw points\n",
+ " y_p_cent = [y for (y,x) in bf_stats[\"gauss_centers\"]]\n",
+ " x_f_cent = [x for (y,x) in bf_stats[\"fourier_centers\"]]\n",
+ " y_f_cent = [y for (y,x) in bf_stats[\"fourier_centers\"]]\n",
+ " max_sf = np.max(np.abs(x_f_cent+y_f_cent))\n",
+ " pair_w_gap = 0.01\n",
+ " group_w_gap = 0.03\n",
+ " h_gap = 0.03\n",
+ " plt_w = (num_filters_x/num_pooling_filters)\n",
+ " plt_h = plt_w\n",
+ " if figsize is None:\n",
+ " fig = plt.figure()\n",
+ " figsize = (fig.get_figwidth(), fig.get_figheight())\n",
+ " else:\n",
+ " fig = plt.figure(figsize=figsize) #figsize is (w,h)\n",
+ " axes = []\n",
+ " filter_id = 0\n",
+ " for plot_id in np.ndindex((num_filters_y, num_filters_x)):\n",
+ " if all(pid == 0 for pid in plot_id):\n",
+ " axes.append(clear_axis(fig.add_axes([0, plt_h+h_gap, 2*plt_w, plt_h])))\n",
+ " scalarMap._A = []\n",
+ " cbar = fig.colorbar(scalarMap, ax=axes[-1], ticks=[-1, 0, 1], aspect=10, location=\"bottom\")\n",
+ " cbar.ax.set_xticklabels([\"-1\", \"0\", \"1\"])\n",
+ " cbar.ax.xaxis.set_ticks_position('top')\n",
+ " cbar.ax.xaxis.set_label_position('top')\n",
+ " for label in cbar.ax.xaxis.get_ticklabels():\n",
+ " label.set_weight(\"bold\")\n",
+ " label.set_fontsize(10+figsize[0])\n",
+ " if (filter_id < num_pooling_filters):\n",
+ " example_filter = pooling_filters[:, filter_indices[filter_id]]\n",
+ " top_indices = np.argsort(np.abs(example_filter))[::-1] #descending\n",
+ " selected_indices = top_indices[:num_connected_weights][::-1] #select top, plot weakest first\n",
+ " filter_norm = np.max(np.abs(example_filter))\n",
+ " connection_colors = [scalarMap.to_rgba(example_filter[bf_idx]/filter_norm)\n",
+ " for bf_idx in range(bf_stats[\"num_outputs\"])]\n",
+ " if num_connected_weights < top_indices.size:\n",
+ " black_indices = top_indices[num_connected_weights:][::-1]\n",
+ " xp = [x_p_cent[i] for i in black_indices]+[x_p_cent[i] for i in selected_indices]\n",
+ " yp = [y_p_cent[i] for i in black_indices]+[y_p_cent[i] for i in selected_indices]\n",
+ " xf = [x_f_cent[i] for i in black_indices]+[x_f_cent[i] for i in selected_indices]\n",
+ " yf = [y_f_cent[i] for i in black_indices]+[y_f_cent[i] for i in selected_indices]\n",
+ " c = [(0.1,0.1,0.1,1.0) for i in black_indices]+[connection_colors[i] for i in selected_indices]\n",
+ " else:\n",
+ " xp = [x_p_cent[i] for i in selected_indices]\n",
+ " yp = [y_p_cent[i] for i in selected_indices]\n",
+ " xf = [x_f_cent[i] for i in selected_indices]\n",
+ " yf = [y_f_cent[i] for i in selected_indices]\n",
+ " c = [connection_colors[i] for i in selected_indices]\n",
+ " (y_id, x_id) = plot_id\n",
+ " if x_id == 0:\n",
+ " ax_l = 0\n",
+ " ax_b = - y_id * (plt_h+h_gap)\n",
+ " else:\n",
+ " bbox = axes[-1].get_position().get_points()[0]#bbox is [[x0,y0],[x1,y1]]\n",
+ " prev_l = bbox[0]\n",
+ " prev_b = bbox[1]\n",
+ " ax_l = prev_l + plt_w + group_w_gap\n",
+ " ax_b = prev_b\n",
+ " ax_w = plt_w\n",
+ " ax_h = plt_h\n",
+ " axes.append(clear_axis(fig.add_axes([ax_l, ax_b, ax_w, ax_h])))\n",
+ " axes[-1].invert_yaxis()\n",
+ " axes[-1].scatter(xp, yp, c=c, s=spot_size, alpha=0.8)\n",
+ " axes[-1].set_xlim(0, bf_stats[\"patch_edge_size\"]-1)\n",
+ " axes[-1].set_ylim(bf_stats[\"patch_edge_size\"]-1, 0)\n",
+ " axes[-1].set_aspect(\"equal\")\n",
+ " axes[-1].set_facecolor(\"w\")\n",
+ " axes.append(clear_axis(fig.add_axes([ax_l+ax_w+pair_w_gap, ax_b, ax_w, ax_h])))\n",
+ " axes[-1].scatter(xf, yf, c=c, s=spot_size, alpha=0.8)\n",
+ " axes[-1].set_xlim([-max_sf, max_sf])\n",
+ " axes[-1].set_ylim([-max_sf, max_sf])\n",
+ " axes[-1].xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True))\n",
+ " axes[-1].yaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True))\n",
+ " axes[-1].set_aspect(\"equal\")\n",
+ " axes[-1].set_facecolor(\"w\")\n",
+ " filter_id += 1\n",
+ " plt.show()\n",
+ " return fig"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "kernel_pos = 0\n",
+ "pool_weights = model.pool_1.layer.weight.detach().cpu().numpy()\n",
+ "outputs, inputs, kernel_h, kernel_w = pool_weights.shape\n",
+ "\n",
+ "fig = plot_pooling_centers(\n",
+ " bf_stats,\n",
+ " pool_weights[:, :, kernel_pos, kernel_pos].T,\n",
+ " num_pooling_filters=outputs,\n",
+ " num_connected_weights=inputs,\n",
+ " spot_size=3,\n",
+ " figsize=(5, 5))\n",
+ "\n",
+ "fig.savefig(\n",
+ " f'{model.params.disp_dir}/pooling_spots.png',\n",
+ " transparent=False,\n",
+ " bbox_inches='tight')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def plot_pooling_summaries(bf_stats, pooling_filters, num_pooling_filters,\n",
+ " num_connected_weights, lines=False, figsize=None):\n",
+ " \"\"\"\n",
+ " Plot 2nd layer (fully-connected) weights in terms of connection strengths to 1st layer weights\n",
+ " Inputs:\n",
+ " bf_stats [dict] output of dp.get_dictionary_stats() which was run on the 1st layer weights\n",
+ " pooling_filters [np.ndarray] 2nd layer weights\n",
+ " should be shape [num_1st_layer_neurons, num_2nd_layer_neurons]\n",
+ " num_pooling_filters [int] How many 2nd layer neurons to plot\n",
+ " num_connected_weights [int] How many 1st layer weight summaries to include\n",
+ " for a given 2nd layer neuron\n",
+ " lines [bool] if True, 1st layer weight summaries will appear as lines instead of ellipses\n",
+ " \"\"\"\n",
+ " num_inputs = bf_stats[\"num_inputs\"]\n",
+ " num_outputs = bf_stats[\"num_outputs\"]\n",
+ " tot_pooling_filters = pooling_filters.shape[1]\n",
+ " patch_edge_size = np.int32(np.sqrt(num_inputs))\n",
+ " filter_idx_list = np.arange(num_pooling_filters, dtype=np.int32)\n",
+ " assert num_pooling_filters <= num_outputs, (\n",
+ " \"num_pooling_filters must be less than or equal to bf_stats['num_outputs']\")\n",
+ " cmap = bgr_colormap()#plt.get_cmap('bwr')\n",
+ " cNorm = matplotlib.colors.SymLogNorm(linthresh=0.03, linscale=0.01, vmin=-1.0, vmax=1.0)\n",
+ " scalarMap = matplotlib.cm.ScalarMappable(norm=cNorm, cmap=cmap)\n",
+ " num_plots_y = np.int32(np.ceil(np.sqrt(num_pooling_filters)))\n",
+ " num_plots_x = np.int32(np.ceil(np.sqrt(num_pooling_filters)))+1 # +cbar col\n",
+ " gs_widths = [1 for _ in range(num_plots_x-1)]+[0.3]\n",
+ " gs = gridspec.GridSpec(num_plots_y, num_plots_x, width_ratios=gs_widths)\n",
+ " if figsize is None:\n",
+ " fig = plt.figure()\n",
+ " figsize = (fig.get_figwidth(), fig.get_figheight())\n",
+ " else:\n",
+ " fig = plt.figure(figsize=figsize)\n",
+ " filter_total = 0\n",
+ " for plot_id in np.ndindex((num_plots_y, num_plots_x-1)):\n",
+ " (y_id, x_id) = plot_id\n",
+ " ax = fig.add_subplot(gs[plot_id])\n",
+ " if (filter_total < num_pooling_filters and x_id != num_plots_x-1):\n",
+ " ax = clear_axis(ax, spines=\"k\")\n",
+ " filter_idx = filter_idx_list[filter_total]\n",
+ " example_filter = pooling_filters[:, filter_idx]\n",
+ " top_indices = np.argsort(np.abs(example_filter))[::-1] #descending\n",
+ " filter_norm = np.max(np.abs(example_filter))\n",
+ " SFs = np.asarray([np.sqrt(fcent[0]**2 + fcent[1]**2)\n",
+ " for fcent in bf_stats[\"fourier_centers\"]], dtype=np.float32)\n",
+ " # Plot weakest of the top connected filters first because of occlusion\n",
+ " for bf_idx in top_indices[:num_connected_weights][::-1]:\n",
+ " connection_strength = example_filter[bf_idx]/filter_norm\n",
+ " color_val = scalarMap.to_rgba(connection_strength)\n",
+ " center = bf_stats[\"gauss_centers\"][bf_idx]\n",
+ " evals, evecs = bf_stats[\"gauss_orientations\"][bf_idx]\n",
+ " orientations = bf_stats[\"fourier_centers\"][bf_idx]\n",
+ " angle = np.rad2deg(np.pi/2 + np.arctan2(*orientations))\n",
+ " alpha = 0.5#todo:spatial_freq for filled ellipses?\n",
+ " ellipse = plot_ellipse(ax, center, evals, angle, color_val, alpha=alpha, lines=lines)\n",
+ " ax.set_xlim(0, patch_edge_size-1)\n",
+ " ax.set_ylim(patch_edge_size-1, 0)\n",
+ " filter_total += 1\n",
+ " else:\n",
+ " ax = clear_axis(ax, spines=\"none\")\n",
+ " ax.set_aspect(\"equal\")\n",
+ " scalarMap._A = []\n",
+ " ax = clear_axis(fig.add_subplot(gs[0, -1]))\n",
+ " cbar = fig.colorbar(scalarMap, ax=ax, ticks=[-1, 0, 1])\n",
+ " cbar.ax.set_yticklabels([\"-1\", \"0\", \"1\"])\n",
+ " for label in cbar.ax.yaxis.get_ticklabels():\n",
+ " label.set_weight(\"bold\")\n",
+ " label.set_fontsize(14)\n",
+ " plt.show()\n",
+ " return fig"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig = plot_pooling_summaries(\n",
+ " bf_stats,\n",
+ " pool_weights[:, :, kernel_pos, kernel_pos].T,\n",
+ " num_pooling_filters=outputs,\n",
+ " num_connected_weights=40,\n",
+ " lines=True,\n",
+ " figsize=(18,18))\n",
+ "\n",
+ "fig.savefig(\n",
+ " f'{model.params.disp_dir}/pooling_lines.png',\n",
+ " transparent=False,\n",
+ " bbox_inches='tight')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "P = pool_weights[:, :, kernel_pos, kernel_pos] # [inputs, outputs]\n",
+ "p_norm = np.linalg.norm(P, ord=2, axis=0)\n",
+ "affinity = np.dot(P.T, P) # cosyne similarity of neurons in embedded space\n",
+ "for i in range(affinity.shape[0]):\n",
+ " for j in range(affinity.shape[1]):\n",
+ " affinity[i, j] = affinity[i, j] / (p_norm[i] * p_norm[j])\n",
+ "affinity = affinity.T # [inputs, inputs]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig = plot_pooling_centers(\n",
+ " bf_stats,\n",
+ " affinity,\n",
+ " num_pooling_filters=outputs,\n",
+ " num_connected_weights=128, \n",
+ " spot_size=30,\n",
+ " figsize=(5, 5))\n",
+ "\n",
+ "fig.savefig(\n",
+ " f'{model.params.disp_dir}/affinity_spots.png',\n",
+ " transparent=False,\n",
+ " bbox_inches='tight')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig = plot_pooling_summaries(\n",
+ " bf_stats,\n",
+ " affinity,\n",
+ " num_pooling_filters=outputs,\n",
+ " num_connected_weights=15,\n",
+ " lines=True,\n",
+ " figsize=(10, 10))\n",
+ "\n",
+ "fig.savefig(\n",
+ " f'{model.params.disp_dir}/affinity_lines.png',\n",
+ " transparent=False,\n",
+ " bbox_inches='tight')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "params.standardize_data = False\n",
+ "train_loader, val_loader, test_loader, data_stats, data_mean_std = dataset_utils.load_dataset(params)\n",
+ "train_mean_image = 0#data_mean_std['dataset_mean_image'].to(model.params.device)\n",
+ "train_std_image = 1#data_mean_std['dataset_std_image'].to(model.params.device)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "example_batch = next(iter(train_loader))[0].to(model.params.device)\n",
+ "example_batch = model[0].preprocess_data(example_batch)\n",
+ "#example_batch *= train_std_image\n",
+ "#example_batch += train_mean_image\n",
+ "batch_min = example_batch.min().item()\n",
+ "batch_max = example_batch.max().item()\n",
+ "\n",
+ "example_image = example_batch[0, ...]\n",
+ "print(\n",
+ " f'example image min = {example_image.min().item()}'+\n",
+ " f'\\nexample image mean = {example_image.mean().item()}'+\n",
+ " f'\\nexample image max = {example_image.max().item()}'+\n",
+ " f'\\nexample image std = {example_image.std().item()}')\n",
+ "preproc_image, example_image_mean, example_image_std = dp.standardize(example_image[None, ...], samplewise=True)\n",
+ "print(\n",
+ " f'preproc image min = {preproc_image.min().item()}'+\n",
+ " f'\\npreproc image mean = {preproc_image.mean().item()}'+\n",
+ " f'\\npreproc image max = {preproc_image.max().item()}'+\n",
+ " f'\\npreproc image std = {preproc_image.std().item()}')\n",
+ "\n",
+ "plot_example_image = ((example_image * train_std_image) + train_mean_image).cpu().numpy().transpose(1,2,0)\n",
+ "plot_preproc_image = ((preproc_image - preproc_image.min())/(preproc_image.max() - preproc_image.min()))[0,...].cpu().numpy().transpose(1,2,0)\n",
+ "\n",
+ "fig, axs = plot.subplots(nrows=1, ncols=2)\n",
+ "ax = pf.clear_axis(axs[0])\n",
+ "ax.imshow(plot_example_image, vmin=0, vmax=1)\n",
+ "ax.format(title='Original')\n",
+ "ax = pf.clear_axis(axs[1])\n",
+ "ax.format(title='Preprocessed')\n",
+ "ax.imshow(plot_preproc_image, vmin=0, vmax=1)\n",
+ "plot.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "alpha_1 = model.lca_1.get_encodings(preproc_image)\n",
+ "beta_1 = model.pool_1.get_encodings(alpha_1)\n",
+ "alpha_2 = model.lca_2.get_encodings(beta_1)\n",
+ "beta_2 = model(preproc_image)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class lca_2_recon_params(LcaParams):\n",
+ " def set_params(self):\n",
+ " super(lca_2_recon_params, self).set_params()\n",
+ " self.model_type = 'lca'\n",
+ " self.model_name = 'lca_2_recon'\n",
+ " self.version = '0'\n",
+ " self.layer_types = ['fc']\n",
+ " self.standardize_data = False\n",
+ " self.rescale_data_to_one = False\n",
+ " self.center_dataset = False\n",
+ " self.batch_size = 1\n",
+ " self.dt = 0.001\n",
+ " self.tau = 0.2\n",
+ " self.num_steps = 75\n",
+ " self.rectify_a = True\n",
+ " self.thresh_type = 'hard'\n",
+ " self.compute_helper_params()\n",
+ " \n",
+ "params = lca_2_recon_params()\n",
+ "params.set_params()\n",
+ "params.layer_channels = list(model.pool_2.weight.shape)\n",
+ "params.sparse_mult = 0.0#model.lca_2.params.sparse_mult\n",
+ "params.data_shape = list(beta_2.shape)\n",
+ "params.epoch_size = 1\n",
+ "params.num_pixels = np.prod(params.data_shape)\n",
+ "\n",
+ "lca_2_recon_model = loaders.load_model(params.model_type)\n",
+ "lca_2_recon_model.setup(params)\n",
+ "lca_2_recon_model.to(params.device)\n",
+ "lca_2_recon_model.eval()\n",
+ "with torch.no_grad():\n",
+ " lca_2_recon_model.weight = nn.Parameter(model.pool_2.weight.T)\n",
+ "a2h_b2 = lca_2_recon_model(beta_2)[:, :, None, None]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "alpha_2_bin = alpha_2.detach().cpu().numpy()>0\n",
+ "alpha_2_hat_bin = a2h_b2.detach().cpu().numpy()>0\n",
+ "\n",
+ "alpha_2_nnz = np.count_nonzero(alpha_2_bin)\n",
+ "alpha_2_hat_nnz = np.count_nonzero(alpha_2_hat_bin)\n",
+ "\n",
+ "hamming_dist = np.abs(np.sum(alpha_2_bin * alpha_2_hat_bin))\n",
+ "\n",
+ "num_alpha_2 = np.prod(list(a2h_b2.shape))\n",
+ "alpha_2_edge = int(np.sqrt(num_alpha_2))\n",
+ "\n",
+ "plot_alpha_2_hat = ((a2h_b2 - a2h_b2.min())/(a2h_b2.max() - a2h_b2.min())).reshape(alpha_2_edge, alpha_2_edge).detach().cpu().numpy()\n",
+ "\n",
+ "plot_alpha_2 = ((alpha_2 - alpha_2.min()) / (alpha_2.max() - alpha_2.min())).reshape(alpha_2_edge, alpha_2_edge).detach().cpu().numpy()\n",
+ "\n",
+ "alpha_diff = np.abs(plot_alpha_2 - plot_alpha_2_hat)\n",
+ "\n",
+ "fig, axs = plot.subplots(nrows=1, ncols=3)\n",
+ "\n",
+ "ax = pf.clear_axis(axs[0])\n",
+ "ax.imshow(alpha_diff, vmin=0, vmax=1)\n",
+ "ax.format(title='alpha 2 differences')\n",
+ "\n",
+ "ax = pf.clear_axis(axs[1])\n",
+ "ax.imshow(plot_alpha_2_hat, vmin=0, vmax=1)\n",
+ "ax.format(title='alpha 2 hat')\n",
+ "\n",
+ "ax = pf.clear_axis(axs[2])\n",
+ "m = ax.imshow(plot_alpha_2, vmin=0, vmax=1)\n",
+ "ax.format(title='alpha 2')\n",
+ "\n",
+ "ax.colorbar(m, ax=ax)\n",
+ "axs.format(suptitle=f'active index overlap = {hamming_dist}; alpha 2 nnz = {alpha_2_nnz}; alpha 2 hat nnz = {alpha_2_hat_nnz}')\n",
+ "plot.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with torch.no_grad():\n",
+ " b1h_a2h_b2 = F.conv_transpose2d(\n",
+ " input=a2h_b2,\n",
+ " weight=model.lca_2.weight,\n",
+ " bias=None,\n",
+ " stride=model.lca_2.params.stride,\n",
+ " padding=model.lca_2.params.padding)\n",
+ " \n",
+ " b1h_a2 = F.conv_transpose2d(\n",
+ " input=alpha_2,\n",
+ " weight=model.lca_2.weight,\n",
+ " bias=None,\n",
+ " stride=model.lca_2.params.stride,\n",
+ " padding=model.lca_2.params.padding)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "squared_error = torch.pow(beta_1 - b1h_a2h_b2, 2.)\n",
+ "l2_dist = 0.5 * torch.mean(squared_error).detach().cpu().numpy()\n",
+ "\n",
+ "num_beta_1 = np.prod(list(b1h_a2h_b2.shape))\n",
+ "beta_1_edge = int(np.floor(np.sqrt(num_beta_1)))\n",
+ "beta_1_resh = int(beta_1_edge**2)\n",
+ "\n",
+ "plot_beta_1_hat = ((b1h_a2h_b2 - b1h_a2h_b2.min())/(b1h_a2h_b2.max() - b1h_a2h_b2.min())).view(-1)[:beta_1_resh].reshape(beta_1_edge, beta_1_edge).detach().cpu().numpy()\n",
+ "\n",
+ "plot_beta_1 = ((beta_1 - beta_1.min()) / (beta_1.max() - beta_1.min())).view(-1)[:beta_1_resh].reshape(beta_1_edge, beta_1_edge).detach().cpu().numpy()\n",
+ "\n",
+ "beta_diff = np.abs(plot_beta_1 - plot_beta_1_hat)\n",
+ "\n",
+ "fig, axs = plot.subplots(nrows=1, ncols=3)\n",
+ "\n",
+ "ax = pf.clear_axis(axs[0])\n",
+ "ax.imshow(beta_diff, vmin=0, vmax=1)\n",
+ "ax.format(title='beta 1 differences')\n",
+ "\n",
+ "ax = pf.clear_axis(axs[1])\n",
+ "ax.imshow(plot_beta_1_hat, vmin=0, vmax=1)\n",
+ "ax.format(title='beta 1 hat')\n",
+ "\n",
+ "ax = pf.clear_axis(axs[2])\n",
+ "m = ax.imshow(plot_beta_1, vmin=0, vmax=1)\n",
+ "ax.format(title='beta 1')\n",
+ "\n",
+ "ax.colorbar(m, ax=ax)\n",
+ "axs.format(suptitle=f'l2 distance = {l2_dist:0.5f}')\n",
+ "plot.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from DeepSparseCoding.modules.lca_module import LcaModule\n",
+ "from DeepSparseCoding.models.base_model import BaseModel\n",
+ "from DeepSparseCoding.utils.run_utils import compute_deconv_output_shape\n",
+ "import DeepSparseCoding.modules.losses as losses\n",
+ "\n",
+ "class TransposedLcaModule(LcaModule):\n",
+ " def setup_module(self, params):\n",
+ " super(TransposedLcaModule, self).setup_module(params)\n",
+ " if self.params.layer_types[0] == 'conv':\n",
+ " assert (self.params.data_shape[-1] % self.params.stride == 0), (\n",
+ " f'Stride = {self.params.stride} must divide evenly into input edge size = {self.params.data_shape[-1]}')\n",
+ " self.w_shape = [\n",
+ " self.params.layer_channels[1],\n",
+ " self.params.layer_channels[0],\n",
+ " self.params.kernel_size,\n",
+ " self.params.kernel_size\n",
+ " ]\n",
+ " output_height = compute_deconv_output_shape(\n",
+ " self.params.data_shape[1],\n",
+ " self.params.kernel_size,\n",
+ " self.params.stride,\n",
+ " self.params.padding,\n",
+ " output_padding=self.params.output_padding,\n",
+ " dilation=1)\n",
+ " output_width = compute_deconv_output_shape(\n",
+ " self.params.data_shape[2],\n",
+ " self.params.kernel_size,\n",
+ " self.params.stride,\n",
+ " self.params.padding,\n",
+ " output_padding=self.params.output_padding,\n",
+ " dilation=1)\n",
+ " self.layer_output_shape = [self.params.layer_channels[1], output_height, output_width]\n",
+ " w_init = torch.randn(self.w_shape)\n",
+ " w_init_normed = dp.l2_normalize_weights(w_init, eps=self.params.eps)\n",
+ " self.weight = nn.Parameter(w_init_normed, requires_grad=True)\n",
+ "\n",
+ " def compute_excitatory_current(self, input_tensor, a_in):\n",
+ " if self.params.layer_types[0] == 'fc':\n",
+ " excitatory_current = torch.matmul(input_tensor, self.weight.T)\n",
+ " else:\n",
+ " recon = self.get_recon_from_latents(a_in)\n",
+ " recon_error = input_tensor - recon\n",
+ " error_injection = F.conv_transpose2d(\n",
+ " input=recon_error,\n",
+ " weight=self.weight,\n",
+ " bias=None,\n",
+ " stride=self.params.stride,\n",
+ " padding=self.params.padding,\n",
+ " output_padding=self.params.output_padding,\n",
+ " dilation=1\n",
+ " )\n",
+ " excitatory_current = error_injection + a_in\n",
+ " return excitatory_current\n",
+ "\n",
+ " def get_recon_from_latents(self, a_in):\n",
+ " if self.params.layer_types[0] == 'fc':\n",
+ " recon = torch.matmul(a_in, self.weight)\n",
+ " else:\n",
+ " recon = F.conv2d(\n",
+ " input=a_in,\n",
+ " weight=self.weight,\n",
+ " bias=None,\n",
+ " stride=self.params.stride,\n",
+ " padding=self.params.padding,\n",
+ " dilation=1\n",
+ " )\n",
+ " return recon\n",
+ "\n",
+ "class TransposedLcaModel(BaseModel, TransposedLcaModule):\n",
+ " def setup(self, params, logger=None):\n",
+ " super(TransposedLcaModel, self).setup(params, logger)\n",
+ " self.setup_module(params)\n",
+ " self.setup_optimizer()\n",
+ " if params.checkpoint_boot_log != '':\n",
+ " checkpoint = self.get_checkpoint_from_log(params.checkpoint_boot_log)\n",
+ " self.module.load_state_dict(checkpoint['model_state_dict'])\n",
+ " self.module.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
+ "\n",
+ " def get_total_loss(self, input_tuple):\n",
+ " input_tensor, input_labels = input_tuple\n",
+ " latents = self.get_encodings(input_tensor)\n",
+ " recon = self.get_recon_from_latents(latents)\n",
+ " recon_loss = losses.half_squared_l2(input_tensor, recon)\n",
+ " sparse_loss = self.params.sparse_mult * losses.l1_norm(latents)\n",
+ " total_loss = recon_loss + sparse_loss\n",
+ " return total_loss\n",
+ "\n",
+ " def generate_update_dict(self, input_data, input_labels=None, batch_step=0, update_dict=None):\n",
+ " if update_dict is None:\n",
+ " update_dict = super(TransposedLcaModel, self).generate_update_dict(input_data, input_labels, batch_step)\n",
+ " stat_dict = dict()\n",
+ " latents = self.get_encodings(input_data)\n",
+ " recon = self.get_recon_from_latents(latents)\n",
+ " recon_loss = losses.half_squared_l2(input_data, recon).item()\n",
+ " sparse_loss = self.params.sparse_mult * losses.l1_norm(latents).item()\n",
+ " stat_dict['weight_lr'] = self.scheduler.get_lr()[0]\n",
+ " stat_dict['loss_recon'] = recon_loss\n",
+ " stat_dict['loss_sparse'] = sparse_loss\n",
+ " stat_dict['loss_total'] = recon_loss + sparse_loss\n",
+ " stat_dict['input_max_mean_min'] = [\n",
+ " input_data.max().item(), input_data.mean().item(), input_data.min().item()]\n",
+ " stat_dict['recon_max_mean_min'] = [\n",
+ " recon.max().item(), recon.mean().item(), recon.min().item()]\n",
+ " def count_nonzero(array, dim):\n",
+ " # TODO: github issue 23907 requests torch.count_nonzero, integrated in torch 1.7\n",
+ " return torch.sum(array !=0, dim=dim, dtype=torch.float)\n",
+ " latent_dims = tuple([i for i in range(len(latents.shape))])\n",
+ " latent_nnz = count_nonzero(latents, dim=latent_dims).item()\n",
+ " stat_dict['fraction_active_all_latents'] = latent_nnz / latents.numel()\n",
+ " if self.params.layer_types[0] == 'conv':\n",
+ " latent_map_dims = latent_dims[2:]\n",
+ " latent_map_size = np.prod(list(latents.shape[2:]))\n",
+ " latent_channel_nnz = count_nonzero(latents, dim=latent_map_dims)/latent_map_size\n",
+ " latent_channel_mean_nnz = torch.mean(latent_channel_nnz).item()\n",
+ " stat_dict['fraction_active_latents_per_channel'] = latent_channel_mean_nnz\n",
+ " num_channels = latents.shape[1]\n",
+ " latent_patch_mean_nnz = torch.mean(count_nonzero(latents, dim=1)/num_channels).item()\n",
+ " stat_dict['fraction_active_latents_per_patch'] = latent_patch_mean_nnz\n",
+ " update_dict.update(stat_dict)\n",
+ " return update_dict"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class lca_1_recon_params(LcaParams):\n",
+ " def set_params(self):\n",
+ " super(lca_1_recon_params, self).set_params()\n",
+ " self.model_type = 'lca'\n",
+ " self.model_name = 'lca_1_recon'\n",
+ " self.version = '0'\n",
+ " self.layer_types = ['conv']\n",
+ " self.standardize_data = False\n",
+ " self.rescale_data_to_one = False\n",
+ " self.center_dataset = False\n",
+ " self.batch_size = 1\n",
+ " self.dt = 0.001\n",
+ " self.tau = 0.2\n",
+ " self.num_steps = 75\n",
+ " self.rectify_a = True\n",
+ " self.thresh_type = 'hard'\n",
+ " self.compute_helper_params()\n",
+ " \n",
+ "params = lca_1_recon_params()\n",
+ "params.set_params()\n",
+ "params.layer_channels = model.pool_1.params.layer_channels[::-1]\n",
+ "params.kernel_size = model.pool_1.params.pool_ksize\n",
+ "params.stride = model.pool_1.params.pool_stride\n",
+ "params.padding = 0\n",
+ "params.sparse_mult = 0.00#model.lca_1.params.sparse_mult\n",
+ "params.data_shape = list(b1h_a2h_b2.shape[1:])\n",
+ "params.epoch_size = 1\n",
+ "params.output_padding = 1\n",
+ "params.num_pixels = np.prod(params.data_shape)\n",
+ "\n",
+ "lca_1_recon_model = TransposedLcaModel()\n",
+ "lca_1_recon_model.setup(params)\n",
+ "lca_1_recon_model.to(params.device)\n",
+ "lca_1_recon_model.eval()\n",
+ "with torch.no_grad():\n",
+ " lca_1_recon_model.weight = nn.Parameter(model.pool_1.weight)\n",
+ "a1h_b1h_a2h_b2 = lca_1_recon_model(b1h_a2h_b2)\n",
+ "a1h_b1h_a2 = lca_1_recon_model(b1h_a2)\n",
+ "a1h_b1 = lca_1_recon_model(beta_1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "alpha_1_bin = alpha_1.detach().cpu().numpy()>0\n",
+ "alpha_1_hat_bin = a1h_b1h_a2h_b2.detach().cpu().numpy()>0\n",
+ "hamming_dist = np.abs(np.sum(alpha_1_bin * alpha_1_hat_bin))\n",
+ "alpha_1_nnz = np.count_nonzero(alpha_1_bin)\n",
+ "alpha_1_hat_nnz = np.count_nonzero(alpha_1_hat_bin)\n",
+ "\n",
+ "num_alpha_1 = np.prod(list(a1h_b1h_a2h_b2.shape))\n",
+ "alpha_1_edge = int(np.floor(np.sqrt(num_alpha_1)))\n",
+ "alpha_1_resh = int(alpha_1_edge**2)\n",
+ "\n",
+ "plot_alpha_1_hat = ((a1h_b1h_a2h_b2 - a1h_b1h_a2h_b2.min())/(a1h_b1h_a2h_b2.max() - a1h_b1h_a2h_b2.min())).view(-1)[:alpha_1_resh].reshape(alpha_1_edge, alpha_1_edge).detach().cpu().numpy()\n",
+ "\n",
+ "plot_alpha_1 = ((alpha_1 - alpha_1.min()) / (alpha_1.max() - alpha_1.min())).view(-1)[:alpha_1_resh].reshape(alpha_1_edge, alpha_1_edge).detach().cpu().numpy()\n",
+ "\n",
+ "alpha_diff = np.abs(plot_alpha_1 - plot_alpha_1_hat)\n",
+ "\n",
+ "fig, axs = plot.subplots(nrows=1, ncols=3)\n",
+ "\n",
+ "ax = pf.clear_axis(axs[0])\n",
+ "ax.imshow(alpha_diff, vmin=0, vmax=1)\n",
+ "ax.format(title='alpha 1 differences')\n",
+ "\n",
+ "ax = pf.clear_axis(axs[1])\n",
+ "ax.imshow(plot_alpha_1_hat, vmin=0, vmax=1)\n",
+ "ax.format(title='alpha 1 hat')\n",
+ "\n",
+ "ax = pf.clear_axis(axs[2])\n",
+ "m = ax.imshow(plot_alpha_1, vmin=0, vmax=1)\n",
+ "ax.format(title='alpha 1')\n",
+ "\n",
+ "ax.colorbar(m, ax=ax)\n",
+ "axs.format(suptitle=f'active index overlap = {hamming_dist}; alpha 1 nnz = {alpha_1_nnz}; alpha 1 hat nnz = {alpha_1_hat_nnz}')\n",
+ "plot.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with torch.no_grad():\n",
+ " recon_from_alpha_1 = F.conv_transpose2d(\n",
+ " input=alpha_1,\n",
+ " weight=model.lca_1.weight,\n",
+ " bias=None,\n",
+ " stride=model.lca_1.params.stride,\n",
+ " padding=model.lca_1.params.padding)\n",
+ " \n",
+ " recon_from_beta_1 = F.conv_transpose2d(\n",
+ " input=a1h_b1,\n",
+ " weight=model.lca_1.weight,\n",
+ " bias=None,\n",
+ " stride=model.lca_1.params.stride,\n",
+ " padding=model.lca_1.params.padding)\n",
+ " \n",
+ " recon_from_alpha_2 = F.conv_transpose2d(\n",
+ " input=a1h_b1h_a2,\n",
+ " weight=model.lca_1.weight,\n",
+ " bias=None,\n",
+ " stride=model.lca_1.params.stride,\n",
+ " padding=model.lca_1.params.padding)\n",
+ " \n",
+ " recon_from_beta_2 = F.conv_transpose2d(\n",
+ " input=a1h_b1h_a2h_b2,\n",
+ " weight=model.lca_1.weight,\n",
+ " bias=None,\n",
+ " stride=model.lca_1.params.stride,\n",
+ " padding=model.lca_1.params.padding)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def plot_func(x):\n",
+ " x *= example_image_std\n",
+ " x += example_image_mean\n",
+ " x = ((x - x.min()) / (x.max() - x.min())).squeeze().cpu().numpy().transpose(1,2,0)\n",
+ " return(x)\n",
+ "\n",
+ "alpha_2_nnz = torch.sum(a2h_b2 !=0,\n",
+ " dim=tuple([i for i in range(len(a2h_b2.shape))]),\n",
+ " dtype=torch.float)/a2h_b2.numel()\n",
+ "alpha_1_nnz = torch.sum(a1h_b1h_a2h_b2 !=0,\n",
+ " dim=tuple([i for i in range(len(a1h_b1h_a2h_b2.shape))]),\n",
+ " dtype=torch.float)/a1h_b1h_a2h_b2.numel()\n",
+ "print(\n",
+ " f'beta2 shape = {beta_2.shape}' + \n",
+ " f'\\nalpha2^ nnz = {alpha_2_nnz}'+\n",
+ " f'\\nalpha2^ shape = {a2h_b2.shape}'+\n",
+ " f'\\nbeta1^ shape = {b1h_a2h_b2.shape}'\n",
+ " f'\\nalpha1^ nnz = {alpha_1_nnz}'+\n",
+ " f'\\nalpha1^ shape = {a1h_b1h_a2h_b2.shape}'+\n",
+ " f'\\nimage^ shape = {recon_from_beta_2.shape}'\n",
+ ")\n",
+ "print(\n",
+ " f'recon min = {recon_from_beta_2.min().item()}'+\n",
+ " f'\\nrecon mean = {recon_from_beta_2.mean().item()}'+\n",
+ " f'\\nrecon max = {recon_from_beta_2.max().item()}'+\n",
+ " f'\\nrecon std = {recon_from_beta_2.std().item()}'\n",
+ ")\n",
+ "\n",
+ "fig, axs = plot.subplots(nrows=2, ncols=3)\n",
+ "\n",
+ "ax = pf.clear_axis(axs[0,0])\n",
+ "ax.imshow(plot_preproc_image, vmin=0, vmax=1)\n",
+ "ax.format(title='original')\n",
+ "\n",
+ "ax = pf.clear_axis(axs[0,1])\n",
+ "ax.imshow(plot_func(recon_from_alpha_1), vmin=0, vmax=1)\n",
+ "ax.format(title='recon from alpha 1')\n",
+ "\n",
+ "ax = pf.clear_axis(axs[0,2])\n",
+ "ax.imshow(plot_func(recon_from_beta_1), vmin=0, vmax=1)\n",
+ "ax.format(title='recon from beta 1')\n",
+ "\n",
+ "ax = pf.clear_axis(axs[1,0])\n",
+ "ax.imshow(plot_func(recon_from_alpha_2), vmin=0, vmax=1)\n",
+ "ax.format(title='recon from alpha 2')\n",
+ "\n",
+ "ax = pf.clear_axis(axs[1,1])\n",
+ "ax.imshow(plot_func(recon_from_beta_2), vmin=0, vmax=1)\n",
+ "ax.format(title='recon from beta 2')\n",
+ "\n",
+ "ax = pf.clear_axis(axs[1,2])\n",
+ "plot.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.7"
+ },
+ "varInspector": {
+ "cols": {
+ "lenName": 16,
+ "lenType": 16,
+ "lenVar": 40
+ },
+ "kernels_config": {
+ "python": {
+ "delete_cmd_postfix": "",
+ "delete_cmd_prefix": "del ",
+ "library": "var_list.py",
+ "varRefreshCmd": "print(var_dic_list())"
+ },
+ "r": {
+ "delete_cmd_postfix": ") ",
+ "delete_cmd_prefix": "rm(",
+ "library": "var_list.r",
+ "varRefreshCmd": "cat(var_dic_list()) "
+ }
+ },
+ "types_to_exclude": [
+ "module",
+ "function",
+ "builtin_function_or_method",
+ "instance",
+ "_Feature"
+ ],
+ "window_display": false
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/notebooks/visualize_model_weights.ipynb b/notebooks/visualize_model_weights.ipynb
index 84678ccb..a6bb3460 100644
--- a/notebooks/visualize_model_weights.ipynb
+++ b/notebooks/visualize_model_weights.ipynb
@@ -29,16 +29,12 @@
"outputs": [],
"source": [
"workspace_dir = os.path.expanduser(\"~\")+\"/Work/\"\n",
- "model_name = 'lca_dsprites'\n",
- "num_epochs = 100\n",
- "sparse_mult = 0.05\n",
- "model_name += '_{}_{}'.format(sparse_mult, num_epochs)\n",
+ "model_name = 'conv_lca_mnist'\n",
+ "#num_epochs = 200\n",
+ "#sparse_mult = 0.25\n",
+ "#model_name += '_{}_{}'.format(sparse_mult, num_epochs)\n",
"log_file = workspace_dir+'/Torch_projects/{}/logfiles/{}_v0.log'.format(model_name, model_name)\n",
"logger = Logger(log_file, overwrite=False)\n",
- "\n",
- "target_index = 1\n",
- "\n",
- "logger = Logger(log_files[target_index], overwrite=False)\n",
"log_text = logger.load_file()\n",
"params = logger.read_params(log_text)[-1]"
]
@@ -94,7 +90,7 @@
" paired_pics = [paired_pics[i, :, :, :] for i in range(paired_pics.shape[0])]\n",
" print(np.array(paired_pics).shape)\n",
" visualize_util.grid_save_images(paired_pics, os.path.join('', \"reconstructions.jpg\"))\n",
- "else:\n",
+ "elif model.params.model_type.lower() in ['mlp', 'ensemble']:\n",
" test_results = run_utils.test_epoch(0, model, test_loader, log_to_file=False)\n",
" print(test_results)"
]
@@ -111,8 +107,14 @@
"else:\n",
" weights = list(model.parameters())[0].data.cpu().numpy()\n",
"\n",
- "num_neurons, num_pixels = weights.shape\n",
- "weights = np.reshape(weights, [num_neurons, int(np.sqrt(num_pixels)), int(np.sqrt(num_pixels))])"
+ "if weights.ndim == 4:\n",
+ " num_neurons, num_channels, num_h, num_w = weights.shape\n",
+ " num_pixels = num_channels * num_h * num_w\n",
+ "elif weights.ndim == 2:\n",
+ " num_neurons, num_pixels = weights.shape\n",
+ " weights = np.reshape(weights, [num_neurons, 1, int(np.sqrt(num_pixels)), int(np.sqrt(num_pixels))])\n",
+ "else:\n",
+ " assert False, (f'weights.ndim == {weights.ndim} must be 2 or 4')\n"
]
},
{
@@ -141,7 +143,11 @@
"def pad_matrix_to_image(matrix, pad_size=0, pad_value=0, normalize=False):\n",
" if normalize:\n",
" matrix = normalize_data_with_max(matrix)[0]\n",
- " num_weights, img_h, img_w = matrix.shape\n",
+ " num_weights, img_c, img_h, img_w = matrix.shape\n",
+ " if img_c == 1:\n",
+ " matrix = matrix.squeeze()\n",
+ " else:\n",
+ " assert False, (f'Multiple color channels are not currently supported') # TODO\n",
" num_extra_images = int(np.ceil(np.sqrt(num_weights))**2 - num_weights)\n",
" if num_extra_images > 0:\n",
" matrix = np.concatenate(\n",
@@ -181,7 +187,8 @@
"\n",
"tfpf.plot_image(pad_matrix_to_image(weights), vmin=None, vmax=None, title=\"\", save_filename=model.params.disp_dir+\"/weights_plot_image.png\")\n",
"tfpf.plot_weights(weights, save_filename=model.params.disp_dir+\"/weights_plot_weights.png\")\n",
- "tfpf.plot_data_tiled(weights[..., None], save_filename=model.params.disp_dir+\"/weights_plot_data_tiled.png\")"
+ "tfpf.plot_data_tiled(np.transpose(weights, (0, 2, 3, 1)),\n",
+ " save_filename=model.params.disp_dir+\"/weights_plot_data_tiled.png\")"
]
},
{
diff --git a/params/base_params.py b/params/base_params.py
index c82fe486..056dd3c9 100644
--- a/params/base_params.py
+++ b/params/base_params.py
@@ -9,10 +9,15 @@ class BaseParams(object):
"""
all models
batch_size [int] number of images in a training batch
+ center_dataset [bool] if True, subtract the mean dataset image from all datapoints
+ checkpoint_boot_log [str] path to a training log file for booting from checkpoint
+ if set, all specified model params must mach those in the log file #TODO: meaningful errors if not
data_dir [str] location of dataset folders
device [str] which device to run on
dtype [torch dtype] dtype for network variables
eps [float] small value to avoid division by zero
+ fast_mnist [bool] if True, use the fastMNIST dataset,
+ which loads faster but does not allow for torchvision transforms like flip and rotate
lib_root_dir [str] system location of this library directory
log_to_file [bool] if set, log to file, else log to stderr
model_name [str] name for model (can be anything)
@@ -34,6 +39,10 @@ class BaseParams(object):
standardize_data [bool] if set, z-score data to have mean=0 and standard deviation=1 using numpy operators
train_logs_per_epoch [int or None] how often to send updates to the logfile
workspace_dir [str] system directory that is the parent to the primary repository directory
+ num_validation [int] number of images to reserve for the validation set (only works with some datasets)
+
+ ensemble
+ allow_parent_grads [bool] if True, allow loss gradients to propagate through all members of the ensemble
mlp
activation_functions [list of str] strings correspond to activation functions for layers.
@@ -44,6 +53,11 @@ class BaseParams(object):
layer_types [list of str] weight connectivity type, either "conv" or "fc"
len must be equal to the len of layer_channels - 1
layer_channels [list of int] number of outputs per layer, including the input layer
+ kernel_sizes [list of ints] number of pixels on the edge of a square kernel, only used if layer_types is "conv"
+ strides [list of ints] number of pixels for the convolutional stride, assumes equal horizontal and vertical strides and is only used if layer_types is "conv"
+ max_pool [list of bools] if True, the network includes a max pooling op after the conv/fc op and before the dropout op
+ pool_ksizes [list of ints] number of pixels on the edge of a square max pooling kernel
+ pool_strides [list of ints] number of pixels in pooling stride, assumes equal and horizontal strides
lca
dt [float] discrete global time constant for neuron dynamics
@@ -52,7 +66,7 @@ class BaseParams(object):
num_steps [int] number of lca inference steps to take
rectify_a [bool] if set, rectify the layer 1 neuron activity
sparse_mult [float] multiplyer placed in front of the sparsity loss term
- tau [float] LCA time constant
+ tau [float] LCA time constant; larger values result in smaller step sizes (i.e. slower convergence)
lca update rule (step_size) is multiplied by dt/tau
thresh_type [str] specifying LCA threshold function; can be "hard" or "soft"
@@ -66,8 +80,10 @@ def __init__(self):
self.compute_helper_params()
def set_params(self):
- self.standardize_data = False
+ self.center_dataset = False
+ self.checkpoint_boot_log = ''
self.rescale_data_to_one = False
+ self.standardize_data = False
self.model_type = None
self.log_to_file = True
self.train_logs_per_epoch = None
@@ -75,6 +91,7 @@ def set_params(self):
self.shuffle_data = True
self.eps = 1e-12
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ self.num_validation = 0
self.rand_seed = 123456789
self.rand_state = np.random.RandomState(self.rand_seed)
self.workspace_dir = os.path.join(os.path.expanduser('~'), 'Work')
diff --git a/params/lca_cifar10_params.py b/params/lca_cifar10_params.py
new file mode 100644
index 00000000..40390315
--- /dev/null
+++ b/params/lca_cifar10_params.py
@@ -0,0 +1,47 @@
+import types
+
+from DeepSparseCoding.params.base_params import BaseParams
+
+
+class params(BaseParams):
+ def set_params(self):
+ super(params, self).set_params()
+ self.model_type = 'lca'
+ self.model_name = 'lca_cifar10'
+ self.version = '0'
+ self.dataset = 'cifar10'
+ self.layer_types = ['conv']
+ self.num_validation = 10000
+ self.standardize_data = True
+ self.rescale_data_to_one = False
+ self.center_dataset = False
+ self.batch_size = 25
+ self.num_epochs = 500
+ self.train_logs_per_epoch = 6
+ self.renormalize_weights = True
+ self.layer_channels = [3, 128]
+ self.kernel_size = 8
+ self.stride = 2
+ self.padding = 0
+ self.weight_decay = 0.0
+ self.weight_lr = 0.001
+ self.optimizer = types.SimpleNamespace()
+ self.optimizer.name = 'sgd'
+ self.optimizer.lr_annealing_milestone_frac = [0.8] # fraction of num_epochs
+ self.optimizer.lr_decay_rate = 0.8
+ self.dt = 0.001
+ self.tau = 0.1#0.2
+ self.num_steps = 37#75
+ self.rectify_a = True
+ self.thresh_type = 'hard'
+ self.sparse_mult = 0.35#0.30
+ self.compute_helper_params()
+
+ def compute_helper_params(self):
+ super(params, self).compute_helper_params()
+ self.optimizer.milestones = [frac * self.num_epochs
+ for frac in self.optimizer.lr_annealing_milestone_frac]
+ self.step_size = self.dt / self.tau
+ self.num_pixels = 3072
+ self.in_channels = self.layer_channels[0]
+ self.out_channels = self.layer_channels[1]
diff --git a/params/lca_dsprites_params.py b/params/lca_dsprites_params.py
index d498b076..f0090e40 100644
--- a/params/lca_dsprites_params.py
+++ b/params/lca_dsprites_params.py
@@ -12,6 +12,7 @@ def set_params(self):
super(params, self).set_params()
self.model_type = 'lca'
self.model_name = 'lca_dsprites'
+ self.layer_types = ['fc']
self.version = '0'
self.dataset = 'dsprites'
self.standardize_data = False
@@ -26,13 +27,13 @@ def set_params(self):
self.optimizer.lr_annealing_milestone_frac = [0.7] # fraction of num_epochs
self.optimizer.lr_decay_rate = 0.5
self.renormalize_weights = True
+ self.layer_channels = [1, int(self.num_pixels*1.5)]
self.dt = 0.001
self.tau = 0.03
self.num_steps = 75
self.rectify_a = False
self.thresh_type = 'soft'
self.sparse_mult = 0.25
- self.num_latent = int(self.num_pixels*1.5)
self.compute_helper_params()
def compute_helper_params(self):
diff --git a/params/lca_mlp_cifar10_params.py b/params/lca_mlp_cifar10_params.py
new file mode 100644
index 00000000..475fdee2
--- /dev/null
+++ b/params/lca_mlp_cifar10_params.py
@@ -0,0 +1,93 @@
+import os
+import types
+import numpy as np
+import torch
+
+from DeepSparseCoding.params.base_params import BaseParams
+from DeepSparseCoding.params.lca_mnist_params import params as LcaParams
+from DeepSparseCoding.params.mlp_mnist_params import params as MlpParams
+from DeepSparseCoding.utils.run_utils import compute_conv_output_shape
+
+
+class shared_params(object):
+ def __init__(self):
+ self.model_type = 'ensemble'
+ self.model_name = 'lca_mlp_cifar10'
+ self.version = '0'
+ self.dataset = 'cifar10'
+ self.standardize_data = True
+ self.batch_size = 25
+ self.num_epochs = 10#00
+ self.train_logs_per_epoch = 4
+ self.allow_parent_grads = True
+
+
+class lca_params(LcaParams):
+ def set_params(self):
+ super(lca_params, self).set_params()
+ for key, value in shared_params().__dict__.items(): setattr(self, key, value)
+ self.model_type = 'lca'
+ self.layer_name = 'lca'
+ self.layer_types = ['conv']
+ self.weight_decay = 0.0
+ self.weight_lr = 0.001
+ self.renormalize_weights = True
+ self.layer_channels = [3, 512]
+ self.kernel_size = 8
+ self.stride = 2
+ self.padding = 0
+ self.optimizer = types.SimpleNamespace()
+ self.optimizer.name = 'sgd'
+ self.optimizer.lr_annealing_milestone_frac = [0.8] # fraction of num_epochs
+ self.optimizer.lr_decay_rate = 0.8
+ self.dt = 0.001
+ self.tau = 0.2
+ self.num_steps = 75
+ self.rectify_a = True
+ self.thresh_type = 'hard'
+ self.sparse_mult = 0.30
+ self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/conv_lca_cifar10/logfiles/lca_cifar10_v0.log'
+ self.compute_helper_params()
+
+
+class mlp_params(MlpParams):
+ def set_params(self):
+ super(mlp_params, self).set_params()
+ for key, value in shared_params().__dict__.items(): setattr(self, key, value)
+ self.model_type = 'mlp'
+ self.layer_name = 'classifier'
+ self.weight_lr = 2e-3
+ self.weight_decay = 1e-6
+ self.layer_types = ['fc']
+ self.layer_channels = [None, 10]
+ self.activation_functions = ['identity']
+ self.dropout_rate = [0.0] # probability of value being set to zero
+ self.optimizer = types.SimpleNamespace()
+ self.optimizer.name = 'sgd'
+ self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs
+ self.optimizer.lr_decay_rate = 0.8
+ self.compute_helper_params()
+
+
+class params(BaseParams):
+ def set_params(self):
+ super(params, self).set_params()
+ lca_params_inst = lca_params()
+ mlp_params_inst = mlp_params()
+ lca_output_height = compute_conv_output_shape(
+ 32, # TODO: infer this? currently hardcoded CIFAR10 size
+ lca_params_inst.kernel_size,
+ lca_params_inst.stride,
+ lca_params_inst.padding,
+ dilation=1)
+ lca_output_width = compute_conv_output_shape(
+ 32,
+ lca_params_inst.kernel_size,
+ lca_params_inst.stride,
+ lca_params_inst.padding,
+ dilation=1)
+ lca_output_shape = [lca_params_inst.layer_channels[1], lca_output_height, lca_output_width]
+ mlp_params_inst.layer_channels[0] = np.prod(lca_output_shape)
+ self.ensemble_params = [lca_params_inst, mlp_params_inst]
+ for key, value in shared_params().__dict__.items():
+ setattr(self, key, value)
diff --git a/params/lca_mlp_mnist_params.py b/params/lca_mlp_mnist_params.py
index 212de71f..6f230aac 100644
--- a/params/lca_mlp_mnist_params.py
+++ b/params/lca_mlp_mnist_params.py
@@ -14,11 +14,13 @@ def __init__(self):
self.model_name = 'lca_768_mlp_mnist'
self.version = '0'
self.dataset = 'mnist'
+ self.fast_mnist = True
self.standardize_data = False
self.num_pixels = 28*28*1
self.batch_size = 100
self.num_epochs = 1200
self.train_logs_per_epoch = 4
+ self.allow_parent_grads = False
class lca_params(LcaParams):
@@ -27,6 +29,7 @@ def set_params(self):
for key, value in shared_params().__dict__.items():
setattr(self, key, value)
self.model_type = 'lca'
+ self.layer_name = 'lca'
self.weight_decay = 0.0
self.weight_lr = 0.1
self.optimizer = types.SimpleNamespace()
@@ -34,14 +37,14 @@ def set_params(self):
self.optimizer.lr_annealing_milestone_frac = [0.7] # fraction of num_epochs
self.optimizer.lr_decay_rate = 0.5
self.renormalize_weights = True
+ self.layer_channels = [1, 768]
self.dt = 0.001
self.tau = 0.03
self.num_steps = 75
self.rectify_a = True
self.thresh_type = 'soft'
self.sparse_mult = 0.25
- self.num_latent = 768#self.num_pixels*4
- #self.allow_parent_grads = False # TODO: enable this param
+ self.checkpoint_boot_log = ''
self.compute_helper_params()
@@ -51,10 +54,11 @@ def set_params(self):
for key, value in shared_params().__dict__.items():
setattr(self, key, value)
self.model_type = 'mlp'
+ self.layer_name = 'classifier'
self.weight_lr = 1e-4
self.weight_decay = 0.0
self.layer_types = ['fc']
- self.layer_channels = [768, 10]#[self.num_pixels*4, 10]
+ self.layer_channels = [768, 10]
self.activation_functions = ['identity']
self.dropout_rate = [0.0] # probability of value being set to zero
self.optimizer = types.SimpleNamespace()
diff --git a/params/lca_mnist_params.py b/params/lca_mnist_params.py
index dd9cdf34..6274ef54 100644
--- a/params/lca_mnist_params.py
+++ b/params/lca_mnist_params.py
@@ -1,38 +1,54 @@
-import os
import types
-import numpy as np
-import torch
-
from DeepSparseCoding.params.base_params import BaseParams
+CONV = True
+
+
class params(BaseParams):
def set_params(self):
super(params, self).set_params()
self.model_type = 'lca'
- self.model_name = 'lca_768_mnist'
self.version = '0'
self.dataset = 'mnist'
+ self.fast_mnist = True
self.standardize_data = False
self.num_pixels = 784
- self.batch_size = 100
- self.num_epochs = 1000
- self.weight_decay = 0.
- self.weight_lr = 0.1
- self.train_logs_per_epoch = 6
- self.optimizer = types.SimpleNamespace()
- self.optimizer.name = 'sgd'
- self.optimizer.lr_annealing_milestone_frac = [0.7] # fraction of num_epochs
- self.optimizer.lr_decay_rate = 0.5
- self.renormalize_weights = True
self.dt = 0.001
self.tau = 0.03
self.num_steps = 75
self.rectify_a = True
self.thresh_type = 'soft'
self.sparse_mult = 0.25
- self.num_latent = 768#self.num_pixels*4
+ self.renormalize_weights = True
+ self.optimizer = types.SimpleNamespace()
+ self.optimizer.name = 'sgd'
+ self.num_epochs = 1000
+ self.weight_decay = 0.0
+ self.train_logs_per_epoch = 6
+ if CONV:
+ self.layer_types = ['conv']
+ self.model_name = 'conv_lca_mnist'
+ self.rescale_data_to_one = True
+ self.batch_size = 50
+ self.weight_lr = 0.001
+ self.optimizer.lr_annealing_milestone_frac = [0.8] # fraction of num_epochs
+ self.optimizer.lr_decay_rate = 0.8
+ self.layer_channels = [1, 128]
+ self.kernel_size = 8
+ self.stride = 2
+ self.padding = 0
+ else:
+ self.layer_types = ['fc']
+ self.model_type = 'lca'
+ self.model_name = 'lca_768_mnist'
+ self.rescale_data_to_one = False
+ self.batch_size = 100
+ self.weight_lr = 0.1
+ self.optimizer.lr_annealing_milestone_frac = [0.7] # fraction of num_epochs
+ self.optimizer.lr_decay_rate = 0.5
+ self.layer_channels = [1, 768] #self.num_pixels * 4
self.compute_helper_params()
def compute_helper_params(self):
diff --git a/params/mlp_cifar10_params.py b/params/mlp_cifar10_params.py
new file mode 100644
index 00000000..24aa8f29
--- /dev/null
+++ b/params/mlp_cifar10_params.py
@@ -0,0 +1,43 @@
+import os
+import types
+
+import numpy as np
+import torch
+
+from DeepSparseCoding.params.base_params import BaseParams
+
+
+class params(BaseParams):
+ def set_params(self):
+ super(params, self).set_params()
+ self.model_type = 'mlp'
+ self.model_name = 'mlp_cifar10'
+ self.version = '0'
+ self.dataset = 'cifar10'
+ self.standardize_data = True
+ self.rescale_data_to_one = False
+ self.center_data = False
+ self.num_validation = 1000
+ self.batch_size = 50
+ self.num_epochs = 500
+ self.weight_decay = 3e-6
+ self.weight_lr = 2e-3
+ self.layer_types = ['conv', 'fc']
+ self.layer_channels = [3, 512, 10]
+ self.kernel_sizes = [8, None]
+ self.strides = [2, None]
+ self.activation_functions = ['lrelu', 'identity']
+ self.dropout_rate = [0.5, 0.0] # probability of value being set to zero
+ self.max_pool = [True, False]
+ self.pool_ksizes = [5, None]
+ self.pool_strides = [4, None]
+ self.train_logs_per_epoch = 4
+ self.optimizer = types.SimpleNamespace()
+ self.optimizer.name = 'adam'
+ self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs
+ self.optimizer.lr_decay_rate = 0.1
+
+ def compute_helper_params(self):
+ super(params, self).compute_helper_params()
+ self.optimizer.milestones = [frac * self.num_epochs
+ for frac in self.optimizer.lr_annealing_milestone_frac]
diff --git a/params/mlp_mnist_params.py b/params/mlp_mnist_params.py
index 901ab60e..dcc7e195 100644
--- a/params/mlp_mnist_params.py
+++ b/params/mlp_mnist_params.py
@@ -14,17 +14,19 @@ def set_params(self):
self.model_name = 'mlp_768_mnist'
self.version = '0'
self.dataset = 'mnist'
+ self.fast_mnist = True
self.standardize_data = False
self.rescale_data_to_one = False
- self.num_pixels = 28*28*1
self.batch_size = 50
self.num_epochs = 300
self.weight_lr = 5e-4
self.weight_decay = 2e-6
self.layer_types = ['fc', 'fc']
+ self.num_pixels = 28*28*1
self.layer_channels = [self.num_pixels, 768, 10]
self.activation_functions = ['lrelu', 'identity']
self.dropout_rate = [0.5, 0.0] # probability of value being set to zero
+ self.max_pool = [False, False]
self.train_logs_per_epoch = 4
self.optimizer = types.SimpleNamespace()
self.optimizer.name = 'adam'
diff --git a/params/smt_cifar10_params.py b/params/smt_cifar10_params.py
new file mode 100644
index 00000000..940d4fcf
--- /dev/null
+++ b/params/smt_cifar10_params.py
@@ -0,0 +1,259 @@
+import os
+import types
+
+import numpy as np
+import torch
+
+from DeepSparseCoding.params.base_params import BaseParams
+from DeepSparseCoding.params.lca_cifar10_params import params as LcaParams
+from DeepSparseCoding.params.mlp_mnist_params import params as MlpParams
+from DeepSparseCoding.utils.run_utils import compute_conv_output_shape
+
+
+class shared_params(object):
+ def __init__(self):
+ self.model_type = 'ensemble'
+ self.model_name = 'test_smt_cifar10'
+ #self.version = 'lplpm'
+ self.version = '2lp'
+ self.dataset = 'cifar10'
+ self.standardize_data = True
+ self.rescale_data_to_one = False
+ self.center_dataset = False
+ self.batch_size = 30
+ self.num_epochs = 200
+ self.train_logs_per_epoch = 4
+ self.allow_parent_grads = False
+
+
+class lca_1_params(LcaParams):
+ def set_params(self):
+ super(lca_1_params, self).set_params()
+ for key, value in shared_params().__dict__.items(): setattr(self, key, value)
+ self.model_type = 'lca'
+ self.layer_name = 'lca_1'
+ self.layer_types = ['conv']
+ self.weight_decay = 0.0
+ #self.weight_lr = 1e-3
+ self.weight_lr = 0.0 # For next layer training
+ self.renormalize_weights = True
+ #self.layer_channels = [3, 128]
+ self.layer_channels = [3, 256]
+ self.kernel_size = 8
+ #self.stride = 2
+ self.stride = 1
+ self.padding = 0
+ self.optimizer = types.SimpleNamespace()
+ self.optimizer.name = 'sgd'
+ self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs
+ self.optimizer.lr_decay_rate = 0.8
+ self.dt = 0.001
+ #self.tau = 0.2#0.10
+ self.tau = 0.25
+ #self.num_steps = 75#37
+ self.num_steps = 75
+ self.rectify_a = True
+ self.thresh_type = 'hard'
+ #self.sparse_mult = 0.35
+ self.sparse_mult = 0.28
+ #self.checkpoint_boot_log = ''
+ #self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/smt_cifar10/logfiles/smt_cifar10_vlplp.log'
+ self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/smt_cifar10/logfiles/smt_cifar10_v2l.log'
+ self.compute_helper_params()
+
+
+class pooling_1_params(BaseParams):
+ def set_params(self):
+ super(pooling_1_params, self).set_params()
+ for key, value in shared_params().__dict__.items():
+ setattr(self, key, value)
+ self.model_type = 'pooling'
+ self.layer_name = 'pool_1'
+ self.layer_types = ['conv']
+ self.weight_lr = 1e-3
+ #self.weight_lr = 0.0 # For next layer training
+ #self.layer_channels = [128, 32]
+ self.layer_channels = [256, 32]
+ #self.pool_ksize = 2
+ self.pool_ksize = 4
+ self.pool_stride = 2
+ self.renormalize_weights = True
+ self.optimizer = types.SimpleNamespace()
+ self.optimizer.name = 'sgd'
+ self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs
+ self.optimizer.lr_decay_rate = 0.8
+ self.checkpoint_boot_log = ''
+ #self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/smt_cifar10/logfiles/smt_cifar10_vlplp.log'
+ #self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/smt_cifar10/logfiles/smt_cifar10_v2lp.log'
+ self.compute_helper_params()
+
+ def compute_helper_params(self):
+ super(pooling_1_params, self).compute_helper_params()
+ self.optimizer.milestones = [frac * self.num_epochs
+ for frac in self.optimizer.lr_annealing_milestone_frac]
+
+
+class lca_2_params(LcaParams):
+ def set_params(self):
+ super(lca_2_params, self).set_params()
+ for key, value in shared_params().__dict__.items(): setattr(self, key, value)
+ for key, value in lca_1_params().__dict__.items(): setattr(self, key, value)
+ self.layer_name = 'lca_2'
+ self.weight_lr = 1e-3
+ #self.weight_lr = 0.0 # For next layer training
+ #self.layer_channels = [32, 256]
+ self.layer_channels = [32, 512]
+ #self.kernel_size = 6
+ self.kernel_size = 8
+ self.stride = 1
+ self.padding = 0
+ self.sparse_mult = 0.15
+ self.tau = 0.20
+ self.checkpoint_boot_log = ''
+ #self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/smt_cifar10/logfiles/smt_cifar10_vlplp.log'
+ self.compute_helper_params()
+
+class pooling_2_params(BaseParams):
+ def set_params(self):
+ super(pooling_2_params, self).set_params()
+ for key, value in shared_params().__dict__.items(): setattr(self, key, value)
+ for key, value in pooling_1_params().__dict__.items(): setattr(self, key, value)
+ self.layer_name = 'pool_2'
+ self.weight_lr = 1e-3
+ #self.weight_lr = 0.0 # For next layer training
+ #self.layer_types = ['fc']
+ self.layer_types = ['conv']
+ #self.layer_channels = [None, 64]
+ self.layer_channels = [512, 150]
+ self.pool_ksize = 4
+ self.pool_stride = 1
+ self.optimizer = types.SimpleNamespace()
+ self.optimizer.name = 'sgd'
+ self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs
+ self.optimizer.lr_decay_rate = 0.8
+ self.checkpoint_boot_log = ''
+ #self.checkpoint_boot_log = '/mnt/qb/bethge/dpaiton/Projects/smt_cifar10/logfiles/smt_cifar10_vlplp.log'
+ self.compute_helper_params()
+
+ def compute_helper_params(self):
+ super(pooling_2_params, self).compute_helper_params()
+ self.optimizer.milestones = [frac * self.num_epochs
+ for frac in self.optimizer.lr_annealing_milestone_frac]
+
+
+class mlp_params(MlpParams):
+ def set_params(self):
+ super(mlp_params, self).set_params()
+ for key, value in shared_params().__dict__.items(): setattr(self, key, value)
+ self.model_type = 'mlp'
+ self.layer_name = 'classifier'
+ self.weight_lr = 1e-2
+ self.weight_decay = 1e-6
+ self.layer_types = ['fc']
+ #self.layer_channels = [64, 10]
+ self.layer_channels = [150, 10]
+ self.activation_functions = ['identity']
+ self.dropout_rate = [0.0] # probability of value being set to zero
+ self.optimizer = types.SimpleNamespace()
+ self.optimizer.name = 'sgd'
+ self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs
+ self.optimizer.lr_decay_rate = 0.8
+ self.compute_helper_params()
+
+
+class params(BaseParams):
+ def set_params(self):
+ super(params, self).set_params()
+ lca_1_params_inst = lca_1_params()
+ pooling_1_params_inst = pooling_1_params()
+ lca_2_params_inst = lca_2_params()
+ pooling_2_params_inst = pooling_2_params()
+ mlp_params_inst = mlp_params()
+ data_shape = [3, 32, 32]
+ lca_1_output_height = compute_conv_output_shape(
+ data_shape[1],
+ lca_1_params_inst.kernel_size,
+ lca_1_params_inst.stride,
+ lca_1_params_inst.padding,
+ dilation=1)
+ lca_1_output_width = compute_conv_output_shape(
+ data_shape[2],
+ lca_1_params_inst.kernel_size,
+ lca_1_params_inst.stride,
+ lca_1_params_inst.padding,
+ dilation=1)
+ lca_1_shape = [
+ lca_1_params_inst.layer_channels[-1],
+ lca_1_output_height,
+ lca_1_output_width
+ ]
+ pooling_1_output_height = compute_conv_output_shape(
+ lca_1_output_height,
+ pooling_1_params_inst.pool_ksize,
+ pooling_1_params_inst.pool_stride,
+ padding=0,
+ dilation=1)
+ pooling_1_output_width = compute_conv_output_shape(
+ lca_1_output_width,
+ pooling_1_params_inst.pool_ksize,
+ pooling_1_params_inst.pool_stride,
+ padding=0,
+ dilation=1)
+ pooling_1_shape = [
+ pooling_1_params_inst.layer_channels[-1],
+ pooling_1_output_height,
+ pooling_1_output_width
+ ]
+ lca_2_params_inst.data_shape = [
+ int(pooling_1_params_inst.layer_channels[-1]),
+ int(pooling_1_output_height),
+ int(pooling_1_output_width)]
+ lca_2_output_height = compute_conv_output_shape(
+ pooling_1_output_height,
+ lca_2_params_inst.kernel_size,
+ lca_2_params_inst.stride,
+ lca_2_params_inst.padding,
+ dilation=1)
+ lca_2_output_width = compute_conv_output_shape(
+ pooling_1_output_width,
+ lca_2_params_inst.kernel_size,
+ lca_2_params_inst.stride,
+ lca_2_params_inst.padding,
+ dilation=1)
+ lca_2_shape = [
+ lca_2_params_inst.layer_channels[-1],
+ lca_2_output_height,
+ lca_2_output_width
+ ]
+ lca_2_flat_dim = int(np.prod(lca_2_shape))
+ pooling_2_params_inst.layer_channels[0] = lca_2_flat_dim
+ pooling_2_output_height = compute_conv_output_shape(
+ lca_2_output_height,
+ pooling_2_params_inst.pool_ksize,
+ pooling_2_params_inst.pool_stride,
+ padding=0,
+ dilation=1)
+ pooling_2_output_width = compute_conv_output_shape(
+ lca_2_output_width,
+ pooling_2_params_inst.pool_ksize,
+ pooling_2_params_inst.pool_stride,
+ padding=0,
+ dilation=1)
+ pooling_2_shape = [
+ pooling_2_params_inst.layer_channels[-1],
+ pooling_2_output_height,
+ pooling_2_output_width
+ ]
+ l1_overcompleteness = np.prod(lca_1_shape) / np.prod(data_shape)
+ p1_overcompleteness = np.prod(pooling_1_shape) / np.prod(lca_1_shape)
+ l2_overcompleteness = np.prod(lca_2_shape) / np.prod(pooling_1_shape)
+ p2_overcompleteness = np.prod(pooling_2_shape) / np.prod(lca_2_shape)
+ self.ensemble_params = [
+ lca_1_params_inst,
+ pooling_1_params_inst,
+ #lca_2_params_inst,
+ #pooling_2_params_inst,
+ #mlp_params_inst
+ ]
+ for key, value in shared_params().__dict__.items():
+ setattr(self, key, value)
diff --git a/params/test_params.py b/params/test_params.py
index 01a687b5..80a25b6b 100644
--- a/params/test_params.py
+++ b/params/test_params.py
@@ -1,13 +1,14 @@
import os
import sys
import types
+from os.path import dirname as up
+
+ROOT_DIR = up(up(up(os.path.realpath(__file__))))
+if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
import numpy as np
import torch
-ROOT_DIR = os.path.dirname(os.getcwd())
-if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
-
from DeepSparseCoding.params.base_params import BaseParams
from DeepSparseCoding.params.lca_mnist_params import params as LcaParams
from DeepSparseCoding.params.mlp_mnist_params import params as MlpParams
@@ -35,6 +36,7 @@ def __init__(self):
self.num_test_images = 0
self.standardize_data = False
self.rescale_data_to_one = False
+ self.allow_parent_grads = False
self.num_epochs = 3
self.train_logs_per_epoch = 1
@@ -42,18 +44,17 @@ def __init__(self):
class base_params(BaseParams):
def set_params(self):
super(base_params, self).set_params()
- for key, value in shared_params().__dict__.items():
- setattr(self, key, value)
+ for key, value in shared_params().__dict__.items(): setattr(self, key, value)
class lca_params(BaseParams):
def set_params(self):
super(lca_params, self).set_params()
- for key, value in shared_params().__dict__.items():
- setattr(self, key, value)
+ for key, value in shared_params().__dict__.items(): setattr(self, key, value)
self.model_type = 'lca'
self.weight_decay = 0.0
self.weight_lr = 0.1
+ self.layer_types = ['fc']
self.optimizer = types.SimpleNamespace()
self.optimizer.name = 'sgd'
self.optimizer.lr_annealing_milestone_frac = [0.7] # fraction of num_epochs
@@ -65,17 +66,49 @@ def set_params(self):
self.rectify_a = True
self.thresh_type = 'soft'
self.sparse_mult = 0.25
- self.num_latent = 128
+ self.layer_channels = [64, 128]
self.optimizer.milestones = [frac * self.num_epochs
for frac in self.optimizer.lr_annealing_milestone_frac]
self.step_size = self.dt / self.tau
+# TODO: Add ability to test multiple param values
+#class conv_lca_params(lca_params):
+# def set_params(self):
+# super(conv_lca_params, self).set_params()
+# self.layer_types = ['conv']
+# self.kernel_size = 8
+# self.stride = 2
+# self.padding = 0
+# self.optimizer.milestones = [frac * self.num_epochs
+# for frac in self.optimizer.lr_annealing_milestone_frac]
+# self.step_size = self.dt / self.tau
+# self.out_channels = self.layer_channels
+# self.in_channels = 1
+
+
+class pooling_params(BaseParams):
+ def set_params(self):
+ super(pooling_params, self).set_params()
+ for key, value in shared_params().__dict__.items(): setattr(self, key, value)
+ self.model_type = 'pooling'
+ self.layer_name = 'test_pool_1'
+ self.weight_lr = 1e-3
+ self.layer_types = ['conv']
+ self.layer_channels = [128, 32]
+ self.pool_ksize = 2
+ self.pool_stride = 2 # non-overlapping
+ self.optimizer = types.SimpleNamespace()
+ self.optimizer.name = 'sgd'
+ self.optimizer.lr_annealing_milestone_frac = [0.3] # fraction of num_epochs
+ self.optimizer.lr_decay_rate = 0.8
+ self.optimizer.milestones = [frac * self.num_epochs
+ for frac in self.optimizer.lr_annealing_milestone_frac]
+
class mlp_params(BaseParams):
def set_params(self):
super(mlp_params, self).set_params()
- for key, value in shared_params().__dict__.items():
- setattr(self, key, value)
+ for key, value in shared_params().__dict__.items(): setattr(self, key, value)
self.model_type = 'mlp'
self.weight_lr = 1e-4
self.weight_decay = 0.0
@@ -83,6 +116,7 @@ def set_params(self):
self.layer_channels = [128, 10]
self.activation_functions = ['identity']
self.dropout_rate = [0.0] # probability of value being set to zero
+ self.max_pool = [False]
self.optimizer = types.SimpleNamespace()
self.optimizer.name = 'adam'
self.optimizer.lr_annealing_milestone_frac = [0.8] # fraction of num_epochs
@@ -94,6 +128,9 @@ def set_params(self):
class ensemble_params(BaseParams):
def set_params(self):
super(ensemble_params, self).set_params()
- self.ensemble_params = [lca_params(), mlp_params()]
- for key, value in shared_params().__dict__.items():
- setattr(self, key, value)
+ layer1_params = lca_params()
+ layer1_params.layer_name = 'layer1'
+ layer2_params = mlp_params()
+ layer2_params.layer_name = 'layer2'
+ self.ensemble_params = [layer1_params, layer2_params]
+ for key, value in shared_params().__dict__.items(): setattr(self, key, value)
diff --git a/requirements.txt b/requirements.txt
index f7e76024..02e1795b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -11,9 +11,4 @@ Pillow>=5.3.0
scikit-image>=0.14.1
scikit-learn>=0.20.0
scipy>=1.1.0
-seaborn>=0.9.0
-tensorflow-gpu==1.15.2
-tensorflow-estimator==1.15.1
-tensorboard==1.15
-tensorflow-probability==0.8.0
-tensorflow-compression
+seaborn>=0.9.0
\ No newline at end of file
diff --git a/tests/test_data_processing.py b/tests/test_data_processing.py
index a615b4d0..7a3f8af6 100644
--- a/tests/test_data_processing.py
+++ b/tests/test_data_processing.py
@@ -1,15 +1,14 @@
import os
import sys
import unittest
+from os.path import dirname as up
+ROOT_DIR = up(up(up(os.path.realpath(__file__))))
+if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
import torch
import numpy as np
-
-ROOT_DIR = os.path.dirname(os.getcwd())
-if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
-
import DeepSparseCoding.utils.data_processing as dp
@@ -21,9 +20,9 @@ def test_reshape_data(self):
function call: reshape_data(data, flatten=None, out_shape=None):
24 possible conditions:
data: [np.ndarray] data of shape:
- n is num_examples, i is num_rows, j is num_cols, k is num_channels, l is num_examples = i*j*k
- (l) - single data point of shape l, assumes 1 color channel
- (n, l) - n data points, each of shape l (flattened)
+ n is num_examples, i is num_channels, j is num_rows, k is num_cols
+ (i) - single data point of shape i
+ (n, i) - n data points, each of shape i (flattened)
(i, j, k) - single datapoint of of shape (i, j, k)
(n, i, j, k) - n data points, each of shape (i,j,k)
flatten: True, False, None
@@ -44,8 +43,8 @@ def test_reshape_data(self):
input_array_list = [
np.zeros((num_elements)), # assumed num_examples == 1
np.zeros((num_examples, num_elements)),
- np.zeros((num_rows, num_cols, num_channels)), # assumed num_examples == 1
- np.zeros((num_examples, num_rows, num_cols, num_channels))]
+ np.zeros((num_channels, num_rows, num_cols)), # assumed num_examples == 1
+ np.zeros((num_examples, num_channels, num_rows, num_cols))]
for input_array in input_array_list:
input_shape = input_array.shape
input_ndim = input_array.ndim
@@ -54,7 +53,7 @@ def test_reshape_data(self):
out_shape_list = [
None,
(num_elements,),
- (num_rows, num_cols, num_channels)]
+ (num_channels, num_rows, num_cols)]
if(num_channels == 1):
out_shape_list.append((num_rows, num_cols))
else:
@@ -62,7 +61,7 @@ def test_reshape_data(self):
out_shape_list = [
None,
(num_examples, num_elements),
- (num_examples, num_rows, num_cols, num_channels)]
+ (num_examples, num_channels, num_rows, num_cols)]
if(num_channels == 1):
out_shape_list.append((num_examples, num_rows, num_cols))
for out_shape in out_shape_list:
@@ -83,7 +82,7 @@ def test_reshape_data(self):
reshaped_array = reshape_outputs[0].numpy()
err_msg += f'\nreshaped_array.shape={reshaped_array.shape}'
self.assertEqual(reshape_outputs[1], input_shape, err_msg) # orig_shape
- (resh_num_examples, resh_num_rows, resh_num_cols, resh_num_channels) = reshape_outputs[2:]
+ (resh_num_examples, resh_num_channels, resh_num_rows, resh_num_cols) = reshape_outputs[2:]
err_msg += (f'\nfunction_shape_outputs={reshape_outputs[2:]}')
if(out_shape is None):
if(flatten is None):
@@ -105,26 +104,26 @@ def test_reshape_data(self):
expected_out_shape,
err_msg)
self.assertEqual(
- resh_num_rows*resh_num_cols*resh_num_channels,
+ resh_num_channels * resh_num_rows * resh_num_cols,
expected_out_shape[1],
err_msg)
elif(flatten == False):
- expected_out_shape = (num_examples, num_rows, num_cols, num_channels)
+ expected_out_shape = (num_examples, num_channels, num_rows, num_cols)
err_msg += f'\nexpected_out_shape={expected_out_shape}'
self.assertEqual(
reshaped_array.shape,
expected_out_shape,
err_msg)
self.assertEqual(
- resh_num_rows,
+ resh_num_channels,
expected_out_shape[1],
err_msg)
self.assertEqual(
- resh_num_cols,
+ resh_num_rows,
expected_out_shape[2],
err_msg)
self.assertEqual(
- resh_num_channels,
+ resh_num_cols,
expected_out_shape[3],
err_msg)
else:
@@ -135,22 +134,12 @@ def test_reshape_data(self):
self.assertEqual(reshaped_array.shape, expected_out_shape, err_msg)
self.assertEqual(resh_num_examples, None, err_msg)
-
- def test_flatten_feature_map(self):
- unflat_shape = [8, 4, 4, 3]
- flat_shape = [8, 4*4*3]
- shapes = [unflat_shape, flat_shape]
- for shape in shapes:
- test_map = torch.zeros(shape)
- flat_map = dp.flatten_feature_map(test_map).numpy()
- self.assertEqual(list(flat_map.shape), flat_shape)
-
def test_standardize(self):
num_tolerance_decimals = 5
unflat_shape = [8, 4, 4, 3]
flat_shape = [8, 4*4*3]
shape_options = [unflat_shape, flat_shape]
- eps_options = [1e-6, None]
+ eps_options = [1e-8, None]
samplewise_options = [True, False]
for shape in shape_options:
for eps_val in eps_options:
@@ -237,3 +226,40 @@ def test_label_conversion(self):
np.testing.assert_equal(func_dense, dense_labels)
func_one_hot = dp.dense_to_one_hot(torch.tensor(dense_labels), num_classes).numpy()
np.testing.assert_equal(func_one_hot, one_hot_labels)
+
+ def test_atleastkd(self):
+ x = np.random.standard_normal([2, 3, 4])
+ ks = [0, 3, 8, 10]
+ for k in ks:
+ new_x = dp.atleast_kd(torch.tensor(x), k).numpy()
+ test_nd = np.maximum(k, x.ndim)
+ np.testing.assert_equal(new_x.ndim, test_nd)
+
+ def test_l2_weight_norm(self):
+ w_fc = np.random.standard_normal([38, 24])
+ w_conv = np.random.standard_normal([38, 24, 8, 8])
+ for w in [w_fc, w_conv]:
+ w_norm = dp.get_weights_l2_norm(torch.tensor(w), eps=1e-12).numpy()
+ normed_w = dp.l2_normalize_weights(torch.tensor(w), eps=1e-12).numpy()
+ normed_w_norm = dp.get_weights_l2_norm(torch.tensor(normed_w), eps=1e-12).numpy()
+ np.testing.assert_allclose(normed_w_norm, 1.0, rtol=1e-10)
+ np.testing.assert_allclose(w / w_norm, normed_w, rtol=1e-10)
+
+ def test_patches(self):
+ err = 1e-6
+ rand_mean = 0; rand_var = 1
+ num_im = 10; im_edge = 512; im_chan = 1; patch_edge = 16
+ num_patches = int(num_im * (im_edge / patch_edge)**2)
+ rand_seed = 1234
+ rand_state = np.random.RandomState(rand_seed)
+ data = np.stack([rand_state.normal(rand_mean, rand_var, size=[im_chan, im_edge, im_edge])
+ for _ in range(num_im)])
+ data_shape = list(data.shape)
+ patch_shape = [im_chan, patch_edge, patch_edge]
+ datapoint = torch.tensor(data[0, ...])
+ datapoint_patches = dp.single_image_to_patches(datapoint, patch_shape)
+ datapoint_recon = dp.patches_to_single_image(datapoint_patches, data_shape[1:])
+ np.testing.assert_allclose(datapoint.numpy(), datapoint_recon.numpy(), rtol=err)
+ patches = dp.images_to_patches(torch.tensor(data), patch_shape)
+ data_recon = dp.patches_to_images(patches, data_shape[1:])
+ np.testing.assert_allclose(data, data_recon.numpy(), rtol=err)
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index d1c0cde4..c2ac7ea2 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -2,13 +2,14 @@
import sys
import unittest
import types
+from os.path import dirname as up
+
+ROOT_DIR = up(up(up(os.path.realpath(__file__))))
+if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
import numpy as np
from torchvision import datasets
-ROOT_DIR = os.path.dirname(os.getcwd())
-if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
-
import DeepSparseCoding.utils.dataset_utils as dataset_utils
class TestDatasets(unittest.TestCase):
@@ -33,7 +34,7 @@ def test_mnist(self):
params.dataset = 'mnist'
params.shuffle_data = True
params.batch_size = 10000
- train_loader, val_loader, test_loader, data_params = dataset_utils.load_dataset(params)
+ train_loader, val_loader, test_loader, data_params = dataset_utils.load_dataset(params)[:4]
for key, value in data_params.items():
setattr(params, key, value)
assert len(train_loader.dataset) == params.epoch_size
@@ -60,10 +61,15 @@ def test_synthetic(self):
params.dist_type = dist_type
params.num_classes = num_classes
params.rand_state = rand_state
- train_loader, val_loader, test_loader, data_params = dataset_utils.load_dataset(params)
+ train_loader, val_loader, test_loader, data_params = dataset_utils.load_dataset(params)[:4]
for key, value in data_params.items():
setattr(params, key, value)
assert len(train_loader.dataset) == epoch_size
for batch_idx, (data, target) in enumerate(train_loader):
- assert data.numpy().shape == (params.batch_size, params.data_edge_size, params.data_edge_size, 1)
+ expected_size = (
+ params.batch_size,
+ 1,
+ params.data_edge_size,
+ params.data_edge_size)
+ assert data.numpy().shape == expected_size
assert batch_idx + 1 == epoch_size // params.batch_size
diff --git a/tests/test_foolbox.py b/tests/test_foolbox.py
index 02dc807d..24ca9774 100644
--- a/tests/test_foolbox.py
+++ b/tests/test_foolbox.py
@@ -1,16 +1,17 @@
import os
import sys
import unittest
+from os.path import dirname as up
-#import numpy as np
-import eagerpy as ep
-from foolbox import PyTorchModel, accuracy, samples
-import foolbox.attacks as fa
-
-ROOT_DIR = os.path.dirname(os.getcwd())
+ROOT_DIR = up(up(up(os.path.realpath(__file__))))
if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
-import DeepSparseCoding.utils.loaders as loaders
+#import numpy as np
+#import eagerpy as ep
+#from foolbox import PyTorchModel, accuracy, samples
+#import foolbox.attacks as fa
+
+#import DeepSparseCoding.utils.loaders as loaders
#import DeepSparseCoding.utils.dataset_utils as datasets
#import DeepSparseCoding.utils.run_utils as run_utils
@@ -28,7 +29,7 @@
# 'steps':3}} # max perturbation it can reach is 0.5
# attack = fa.LinfPGD(**attack_params['linfPGD'])
# epsilons = [0.3] # allowed perturbation size
-# params['ensemble'] = loaders.load_params(self.test_params_file, key='ensemble_params')
+# params['ensemble'] = loaders.load_params_file(self.test_params_file, key='ensemble_params')
# params['ensemble'].train_logs_per_epoch = None
# params['ensemble'].shuffle_data = False
# train_loader, val_loader, test_loader, data_params = datasets.load_dataset(params['ensemble'])
@@ -42,4 +43,4 @@
# fmodel = PyTorchModel(model.eval(), bounds=(0, 1))
# model_output = fmodel.forward()
# adv_model_outputs, adv_images, success = attack(fmodel, train_data_batch, train_target_batch, epsilons=epsilons)
-#
\ No newline at end of file
+#
diff --git a/tests/test_models.py b/tests/test_models.py
index a5e16ccd..7fdd172a 100644
--- a/tests/test_models.py
+++ b/tests/test_models.py
@@ -1,12 +1,13 @@
import os
import sys
import unittest
+from os.path import dirname as up
-import numpy as np
-
-ROOT_DIR = os.path.dirname(os.getcwd())
+ROOT_DIR = up(up(up(os.path.realpath(__file__))))
if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
+import numpy as np
+
import DeepSparseCoding.utils.loaders as loaders
import DeepSparseCoding.utils.dataset_utils as datasets
import DeepSparseCoding.utils.run_utils as run_utils
@@ -18,37 +19,54 @@ def setUp(self):
self.model_list = loaders.get_model_list(self.dsc_dir)
self.test_params_file = os.path.join(*[self.dsc_dir, 'params', 'test_params.py'])
+ ### TODO - define endpoint function for checkpoint loading & test independently
+ ### TODO - add ability to test multiple options (e.g. 'conv' and 'fc') from test params
def test_model_loading(self):
for model_type in self.model_list:
- model_type = ''.join(model_type.split('_')[:-1]) # remove '_model' at the end
+ model_type = '_'.join(model_type.split('_')[:-1]) # remove '_model' at the end
model = loaders.load_model(model_type)
- params = loaders.load_params(self.test_params_file, key=model_type+'_params')
- train_loader, val_loader, test_loader, data_params = datasets.load_dataset(params)
+ params = loaders.load_params_file(self.test_params_file, key=model_type+'_params')
+ train_loader, val_loader, test_loader, data_params = datasets.load_dataset(params)[:4]
for key, value in data_params.items():
setattr(params, key, value)
model.setup(params)
- ### TODO - more basic test to compute gradients per model###
+
+ ### TODO - more basic test to compute gradients per model
#def test_gradients(self):
# for model_type in self.model_list:
# model_type = ''.join(model_type.split('_')[:-1]) # remove '_model' at the end
# model = loaders.load_model(model_type)
+ ### TODO - test for gradient blocking
+ #def test_get_module_encodings(self):
+ # """
+ # Test for gradient blocking in the get_module_encodings function
+
+ # construct test model1 & model2
+ # construct test ensemble model = model1 -> model2
+ # get encoding & grads for allow_grads={True, False}
+ # False: compare grads for model1 alone vs model1 in ensemble
+ # True: ensure that grad is different from model1 alone
+ # * Should also manually compute grads to compare?
+ # """
+ # # test should utilize run_utils.get_module_encodings()
+
+
def test_lca_ensemble_gradients(self):
+ ## Load models
params = {}
models = {}
- params['lca'] = loaders.load_params(self.test_params_file, key='lca_params')
+ params['lca'] = loaders.load_params_file(self.test_params_file, key='lca_params')
params['lca'].train_logs_per_epoch = None
params['lca'].shuffle_data = False
- train_loader, val_loader, test_loader, data_params = datasets.load_dataset(params['lca'])
- for key, value in data_params.items():
- setattr(params['lca'], key, value)
+ train_loader, val_loader, test_loader, data_params = datasets.load_dataset(params['lca'])[:4]
+ for key, value in data_params.items(): setattr(params['lca'], key, value)
models['lca'] = loaders.load_model(params['lca'].model_type)
models['lca'].setup(params['lca'])
models['lca'].to(params['lca'].device)
- params['ensemble'] = loaders.load_params(self.test_params_file, key='ensemble_params')
- for key, value in data_params.items():
- setattr(params['ensemble'], key, value)
+ params['ensemble'] = loaders.load_params_file(self.test_params_file, key='ensemble_params')
+ for key, value in data_params.items(): setattr(params['ensemble'], key, value)
err_msg = f'\ndata_shape={params["ensemble"].data_shape}'
err_msg += f'\nnum_pixels={params["ensemble"].num_pixels}'
err_msg += f'\nbatch_size={params["ensemble"].batch_size}'
@@ -56,9 +74,11 @@ def test_lca_ensemble_gradients(self):
models['ensemble'] = loaders.load_model(params['ensemble'].model_type)
models['ensemble'].setup(params['ensemble'])
models['ensemble'].to(params['ensemble'].device)
+ ## Overwrite weight initialization so that they have the same weights
ensemble_state_dict = models['ensemble'].state_dict()
- ensemble_state_dict['lca.w'] = models['lca'].w.clone()
+ ensemble_state_dict['layer1.weight'] = models['lca'].weight.clone()
models['ensemble'].load_state_dict(ensemble_state_dict)
+ ## Load data
data, target = next(iter(train_loader))
train_data_batch = models['lca'].preprocess_data(data.to(params['lca'].device))
train_target_batch = target.to(params['lca'].device)
@@ -66,27 +86,37 @@ def test_lca_ensemble_gradients(self):
for submodel in models['ensemble']:
submodel.optimizer.zero_grad()
inputs = [train_data_batch] # only the first model acts on input
+ ## Verify feedforward encodings
+ lca_encoding = models['lca'](inputs[0]).cpu().detach().numpy()
+ ensemble_encoding = models['ensemble'][0].get_encodings(inputs[0]).cpu().detach().numpy()
+ assert np.all(lca_encoding == ensemble_encoding), (err_msg+'\n'
+ +f'Forward encodings for lca and ensemble[0] should be equal, but are not')
+ ## Verify LCA loss
+ lca_loss = models['lca'].get_total_loss((inputs[0], train_target_batch))
+ ensemble_losses = [models['ensemble'].get_total_loss((inputs[0], train_target_batch), 0)]
+ lca_loss_val = lca_loss.cpu().detach().numpy()
+ ensemble_loss_val = ensemble_losses[0].cpu().detach().numpy()
+ assert lca_loss_val == ensemble_loss_val, (err_msg+'\n'
+ +f'Losses should be equal, but are lca={lca_loss_val} and ensemble={ensemble_loss_val}')
+ ## Compute remaining ensemble outputs
for submodel in models['ensemble']:
inputs.append(submodel.get_encodings(inputs[-1]).detach())
- lca_loss = models['lca'].get_total_loss((train_data_batch, train_target_batch))
- ensemble_losses = [models['ensemble'].get_total_loss((inputs[0], train_target_batch), 0)]
ensemble_losses.append(models['ensemble'].get_total_loss((inputs[1], train_target_batch), 1))
+ ## Verify lca grad & ensemble grad are equal
lca_loss.backward()
ensemble_losses[0].backward()
ensemble_losses[1].backward()
- lca_loss_val = lca_loss.cpu().detach().numpy()
- lca_w_grad = models['lca'].w.grad.cpu().numpy()
- ensemble_loss_val = ensemble_losses[0].cpu().detach().numpy()
- ensemble_w_grad = models['ensemble'][0].w.grad.cpu().numpy()
- assert lca_loss_val == ensemble_loss_val, (err_msg+'\n'
- +'Losses should be equal, but are lca={lca_loss_val} and ensemble={ensemble_loss_val}')
- assert np.all(lca_w_grad == ensemble_w_grad), (err_msg+'\nGrads should be equal, but are not.')
- lca_pre_train_w = models['lca'].w.cpu().detach().numpy().copy()
- ensemble_pre_train_w = models['ensemble'][0].w.cpu().detach().numpy().copy()
+ lca_w_grad = models['lca'].weight.grad.cpu().numpy()
+ ensemble_w_grad = models['ensemble'][0].weight.grad.cpu().numpy()
+ assert np.all(lca_w_grad == ensemble_w_grad), (err_msg+'\n'
+ +f'Grads should be equal, but are not.')
+ ## Verify weight updates are equal
+ lca_pre_train_w = models['lca'].weight.cpu().detach().numpy().copy()
+ ensemble_pre_train_w = models['ensemble'][0].weight.cpu().detach().numpy().copy()
run_utils.train_epoch(1, models['lca'], train_loader)
run_utils.train_epoch(1, models['ensemble'], train_loader)
- lca_w = models['lca'].w.cpu().detach().numpy().copy()
- ensemble_w = models['ensemble'][0].w.cpu().detach().numpy().copy()
+ lca_w = models['lca'].weight.cpu().detach().numpy().copy()
+ ensemble_w = models['ensemble'][0].weight.cpu().detach().numpy().copy()
assert np.all(lca_pre_train_w == ensemble_pre_train_w), (err_msg+'\n'
+"lca & ensemble weights are not equal before one epoch of training")
assert not np.all(lca_pre_train_w == lca_w), (err_msg+'\n'
diff --git a/tests/test_param_loading.py b/tests/test_param_loading.py
index 721b6a59..7d488a73 100644
--- a/tests/test_param_loading.py
+++ b/tests/test_param_loading.py
@@ -1,7 +1,8 @@
import os
import sys
+from os.path import dirname as up
-ROOT_DIR = os.path.dirname(os.getcwd())
+ROOT_DIR = up(up(up(os.path.realpath(__file__))))
if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
import DeepSparseCoding.utils.loaders as loaders
@@ -12,4 +13,4 @@ def test_param_loading():
for params_name in params_list:
if 'test_' not in params_name:
params_file = os.path.join(*[dsc_dir, 'params', params_name+'.py'])
- params = loaders.load_params(params_file, key='params')
+ params = loaders.load_params_file(params_file, key='params')
diff --git a/tf1x/additional_requirements.txt b/tf1x/additional_requirements.txt
new file mode 100644
index 00000000..8e5be756
--- /dev/null
+++ b/tf1x/additional_requirements.txt
@@ -0,0 +1,5 @@
+tensorflow-gpu==1.15.2
+tensorflow-estimator==1.15.1
+tensorboard==1.15
+tensorflow-probability==0.8.0
+tensorflow-compression
\ No newline at end of file
diff --git a/tf1x/analysis/iso_response_analysis.py b/tf1x/analysis/iso_response_analysis.py
index 754dd044..21cbd701 100644
--- a/tf1x/analysis/iso_response_analysis.py
+++ b/tf1x/analysis/iso_response_analysis.py
@@ -219,29 +219,32 @@ def __init__(self):
cont_analysis['min_angle'] = 15
cont_analysis['batch_size'] = 100
cont_analysis['vh_image_scale'] = 31.773287 # Mean of the l2 norm of the training set
- cont_analysis['comparison_method'] = 'closest' # rand or closest
-
+ cont_analysis['comparison_method'] = 'rand' # rand or closest
+ cont_analysis['measure_upper_right'] = False
+ cont_analysis['bounds'] = ((-1, 1), (-1, 1))
+ cont_analysis['target_act'] = 0.5
cont_analysis['num_neurons'] = 100 # How many neurons to plot
cont_analysis['num_comparisons'] = 300 # How many planes to construct (None is all of them)
cont_analysis['x_range'] = [-2.0, 2.0]
cont_analysis['y_range'] = [-2.0, 2.0]
cont_analysis['num_images'] = int(30**2)
- cont_analysis['params_list'] = [lca_512_vh_params()]
+ cont_analysis['params_list'] = [lca_1024_vh_params(), lca_2560_vh_params()]
#cont_analysis['params_list'] = [lca_768_vh_params()]
#cont_analysis['params_list'] = [lca_1024_vh_params()]
#cont_analysis['params_list'] = [lca_2560_vh_params()]
#cont_analysis['iso_save_name'] = "iso_curvature_xrange1.3_yrange-2.2_"
#cont_analysis['iso_save_name'] = "iso_curvature_ryan_"
- cont_analysis['iso_save_name'] = "rescaled_closecomp_"
+ cont_analysis['iso_save_name'] = "newfits_rescaled_randomcomp_"
#cont_analysis['iso_save_name'] = ''
- np.savez(save_root+'iso_params_'+cont_analysis['iso_save_name']+params.save_info+".npz",
- data=cont_analysis)
analyzer_list = [load_analyzer(params) for params in cont_analysis['params_list']]
for analyzer, params in zip(analyzer_list, cont_analysis['params_list']):
+ save_root=analyzer.analysis_out_dir+'savefiles/'
+ np.savez(save_root+'iso_params_'+cont_analysis['iso_save_name']+params.save_info+".npz",
+ data=cont_analysis)
print(analyzer.analysis_params.display_name)
print("Computing the iso-response vectors...")
cont_analysis['target_neuron_ids'] = iso_data.get_rand_target_neuron_ids(
@@ -297,7 +300,6 @@ def __init__(self):
datapoints,
get_dsc_activations_cell,
activation_function_kwargs)
- save_root=analyzer.analysis_out_dir+'savefiles/'
if use_rand_orth_vects:
np.savez(save_root+'iso_rand_activations_'+cont_analysis['iso_save_name']+params.save_info+'.npz',
data=activations)
@@ -310,10 +312,11 @@ def __init__(self):
data=contour_dataset)
cont_analysis['comparison_neuron_ids'] = analyzer.comparison_neuron_ids
cont_analysis['contour_dataset'] = contour_dataset
+ cont_analysis['activations'] = activations
curvatures, fits = hist_funcs.iso_response_curvature_poly_fits(
cont_analysis['activations'],
target_act=cont_analysis['target_act'],
- measure_upper_right=False
+ bounds=cont_analysis['bounds']
)
cont_analysis['curvatures'] = np.stack(np.stack(curvatures, axis=0), axis=0)
np.savez(save_root+'group_iso_vectors_'+cont_analysis['iso_save_name']+params.save_info+'.npz',
diff --git a/tf1x/analyze_model.py b/tf1x/analyze_model.py
index c8db1e28..a30975aa 100644
--- a/tf1x/analyze_model.py
+++ b/tf1x/analyze_model.py
@@ -1,13 +1,14 @@
import os
import sys
import argparse
+from os.path import dirname as up
+
+ROOT_DIR = up(up(up(os.path.realpath(__file__))))
+if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
import numpy as np
import tensorflow as tf
-ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
-if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
-
from DeepSparseCoding.tf1x.utils.logger import Logger
import DeepSparseCoding.tf1x.utils.data_processing as dp
import DeepSparseCoding.tf1x.data.data_selector as ds
diff --git a/tf1x/params/lca_conv_params.py b/tf1x/params/lca_conv_params.py
index cedd0d96..55f492d2 100644
--- a/tf1x/params/lca_conv_params.py
+++ b/tf1x/params/lca_conv_params.py
@@ -57,19 +57,24 @@ def set_data_params(self, data_type):
self.data_type = data_type
if data_type.lower() == "mnist":
self.model_name += "_mnist"
+ self.log_int = 200
self.rescale_data = True
self.center_data = False
self.whiten_data = False
self.lpf_data = False # only for ZCA
+ self.num_val = 0
+ self.batch_size = 50
self.lpf_cutoff = 0.7
- self.num_neurons = 768
+ self.num_neurons = 128
+ self.num_steps = 75
self.stride_y = 2
self.stride_x = 2
self.patch_size_y = 8 # weight receptive field
self.patch_size_x = 8
for schedule_idx in range(len(self.schedule)):
- self.schedule[schedule_idx]["sparse_mult"] = 0.21
- self.schedule[schedule_idx]["weight_lr"] = [0.1]
+ self.schedule[schedule_idx]["num_batches"] = int(6e5)
+ self.schedule[schedule_idx]["sparse_mult"] = 0.25
+ self.schedule[schedule_idx]["weight_lr"] = [0.001]
elif data_type.lower() == "vanhateren":
self.model_name += "_vh"
diff --git a/tf1x/tests/analysis/atas_test.py b/tf1x/tests/analysis/atas_test.py
index af071a2b..061a12aa 100644
--- a/tf1x/tests/analysis/atas_test.py
+++ b/tf1x/tests/analysis/atas_test.py
@@ -1,7 +1,8 @@
import os
import sys
+from os.path import dirname as up
-ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
+ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__))))))
if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
import numpy as np
diff --git a/tf1x/tests/data/data_selector_test.py b/tf1x/tests/data/data_selector_test.py
index 43864e26..76ec089b 100644
--- a/tf1x/tests/data/data_selector_test.py
+++ b/tf1x/tests/data/data_selector_test.py
@@ -1,7 +1,8 @@
import os
import sys
+from os.path import dirname as up
-ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
+ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__))))))
if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
import numpy as np
diff --git a/tf1x/tests/models/build_test.py b/tf1x/tests/models/build_test.py
index c2acca84..70ffb516 100644
--- a/tf1x/tests/models/build_test.py
+++ b/tf1x/tests/models/build_test.py
@@ -1,8 +1,9 @@
import copy
import os
import sys
+from os.path import dirname as up
-ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
+ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__))))))
if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
import numpy as np
diff --git a/tf1x/tests/models/comb_test.py b/tf1x/tests/models/comb_test.py
index f54396f4..6d10e833 100644
--- a/tf1x/tests/models/comb_test.py
+++ b/tf1x/tests/models/comb_test.py
@@ -1,8 +1,9 @@
import copy
import os
import sys
+from os.path import dirname as up
-ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
+ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__))))))
if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
import numpy as np
diff --git a/tf1x/tests/models/run_test.py b/tf1x/tests/models/run_test.py
index 20f379be..7342d63a 100644
--- a/tf1x/tests/models/run_test.py
+++ b/tf1x/tests/models/run_test.py
@@ -1,8 +1,9 @@
import copy
import os
import sys
+from os.path import dirname as up
-ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
+ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__))))))
if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
import numpy as np
diff --git a/tf1x/tests/utils/checkpoint_test.py b/tf1x/tests/utils/checkpoint_test.py
index 4833d59e..bd7c5617 100644
--- a/tf1x/tests/utils/checkpoint_test.py
+++ b/tf1x/tests/utils/checkpoint_test.py
@@ -1,7 +1,8 @@
import os
import sys
+from os.path import dirname as up
-ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
+ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__))))))
if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
import numpy as np
diff --git a/tf1x/tests/utils/contrast_normalize_test.py b/tf1x/tests/utils/contrast_normalize_test.py
index 3ec003fe..32712405 100644
--- a/tf1x/tests/utils/contrast_normalize_test.py
+++ b/tf1x/tests/utils/contrast_normalize_test.py
@@ -1,7 +1,8 @@
import os
import sys
+from os.path import dirname as up
-ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
+ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__))))))
if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
import numpy as np
diff --git a/tf1x/tests/utils/patches_test.py b/tf1x/tests/utils/patches_test.py
index 4639a162..ad0e7c78 100644
--- a/tf1x/tests/utils/patches_test.py
+++ b/tf1x/tests/utils/patches_test.py
@@ -1,7 +1,8 @@
import os
import sys
+from os.path import dirname as up
-ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
+ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__))))))
if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
import numpy as np
diff --git a/tf1x/tests/utils/reshape_data_test.py b/tf1x/tests/utils/reshape_data_test.py
index ff6922ad..18b76eed 100644
--- a/tf1x/tests/utils/reshape_data_test.py
+++ b/tf1x/tests/utils/reshape_data_test.py
@@ -1,7 +1,8 @@
import os
import sys
+from os.path import dirname as up
-ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
+ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__))))))
if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
import numpy as np
diff --git a/tf1x/tests/utils/standardize_data_test.py b/tf1x/tests/utils/standardize_data_test.py
index 524d3f18..e918bd1a 100644
--- a/tf1x/tests/utils/standardize_data_test.py
+++ b/tf1x/tests/utils/standardize_data_test.py
@@ -1,7 +1,8 @@
import os
import sys
+from os.path import dirname as up
-ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
+ROOT_DIR = up(up(up(up(up(os.path.realpath(__file__))))))
if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
import numpy as np
diff --git a/tf1x/train_model.py b/tf1x/train_model.py
index a11ff070..bc2cd08c 100644
--- a/tf1x/train_model.py
+++ b/tf1x/train_model.py
@@ -2,15 +2,16 @@
import sys
import time as ti
import argparse
+from os.path import dirname as up
+
+ROOT_DIR = up(up(up(os.path.realpath(__file__))))
+if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
import matplotlib
matplotlib.use("Agg")
import numpy as np
import tensorflow as tf
-ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
-if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
-
import DeepSparseCoding.tf1x.params.param_picker as pp
import DeepSparseCoding.tf1x.models.model_picker as mp
import DeepSparseCoding.tf1x.data.data_selector as ds
diff --git a/tf1x/utils/data_processing.py b/tf1x/utils/data_processing.py
index 1fd62606..999a8dcf 100644
--- a/tf1x/utils/data_processing.py
+++ b/tf1x/utils/data_processing.py
@@ -88,7 +88,7 @@ def reshape_data(data, flatten=None, out_shape=None):
if flatten == True:
data = np.reshape(data, (num_examples, num_rows*num_cols*num_channels))
else:
- assert False, ("Data must have 1, 2, 3, or 4 dimensions.")
+ assert False, (f'Data must have 1, 2, 3, or 4 dimensions, not {orig_ndim}')
else:
num_examples = None; num_rows=None; num_cols=None; num_channels=None
data = np.reshape(data, out_shape)
@@ -284,7 +284,7 @@ def generate_grating(patch_edge_size, location, diameter, orientation, frequency
"""
vals = np.linspace(-np.pi, np.pi, patch_edge_size)
X, Y = np.meshgrid(vals, vals)
- Xr = np.cos(orientation)*X + -np.sin(orientation)*Y # countercloclwise
+ Xr = np.cos(orientation)*X + -np.sin(orientation)*Y # counterclockwise
Yr = np.sin(orientation)*X + np.cos(orientation)*Y
stim = contrast*np.sin(Yr*frequency+phase)
if diameter > 0: # Generate mask
@@ -958,13 +958,13 @@ def pca_reduction(data, num_pcs=-1):
data_mean = data.mean(axis=(1))[:,None]
data -= data_mean
Cov = np.cov(data.T) # Covariace matrix
- U, S, V = np.linalg.svd(Cov) # SVD decomposition
+ U, S, VT = np.linalg.svd(Cov) # SVD decomposition
diagS = np.diag(S)
if num_pcs <= 0:
n = num_rows
else:
n = num_pcs
- data_reduc = np.dot(data, np.dot(np.dot(U[:, :n], diagS[:n, :n]), V[:n, :]))
+ data_reduc = np.dot(data, np.dot(np.dot(U[:, :n], diagS[:n, :n]), VT[:n, :]))
return data_reduc
def compute_power_spectrum(data):
diff --git a/tf1x/utils/jov_funcs.py b/tf1x/utils/jov_funcs.py
index be46539e..5dd8f78e 100644
--- a/tf1x/utils/jov_funcs.py
+++ b/tf1x/utils/jov_funcs.py
@@ -119,10 +119,6 @@ def plot_iso_contour_set(analysis_dict, curvatures, num_levels, num_x, num_y, sh
width=arrow_width, head_width=arrow_head_width, head_length=arrow_head_length,
fc='k', ec='k', linestyle='-', linewidth=arrow_linewidth)
tenth_range_shift = ((max(analysis_dict['x_range']) - min(analysis_dict['x_range']))/10) # For shifting labels
- #text_handle = curve_axes[-1].text(
- # target_vector_x+(tenth_range_shift*phi_k_text_x_offset),
- # target_vector_y+(tenth_range_shift*phi_k_text_y_offset),
- # r'$\Phi_{k}$', horizontalalignment='center', verticalalignment='center')
# plot comparison neuron arrow & label
proj_comparison = analysis_dict['contour_dataset']['proj_comparison_vect'][neuron_index][orth_index]
comparison_vector_x = proj_comparison[0].item()
@@ -130,10 +126,6 @@ def plot_iso_contour_set(analysis_dict, curvatures, num_levels, num_x, num_y, sh
curve_axes[-1].arrow(0, 0, comparison_vector_x, comparison_vector_y,
width=arrow_width, head_width=arrow_head_width, head_length=arrow_head_length,
fc='k', ec='k', linestyle='-', linewidth=arrow_linewidth)
- #text_handle = curve_axes[-1].text(
- # comparison_vector_x+(tenth_range_shift*phi_j_text_x_offset),
- # comparison_vector_y+(tenth_range_shift*phi_j_text_y_offset),
- # r'$\Phi_{j}$', horizontalalignment='center', verticalalignment='center')
# Plot orthogonal vector Nu
proj_orth = analysis_dict['contour_dataset']['proj_orth_vect'][neuron_index][orth_index]
orth_vector_x = proj_orth[0].item()
@@ -141,10 +133,6 @@ def plot_iso_contour_set(analysis_dict, curvatures, num_levels, num_x, num_y, sh
curve_axes[-1].arrow(0, 0, orth_vector_x, orth_vector_y,
width=arrow_width, head_width=arrow_head_width, head_length=arrow_head_length,
fc='k', ec='k', linestyle='-', linewidth=arrow_linewidth)
- #text_handle = curve_axes[-1].text(
- # orth_vector_x+(tenth_range_shift*nu_text_x_offset),
- # orth_vector_y+(tenth_range_shift*nu_text_y_offset),
- # r'$\nu$', horizontalalignment='center', verticalalignment='center')
# Plot axes
curve_axes[-1].set_aspect('equal')
curve_axes[-1].plot(analysis_dict['x_range'], [0,0], color='k', linewidth=arrow_linewidth/2)
@@ -153,14 +141,11 @@ def plot_iso_contour_set(analysis_dict, curvatures, num_levels, num_x, num_y, sh
k_idx = analysis_dict["target_neuron_ids"][neuron_index]
curv_val = curvatures[y_id, x_id]
curve_axes[-1].set_title(
- #f'k={k_idx}; j={j_idx}\nC={curv_val:.3f}',
f'C={curv_val:.3f}',
fontsize=rcParams['axes.titlesize']/2,
pad=2,
horizontalalignment='center'
)
- #curve_axes[-1].text(x=-0.08, y=1.75, s=f'C={curvatures[y_id, x_id]:.3f}',
- # horizontalalignment='right', verticalalignment='center', fontsize=6)
if y_id==0:
curve_axes[-1].set_ylabel(str(neuron_index), visible=True)
# Add colorbar
@@ -180,6 +165,123 @@ def plot_iso_contour_set(analysis_dict, curvatures, num_levels, num_x, num_y, sh
return fig, contour_handles
+def plot_curvature_histograms(activity, contour_pts, contour_angle, view_elevation, contour_text_loc, hist_list,
+ label_list, color_list, mesh_color, bin_centers, title, xlabel, curve_lims,
+ scatter, log=True, text_width=200, width_ratio=1.0, dpi=100):
+ gs0_wspace = 0.5
+ hspace_hist = 0.7
+ wspace_hist = 0.08
+ iso_response_line_thickness = 2
+ respone_attenuation_line_thickness = 2
+ num_y_plots = 2 if log else 1
+ num_x_plots = 1
+ fig = plt.figure(figsize=set_size(text_width, width_ratio, [num_y_plots, num_x_plots]), dpi=dpi)
+ gs_base = gridspec.GridSpec(num_y_plots, num_x_plots, wspace=gs0_wspace)
+ if log:
+ curve_ax = fig.add_subplot(gs_base[0], projection='3d')
+ curve_ax.minorticks_off()
+ x_mesh, y_mesh = np.meshgrid(*contour_pts)
+ curve_ax.set_zlim(0, 1)
+ curve_ax.set_xlim3d(5, 200)
+ curve_ax.grid(False)
+ curve_ax.set_xticklabels([])
+ curve_ax.set_yticklabels([])
+ curve_ax.set_zticklabels([])
+ curve_ax.zaxis.set_rotate_label(False)
+ if scatter:
+ curve_ax.scatter(x_mesh, y_mesh, activity, color=mesh_color, s=0.01)
+ else:
+ curve_ax.plot_wireframe(x_mesh, y_mesh, activity, rcount=100, ccount=100, color=mesh_color, zorder=1,
+ linestyles='dotted', linewidths=0.3, alpha=1.0)
+ # Plane vector visualizations
+ v = Arrow3D([-200/3., -200/3.], [200/2., 200/2.+200/16.],
+ [0, 0.0], mutation_scale=10,
+ lw=0.5, arrowstyle='-|>', color='red', linestyle='dashed')
+ curve_ax.add_artist(v)
+ curve_ax.text(-300/3., 280/3.0, 0.0, r'$\nu$', color='red')
+ phi_k = Arrow3D([-200/3., 0.], [200/2., 200/2.],
+ [0, 0.0], mutation_scale=10,
+ lw=1, arrowstyle='-|>', color='red', linestyle = 'dashed')
+ curve_ax.add_artist(phi_k)
+ curve_ax.text(-175/3., 250/3.0, 0.0, r'${\phi}_{k}$', color='red')
+ # Iso-response curve
+ loc0, loc1, loc2 = contour_text_loc[0]
+ curve_ax.text(loc0, loc1, loc2, 'Iso-\nresponse', color='black', weight='bold', zorder=10)
+ lines = np.array([0.2, 0.203, 0.197]) - 0.1
+ for i in lines:
+ curve_ax.contour3D(x_mesh, y_mesh, activity, [i], colors='black', linewidths=2, zorder=2)
+ # Response attenuation curve
+ loc0, loc1, loc2 = contour_text_loc[1]
+ curve_ax.text(loc0, loc1, loc2, 'Response\nAttenuation', color='black', weight='bold', zorder=10)
+ att_line_offset = 165
+ x, y = contour_pts
+ curve_ax.plot(np.zeros_like(x)+att_line_offset, y, activity[:, att_line_offset],
+ color='black', lw=2, zorder=2)
+ # Activity label
+ #loc0, loc1, loc2 = contour_text_loc[2]
+ #curve_ax.text(loc0, loc1, loc2, 'Activity', color='black', weight='bold', zorder=10, zdir='z')
+ # Additional settings
+ curve_ax.view_init(view_elevation, contour_angle)
+ scaling = np.array([getattr(curve_ax, 'get_{}lim'.format(dim))() for dim in 'xyz'])
+ curve_ax.auto_scale_xyz(*[[np.min(scaling), np.max(scaling)]]*3) # square aspect
+ curve_ax._axis3don = False
+ gs_base_idx = 1 if log else 0
+ # Histogram plots
+ num_hist_y_plots = 2
+ num_hist_x_plots = 2
+ gs_hist = gridspec.GridSpecFromSubplotSpec(num_hist_y_plots, num_hist_x_plots, gs_base[gs_base_idx],
+ hspace=hspace_hist, wspace=wspace_hist)
+ orig_ax = fig.add_subplot(gs_hist[0,0])
+ axes = []
+ for sub_plt_y in range(0, num_hist_y_plots):
+ axes.append([])
+ for sub_plt_x in range(0, num_hist_x_plots):
+ if (sub_plt_x, sub_plt_y) == (0,0):
+ axes[sub_plt_y].append(orig_ax)
+ else:
+ axes[sub_plt_y].append(fig.add_subplot(gs_hist[sub_plt_y, sub_plt_x], sharey=orig_ax))
+ all_x_lists = zip(hist_list, label_list, color_list, bin_centers, title)
+ for axis_x, (curvature_hist, sub_label, sub_color, sub_bins, sub_title) in enumerate(all_x_lists):
+ sub_bins = np.squeeze(sub_bins)
+ all_y_lists = zip(curvature_hist, sub_label, sub_color, xlabel)
+ for axis_y, (dataset_hist, axis_labels, axis_colors, sub_xlabel) in enumerate(all_y_lists):
+ axes[axis_y][axis_x].spines['top'].set_visible(False)
+ axes[axis_y][axis_x].spines['right'].set_visible(False)
+ axes[axis_y][axis_x].set_xticks(sub_bins, minor=True)
+ axes[axis_y][axis_x].set_xticks(sub_bins[::int(len(sub_bins)/4)], minor=False)
+ axes[axis_y][axis_x].xaxis.set_major_formatter(plticker.FormatStrFormatter('%0.3f'))
+ for neuron_hist, label, color in zip(dataset_hist, axis_labels, axis_colors):
+ neuron_hist = np.squeeze(neuron_hist)
+ if log:
+ axes[axis_y][axis_x].semilogy(sub_bins, neuron_hist, color=color, linestyle='-',
+ drawstyle='steps-mid', label=label)
+ axes[axis_y][axis_x].yaxis.set_major_formatter(matplotlib.ticker.LogFormatterSciNotation())
+ else:
+ axes[axis_y][axis_x].plot(sub_bins, neuron_hist, color=color, linestyle='-', drawstyle='steps-mid', label=label)
+ axes[axis_y][axis_x].axvline(0.0, color='black', linestyle='dashed', linewidth=1)
+ if axis_y == 0:
+ axes[axis_y][axis_x].set_title(sub_title)
+ axes[axis_y][axis_x].set_xlabel(sub_xlabel)
+ if axis_x == 0:
+ if log:
+ axes[axis_y][axis_x].set_ylabel('Relative\nLog Frequency')
+ else:
+ axes[axis_y][axis_x].set_ylabel('Relative\nFrequency')
+ ax_handles, ax_labels = axes[axis_y][axis_x].get_legend_handles_labels()
+ legend = axes[axis_y][axis_x].legend(handles=ax_handles, labels=ax_labels, loc='upper right',
+ ncol=3, borderaxespad=0., borderpad=0., handlelength=0., columnspacing=-0.5,
+ labelspacing=0., bbox_to_anchor=(0.95, 0.95))
+ legend.get_frame().set_linewidth(0.0)
+ for text, color in zip(legend.get_texts(), axis_colors):
+ text.set_color(color)
+ for item in legend.legendHandles:
+ item.set_visible(False)
+ if axis_x == 1:
+ axes[axis_y][axis_x].tick_params(axis='y', labelleft=False)
+ plt.show()
+ return fig
+
+
def plot_group_iso_contours(analyzer_list, neuron_indices, orth_indices, num_levels, x_range, y_range, show_contours=True,
curvature=None, text_width=200, width_fraction=1.0, dpi=100):
arrow_width = 0.0
@@ -421,140 +523,142 @@ def compute_curvature_hists(analyzer_list, num_bins):
rand_hist, _ = np.histogram(flat_rand_curvatures, attn_bins, density=False)
analyzer.attn_rand_hist = rand_hist / len(flat_rand_curvatures)
-def plot_curvature_histograms(activity, contour_pts, contour_angle, contour_text_loc, hist_list, label_list,
- color_list, mesh_color, bin_centers, title, xlabel, curve_lims, log=True,
- text_width=200, width_ratio=1.0, dpi=100):
- gs0_wspace = 0.5
- hspace_hist = 0.7
- wspace_hist = 0.08
- view_elevation = 30
- iso_response_line_thickness = 2
- respone_attenuation_line_thickness = 2
- num_y_plots = 2
- num_x_plots = 1
- fig = plt.figure(figsize=set_size(text_width, width_ratio, [num_y_plots, num_x_plots]), dpi=dpi)
- gs_base = gridspec.GridSpec(num_y_plots, num_x_plots, wspace=gs0_wspace)
- curve_ax = fig.add_subplot(gs_base[0], projection='3d')
- x_mesh, y_mesh = np.meshgrid(*contour_pts)
- curve_ax.set_zlim(0, 1)
- curve_ax.set_xlim3d(5, 200)
- curve_ax.grid(b=False, zorder=0)
- x_ticks = curve_ax.get_xticks().tolist()
- x_ticks = np.round(np.linspace(curve_lims['x'][0], curve_lims['x'][1],
- len(x_ticks)), 1).astype(str)
- a_x = [' ']*len(x_ticks)
- a_x[1] = x_ticks[1]
- a_x[-1] = x_ticks[-1]
- curve_ax.set_xticklabels(a_x)
- y_ticks = curve_ax.get_yticks().tolist()
- y_ticks = np.round(np.linspace(curve_lims['y'][0], curve_lims['y'][1],
- len(y_ticks)), 1).astype(str)
- a_y = [' ']*len(y_ticks)
- a_y[1] = y_ticks[1]
- a_y[-1] = y_ticks[-1]
- curve_ax.set_yticklabels(a_y)
- curve_ax.set_zticklabels([])
- curve_ax.zaxis.set_rotate_label(False)
- curve_ax.set_zlabel('Normalized\nActivity', rotation=95, labelpad=-12., position=(-10., 0.))
- #curve_ax.scatter(x_mesh, y_mesh, activity, color=mesh_color, s=0.01)
- curve_ax.plot_wireframe(x_mesh, y_mesh, activity, rcount=100, ccount=100, color=mesh_color, zorder=1,
- linestyles='dotted', linewidths=0.5, alpha=1.0)
- # Plane vector visualizations
- v = Arrow3D([-200/3., -200/3.], [200/2., 200/2.+200/16.],
- [0, 0.0], mutation_scale=10,
- lw=1, arrowstyle='-|>', color='red', linestyle='dashed')
- curve_ax.add_artist(v)
- curve_ax.text(-300/3., 280/3.0, 0.0, r'$\nu$', color='red')
- phi_k = Arrow3D([-200/3., 0.], [200/2., 200/2.],
- [0, 0.0], mutation_scale=10,
- lw=1, arrowstyle='-|>', color='red', linestyle = 'dashed')
- curve_ax.add_artist(phi_k)
- curve_ax.text(-175/3., 250/3.0, 0.0, r'${\phi}_{k}$', color='red')
- # Iso-response curve
- loc0, loc1, loc2 = contour_text_loc[0]
- curve_ax.text(loc0, loc1, loc2, 'Iso-\nresponse', color='black', weight='bold', zorder=10)
- iso_line_offset = 165
- x, y = contour_pts
- curve_ax.plot(np.zeros_like(x)+iso_line_offset, y, activity[:, iso_line_offset],
- color='black', lw=2, zorder=2)
- # Response attenuation curve
- loc0, loc1, loc2 = contour_text_loc[1]
- curve_ax.text(loc0, loc1, loc2, 'Response\nAttenuation', color='black', weight='bold', zorder=10)
- lines = np.array([0.2, 0.203, 0.197]) - 0.1
- for i in lines:
- curve_ax.contour3D(x_mesh, y_mesh, activity, [i], colors='black', linewidths=2, zorder=2)
- # Additional settings
- curve_ax.view_init(view_elevation, contour_angle)
- scaling = np.array([getattr(curve_ax, 'get_{}lim'.format(dim))() for dim in 'xyz'])
- curve_ax.auto_scale_xyz(*[[np.min(scaling), np.max(scaling)]]*3) # make sure it has a square aspect
- num_hist_y_plots = 2
- num_hist_x_plots = 2
- gs_hist = gridspec.GridSpecFromSubplotSpec(num_hist_y_plots, num_hist_x_plots, gs_base[1],
- hspace=hspace_hist, wspace=wspace_hist)
- orig_ax = fig.add_subplot(gs_hist[0,0])
- axes = []
- for sub_plt_y in range(0, num_hist_y_plots):
- axes.append([])
- for sub_plt_x in range(0, num_hist_x_plots):
- if (sub_plt_x, sub_plt_y) == (0,0):
- axes[sub_plt_y].append(orig_ax)
- else:
- axes[sub_plt_y].append(fig.add_subplot(gs_hist[sub_plt_y, sub_plt_x], sharey=orig_ax))
- #[curvature type] [iso/att]
- #[dataset type] [comp/rand]
- #[target neuron id]
- all_x_lists = zip(hist_list, label_list, color_list, bin_centers, title)
- for axis_x, (curvature_hist, sub_label, sub_color, sub_bins, sub_title) in enumerate(all_x_lists):
- sub_bins = np.squeeze(sub_bins)
- #max_hist_val = 0.001
- #min_hist_val = 100
- all_y_lists = zip(curvature_hist, sub_label, sub_color, xlabel)
- for axis_y, (dataset_hist, axis_labels, axis_colors, sub_xlabel) in enumerate(all_y_lists):
- axes[axis_y][axis_x].spines['top'].set_visible(False)
- axes[axis_y][axis_x].spines['right'].set_visible(False)
- axes[axis_y][axis_x].set_xticks(sub_bins, minor=True)
- axes[axis_y][axis_x].set_xticks(sub_bins[::int(len(sub_bins)/4)], minor=False)
- axes[axis_y][axis_x].xaxis.set_major_formatter(plticker.FormatStrFormatter('%0.3f'))
- for neuron_hist, label, color in zip(dataset_hist, axis_labels, axis_colors):
- neuron_hist = np.squeeze(neuron_hist)
- plot_hist = []
-
- if log:
- axes[axis_y][axis_x].semilogy(sub_bins, neuron_hist, color=color, linestyle='-', drawstyle='steps-mid', label=label)
- #axes[axis_y][axis_x].set_yscale('log')
- axes[axis_y][axis_x].yaxis.set_major_formatter(
- ticker.FuncFormatter(
- lambda y,pos: ('{{:.{:1d}f}}'.format(int(np.maximum(-np.log10(y),0)))).format(y)
- )
- )
- else:
- axes[axis_y][axis_x].plot(sub_bins, neuron_hist, color=color, linestyle='-', drawstyle='steps-mid', label=label)
- #if np.max(hist) > max_hist_val:
- # max_hist_val = np.max(hist)
- #if np.min(hist) < min_hist_val:
- # min_hist_val = np.min(hist)
- axes[axis_y][axis_x].axvline(0.0, color='black', linestyle='dashed', linewidth=1)
- if axis_y == 0:
- axes[axis_y][axis_x].set_title(sub_title)
- axes[axis_y][axis_x].set_xlabel(sub_xlabel)
- if axis_x == 0:
- if log:
- axes[axis_y][axis_x].set_ylabel('Relative\nLog Frequency')
- else:
- axes[axis_y][axis_x].set_ylabel('Relative\nFrequency')
- ax_handles, ax_labels = axes[axis_y][axis_x].get_legend_handles_labels()
- legend = axes[axis_y][axis_x].legend(handles=ax_handles, labels=ax_labels, loc='upper right',
- ncol=3, borderaxespad=0., borderpad=0., handlelength=0., columnspacing=-0.5,
- labelspacing=0., bbox_to_anchor=(0.95, 0.95))
- legend.get_frame().set_linewidth(0.0)
- for text, color in zip(legend.get_texts(), axis_colors):
- text.set_color(color)
- for item in legend.legendHandles:
- item.set_visible(False)
- if axis_x == 1:
- axes[axis_y][axis_x].tick_params(axis='y', labelleft=False)
- plt.show()
- return fig
+
+#def plot_curvature_histograms(activity, contour_pts, contour_angle, contour_text_loc, hist_list, label_list,
+# color_list, mesh_color, bin_centers, title, xlabel, curve_lims, log=True,
+# text_width=200, width_ratio=1.0, dpi=100):
+# gs0_wspace = 0.5
+# hspace_hist = 0.7
+# wspace_hist = 0.08
+# view_elevation = 30
+# iso_response_line_thickness = 2
+# respone_attenuation_line_thickness = 2
+# num_y_plots = 2
+# num_x_plots = 1
+# fig = plt.figure(figsize=set_size(text_width, width_ratio, [num_y_plots, num_x_plots]), dpi=dpi)
+# gs_base = gridspec.GridSpec(num_y_plots, num_x_plots, wspace=gs0_wspace)
+# curve_ax = fig.add_subplot(gs_base[0], projection='3d')
+# x_mesh, y_mesh = np.meshgrid(*contour_pts)
+# curve_ax.set_zlim(0, 1)
+# curve_ax.set_xlim3d(5, 200)
+# curve_ax.grid(b=False, zorder=0)
+# x_ticks = curve_ax.get_xticks().tolist()
+# x_ticks = np.round(np.linspace(curve_lims['x'][0], curve_lims['x'][1],
+# len(x_ticks)), 1).astype(str)
+# a_x = [' ']*len(x_ticks)
+# a_x[1] = x_ticks[1]
+# a_x[-1] = x_ticks[-1]
+# curve_ax.set_xticklabels(a_x)
+# y_ticks = curve_ax.get_yticks().tolist()
+# y_ticks = np.round(np.linspace(curve_lims['y'][0], curve_lims['y'][1],
+# len(y_ticks)), 1).astype(str)
+# a_y = [' ']*len(y_ticks)
+# a_y[1] = y_ticks[1]
+# a_y[-1] = y_ticks[-1]
+# curve_ax.set_yticklabels(a_y)
+# curve_ax.set_zticklabels([])
+# curve_ax.zaxis.set_rotate_label(False)
+# curve_ax.set_zlabel('Normalized\nActivity', rotation=95, labelpad=-12., position=(-10., 0.))
+# #curve_ax.scatter(x_mesh, y_mesh, activity, color=mesh_color, s=0.01)
+# curve_ax.plot_wireframe(x_mesh, y_mesh, activity, rcount=100, ccount=100, color=mesh_color, zorder=1,
+# linestyles='dotted', linewidths=0.5, alpha=1.0)
+# # Plane vector visualizations
+# v = Arrow3D([-200/3., -200/3.], [200/2., 200/2.+200/16.],
+# [0, 0.0], mutation_scale=10,
+# lw=1, arrowstyle='-|>', color='red', linestyle='dashed')
+# curve_ax.add_artist(v)
+# curve_ax.text(-300/3., 280/3.0, 0.0, r'$\nu$', color='red')
+# phi_k = Arrow3D([-200/3., 0.], [200/2., 200/2.],
+# [0, 0.0], mutation_scale=10,
+# lw=1, arrowstyle='-|>', color='red', linestyle = 'dashed')
+# curve_ax.add_artist(phi_k)
+# curve_ax.text(-175/3., 250/3.0, 0.0, r'${\phi}_{k}$', color='red')
+# # Iso-response curve
+# loc0, loc1, loc2 = contour_text_loc[0]
+# curve_ax.text(loc0, loc1, loc2, 'Iso-\nresponse', color='black', weight='bold', zorder=10)
+# iso_line_offset = 165
+# x, y = contour_pts
+# curve_ax.plot(np.zeros_like(x)+iso_line_offset, y, activity[:, iso_line_offset],
+# color='black', lw=2, zorder=2)
+# # Response attenuation curve
+# loc0, loc1, loc2 = contour_text_loc[1]
+# curve_ax.text(loc0, loc1, loc2, 'Response\nAttenuation', color='black', weight='bold', zorder=10)
+# lines = np.array([0.2, 0.203, 0.197]) - 0.1
+# for i in lines:
+# curve_ax.contour3D(x_mesh, y_mesh, activity, [i], colors='black', linewidths=2, zorder=2)
+# # Additional settings
+# curve_ax.view_init(view_elevation, contour_angle)
+# scaling = np.array([getattr(curve_ax, 'get_{}lim'.format(dim))() for dim in 'xyz'])
+# curve_ax.auto_scale_xyz(*[[np.min(scaling), np.max(scaling)]]*3) # make sure it has a square aspect
+# num_hist_y_plots = 2
+# num_hist_x_plots = 2
+# gs_hist = gridspec.GridSpecFromSubplotSpec(num_hist_y_plots, num_hist_x_plots, gs_base[1],
+# hspace=hspace_hist, wspace=wspace_hist)
+# orig_ax = fig.add_subplot(gs_hist[0,0])
+# axes = []
+# for sub_plt_y in range(0, num_hist_y_plots):
+# axes.append([])
+# for sub_plt_x in range(0, num_hist_x_plots):
+# if (sub_plt_x, sub_plt_y) == (0,0):
+# axes[sub_plt_y].append(orig_ax)
+# else:
+# axes[sub_plt_y].append(fig.add_subplot(gs_hist[sub_plt_y, sub_plt_x], sharey=orig_ax))
+# #[curvature type] [iso/att]
+# #[dataset type] [comp/rand]
+# #[target neuron id]
+# all_x_lists = zip(hist_list, label_list, color_list, bin_centers, title)
+# for axis_x, (curvature_hist, sub_label, sub_color, sub_bins, sub_title) in enumerate(all_x_lists):
+# sub_bins = np.squeeze(sub_bins)
+# #max_hist_val = 0.001
+# #min_hist_val = 100
+# all_y_lists = zip(curvature_hist, sub_label, sub_color, xlabel)
+# for axis_y, (dataset_hist, axis_labels, axis_colors, sub_xlabel) in enumerate(all_y_lists):
+# axes[axis_y][axis_x].spines['top'].set_visible(False)
+# axes[axis_y][axis_x].spines['right'].set_visible(False)
+# axes[axis_y][axis_x].set_xticks(sub_bins, minor=True)
+# axes[axis_y][axis_x].set_xticks(sub_bins[::int(len(sub_bins)/4)], minor=False)
+# axes[axis_y][axis_x].xaxis.set_major_formatter(plticker.FormatStrFormatter('%0.3f'))
+# for neuron_hist, label, color in zip(dataset_hist, axis_labels, axis_colors):
+# neuron_hist = np.squeeze(neuron_hist)
+# plot_hist = []
+#
+# if log:
+# axes[axis_y][axis_x].semilogy(sub_bins, neuron_hist, color=color, linestyle='-', drawstyle='steps-mid', label=label)
+# #axes[axis_y][axis_x].set_yscale('log')
+# axes[axis_y][axis_x].yaxis.set_major_formatter(
+# ticker.FuncFormatter(
+# lambda y,pos: ('{{:.{:1d}f}}'.format(int(np.maximum(-np.log10(y),0)))).format(y)
+# )
+# )
+# else:
+# axes[axis_y][axis_x].plot(sub_bins, neuron_hist, color=color, linestyle='-', drawstyle='steps-mid', label=label)
+# #if np.max(hist) > max_hist_val:
+# # max_hist_val = np.max(hist)
+# #if np.min(hist) < min_hist_val:
+# # min_hist_val = np.min(hist)
+# axes[axis_y][axis_x].axvline(0.0, color='black', linestyle='dashed', linewidth=1)
+# if axis_y == 0:
+# axes[axis_y][axis_x].set_title(sub_title)
+# axes[axis_y][axis_x].set_xlabel(sub_xlabel)
+# if axis_x == 0:
+# if log:
+# axes[axis_y][axis_x].set_ylabel('Relative\nLog Frequency')
+# else:
+# axes[axis_y][axis_x].set_ylabel('Relative\nFrequency')
+# ax_handles, ax_labels = axes[axis_y][axis_x].get_legend_handles_labels()
+# legend = axes[axis_y][axis_x].legend(handles=ax_handles, labels=ax_labels, loc='upper right',
+# ncol=3, borderaxespad=0., borderpad=0., handlelength=0., columnspacing=-0.5,
+# labelspacing=0., bbox_to_anchor=(0.95, 0.95))
+# legend.get_frame().set_linewidth(0.0)
+# for text, color in zip(legend.get_texts(), axis_colors):
+# text.set_color(color)
+# for item in legend.legendHandles:
+# item.set_visible(False)
+# if axis_x == 1:
+# axes[axis_y][axis_x].tick_params(axis='y', labelleft=False)
+# plt.show()
+# return fig
+
def plot_contrast_orientation_tuning(bf_indices, contrasts, orientations, activations, figsize=(32,32)):
'''
diff --git a/tf1x/utils/logger.py b/tf1x/utils/logger.py
index c527e488..edfb404e 100644
--- a/tf1x/utils/logger.py
+++ b/tf1x/utils/logger.py
@@ -90,7 +90,13 @@ def read_js(self, tokens, text):
assert len(tokens) == 2, ("Input variable tokens must be a list of length 2")
matches = re.findall(re.escape(tokens[0])+"([\s\S]*?)"+re.escape(tokens[1]), text)
if len(matches) > 1:
- js_matches = [js.loads(match) for match in matches]
+ js_matches = []
+ for match_idx, match in enumerate(matches):
+ try:
+ js_matches.append(js.loads(match))
+ except:
+ print(f'ERROR: JSON load failed on match index {match_idx}')
+ import IPython; IPython.embed(); raise SystemExit
else:
js_matches = [js.loads(matches[0])]
return js_matches
diff --git a/tf1x/vis/JOV_Euler_Attacks.ipynb b/tf1x/vis/JOV_Euler_Attacks.ipynb
index efef9419..60a7d42c 100644
--- a/tf1x/vis/JOV_Euler_Attacks.ipynb
+++ b/tf1x/vis/JOV_Euler_Attacks.ipynb
@@ -32,11 +32,13 @@
"outputs": [],
"source": [
"import autograd.numpy as np\n",
+ "import matplotlib as mpl\n",
"import matplotlib.pyplot as plt\n",
"\n",
"root_path = os.path.dirname(os.path.dirname(os.path.dirname(os.getcwd())))\n",
"if root_path not in sys.path: sys.path.append(root_path)\n",
"\n",
+ "import DeepSparseCoding.tf1x.utils.jov_funcs as jov\n",
"import DeepSparseCoding.tf1x.analysis.analysis_picker as ap\n",
"from DeepSparseCoding.tf1x.data.dataset import Dataset\n",
"import DeepSparseCoding.tf1x.utils.data_processing as dp"
@@ -65,7 +67,8 @@
" import schematic_utils\n",
"except ImportError:\n",
" import sys\n",
- " sys.path.append(\"../schematic_figure/\")\n",
+ " usr = os.path.expanduser('~')\n",
+ " sys.path.append(usr+'/Work/DeepSparseCoding/tf1x/')\n",
" import schematic_utils"
]
},
@@ -75,9 +78,30 @@
"metadata": {},
"outputs": [],
"source": [
- "figsize = (16, 16)\n",
+ "text_width = 540.60236 #pt 416.83269 #pt = 14.65cm\n",
+ "text_width_cm = 18.9973 # 14.705\n",
+ "fontsize = 10\n",
+ "dpi = 300\n",
+ "file_extensions = ['.pdf']#, '.eps', '.png']\n",
+ "#figsize = (16, 16)\n",
+ "num_y_plots = 3\n",
+ "num_x_plots = 1\n",
+ "width_ratio = 1.0\n",
+ "figsize = jov.set_size(text_width, width_ratio, [num_y_plots, num_x_plots])\n",
"fontsize = 20\n",
- "dpi = 200"
+ "font_settings = {\n",
+ " \"text.usetex\": True,\n",
+ " \"font.family\": 'serif',\n",
+ " \"font.serif\": 'Computer Modern Roman',\n",
+ " \"axes.labelsize\": fontsize,\n",
+ " \"axes.titlesize\": fontsize,\n",
+ " \"figure.titlesize\": fontsize+2,\n",
+ " \"font.size\": fontsize,\n",
+ " \"legend.fontsize\": fontsize,\n",
+ " \"xtick.labelsize\": fontsize-2,\n",
+ " \"ytick.labelsize\": fontsize-2,\n",
+ "}\n",
+ "mpl.rcParams.update(font_settings)"
]
},
{
@@ -364,7 +388,8 @@
"metadata": {},
"outputs": [],
"source": [
- "f = plt.figure(figsize=(2*figsize[0],figsize[1]), dpi=dpi)\n",
+ "figsize = (2*16,16)#(figsize[0], figsize[1])\n",
+ "f = plt.figure(figsize=figsize, dpi=dpi)\n",
"fig_shape = (1, 4)\n",
"\n",
"mlp_ax = plt.subplot2grid(fig_shape, loc=(0, 0), colspan=1, fig=f)\n",
@@ -404,11 +429,57 @@
"metadata": {},
"outputs": [],
"source": [
- "for ext in [\".png\", \".eps\"]:\n",
- " save_name = (analyzer.analysis_out_dir+\"/vis/contours_and_gradients_schematic\"\n",
+ "for ext in file_extensions:#[\".png\", \".eps\"]:\n",
+ " save_name = (analyzer.analysis_out_dir+\"/vis/contours_and_gradients_schematic_new\"\n",
" +\"_\"+analyzer.analysis_params.save_info+ext)\n",
" f.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.01, dpi=dpi)"
]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig_size_inches = f.get_size_inches()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "fig_size_inches"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "f.set_size_inches(fig_size_inches/2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for ext in file_extensions:#[\".png\", \".eps\"]:\n",
+ " save_name = (analyzer.analysis_out_dir+\"/vis/contours_and_gradients_schematic_small\"\n",
+ " +\"_\"+analyzer.analysis_params.save_info+ext)\n",
+ " f.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.01, dpi=dpi)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
}
],
"metadata": {
diff --git a/tf1x/vis/JOV_figs.ipynb b/tf1x/vis/JOV_figs.ipynb
index 7c92c4ca..76111b37 100644
--- a/tf1x/vis/JOV_figs.ipynb
+++ b/tf1x/vis/JOV_figs.ipynb
@@ -85,10 +85,16 @@
"metadata": {},
"outputs": [],
"source": [
- "text_width = 416.83269 #pt = 14.65cm\n",
- "text_width_cm = 14.705\n",
- "fontsize = 12\n",
- "dpi = 1200"
+ "\"\"\"\n",
+ "textwidth in pt: 540.60236pt\n",
+ "textwidth in cm: 18.9973cm\n",
+ "textwidth in in: 7.48178in\n",
+ "\"\"\"\n",
+ "text_width = 540.60236 #pt 416.83269 #pt = 14.65cm\n",
+ "text_width_cm = 18.9973 # 14.705\n",
+ "fontsize = 10\n",
+ "dpi = 300\n",
+ "file_extensions = ['.pdf']#, '.eps', '.png']"
]
},
{
@@ -99,10 +105,11 @@
"source": [
"font_settings = {\n",
" \"text.usetex\": True,\n",
- " \"font.family\": \"serif\",\n",
+ " \"font.family\": 'serif',\n",
+ " \"font.serif\": 'Computer Modern Roman',\n",
" \"axes.labelsize\": fontsize,\n",
" \"axes.titlesize\": fontsize,\n",
- " \"figure.titlesize\": fontsize,\n",
+ " \"figure.titlesize\": fontsize+2,\n",
" \"font.size\": fontsize,\n",
" \"legend.fontsize\": fontsize,\n",
" \"xtick.labelsize\": fontsize-2,\n",
@@ -305,18 +312,8 @@
"for analyzer, save_name in zip(analyzer_list, save_names):\n",
" analyzer.iso_params = np.load(analyzer.analysis_out_dir+'savefiles/iso_params_'+save_name\n",
" +analyzer.analysis_params.save_info+'.npz', allow_pickle=True)['data'].item()\n",
- " #min_angle = analyzer.iso_params['min_angle']\n",
- " #batch_size = analyzer.iso_params['batch_size']\n",
- " #vh_image_scale = analyzer.iso_params['vh_image_scale']\n",
- " #comparison_method = analyzer.iso_params['comparison_method']\n",
- " #num_neurons = analyzer.iso_params['num_neurons']\n",
- " #analyzer.num_comparison_vectors = analyzer.iso_params['num_comparisons']\n",
" x_range = analyzer.iso_params['x_range']\n",
" y_range = analyzer.iso_params['y_range']\n",
- " #num_images = analyzer.iso_params['num_images']\n",
- " #params_list = analyzer.iso_params['params_list']\n",
- " #iso_save_name = analyzer.iso_params['iso_save_name']\n",
- " #target_neuron_ids = analyzer.iso_params['target_neuron_ids']\n",
"\n",
" iso_vectors = np.load(analyzer.analysis_out_dir+'savefiles/iso_vectors_'+save_name\n",
" +analyzer.analysis_params.save_info+'.npz', allow_pickle=True)['data'].item()\n",
@@ -340,6 +337,28 @@
" analyzer = add_analyzer_keys(analyzer)"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "target_act = 0.5 # target activity spot between min & max value of normalized activity (btwn 0 and 1)\n",
+ "lca_activations = analyzer_list[-1].comp_activations\n",
+ "curvatures, fits, contours = ha.iso_response_curvature_poly_fits(\n",
+ " lca_activations,\n",
+ " target_act=target_act\n",
+ ")\n",
+ "max_comp_indices = []\n",
+ "max_vals = []\n",
+ "for target_neuron_id in range(len(curvatures)):\n",
+ " max_idx = np.argmax(curvatures[target_neuron_id])\n",
+ " max_comp_indices.append(max_idx)\n",
+ " max_vals.append(curvatures[target_neuron_id][max_idx])\n",
+ "max_target_id = np.argmax(max_vals)\n",
+ "max_comparison_id = max_comp_indices[max_target_id]"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -363,27 +382,29 @@
"min_comparison_id = target_min_idx[min_target_id]\n",
"\n",
"# 8(.039), 17(.028) 23(.037), 25(.036), 41(.033), 48(.035), 49(0.039)\n",
- "neuron_indices = [0, 0, 0, min_target_id]\n",
- "orth_indices = [0, 0, 0, min_comparison_id]\n",
- "target_act = 0.5 # target activity spot between min & max value of normalized activity (btwn 0 and 1)\n",
+ "neuron_indices = [0, 0, 0, max_target_id]#min_target_id]\n",
+ "orth_indices = [0, 0, 0, max_comparison_id]#min_comparison_id]\n",
"num_plots_y = 2\n",
"num_plots_x = 2\n",
"width_fraction = 1.0\n",
"show_contours = True\n",
"\n",
"lca_activations = analyzer_list[-1].comp_activations[neuron_indices[-1], orth_indices[-1], ...][None, None, ...]\n",
- "curvatures, fits = ha.iso_response_curvature_poly_fits(\n",
+ "curvatures, fits, contours = ha.iso_response_curvature_poly_fits(\n",
" lca_activations,\n",
- " target_act=target_act,\n",
- " measure_upper_right=False\n",
+ " target_act=target_act\n",
")\n",
"curvature = [None, None, None, curvatures[0][0]]\n",
"\n",
+ "#for analyzer in analyzer_list:\n",
+ "# analyzer.comp_activations = analyzer.comp_activations - analyzer.comp_activations.min()\n",
+ "# analyzer.comp_activations = analyzer.comp_activations / analyzer.comp_activations.max()\n",
+ "\n",
"contour_fig, contour_handles = nc.plot_group_iso_contours(analyzer_list, neuron_indices, orth_indices,\n",
" num_levels, x_range, y_range, show_contours, curvature, text_width, width_fraction, dpi)\n",
"\n",
"for analyzer, neuron_index, orth_index, save_suffix in zip(analyzer_list, neuron_indices, orth_indices, save_names):\n",
- " for ext in [\".eps\"]:#[\".png\", \".eps\"]:\n",
+ " for ext in file_extensions:\n",
" neuron_str = str(analyzer.target_neuron_ids[neuron_index])\n",
" orth_str = str(analyzer.comparison_neuron_ids[neuron_index][orth_index])\n",
" save_name = analyzer.analysis_out_dir+\"/vis/iso_contour_comparison_\"\n",
@@ -432,11 +453,11 @@
" num_y,\n",
" show_contours,\n",
" text_width,\n",
- " width_fraction,\n",
+ " 1.00,\n",
" dpi\n",
")\n",
"\n",
- "for ext in [\".eps\"]:#[\".png\", \".eps\"]:\n",
+ "for ext in file_extensions:\n",
" save_name = analyzer.analysis_out_dir+\"/vis/scaled_iso_contours_set_\"\n",
" if not show_contours:\n",
" save_name += \"continuous_\"\n",
@@ -550,154 +571,6 @@
"full_xlabel = [\"Curvature (Comparison)\", \"Curvature (Random)\"]"
]
},
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "def plot_curvature_histograms(activity, contour_pts, contour_angle, view_elevation, contour_text_loc, hist_list,\n",
- " label_list, color_list, mesh_color, bin_centers, title, xlabel, curve_lims,\n",
- " scatter, log=True, text_width=200, width_ratio=1.0, dpi=100):\n",
- " gs0_wspace = 0.5\n",
- " hspace_hist = 0.7\n",
- " wspace_hist = 0.08\n",
- " iso_response_line_thickness = 2\n",
- " respone_attenuation_line_thickness = 2\n",
- " num_y_plots = 2\n",
- " num_x_plots = 1\n",
- " fig = plt.figure(figsize=nc.set_size(text_width, width_ratio, [num_y_plots, num_x_plots]), dpi=dpi)\n",
- " gs_base = gridspec.GridSpec(num_y_plots, num_x_plots, wspace=gs0_wspace)\n",
- " \n",
- " curve_ax = fig.add_subplot(gs_base[0], projection='3d')\n",
- " curve_ax.minorticks_off()\n",
- " x_mesh, y_mesh = np.meshgrid(*contour_pts)\n",
- " curve_ax.set_zlim(0, 1)\n",
- " curve_ax.set_xlim3d(5, 200)\n",
- " curve_ax.grid(False)\n",
- " #x_ticks = curve_ax.get_xticks().tolist()\n",
- " #x_ticks = np.round(np.linspace(curve_lims['x'][0], curve_lims['x'][1],\n",
- " # len(x_ticks)), 1).astype(str)\n",
- " #a_x = [' ']*len(x_ticks)\n",
- " #a_x[1] = x_ticks[1]\n",
- " #a_x[-1] = x_ticks[-1]\n",
- " curve_ax.set_xticklabels([])#a_x)\n",
- " #y_ticks = curve_ax.get_yticks().tolist()\n",
- " #y_ticks = np.round(np.linspace(curve_lims['y'][0], curve_lims['y'][1],\n",
- " # len(y_ticks)), 1).astype(str)\n",
- " #a_y = [' ']*len(y_ticks)\n",
- " #a_y[1] = y_ticks[1]\n",
- " #a_y[-1] = y_ticks[-1]\n",
- " curve_ax.set_yticklabels([])#a_y)\n",
- " curve_ax.set_zticklabels([])\n",
- " curve_ax.zaxis.set_rotate_label(False)\n",
- " #curve_ax.set_zlabel('Activity', rotation=95, labelpad=-15., position=(-10., 0.))\n",
- " if scatter:\n",
- " curve_ax.scatter(x_mesh, y_mesh, activity, color=mesh_color, s=0.01)\n",
- " else:\n",
- " curve_ax.plot_wireframe(x_mesh, y_mesh, activity, rcount=100, ccount=100, color=mesh_color, zorder=1,\n",
- " linestyles='dotted', linewidths=0.3, alpha=1.0)\n",
- " \n",
- " # Plane vector visualizations\n",
- " v = nc.Arrow3D([-200/3., -200/3.], [200/2., 200/2.+200/16.], \n",
- " [0, 0.0], mutation_scale=10, \n",
- " lw=0.5, arrowstyle='-|>', color='red', linestyle='dashed')\n",
- " curve_ax.add_artist(v)\n",
- " curve_ax.text(-300/3., 280/3.0, 0.0, r'$\\nu$', color='red')\n",
- " phi_k = nc.Arrow3D([-200/3., 0.], [200/2., 200/2.], \n",
- " [0, 0.0], mutation_scale=10, \n",
- " lw=1, arrowstyle='-|>', color='red', linestyle = 'dashed')\n",
- " curve_ax.add_artist(phi_k)\n",
- " curve_ax.text(-175/3., 250/3.0, 0.0, r'${\\phi}_{k}$', color='red')\n",
- " \n",
- " # Iso-response curve\n",
- " loc0, loc1, loc2 = contour_text_loc[0]\n",
- " curve_ax.text(loc0, loc1, loc2, 'Iso-\\nresponse', color='black', weight='bold', zorder=10)\n",
- " lines = np.array([0.2, 0.203, 0.197]) - 0.1\n",
- " for i in lines:\n",
- " curve_ax.contour3D(x_mesh, y_mesh, activity, [i], colors='black', linewidths=2, zorder=2)\n",
- " \n",
- " # Response attenuation curve\n",
- " loc0, loc1, loc2 = contour_text_loc[1]\n",
- " curve_ax.text(loc0, loc1, loc2, 'Response\\nAttenuation', color='black', weight='bold', zorder=10)\n",
- " att_line_offset = 165\n",
- " x, y = contour_pts\n",
- " curve_ax.plot(np.zeros_like(x)+att_line_offset, y, activity[:, att_line_offset],\n",
- " color='black', lw=2, zorder=2)\n",
- " \n",
- " # Activity label\n",
- " #loc0, loc1, loc2 = contour_text_loc[2]\n",
- " #curve_ax.text(loc0, loc1, loc2, 'Activity', color='black', weight='bold', zorder=10, zdir='z')\n",
- " \n",
- " # Additional settings\n",
- " curve_ax.view_init(view_elevation, contour_angle)\n",
- " scaling = np.array([getattr(curve_ax, 'get_{}lim'.format(dim))() for dim in 'xyz'])\n",
- " curve_ax.auto_scale_xyz(*[[np.min(scaling), np.max(scaling)]]*3) # square aspect\n",
- " #curve_ax.get_xaxis().set_visible(False)\n",
- " #curve_ax.get_yaxis().set_visible(False)\n",
- " curve_ax._axis3don = False\n",
- " #y_aspect = 2\n",
- " #scale_x = [np.min(scaling), np.max(scaling)]\n",
- " #scale_y = [y_aspect*np.min(scaling), y_aspect*np.max(scaling)]\n",
- " #scale_z = [np.min(scaling), np.max(scaling)]\n",
- " #curve_ax.auto_scale_xyz(scale_x, scale_y, scale_z)\n",
- " \n",
- " # Histogram plots\n",
- " num_hist_y_plots = 2\n",
- " num_hist_x_plots = 2\n",
- " gs_hist = gridspec.GridSpecFromSubplotSpec(num_hist_y_plots, num_hist_x_plots, gs_base[1],\n",
- " hspace=hspace_hist, wspace=wspace_hist)\n",
- " orig_ax = fig.add_subplot(gs_hist[0,0])\n",
- " axes = []\n",
- " for sub_plt_y in range(0, num_hist_y_plots):\n",
- " axes.append([])\n",
- " for sub_plt_x in range(0, num_hist_x_plots):\n",
- " if (sub_plt_x, sub_plt_y) == (0,0):\n",
- " axes[sub_plt_y].append(orig_ax)\n",
- " else:\n",
- " axes[sub_plt_y].append(fig.add_subplot(gs_hist[sub_plt_y, sub_plt_x], sharey=orig_ax))\n",
- " all_x_lists = zip(hist_list, label_list, color_list, bin_centers, title)\n",
- " for axis_x, (curvature_hist, sub_label, sub_color, sub_bins, sub_title) in enumerate(all_x_lists):\n",
- " sub_bins = np.squeeze(sub_bins)\n",
- " all_y_lists = zip(curvature_hist, sub_label, sub_color, xlabel)\n",
- " for axis_y, (dataset_hist, axis_labels, axis_colors, sub_xlabel) in enumerate(all_y_lists):\n",
- " axes[axis_y][axis_x].spines['top'].set_visible(False)\n",
- " axes[axis_y][axis_x].spines['right'].set_visible(False)\n",
- " axes[axis_y][axis_x].set_xticks(sub_bins, minor=True)\n",
- " axes[axis_y][axis_x].set_xticks(sub_bins[::int(len(sub_bins)/4)], minor=False)\n",
- " axes[axis_y][axis_x].xaxis.set_major_formatter(plticker.FormatStrFormatter('%0.3f'))\n",
- " for neuron_hist, label, color in zip(dataset_hist, axis_labels, axis_colors):\n",
- " neuron_hist = np.squeeze(neuron_hist)\n",
- " if log:\n",
- " axes[axis_y][axis_x].semilogy(sub_bins, neuron_hist, color=color, linestyle='-',\n",
- " drawstyle='steps-mid', label=label)\n",
- " axes[axis_y][axis_x].yaxis.set_major_formatter(matplotlib.ticker.LogFormatterSciNotation())\n",
- " else:\n",
- " axes[axis_y][axis_x].plot(sub_bins, neuron_hist, color=color, linestyle='-', drawstyle='steps-mid', label=label)\n",
- " axes[axis_y][axis_x].axvline(0.0, color='black', linestyle='dashed', linewidth=1)\n",
- " if axis_y == 0:\n",
- " axes[axis_y][axis_x].set_title(sub_title)\n",
- " axes[axis_y][axis_x].set_xlabel(sub_xlabel)\n",
- " if axis_x == 0:\n",
- " if log:\n",
- " axes[axis_y][axis_x].set_ylabel('Relative\\nLog Frequency')\n",
- " else:\n",
- " axes[axis_y][axis_x].set_ylabel('Relative\\nFrequency')\n",
- " ax_handles, ax_labels = axes[axis_y][axis_x].get_legend_handles_labels()\n",
- " legend = axes[axis_y][axis_x].legend(handles=ax_handles, labels=ax_labels, loc='upper right',\n",
- " ncol=3, borderaxespad=0., borderpad=0., handlelength=0., columnspacing=-0.5,\n",
- " labelspacing=0., bbox_to_anchor=(0.95, 0.95))\n",
- " legend.get_frame().set_linewidth(0.0)\n",
- " for text, color in zip(legend.get_texts(), axis_colors):\n",
- " text.set_color(color)\n",
- " for item in legend.legendHandles:\n",
- " item.set_visible(False)\n",
- " if axis_x == 1:\n",
- " axes[axis_y][axis_x].tick_params(axis='y', labelleft=False)\n",
- " plt.show()\n",
- " return fig"
- ]
- },
{
"cell_type": "code",
"execution_count": null,
@@ -718,12 +591,12 @@
"activity_loc = [-27, 150, 1.5]\n",
"contour_text_loc = [iso_resp_loc, resp_att_loc, activity_loc]\n",
"\n",
- "curvature_log_fig = plot_curvature_histograms(contour_activity, contour_pts, contour_angle, view_elevation, \n",
+ "curvature_log_fig = nc.plot_curvature_histograms(contour_activity, contour_pts, contour_angle, view_elevation, \n",
" contour_text_loc, full_hist_list, full_label_list, full_color_list, mesh_color, full_bin_centers,\n",
- " full_title, full_xlabel, curve_lims, scatter, log=True, text_width=text_width, width_ratio=1.0, dpi=dpi)\n",
+ " full_title, full_xlabel, curve_lims, scatter, log=True, text_width=text_width, width_ratio=0.75, dpi=dpi)\n",
"\n",
"for analyzer in analyzer_list:\n",
- " for ext in [\".pdf\"]:\n",
+ " for ext in file_extensions:\n",
" save_name = (analyzer.analysis_out_dir+\"/vis/\"+iso_save_name+\"curvatures_and_histograms_logy\"\n",
" +\"_\"+analyzer.analysis_params.save_info+ext)\n",
" curvature_log_fig.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.05, dpi=dpi)"
@@ -737,12 +610,12 @@
},
"outputs": [],
"source": [
- "curvature_lin_fig = plot_curvature_histograms(contour_activity, contour_pts, contour_angle, view_elevation,\n",
+ "curvature_lin_fig = nc.plot_curvature_histograms(contour_activity, contour_pts, contour_angle, view_elevation,\n",
" contour_text_loc, full_hist_list, full_label_list, full_color_list, mesh_color, full_bin_centers,\n",
- " full_title, full_xlabel, curve_lims, scatter, log=False, text_width=text_width, width_ratio=1.0, dpi=dpi)\n",
+ " full_title, full_xlabel, curve_lims, scatter, log=False, text_width=text_width, width_ratio=0.75, dpi=dpi)\n",
"\n",
"for analyzer in analyzer_list:\n",
- " for ext in [\".eps\"]:\n",
+ " for ext in file_extensions:\n",
" save_name = (analyzer.analysis_out_dir+\"/vis/\"+iso_save_name+\"curvatures_and_histograms_liny\"\n",
" +\"_\"+analyzer.analysis_params.save_info+ext)\n",
" curvature_lin_fig.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.05, dpi=dpi)"
@@ -820,10 +693,10 @@
"density = False\n",
"\n",
"circ_var_fig = nc.plot_circ_variance_histogram(analyzer_list, circ_var_list, color_list, label_list, num_bins,\n",
- " density, width_ratios, height_ratios, text_width=text_width, width_ratio=1.0, dpi=dpi)\n",
+ " density, width_ratios, height_ratios, text_width=text_width, width_ratio=0.75, dpi=dpi)\n",
"\n",
"for analyzer in analyzer_list:\n",
- " for ext in [\".png\", \".eps\"]:\n",
+ " for ext in file_extensions:\n",
" save_name = (analyzer.analysis_out_dir+\"/vis/circular_variance_combo\"\n",
" +\"_\"+analyzer.analysis_params.save_info+ext)\n",
" circ_var_fig.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.01, dpi=dpi)"
@@ -838,7 +711,7 @@
"spatial_frequencies = np.stack([np.array(analyzer.bf_stats[\"spatial_frequencies\"]) for analyzer in analyzer_list], axis=0)\n",
"circular_variances = np.stack([variance for variance in circ_var_list], axis=0)\n",
"\n",
- "cv_vs_sf_fig = plt.figure(figsize=nc.set_size(text_width), dpi=dpi)\n",
+ "cv_vs_sf_fig = plt.figure(figsize=nc.set_size(text_width, fraction=0.75), dpi=dpi)\n",
"ax = cv_vs_sf_fig.add_subplot()\n",
"for analyzer_idx in range(len(analyzer_list)):\n",
" ax.scatter(spatial_frequencies[analyzer_idx, :], circular_variances[analyzer_idx, :],\n",
@@ -858,7 +731,7 @@
"plt.show()\n",
"\n",
"for analyzer in analyzer_list:\n",
- " for ext in [\".eps\"]:\n",
+ " for ext in file_extensions:\n",
" save_name = (analyzer.analysis_out_dir+\"/vis/spatial_freq_vs_circular_variance\"\n",
" +\"_\"+analyzer.analysis_params.save_info+ext)\n",
" cv_vs_sf_fig.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.01, dpi=dpi)"
@@ -871,7 +744,7 @@
"outputs": [],
"source": [
"params_list = [lca_512_vh_params(), lca_1024_vh_params(), lca_2560_vh_params()]\n",
- "display_names = [\"512 Neurons\", \"768 Neurons\", \"1024 Neurons\"]#, \"2560 Neurons\"]\n",
+ "display_names = [\"512 Neurons\", \"1024 Neurons\", \"2560 Neurons\"]\n",
"for params, display_name in zip(params_list, display_names):\n",
" params.display_name = display_name\n",
" params.model_dir = (os.path.expanduser(\"~\")+\"/Work/Projects/\"+params.model_name)\n",
@@ -925,10 +798,10 @@
"density = True\n",
"\n",
"oc_vs_cv_fig = nc.plot_circ_variance_histogram(analyzer_list, circ_var_list, color_list, label_list, num_bins,\n",
- " density, width_ratios, height_ratios, text_width=text_width, width_ratio=1.0, dpi=dpi)\n",
+ " density, width_ratios, height_ratios, text_width=text_width, width_ratio=0.75, dpi=dpi)\n",
"\n",
"for analyzer in analyzer_list:\n",
- " for ext in [\".png\", \".eps\"]:\n",
+ " for ext in file_extensions:\n",
" save_name = (analyzer.analysis_out_dir+\"/vis/overcompleteness_vs_circular_variance\"\n",
" +\"_\"+analyzer.analysis_params.save_info+ext)\n",
" oc_vs_cv_fig.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.05, dpi=dpi)"
@@ -944,7 +817,9 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {},
+ "metadata": {
+ "scrolled": false
+ },
"outputs": [],
"source": [
"params_list = [lca_512_vh_params(), lca_768_vh_params(), lca_2560_vh_params()]\n",
@@ -970,6 +845,156 @@
" analyzer_list.append(analyzer)"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def closest_val_in_array(num, arr):\n",
+ " curr = arr[0]\n",
+ " for val in arr:\n",
+ " if abs(num - val) < abs(num - curr):\n",
+ " curr = val\n",
+ " curr_idx = np.argwhere(np.array(arr) == curr).item()\n",
+ " return arr[curr_idx]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "code_folding": [],
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "num_interesting_vals = [\n",
+ " np.array([analyzer.nat_selectivity['num_interesting_img_nl'],\n",
+ " analyzer.nat_selectivity['num_interesting_img_l']])\n",
+ " for analyzer in analyzer_list]\n",
+ "\n",
+ "num_interesting_medians = np.stack(\n",
+ " [np.array([np.median(np.array(analyzer.nat_selectivity['num_interesting_img_nl'])),\n",
+ " np.median(np.array(analyzer.nat_selectivity['num_interesting_img_l']))])\n",
+ " for analyzer in analyzer_list], axis=0)\n",
+ "\n",
+ "num_interesting_means = np.stack(\n",
+ " [np.array([analyzer.nat_selectivity['num_interesting_img_nl_mean'],\n",
+ " analyzer.nat_selectivity['num_interesting_img_l_mean']])\n",
+ " for analyzer in analyzer_list], axis=0)\n",
+ "\n",
+ "num_interesting_stds = np.stack(\n",
+ " [np.array([analyzer.nat_selectivity['num_interesting_img_nl_std'],\n",
+ " analyzer.nat_selectivity['num_interesting_img_l_std']])\n",
+ " for analyzer in analyzer_list], axis=0)\n",
+ "\n",
+ "array = [\n",
+ " [1, 2, 3],\n",
+ " [4, 5, 6],\n",
+ "]\n",
+ "\n",
+ "scale = 1\n",
+ "rc_kwargs = {\n",
+ " 'fontsize':scale*matplotlib.rcParams['font.size'],\n",
+ " 'fontfamily':scale*matplotlib.rcParams['font.family'],\n",
+ " 'legend.fontsize': scale*matplotlib.rcParams['font.size'],\n",
+ " 'text.labelsize': scale*matplotlib.rcParams['font.size']\n",
+ "}\n",
+ "figsize = nc.set_size(text_width, fraction=1.00)\n",
+ "with plot.rc.context(**rc_kwargs):\n",
+ " interesting_imgs_fig, axs = plot.subplots(array, sharey=False, sharex=False, aspect=3.0, figsize=figsize)\n",
+ " for ovc_idx, overcompleteness in enumerate(num_interesting_vals):\n",
+ " ax = axs[ovc_idx]\n",
+ " df = pd.DataFrame(\n",
+ " overcompleteness.T,\n",
+ " columns=pd.Index(['Sparse Coding', 'Linear'])#, name='xlabel')\n",
+ " )\n",
+ " box_parts = ax.boxplot(\n",
+ " df,\n",
+ " notch=True,\n",
+ " fill=False,\n",
+ " whis=(5, 95),\n",
+ " marker='*',\n",
+ " markersize=1.0,\n",
+ " lw=1.2\n",
+ " )\n",
+ " colors = ['md_red', 'md_green']\n",
+ " for pc_idx, box in enumerate(box_parts['boxes']):\n",
+ " box.set_color(color_vals[colors[pc_idx]])\n",
+ " ax.format(\n",
+ " ylocator=50,\n",
+ " ylim=[0, np.max([np.max(val) for val in num_interesting_vals])],\n",
+ " title=analyzer_list[ovc_idx].nat_selectivity['oc_label'],\n",
+ " ylabel='Average number of\\nintersting images',\n",
+ " xtickminor=False,\n",
+ " xgrid=False\n",
+ " )\n",
+ "\n",
+ " for idx, analyzer in enumerate(analyzer_list):\n",
+ " ax = axs[idx+3]\n",
+ " angle_min = 0.0\n",
+ " angle_max = 90.0\n",
+ " nbins=20\n",
+ " bins = np.linspace(angle_min, angle_max, nbins)\n",
+ " lin_data = [mean for mean in analyzer.nat_selectivity['lin_means'] if mean>0]\n",
+ " non_lin_data = [mean for mean in analyzer.nat_selectivity['lca_means'] if mean>0]\n",
+ " hist_list = []\n",
+ " color_list = [color_vals['md_green'], color_vals['md_red']]\n",
+ " label_list = ['Linear Autoencoder', 'Sparse Coding']\n",
+ " handles = []\n",
+ " hist_max_list = []\n",
+ " for angles, label, color in zip([lin_data, non_lin_data], label_list, color_list):\n",
+ " # density means the y vals are probability density function at the bin, normalized such that the integral over the range is 1.\n",
+ " hist, bin_edges = np.histogram(np.array(angles).flatten(), bins, density=False)\n",
+ " hist_max_list.append(hist.max())\n",
+ " hist_list.append(hist)\n",
+ " bin_left, bin_right = bin_edges[:-1], bin_edges[1:]\n",
+ " bin_centers = bin_left + (bin_right - bin_left)/2\n",
+ " handles.append(ax.plot(bin_centers, hist, linestyle='-', drawstyle='steps-mid', color=color, label=label))\n",
+ " oc = analyzer.nat_selectivity['oc_label']\n",
+ " ax.spines['top'].set_visible(False)\n",
+ " ax.spines['right'].set_visible(False)\n",
+ " ax.set_xticks(bin_left, minor=True)\n",
+ " ax.set_xticks(bin_left[::2], minor=False)\n",
+ " ax.xaxis.set_major_formatter(plticker.FormatStrFormatter('%0.0f'))\n",
+ " ax.set_xticks([angle_min, angle_max//2, angle_max])\n",
+ " mid_val = max(hist_max_list)//2\n",
+ " max_val = int(max(hist_max_list))\n",
+ " #interval_list = list(range(0, mid_val+51, 50))\n",
+ " #new_mid = closest_val_in_array(mid_val, interval_list)\n",
+ " interval_list = list(range(0, max_val+51, 50))\n",
+ " new_max = closest_val_in_array(max_val, interval_list)\n",
+ " new_mid = new_max//2\n",
+ " ax.set_ylim([0, new_max+0.1*new_max])\n",
+ " ax.set_yticks([0, new_mid, new_max])\n",
+ " #axs[-1].legend(handles, ncol=1, frameon=False, loc='ur', bbox_to_anchor=[1, 1.02])\n",
+ " hist_ax_idx = 3\n",
+ " axs[hist_ax_idx].format(ylabel='Total number of\\ninteresting images')\n",
+ " axs[hist_ax_idx:].format(\n",
+ " suptitle='Sparse Coding Increases Neuron Selectivity for Natural Signals',\n",
+ " xlabel='Mean image-to-weight angle',\n",
+ " xlim=[0, 90],\n",
+ " ygrid=False\n",
+ " )\n",
+ "plot.show()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "for analyzer in analyzer_list:\n",
+ " for ext in file_extensions:\n",
+ " save_name = (analyzer.analysis_out_dir+'/vis/natural_img_selectivity_box_'\n",
+ " +analyzer.analysis_params.save_info+ext)\n",
+ " interesting_imgs_fig.savefig(save_name, transparent=False, pad_inches=0.005, dpi=dpi)"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -984,6 +1009,16 @@
"color_list = [color_vals['md_green'], color_vals['md_red']]\n",
"label_list = ['Linear Autoencoder', 'Sparse Coding']\n",
"\n",
+ "num_interesting_vals = [\n",
+ " np.array([analyzer.nat_selectivity['num_interesting_img_nl'],\n",
+ " analyzer.nat_selectivity['num_interesting_img_l']])\n",
+ " for analyzer in analyzer_list]\n",
+ "\n",
+ "num_interesting_medians = np.stack(\n",
+ " [np.array([analyzer.nat_selectivity['num_interesting_img_nl_mean'],\n",
+ " analyzer.nat_selectivity['num_interesting_img_l_mean']])\n",
+ " for analyzer in analyzer_list], axis=0)\n",
+ "\n",
"num_interesting_means = np.stack(\n",
" [np.array([analyzer.nat_selectivity['num_interesting_img_nl_mean'],\n",
" analyzer.nat_selectivity['num_interesting_img_l_mean']])\n",
@@ -1010,10 +1045,11 @@
" 'fontsize':scale*matplotlib.rcParams['font.size'],\n",
" 'fontfamily':scale*matplotlib.rcParams['font.family'],\n",
" 'legend.fontsize': scale*matplotlib.rcParams['font.size'],\n",
- " 'text.labelsize': scale*matplotlib.rcParams['font.size']-2\n",
+ " 'text.labelsize': scale*matplotlib.rcParams['font.size']\n",
"}\n",
+ "figsize = nc.set_size(text_width, fraction=1.00)\n",
"with plot.rc.context(**rc_kwargs):\n",
- " interesting_imgs_fig, axs = plot.subplots(array, sharey=False, aspect=3.0, width=0.4*text_width_cm)\n",
+ " interesting_imgs_fig, axs = plot.subplots(array, sharey=False, aspect=3.0, figsize=figsize)#, width=0.4*text_width_cm)\n",
" ax = axs[0]\n",
" obj = ax.bar(\n",
" df,\n",
@@ -1033,7 +1069,7 @@
" xlocator=1,\n",
" xminorlocator=0.5,\n",
" ytickminor=False,\n",
- " ylim=[0, np.max(num_interesting_means)+np.max(num_interesting_stds)],\n",
+ " #ylim=[0, np.max(num_interesting_means)+np.max(num_interesting_stds)],\n",
" #suptitle='Average number of intersting images'\n",
" ylabel='Average number of\\nintersting images',\n",
" xgrid=False\n",
@@ -1062,7 +1098,8 @@
" ax.set_xticks(bin_left[::2], minor=False)\n",
" ax.xaxis.set_major_formatter(plticker.FormatStrFormatter('%0.0f'))\n",
" ax.set_xticks([angle_min, angle_max//2, angle_max])\n",
- " ax.set_ylim([0.0, max(hist_max_list)+0.1*max(hist_max_list)])\n",
+ " ax.set_ylim([0, max(hist_max_list)+0.1*max(hist_max_list)])\n",
+ " ax.set_yticks([0, max(hist_max_list)//2, int(max(hist_max_list))])\n",
" ax.format(title=f'{oc}\\n')#, ygrid=False)\n",
" #ax.grid(b=False, which='both', axis='both')\n",
" axs[1].format(ylabel='Total number of\\ninteresting images')\n",
@@ -1084,10 +1121,10 @@
"outputs": [],
"source": [
"for analyzer in analyzer_list:\n",
- " for ext in ['.png', '.pdf', '.eps']:\n",
- " save_name = (analyzer.analysis_out_dir+'/vis/natural_img_selectivity_'\n",
+ " for ext in file_extensions:\n",
+ " save_name = (analyzer.analysis_out_dir+'/vis/natural_img_selectivity_bar_'\n",
" +analyzer.analysis_params.save_info+ext)\n",
- " interesting_imgs_fig.savefig(save_name, transparent=False, pad_inches=0.005, dpi=interesting_imgs_fig.dpi)"
+ " interesting_imgs_fig.savefig(save_name, transparent=False, pad_inches=0.005, dpi=dpi)"
]
},
{
@@ -1480,7 +1517,7 @@
" COLORS, inner_group_names, outer_group_names, titles, text_width, width_ratio=1.0, dpi=dpi)\n",
"\n",
"#for analyzer in analyzer_list:\n",
- "for ext in [\".png\", \".eps\"]:\n",
+ "for ext in file_extensions:\n",
" save_name = (output_dir+'/adv_mse_comparison_boxplots'+ext)\n",
" adv_fig.savefig(save_name, transparent=False, bbox_inches='tight', pad_inches=0.05, dpi=dpi)"
]
@@ -1785,7 +1822,7 @@
"outputs": [],
"source": [
"#for analyzer in analyzer_list:\n",
- "for ext in [\".png\", \".eps\"]:\n",
+ "for ext in file_extensions:\n",
" save_name = (output_dir+'/adv_mse_comparison_example_images'+ext)\n",
" adv_img_fig.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.05, dpi=dpi)"
]
@@ -1839,9 +1876,8 @@
"\n",
"conf_fig = plot_average_conf_step(files, names)\n",
"#for analyzer in analyzer_list:\n",
- "for ext in [\".png\", \".eps\"]:\n",
- " save_name = (output_dir+'adv_mse_comparison_example_images'\n",
- " +\"_\"+analyzer.analysis_params.save_info+ext)\n",
+ "for ext in file_extensions:\n",
+ " save_name = (output_dir+'adv_mse_comparison_example_images'+ext)\n",
" conf_fig.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.01, dpi=dpi)"
]
},
@@ -2082,66 +2118,75 @@
"metadata": {},
"outputs": [],
"source": [
- "def plot_average_mse_step(analysis_files, colors, title, model_names, bar_width, hatches, figsize, dpi):\n",
+ "def plot_average_mse_step(analysis_files, recons, confs, colors, title, model_names, bar_width, hatches, figsize, dpi):\n",
" fig = plt.figure(figsize=figsize, dpi=dpi)\n",
- " gs0 = gridspec.GridSpec(1, 2, wspace=0.2, width_ratios = [2, 1])#, hspace=0.3)\n",
- " left_gs = gridspec.GridSpecFromSubplotSpec(1, 2, gs0[0], wspace=1.3)\n",
- " right_gs = gridspec.GridSpecFromSubplotSpec(1, 1, gs0[1], wspace=0.9)\n",
- " group_data = []\n",
- " group_means = []\n",
- " handles = []\n",
+ " num_conditions = len(analysis_files)\n",
+ " gs_top = gridspec.GridSpec(num_conditions, num_conditions)\n",
" axes = []\n",
- " for x_ax_idx, key in enumerate(['input_adv_mses', 'adversarial_outputs']):\n",
- " axes.append(fig.add_subplot(left_gs[x_ax_idx]))\n",
- " for file_idx, (file, name) in enumerate(zip(analysis_files, model_names)):\n",
- " analysis = np.load(file, allow_pickle=True)[\"data\"].item()\n",
- " adv_conf = 100*np.max(np.squeeze(analysis['adversarial_outputs']), axis=-1)\n",
- " if x_ax_idx == 0:\n",
- " axes[-1].set_ylabel('Adversarial\\nConfidence')\n",
- " axes[-1].axhline(90.0, color='black', linestyle='dashed', linewidth=1) \n",
- " axes[-1].set_ylim([0, 100.1])\n",
- " mean_vals = np.mean(adv_conf, axis=-1)[1:]\n",
- " std_vals = np.std(adv_conf, axis=-1)[1:]\n",
- " else:\n",
- " adv_mse = np.squeeze(analysis['input_adv_mses'])\n",
- " axes[-1].set_ylabel('Adversarial Mean\\nSquared Distance')\n",
- " axes[-1].yaxis.set_major_formatter(ticker.FormatStrFormatter('%.3f'))\n",
- " thresh_indices = np.argwhere(np.mean(adv_conf, axis=-1)>90)\n",
- " first_adv_cross = np.min(thresh_indices[thresh_indices>2]) # first couple are original label\n",
- " axes[-1].axvline(first_adv_cross, color=colors[file_idx][0], linestyle='dashed', linewidth=1)\n",
- " mean_vals = np.mean(adv_mse, axis=-1)[1:]\n",
- " std_vals = np.std(adv_mse, axis=-1)[1:]\n",
- " group_data.append(adv_mse[first_adv_cross, :])\n",
- " group_means.append(mean_vals[first_adv_cross])\n",
- " max_val = 0.020#np.max(mean_vals)+std_vals[np.argmax(mean_vals)]\n",
- " axes[-1].set_ylim([0, max_val])\n",
- " axes[-1].plot(range(len(mean_vals)), mean_vals, label=name,\n",
- " lw=2, color=colors[file_idx][0], zorder=1)\n",
- " axes[-1].fill_between(range(len(mean_vals)), mean_vals + std_vals , mean_vals - std_vals,\n",
- " edgecolor=colors[file_idx][1], alpha=1.0, zorder=0, facecolor=\"none\", hatch=hatches[file_idx],\n",
- " rasterized=False)\n",
- " axes[-1].set_xlabel('Attack Step')\n",
+ " for condition, (condition_analysis_files, recon, conf) in enumerate(zip(analysis_files, recons, confs)):\n",
+ " #gs0 = gridspec.GridSpec(1, 2, wspace=0.2, width_ratios = [2, 1])#, hspace=0.3)\n",
+ " gs0 = gridspec.GridSpecFromSubplotSpec(1, 2, gs_top[condition, :],\n",
+ " wspace=0.2, width_ratios = [2, 1])#, hspace=0.3)\n",
+ " left_gs = gridspec.GridSpecFromSubplotSpec(1, 2, gs0[0], wspace=1.3)\n",
+ " right_gs = gridspec.GridSpecFromSubplotSpec(1, 1, gs0[1], wspace=0.9)\n",
+ " group_data = []\n",
+ " group_means = []\n",
+ " handles = []\n",
+ " for x_ax_idx, key in enumerate(['input_adv_mses', 'adversarial_outputs']):\n",
+ " axes.append(fig.add_subplot(left_gs[x_ax_idx]))\n",
+ " for file_idx, (file, name) in enumerate(zip(condition_analysis_files, model_names)):\n",
+ " analysis = np.load(file, allow_pickle=True)[\"data\"].item()\n",
+ " adv_conf = 100*np.max(np.squeeze(analysis['adversarial_outputs']), axis=-1)\n",
+ " if x_ax_idx == 0:\n",
+ " if condition == 0:\n",
+ " axes[-1].set_ylabel('Adversarial\\nConfidence')\n",
+ " axes[-1].axhline(90.0, color='black', linestyle='dashed', linewidth=1) \n",
+ " axes[-1].set_ylim([0, 100.1])\n",
+ " mean_vals = np.mean(adv_conf, axis=-1)[1:]\n",
+ " std_vals = np.std(adv_conf, axis=-1)[1:]\n",
+ " else:\n",
+ " if condition == 0:\n",
+ " axes[-1].set_ylabel('Adversarial Mean\\nSquared Distance')\n",
+ " adv_mse = np.squeeze(analysis['input_adv_mses'])\n",
+ " axes[-1].yaxis.set_major_formatter(ticker.FormatStrFormatter('%.3f'))\n",
+ " thresh_indices = np.argwhere(np.mean(adv_conf, axis=-1)>90)\n",
+ " first_adv_cross = np.min(thresh_indices[thresh_indices>2]) # first couple are original label\n",
+ " axes[-1].axvline(first_adv_cross, color=colors[file_idx][0], linestyle='dashed', linewidth=1)\n",
+ " mean_vals = np.mean(adv_mse, axis=-1)[1:]\n",
+ " std_vals = np.std(adv_mse, axis=-1)[1:]\n",
+ " group_data.append(adv_mse[first_adv_cross, :])\n",
+ " group_means.append(mean_vals[first_adv_cross])\n",
+ " max_val = 0.03#np.max(mean_vals)+std_vals[np.argmax(mean_vals)]\n",
+ " axes[-1].set_ylim([0, max_val])\n",
+ " axes[-1].plot(range(len(mean_vals)), mean_vals, label=name,\n",
+ " lw=2, color=colors[file_idx][0], zorder=1)\n",
+ " axes[-1].fill_between(range(len(mean_vals)), mean_vals + std_vals , mean_vals - std_vals,\n",
+ " edgecolor=colors[file_idx][1], alpha=1.0, zorder=0, facecolor=\"none\",\n",
+ " hatch=hatches[file_idx], rasterized=False)\n",
+ " if condition == num_conditions-1:\n",
+ " axes[-1].set_xlabel('Attack Step')\n",
+ " axes[-1].grid(False)\n",
+ " axes.append(fig.add_subplot(right_gs[0]))\n",
+ " x_pos = np.arange(2) + 2 * bar_width\n",
+ " linewidth = 1\n",
+ " medianprops = dict(linestyle='--', linewidth=linewidth, color='k')\n",
+ " meanprops = dict(linestyle='-', linewidth=linewidth, color='k')\n",
+ " float_colors = [[52/255, 152/255, 219/255], [231/255, 76/255, 60/255]] # blue, red\n",
+ " axes[-1].set_title(f'c={recon}, '+r'$\\kappa$'+f'={conf}')\n",
+ " for data, means, pos, color, name in zip(group_data, group_means, x_pos, float_colors, model_names):\n",
+ " boxprops = dict(linestyle='-', linewidth=linewidth, color=color)\n",
+ " whiskerprops = boxprops\n",
+ " capprops = boxprops\n",
+ " handles.append(axes[-1].boxplot(data, sym='', positions=[pos],\n",
+ " whis=(5, 95), widths=bar_width, meanline=True, showmeans=True, boxprops=boxprops,\n",
+ " whiskerprops=whiskerprops, capprops=capprops, medianprops=medianprops,\n",
+ " meanprops=meanprops\n",
+ " ))\n",
+ " axes[-1].set_ylim([0, max_val])\n",
+ " axes[-1].set_yticklabels('')\n",
+ " axes[-1].get_xaxis().set_ticks([])\n",
" axes[-1].grid(False)\n",
- " axes.append(fig.add_subplot(right_gs[0]))\n",
- " x_pos = np.arange(2) + 2 * bar_width\n",
- " linewidth = 1\n",
- " medianprops = dict(linestyle='--', linewidth=linewidth, color='k')\n",
- " meanprops = dict(linestyle='-', linewidth=linewidth, color='k')\n",
- " colors = [[52/255, 152/255, 219/255], [231/255, 76/255, 60/255]] # blue, red\n",
- " for data, means, pos, color, name in zip(group_data, group_means, x_pos, colors, model_names):\n",
- " boxprops = dict(linestyle='-', linewidth=linewidth, color=color)\n",
- " whiskerprops = boxprops\n",
- " capprops = boxprops\n",
- " handles.append(axes[-1].boxplot(data, sym='', positions=[pos],\n",
- " whis=(5, 95), widths=bar_width, meanline=True, showmeans=True, boxprops=boxprops,\n",
- " whiskerprops=whiskerprops, capprops=capprops, medianprops=medianprops,\n",
- " meanprops=meanprops\n",
- " ))\n",
- " axes[-1].set_ylim([0, max_val])\n",
- " axes[-1].set_yticklabels('')\n",
- " axes[-1].get_xaxis().set_ticks([])\n",
- " axes[-1].grid(False)\n",
- " axes[-1].text(pos, 0.002, name, horizontalalignment='center', verticalalignment='center')\n",
+ " axes[-1].text(pos, 0.0025, name, horizontalalignment='center', verticalalignment='center')\n",
" fig.subplots_adjust(top=0.8)\n",
" fig.suptitle(title, y=0.98)\n",
" return fig, axes"
@@ -2156,35 +2201,44 @@
"outputs": [],
"source": [
"colors = [[color_vals['md_blue'], color_vals['lt_blue']], [color_vals['md_red'], color_vals['lt_red']]]\n",
- "model_names = ['w/o\\nLCA', 'w/\\nLCA']\n",
+ "model_names = ['w/o LCA', 'w/ LCA']\n",
"hatches = ['///', '\\\\\\\\\\\\']\n",
"\n",
- "k_file_path = analysis_dir+'savefiles/class_adversary_analysis_test_temp_kurakin_targeted.npz'\n",
- "k_img_path = analysis_dir+'savefiles/class_adversary_images_analysis_test_temp_kurakin_targeted.npz'\n",
- "k_mlp_files = [projects_dir + model_name + k_file_path for model_name in [mnist_mlp_768_2layer]]\n",
- "k_lca_files = [projects_dir + model_name + k_file_path for model_name in [mnist_lca_768_2layer]]\n",
- "k_files = k_mlp_files + k_lca_files\n",
- "\n",
- "run_number = '5'\n",
- "c_file_path = (analysis_dir+'savefiles/class_adversary_analysis_test_temp'\n",
- " +str(run_number)+'_carlini_targeted.npz')\n",
- "c_img_path = (analysis_dir+'savefiles/class_adversary_images_analysis_test_temp'\n",
- " +str(run_number)+'_carlini_targeted.npz')\n",
- "c_mlp_files = [projects_dir + model_name + c_file_path for model_name in [mnist_mlp_768_2layer]]\n",
- "c_lca_files = [projects_dir + model_name + c_file_path for model_name in [mnist_lca_768_2layer]]\n",
- "c_files = c_mlp_files + c_lca_files\n",
- "\n",
- "figsize = nc.set_size(text_width, fraction=1.0, subplot=[2, 3])\n",
"#carlini_title = 'Networks with an LCA layer require larger\\nperturbations for equal confidence with the Carlini attack'\n",
"#carlini_title = 'Networks with an LCA layer are more robust than without'\n",
- "carlini_title = ''#Networks with an LCA layer are more robust than without'\n",
- "fig, ax = plot_average_mse_step(c_files, colors, carlini_title,\n",
+ "carlini_title = ''\n",
+ "\n",
+ "all_recons = []\n",
+ "all_confs = []\n",
+ "all_files = []\n",
+ "for recon in ['0.5', '1.0']:\n",
+ " for conf in ['0.0', '10.0']:\n",
+ " if conf == '10.0':\n",
+ " extra_str = '_'\n",
+ " temp = '1.00'\n",
+ " else:\n",
+ " extra_str = ''\n",
+ " temp = '1.0'\n",
+ " c_file_path = (f'{analysis_dir}savefiles/class_adversary_analysis_test'+\n",
+ " f'{extra_str}temp{temp}_conf{conf}_recon{recon}_carlini_targeted.npz')\n",
+ " c_mlp_files = [projects_dir + model_name + c_file_path for model_name in [mnist_mlp_768_2layer]]\n",
+ " temp = '0.65'\n",
+ " c_file_path = (f'{analysis_dir}savefiles/class_adversary_analysis_test'+\n",
+ " f'{extra_str}temp{temp}_conf{conf}_recon{recon}_carlini_targeted.npz')\n",
+ " c_lca_files = [projects_dir + model_name + c_file_path for model_name in [mnist_lca_768_2layer]]\n",
+ " c_files = c_mlp_files + c_lca_files\n",
+ " all_recons.append(recon)\n",
+ " all_confs.append(conf)\n",
+ " all_files.append(c_files)\n",
+ "\n",
+ "figsize = nc.set_size(text_width, fraction=1.0, subplot=[2*2, 3])\n",
+ "fig, ax = plot_average_mse_step(all_files, all_recons, all_confs, colors, carlini_title,\n",
" model_names, bar_width, hatches, figsize, dpi)\n",
"\n",
- "out_list = [projects_dir + model_name + '/analysis/0.0/vis/kurakin_carlini_mse_vs_iteration_temp' + str(run_number)\n",
+ "out_list = [projects_dir + model_name + '/analysis/0.0/vis/carlini_mse_vs_iteration_k0.0-10.0_conditions'\n",
" for model_name in [mnist_lca_768_2layer, mnist_mlp_768_2layer]]\n",
"for out_name in out_list:\n",
- " for ext in [\".png\", \".eps\"]:\n",
+ " for ext in file_extensions:\n",
" save_name = out_name+ext\n",
" fig.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.01, dpi=dpi)"
]
@@ -2334,7 +2388,7 @@
"\n",
"out_list += [path + lista + \"/analysis/0.0/vis/lista_adv_transferability\"]\n",
"for out_name in out_list:\n",
- " for ext in [\".png\", \".eps\"]:\n",
+ " for ext in file_extensions:\n",
" save_name = out_name+ext\n",
" fig.savefig(save_name, transparent=False, bbox_inches=\"tight\", pad_inches=0.01, dpi=dpi)"
]
@@ -2367,7 +2421,7 @@
" image_idx = category_idx + start_idx\n",
" orig_ax = ax_orig_list[category_idx]\n",
" if category_idx == 0:\n",
- " orig_ax.set_title(\"Unperturbed\", y=orig_y_adj, fontsize=fontsize)\n",
+ " orig_ax.set_title(\"Unperturbed\", y=orig_y_adj)#, fontsize=fontsize)\n",
" orig_img = crop(np.squeeze(mlp_grp[0][0][image_idx, ...]), crop_ammount)\n",
" orig_im_handle = show_image_with_label(orig_ax, orig_img, orig_labels[0][0][image_idx], cmap=cmap)\n",
" for model_idx, gs_sub in enumerate([gs_sub0_list[category_idx], gs_sub1_list[category_idx]]):\n",
@@ -2388,27 +2442,27 @@
" vmax = 1.0\n",
" if j == 0: # top left image\n",
" if model_idx == 0:\n",
- " current_ax.set_ylabel(r\"$s^{*}_{T}$\", fontsize=fontsize)\n",
+ " current_ax.set_ylabel(r\"$s^{*}_{T}$\")#, fontsize=fontsize)\n",
" current_target_label = target_labels[j][model_idx][image_idx]\n",
" if category_idx == 0: # top category only\n",
" x_loc = group_name_loc[0]\n",
" y_loc = group_name_loc[1]\n",
- " text_handle = current_ax.text(x_loc, y_loc, group_names[j+model_idx], fontsize=fontsize,\n",
+ " text_handle = current_ax.text(x_loc, y_loc, group_names[j+model_idx],#, fontsize=fontsize,\n",
" horizontalalignment='left', verticalalignment='bottom')\n",
" else: # i == 1\n",
" vmin = np.round(diff_vmin, 2)\n",
" vmax = np.round(diff_vmax, 2)\n",
" if j == 0 and model_idx == 0:\n",
- " current_ax.set_ylabel(r\"$s-s^{*}_{T}$\", fontsize=fontsize)\n",
+ " current_ax.set_ylabel(r\"$s-s^{*}_{T}$\")#, fontsize=fontsize)\n",
" if j == 0 and category_idx == num_categories-1: # bottom left\n",
- " current_ax.set_xlabel(\"w/o\\nLCA\", fontsize=fontsize)\n",
+ " current_ax.set_xlabel(\"w/o\\nLCA\")#, fontsize=fontsize)\n",
" elif j == 1 and category_idx == num_categories-1: # bottom right\n",
- " current_ax.set_xlabel(\"w/\\nLCA\", fontsize=fontsize)\n",
+ " current_ax.set_xlabel(\"w/\\nLCA\")#, fontsize=fontsize)\n",
" im_handle = show_image_with_label(current_ax, current_image, current_target_label, vmin=vmin, vmax=vmax, cmap=cmap)\n",
" if j == 1:\n",
- " pf.add_colorbar_to_im(im_handle, aspect=10, ax=current_ax, ticks=[vmin, vmax], labelsize=fontsize/2)\n",
+ " pf.add_colorbar_to_im(im_handle, aspect=10, ax=current_ax, ticks=[vmin, vmax])#, labelsize=fontsize/2)\n",
"\n",
- "def plot_adv_images_with_figsize(image_groups, labels, mnist_start_idx, cifar_start_idx, figsize, fontsize, dpi=100):\n",
+ "def plot_adv_images_with_figsize(image_groups, labels, mnist_start_idx, cifar_start_idx, figsize, dpi=100):\n",
" mnist_mlp_grp, mnist_lca_grp = image_groups[0]\n",
" cifar_mlp_grp, cifar_lca_grp = image_groups[1]\n",
" mnist_orig_labels, mnist_target_labels, mnist_img_labels = labels[0]\n",
@@ -2420,22 +2474,23 @@
" sub_wspace = 0.2\n",
" orig_y_adj = 1.10\n",
" img_label_loc = [-8.0, -8.0] # [x, y]\n",
- " fig2 = plt.figure(figsize=[figsize[0]/2, figsize[1]], dpi=dpi)\n",
- " gs0 = plt.GridSpec(2, 1, figure=fig2, hspace=0.3)\n",
+ " fig = plt.figure(figsize=[figsize[0]/2, figsize[1]], dpi=dpi)\n",
+ " gs0 = plt.GridSpec(2, 1, figure=fig, hspace=0.3)\n",
" \n",
" num_categories=3\n",
" \n",
" gs_mnist = gridspec.GridSpecFromSubplotSpec(num_categories, 6, gs0[0], hspace=hspace, wspace=wspace)\n",
- " make_grid_subplots_with_fontsize(fig2, gs_mnist, mnist_mlp_grp, mnist_lca_grp, mnist_orig_labels,\n",
+ " make_grid_subplots_with_fontsize(fig, gs_mnist, mnist_mlp_grp, mnist_lca_grp, mnist_orig_labels,\n",
" mnist_target_labels, mnist_img_labels, img_label_loc, orig_y_adj, mnist_start_idx, num_categories,\n",
- " hspace=sub_hspace, wspace=sub_wspace, cmap=\"Greys\", fontsize=fontsize)\n",
+ " hspace=sub_hspace, wspace=sub_wspace, cmap=\"Greys\")\n",
" \n",
" gs_cifar = gridspec.GridSpecFromSubplotSpec(num_categories, 6, gs0[1], hspace=hspace, wspace=wspace)\n",
- " make_grid_subplots_with_fontsize(fig2, gs_cifar, cifar_mlp_grp, cifar_lca_grp, cifar_orig_labels,\n",
+ " make_grid_subplots_with_fontsize(fig, gs_cifar, cifar_mlp_grp, cifar_lca_grp, cifar_orig_labels,\n",
" cifar_target_labels, cifar_img_labels, img_label_loc, orig_y_adj, 0, num_categories,\n",
- " hspace=sub_hspace, wspace=sub_wspace, cmap=\"Greys_r\", fontsize=fontsize)\n",
+ " hspace=sub_hspace, wspace=sub_wspace, cmap=\"Greys_r\")\n",
" \n",
- " plt.show()"
+ " plt.show()\n",
+ " return fig "
]
},
{
@@ -2446,9 +2501,17 @@
},
"outputs": [],
"source": [
- "full_adv_img_fig = plot_adv_images_with_figsize(image_groups, labels, mnist_start_idx=44, cifar_start_idx=0,\n",
- " figsize=(16, 16), fontsize=20, dpi=dpi)"
+ "figsize = nc.set_size(text_width, fraction=1.0, subplot=[16, 16])\n",
+ "full_adv_img_fig = plot_adv_images_with_figsize(image_groups, label_groups, mnist_start_idx=44, cifar_start_idx=0,\n",
+ " figsize=figsize, dpi=dpi)"
]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
}
],
"metadata": {
diff --git a/tf1x/vis/tsne_analysis.py b/tf1x/vis/tsne_analysis.py
index f0f4430e..b60d7220 100644
--- a/tf1x/vis/tsne_analysis.py
+++ b/tf1x/vis/tsne_analysis.py
@@ -1,5 +1,9 @@
import os
import sys
+from os.path import dirname as up
+
+ROOT_DIR = up(up(up(up(os.path.realpath(__file__)))))
+if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
import numpy as np
import pickle
@@ -7,9 +11,6 @@
from tensorflow.contrib.tensorboard.plugins import projector
from scipy.misc import imsave
-ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
-if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
-
import DeepSparseCoding.tf1x.data.data_selector as ds
import DeepSparseCoding.tf1x.analysis.analysis_picker as ap
import DeepSparseCoding.tf1x.utils.data_processing as dp
diff --git a/tf1x/vis/vis_class_adversarial.py b/tf1x/vis/vis_class_adversarial.py
index 7bbd72a7..ac41ce23 100644
--- a/tf1x/vis/vis_class_adversarial.py
+++ b/tf1x/vis/vis_class_adversarial.py
@@ -2,6 +2,11 @@
matplotlib.use('Agg')
import os
import sys
+from os.path import dirname as up
+
+ROOT_DIR = up(up(up(up(os.path.realpath(__file__)))))
+if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
+
import numpy as np
import matplotlib
@@ -13,9 +18,6 @@
import pandas as pd
import pdb
-ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
-if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
-
import DeepSparseCoding.tf1x.data.data_selector as ds
import DeepSparseCoding.tf1x.utils.data_processing as dp
import DeepSparseCoding.tf1x.utils.plot_functions as pf
diff --git a/tf1x/vis/vis_conv_lca.py b/tf1x/vis/vis_conv_lca.py
index 567dea57..08fbda10 100644
--- a/tf1x/vis/vis_conv_lca.py
+++ b/tf1x/vis/vis_conv_lca.py
@@ -1,6 +1,11 @@
# In[1]:
import os
+from os.path import dirname as up
+
+ROOT_DIR = up(up(up(up(os.path.realpath(__file__)))))
+if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
+
import numpy as np
import matplotlib
matplotlib.use('Agg')
@@ -11,9 +16,6 @@
import tensorflow as tf
import pdb
-ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
-if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
-
import DeepSparseCoding.tf1x.data.data_selector as ds
import DeepSparseCoding.tf1x.utils.data_processing as dp
import DeepSparseCoding.tf1x.utils.plot_functions as pf
diff --git a/tf1x/vis/vis_corrupt.py b/tf1x/vis/vis_corrupt.py
index 8b6b3df4..dee47ad2 100644
--- a/tf1x/vis/vis_corrupt.py
+++ b/tf1x/vis/vis_corrupt.py
@@ -2,6 +2,10 @@
matplotlib.use('Agg')
import os
import sys
+from os.path import dirname as up
+
+ROOT_DIR = up(up(up(up(os.path.realpath(__file__)))))
+if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
import numpy as np
import matplotlib
@@ -13,9 +17,6 @@
import pandas as pd
import pickle
-ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
-if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
-
import DeepSparseCoding.tf1x.data.data_selector as ds
import DeepSparseCoding.tf1x.utils.data_processing as dp
import DeepSparseCoding.tf1x.utils.plot_functions as pf
diff --git a/tf1x/vis/vis_recon_adversarial.py b/tf1x/vis/vis_recon_adversarial.py
index ab364e0e..ae926b89 100644
--- a/tf1x/vis/vis_recon_adversarial.py
+++ b/tf1x/vis/vis_recon_adversarial.py
@@ -3,6 +3,10 @@
import os
import sys
import pdb
+from os.path import dirname as up
+
+ROOT_DIR = up(up(up(up(os.path.realpath(__file__)))))
+if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
import numpy as np
import matplotlib
@@ -11,9 +15,6 @@
import matplotlib.gridspec as gridspec
from skimage.measure import compare_psnr
-ROOT_DIR = os.path.dirname(os.path.dirname(os.getcwd()))
-if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
-
from DeepSparseCoding.tf1x.data.dataset import Dataset
import DeepSparseCoding.tf1x.data.data_selector as ds
import DeepSparseCoding.tf1x.utils.data_processing as dp
diff --git a/train_model.py b/train_model.py
index db2b2778..1bad1364 100644
--- a/train_model.py
+++ b/train_model.py
@@ -2,10 +2,13 @@
import sys
import argparse
import time as ti
+from os.path import dirname as up
-ROOT_DIR = os.path.dirname(os.getcwd())
+ROOT_DIR = up(up(os.path.realpath(__file__)))
if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
+import torch
+
import DeepSparseCoding.utils.loaders as loaders
import DeepSparseCoding.utils.run_utils as run_utils
import DeepSparseCoding.utils.dataset_utils as dataset_utils
@@ -19,31 +22,33 @@
t0 = ti.time()
# Load params
-params = loaders.load_params(param_file)
+params = loaders.load_params_file(param_file)
# Load data
-train_loader, val_loader, test_loader, data_stats = dataset_utils.load_dataset(params)
-for key, value in data_stats.items():
- setattr(params, key, value)
+train_loader, val_loader, test_loader, data_stats = dataset_utils.load_dataset(params)[:4]
+for key, value in data_stats.items(): setattr(params, key, value)
# Load model
model = loaders.load_model(params.model_type)
model.setup(params)
model.to(params.device)
+model.log_architecture_details()
# Train model
for epoch in range(1, model.params.num_epochs+1):
run_utils.train_epoch(epoch, model, train_loader)
- if(model.params.model_type.lower() in ['mlp', 'ensemble']):
- run_utils.test_epoch(epoch, model, test_loader)
- model.log_info(f'Completed epoch {epoch}/{model.params.num_epochs}')
+ # TODO: Ensemble models might not actually have a classification objective / need validation
+ #if(model.params.model_type.lower() in ['mlp', 'ensemble']): # TODO: use to validation set here; test at the end of training
+ # run_utils.test_epoch(epoch, model, test_loader)
+ model.logger.log_string(f'Completed epoch {epoch}/{model.params.num_epochs}')
print(f'Completed epoch {epoch}/{model.params.num_epochs}')
+# Final outputs
t1 = ti.time()
tot_time=float(t1-t0)
tot_images = model.params.num_epochs*len(train_loader.dataset)
out_str = f'Training on {tot_images} images is complete. Total time was {tot_time} seconds.\n'
-model.log_info(out_str)
+model.logger.log_string(out_str)
print('Training Complete\n')
model.write_checkpoint()
diff --git a/utils/data_processing.py b/utils/data_processing.py
index fff2fddc..16482dfa 100644
--- a/utils/data_processing.py
+++ b/utils/data_processing.py
@@ -4,34 +4,36 @@
def reshape_data(data, flatten=None, out_shape=None):
"""
- Helper function to reshape input data for processing and return data shape
- Inputs:
+ Reshape input data for processing and return data shape
+
+ Keyword arguments:
data: [tensor] data of shape:
- n is num_examples, i is num_rows, j is num_cols, k is num_channels, l is num_examples = i*j*k
+ n is num_examples, i is num_channels (a.k.a. vector length for fattened inputs), j is num_rows, k is num_cols
if out_shape is not specified, it is assumed that i == j
- (l) - single data point of shape l, assumes 1 color channel
- (n, l) - n data points, each of shape l (flattened)
- (i, j, k) - single datapoint of of shape (i,j, k)
- (n, i, j, k) - n data points, each of shape (i,j,k)
+ (i) - single data point of shape i (flattened; assumed j = k = 1)
+ (n, i) - n data points, each of shape i (flattened; assumed j = k = 1)
+ (i, j, k) - single datapoint of of shape (i, j, k)
+ (n, i, j, k) - n data points, each of shape (i, j, k)
flatten: [bool or None] specify the shape of the output
If out_shape is not None, this arg has no effect
If None, do not reshape data, but add num_examples dimension if necessary
- If True, return ravelled data of shape (num_examples, num_elements)
- If False, return unravelled data of shape (num_examples, sqrt(l), sqrt(l), 1)
- where l is the number of elements (dimensionality) of the datapoints
+ If True, return ravelled data of shape (num_examples, num_elements), where num_elements=i*j*k
+ If False, return unravelled data of shape (num_examples, 1, sqrt(i), sqrt(i))
+ where i is the number of elements (size) of the data points and 1 is the assumed number of channels
If data is flat and flatten==True, or !flat and flatten==False, then None condition will apply
out_shape: [list or tuple] containing the desired output shape
This will overwrite flatten, and return the input reshaped according to out_shape
+
Outputs:
tuple containing:
data: [tensor] data with new shape
- (num_examples, num_rows, num_cols, num_channels) if flatten==False
- (num_examples, num_elements) if flatten==True
+ (num_examples, num_channels, num_rows, num_cols) if flatten==False
+ (num_examples, num_elements) if flatten==True, where num_elements = num_channels*num_rows*num_cols
orig_shape: [tuple of int32] original shape of the input data
num_examples: [int32] number of data examples or None if out_shape is specified
+ num_channels: [int32] number of data channels or None if out_shape is specified
num_rows: [int32] number of data rows or None if out_shape is specified
num_cols: [int32] number of data cols or None if out_shape is specified
- num_channels: [int32] number of data channels or None if out_shape is specified
"""
orig_shape = data.shape
orig_ndim = data.ndim
@@ -48,7 +50,7 @@ def reshape_data(data, flatten=None, out_shape=None):
elif flatten == True:
num_rows = num_elements
num_cols = 1
- data = torch.reshape(data, (num_examples, num_rows*num_cols*num_channels))
+ data = torch.reshape(data, (num_examples, num_channels * num_rows * num_cols))
else: # flatten == False
sqrt_num_elements = np.sqrt(num_elements)
assert np.floor(sqrt_num_elements) == np.ceil(sqrt_num_elements), (
@@ -57,7 +59,7 @@ def reshape_data(data, flatten=None, out_shape=None):
+' and data_shape='+str(orig_shape))
num_rows = int(sqrt_num_elements)
num_cols = num_rows
- data = torch.reshape(data, (num_examples, num_rows, num_cols, num_channels))
+ data = torch.reshape(data, (num_examples, num_channels, num_rows, num_cols))
elif orig_ndim == 2: # already flattened
(num_examples, num_elements) = data.shape
if flatten is None or flatten == True: # don't reshape data
@@ -71,34 +73,35 @@ def reshape_data(data, flatten=None, out_shape=None):
num_rows = int(sqrt_num_elements)
num_cols = num_rows
num_channels = 1
- data = torch.reshape(data, (num_examples, num_rows, num_cols, num_channels))
+ data = torch.reshape(data, (num_examples, num_channels, num_rows, num_cols))
else:
assert False, ('flatten argument must be True, False, or None')
elif orig_ndim == 3: # single data point
num_examples = 1
- num_rows, num_cols, num_channels = data.shape
+ num_channels, num_rows, num_cols = data.shape
if flatten == True:
- data = torch.reshape(data, (num_examples, num_rows * num_cols * num_channels))
+ data = torch.reshape(data, (num_examples, num_channels * num_rows * num_cols))
elif flatten is None or flatten == False: # already not flat
- data = data[None, ...]
+ data = data[None, ...] # add singleton num_examples dimension
else:
assert False, ('flatten argument must be True, False, or None')
- elif orig_ndim == 4: # not flat
- num_examples, num_rows, num_cols, num_channels = data.shape
+ elif orig_ndim == 4: # multiple data points, not flat
+ num_examples, num_channels, num_rows, num_cols = data.shape
if flatten == True:
- data = torch.reshape(data, (num_examples, num_rows*num_cols*num_channels))
+ data = torch.reshape(data, (num_examples, num_channels * num_rows * num_cols))
else:
assert False, ('Data must have 1, 2, 3, or 4 dimensions.')
else:
num_examples = None; num_rows=None; num_cols=None; num_channels=None
data = torch.reshape(data, out_shape)
- return (data, orig_shape, num_examples, num_rows, num_cols, num_channels)
+ return (data, orig_shape, num_examples, num_channels, num_rows, num_cols)
def check_all_same_shape(tensor_list):
"""
Verify that all tensors in the tensor list have the same shape
- Args:
+
+ Keyword arguments:
tensor_list: list of tensors to be checked
Returns:
raises error if the tensors are not the same shape
@@ -107,53 +110,134 @@ def check_all_same_shape(tensor_list):
for index, tensor in enumerate(tensor_list):
if tensor.shape != first_shape:
raise ValueError(
- 'Tensor entry %g in input list has shape %g, but should have shape %g'%(
- index, tensor.shape, first_shape))
+ f'Tensor entry {index} in input list has shape {tensor.shape}, but should have shape {first_shape}')
-def flatten_feature_map(feature_map):
+def get_std_from_dataloader(loader, dataset_mean):
"""
- Flatten input tensor from [batch, y, x, f] to [batch, y*x*f]
- Args:
- feature_map: tensor with shape [batch, y, x, f]
- Returns:
- reshaped_map: tensor with shape [batch, y*x*f]
- """
- map_shape = feature_map.shape
- if(len(map_shape) == 4):
- (batch, y, x, f) = map_shape
- prev_input_features = int(y * x * f)
- resh_map = torch.reshape(feature_map, [-1, prev_input_features])
- elif(len(map_shape) == 2):
- resh_map = feature_map
- else:
- raise ValueError('Input feature_map has incorrect ndims')
- return resh_map
+ TODO: Calculate the standard deviation from all entries in a pytorch data loader
+ Keyword arguments:
+ loader: [pytorch DataLoader] containing the full dataset.
+ This function assumes there is always a target label, i.e. loader.next() returns (data, target)
+ dataset_mean: [torch tensor] of the same shape as a single dataset sample
-def standardize(data, eps=None, samplewise=True):
+ Outputs:
+ dataset_std: [torch tensor] of the same shape as a single dataset sample
+ """
+ #dataset_std = torch.zeros(next(iter(loader)).shape[1:])
+ #for data, target in loader:
+ # std_sum_squares += (data - dataset_mean)**2
+ #dataset_std = torch.sqrt(std_sum_squares / len(loader.dataset))
+ #return dataset_std
+ raise NotImplementedError
+
+
+def get_std_from_dataloader(loader, dataset_mean=None):
+ """
+ Calculate the standard deviation from the mean for all entries in a pytorch data loader
+
+ Keyword arguments:
+ loader: [pytorch DataLoader] containing the full dataset.
+ This function assumes there is always a target label, i.e. loader.next() returns (data, target)
+
+ Outputs:
+ dataset_std: [torch tensor] of the same shape as a single dataset sample
+ """
+ if dataset_mean is None:
+ dataset_mean = get_mean_from_dataloader(loader)
+ dataset_std = torch.zeros(next(iter(loader))[0].shape[1:]) # don't include batch dimension
+ num_batches = 0
+ for data, target in loader:
+ dataset_std += torch.std(data - dataset_mean, dim=0, keepdim=False)
+ num_batches += 1
+ return dataset_std / num_batches
+
+
+def get_mean_from_dataloader(loader):
+ """
+ Calculate the mean datapoint from all entries in a pytorch data loader
+
+ Keyword arguments:
+ loader: [pytorch DataLoader] containing the full dataset.
+ This function assumes there is always a target label, i.e. loader.next() returns (data, target)
+
+ Outputs:
+ dataset_mean: [torch tensor] of the same shape as a single dataset sample
+ """
+ dataset_mean = torch.zeros(next(iter(loader))[0].shape[1:]) # don't include batch dimension
+ num_batches = 0
+ for data, target in loader:
+ dataset_mean += data.mean(dim=0, keepdim=False)
+ num_batches += 1
+ return dataset_mean / num_batches
+
+
+def center(data, samplewise=False, batch_size=100):
+ """
+ Center image dataset to have zero mean
+
+ Keyword arguments:
+ data: [tensor] unnormalized data
+ samplewise: [bool] if True, center each sample individually; if False, compute mean over entire batch
+
+ Outputs:
+ data: [tensor] centered data
+ """
+ data, orig_shape = reshape_data(data, flatten=True)[:2]
+ if(samplewise): # center each input sample individually
+ data_axis = tuple(range(data.ndim)[1:])
+ data_mean = torch.mean(data, dim=data_axis, keepdim=True)
+ else: # center the entire population
+ data_mean = torch.mean(data, dim=0)
+ data = data - data_mean
+ if(data.shape != orig_shape):
+ data = reshape_data(data, out_shape=orig_shape)[0]
+ return data, data_mean
+
+
+def standardize(data, eps=None, samplewise=False, batch_size=100, sample_mean=None, sample_std=None):
"""
Standardize each image data to have zero mean and unit standard-deviation (z-score)
- Uses population standard deviation data.sum() / N, where N = data.shape[0].
- Inputs:
+
+ This function uses population standard deviation data.sum() / N, where N = data.shape[0].
+
+ Keyword arguments:
data: [tensor] unnormalized data
eps: [float] if the std(data) is less than eps, then divide by eps instead of std(data)
+ defaults to 1/sqrt(data_dim) where data_dim is the total size of a data vector
samplewise: [bool] if True, standardize each sample individually; akin to contrast-normalization
if False, compute mean and std over entire batch
+ sample_mean: [tensor] to be used as the dataset mean instead of calculating it,
+ it should be the same shape as a single data element
+ sample_std: [tensor] to be used as the dataset mean instead of calculating it,
+ it should be the same shape as a single data element
+
Outputs:
data: [tensor] normalized data
"""
if(eps is None):
- eps = 1.0 / np.sqrt(data[0,...].numel())
- data, orig_shape = reshape_data(data, flatten=True)[:2] # Adds channel dimension if it's missing
+ eps = 1.0 / data[0,...].numel()
+ data, orig_shape = reshape_data(data, flatten=True)[:2]
num_examples = data.shape[0]
- if(samplewise): # standardize the entire population
- data_axis = tuple(range(data.ndim)[1:]) # standardize each example individually
- data_mean = torch.mean(data, dim=data_axis, keepdim=True)
- data_true_std = torch.std(data, unbiased=False, dim=data_axis, keepdim=True)
- else: # standardize each input sample individually
- data_mean = torch.mean(data)
- data_true_std = torch.std(data, unbiased=False)
+ if(samplewise): # standardize each input sample individually
+ if sample_mean is None:
+ data_mean = torch.mean(data, dim=1, keepdim=True) # [num_examples, 1]
+ else:
+ data_mean = sample_mean.mean().repeat(num_examples, 1)
+ if sample_std is None:
+ data_true_std = torch.std(data - data_mean, unbiased=False, dim=1, keepdim=True)
+ else:
+ data_true_std = sample_std.mean().repeat(num_examples, 1)
+ else: # standardize the entire population
+ if sample_mean is None:
+ data_mean = torch.mean(data, dim=0, keepdim=True) # [1, sample_dim]
+ else:
+ data_mean = sample_mean.view(1, -1)
+ if sample_std is None:
+ data_true_std = torch.std(data - data_mean, dim=0, unbiased=False, keepdim=True)
+ else:
+ data_true_std = sample_std.view(1, -1)
data_std = torch.where(data_true_std >= eps, data_true_std, eps*torch.ones_like(data_true_std))
data = (data - data_mean) / data_std
if(data.shape != orig_shape):
@@ -164,10 +248,12 @@ def standardize(data, eps=None, samplewise=True):
def rescale_data_to_one(data, eps=None, samplewise=True):
"""
Rescale input data to be between 0 and 1
- Inputs:
+
+ Keyword arguments:
data: [tensor] unnormalized data
eps: [float] if the std(data) is less than eps, then divide by eps instead of std(data)
samplewise: [bool] if True, compute it per-sample, otherwise normalize entire batch
+
Outputs:
data: [tensor] centered data of shape (n, i, j, k) or (n, l)
"""
@@ -175,9 +261,9 @@ def rescale_data_to_one(data, eps=None, samplewise=True):
eps = 1.0 / np.sqrt(data[0,...].numel())
if(samplewise):
data_min = torch.min(data.view(-1, np.prod(data.shape[1:])),
- axis=1, keepdims=False)[0].view(-1, *[1]*(data.ndim-1))
+ axis=1, keepdim=False)[0].view(-1, *[1]*(data.ndim-1))
data_max = torch.max(data.view(-1, np.prod(data.shape[1:])),
- axis=1, keepdims=False)[0].view(-1, *[1]*(data.ndim-1))
+ axis=1, keepdim=False)[0].view(-1, *[1]*(data.ndim-1))
else:
data_min = torch.min(data)
data_max = torch.max(data)
@@ -186,11 +272,14 @@ def rescale_data_to_one(data, eps=None, samplewise=True):
data = (data - data_min) / data_range
return data, data_min, data_max
+
def one_hot_to_dense(one_hot_labels):
"""
- converts a matrix of one-hot labels to a list of dense labels
- Inputs:
+ Convert a matrix of one-hot labels to a list of dense labels
+
+ Keyword arguments:
one_hot_labels: one-hot torch tensor of shape [num_labels, num_classes]
+
Outputs:
dense_labels: 1D torch tensor array of labels
The integer value indicates the class and 0 is assumed to be a class.
@@ -202,13 +291,247 @@ def one_hot_to_dense(one_hot_labels):
dense_labels[label_id] = torch.nonzero(one_hot_labels[label_id, :] == 1)
return dense_labels
+
def dense_to_one_hot(labels_dense, num_classes):
"""
- converts a (np.ndarray) vector of dense labels to a (np.ndarray) matrix of one-hot labels
- e.g. [0, 1, 1, 3] -> [00, 01, 01, 11]
+ Converts a (np.ndarray) vector of dense labels to a (np.ndarray) matrix of one-hot labels. E.g. [0, 1, 1, 3] -> [00, 01, 01, 11]
+
+ Keyword arguments:
+ labels_dense: dense torch tensor of shape [num_classes], where each entry is an integer indicating the class label
+ num-classes: The total number of classes in the dataset
+
+ Outputs:
+ one_hot_labels: one-hot torch tensor of shape [num_labels, num_classes]
"""
num_labels = labels_dense.shape[0]
index_offset = torch.arange(end=num_labels, dtype=torch.int32) * num_classes
labels_one_hot = torch.zeros((num_labels, num_classes))
labels_one_hot.view(-1)[index_offset + labels_dense.view(-1)] = 1
return labels_one_hot
+
+
+def atleast_kd(x, k):
+ """
+ Return x reshaped to append singleton dimensions such that x.ndim is at least k
+
+ Keyword arguments:
+ x [Tensor or numpy ndarray]
+ k [int] minimum number of dimensions
+
+ Outputs:
+ x [same as input x] reshaped input to have at least k dimensions
+ """
+ shape = x.shape + (1,) * (k - x.ndim)
+ return x.reshape(shape)
+
+
+def get_weights_l2_norm(w, eps=1e-12):
+ """
+ Return l2 norm of weight matrix
+
+ Keyword arguments:
+ w [Tensor] assumed to have shape [outC, inC] or [outC, inC, kernH, kernW]
+ norm is calculated over vectorized version of inC in the first case or inC*kernH*kernW in the second
+ eps [float] minimum value to prevent division by zero
+
+ Outputs:
+ norm [Tensor] norm of each of the outC weight vectors
+ """
+ if w.ndim == 2: # fully-connected, [outputs, inputs]
+ norms = torch.norm(w, dim=1, keepdim=True)
+ elif w.ndim == 4: # convolutional, [out_channels, in_channels, kernel_height, kernel_width]
+ norms = torch.norm(w.flatten(start_dim=1), dim=-1, keepdim=True)
+ else:
+ assert False, (f'input w must have ndim = 2 or 4, not {w.ndim}')
+ if(torch.max(norms) <= eps): #TODO: raise proper warnings
+ print(f'Warning: input gradient is less than or equal to {eps}')
+ norms = torch.max(norms, eps*torch.ones_like(norms)) # prevent div by 0 # TODO: Change to torch.maximum when it is stable
+ norms = atleast_kd(norms, w.ndim)
+ return norms
+
+
+def l2_normalize_weights(w, eps=1e-12):
+ """
+ l2 normalize weight matrix
+
+ Keyword arguments:
+ w [Tensor] assumed to have shape [inC, outC] or [outC, inC, kernH, kernW]
+ norm is calculated over vectorized version of inC in the first case or inC*kernH*kernW in the second
+ eps [float] minimum value to prevent division by zero
+
+ Outputs:
+ w [Tensor] same type and shape as input w, but with unitary l2 norm when computed over all input dimensions
+ """
+ norms = get_weights_l2_norm(w, eps)
+ return w / norms
+
+
+def single_image_to_patches(image, patch_shape):
+ """
+ Extract patches from a single image
+
+ Keyword arguments:
+ image [torch tensor] of shape [im_chan, im_height, im_width]
+ patch_shape [tuple or list] containing the output shape
+ [patch_chan, patch_height, patch_width]
+ patch_chan must be the same as im_chan
+
+ It is recommended, though not required, that the patch height and width divide evenly into
+ the image height and width, respectively.
+
+ Outputs:
+ patches [torch tensor] of patches of shape [num_patches]+list(patch_shape)
+ """
+ try:
+ im_chan, im_height, im_width = image.shape
+ patch_chan, patch_height, patch_width = patch_shape
+ except Exception as e:
+ raise ValueError(
+ f'This function requires that: '
+ +f'1) The input variable "image" must have shape [im_chan, im_height, im_width], and is {image.shape}'
+ +f'and 2) the input variable "patch_shape" must have shape [patch_chan, patch_height, patch_width], and is {patch_shape}.'
+ ) from e
+ num_row_patches = np.floor(im_height / patch_height)
+ num_col_patches = np.floor(im_width / patch_width)
+ num_patches = int(num_row_patches * num_col_patches)
+ patches = torch.zeros((num_patches, patch_chan, patch_height, patch_width))
+ row_id = 0
+ col_id = 0
+ for patch_idx in range(num_patches):
+ row_end = row_id + patch_height
+ col_end = col_id + patch_width
+ try:
+ patches[patch_idx, ...] = image[:, row_id:row_end, col_id:col_end]
+ except Exception as e:
+ raise ValueError('This function requires that im_chan equal patch_chan.') from e
+ row_id += patch_height
+ if row_id >= im_height:
+ row_id = 0
+ col_id += patch_width
+ if col_id >= im_width:
+ col_id = 0
+ return patches
+
+
+def patches_to_single_image(patches, image_shape):
+ """
+ Convert patches input into a single ouput
+
+ Keyword arguments:
+ patches [torch tensor] of shape [num_patches, patch_chan, patch_height, patch_width]
+ image_shape [list or tuple] of length 2 containing the image shape [im_chan, im_height, im_width]
+
+ im_chan is assumed to equal patch_chan
+
+ Outputs:
+ image [torch tensor] of shape [im_chan, im_height, im_width]
+ """
+ try:
+ num_patches, patch_chan, patch_height, patch_width = patches.shape
+ im_chan, im_height, im_width = image_shape
+ except Exception as e:
+ raise ValueError(
+ f'This funciton requires that input patches has shape'
+ f' [num_patches, patch_chan, patch_height, patch_width] and is {patches.shape}'
+ f' and input image_shape is a list or tuple of integers of length 3 containing [im_chan, im_height, im_width] and is {image_shape}'
+ ) from e
+ image = torch.zeros((im_chan, im_height, im_width))
+ row_id = 0
+ col_id = 0
+ for patch_idx in range(num_patches):
+ row_end = row_id + patch_height
+ col_end = col_id + patch_width
+ image[:, row_id:row_end, col_id:col_end] = patches[patch_idx, ...]
+ row_id += patch_height
+ if row_id >= im_height:
+ row_id = 0
+ col_id += patch_width
+ if col_id >= im_width:
+ col_id = 0
+ return image
+
+def images_to_patches(images, patch_shape):
+ """
+ Extract evenly distributed non-overlapping patches from an image dataset
+
+ Keyword arguments:
+ images [torch tensor] of shape [num_images, im_chan, im_height, im_width] or [im_chan, im_height, im_width] for a single image
+ patch_shape [tuple or list] containing the output shape
+ [patch_chan, patch_height, patch_width]
+ patch_chan must be the same as im_chan
+
+ It is recommended, though not required, that the patch height and width divide evenly into the image height and width, respectively.
+
+ Outputs:
+ patches [np.ndarray] of patches of shape [num_patches]+list(patch_shape)
+ """
+ if images.ndim == 3: # single image
+ return single_image_to_patches(images, patch_shape)
+ num_im, im_chan, im_height, im_width = images.shape
+ patch_chan, patch_height, patch_width = patch_shape
+ num_row_patches = np.floor(im_height / patch_height)
+ num_col_patches = np.floor(im_width / patch_width)
+ num_patches_per_im = int(num_row_patches * num_col_patches)
+ tot_num_patches = int(num_patches_per_im * num_im)
+ patches = torch.zeros([tot_num_patches, ]+list(patch_shape))
+ patch_id = 0
+ for im_id in range(num_im):
+ image = images[im_id, ...]
+ image_patches = single_image_to_patches(image, patch_shape)
+ patch_end = patch_id + num_patches_per_im
+ patches[patch_id:patch_end, ...] = image_patches
+ patch_id += num_patches_per_im
+ return patches
+
+def patches_to_images(patches, image_shape):
+ """
+ Recombine patches tensor into a dataset of images
+
+ Keyword arguments:
+ patches [torch tensor] holding square patch data of shape [num_patches, patch_chan, patch_height, patch_width]
+ image_shape [list or tuple] containing the image dataset shape [im_chan, im_height, im_width]
+
+ It is assumed that im_chan equals patch_chan
+
+ Outputs:
+ images [torch tensor] holding the recombined image dataset
+ """
+ tot_num_patches, patch_chan, patch_height, patch_width = patches.shape
+ im_chan, im_height, im_width = image_shape
+ num_row_patches = np.floor(im_height / patch_height)
+ num_col_patches = np.floor(im_width / patch_width)
+ num_patches_per_im = int(num_row_patches * num_col_patches)
+ num_im = tot_num_patches // num_patches_per_im
+ images = torch.zeros([num_im]+image_shape)
+ patch_id = 0
+ for im_id in range(num_im):
+ patch_end = patch_id + num_patches_per_im
+ patch_batch = patches[patch_id:patch_end, ...]
+ images[im_id, ...] = patches_to_single_image(patch_batch, image_shape)
+ patch_id += num_patches_per_im
+ return images
+
+
+def covariance(tensor):
+ """
+ Returns the covariance matrix of the input tensor
+
+ Keyword arguments:
+ tensor [torch tensor] of shape [num_batch, num_channels] or [num_batch, num_channels, elements_h, elements_w]
+ if tensor.ndim is 2 then the covariance is computed for each element over the batch dimension
+ if tensor.ndim is 4 then the covariance is computed over spatial dimensions for each channel and each batch instance
+
+ Outputs:
+ covariance matrix [torch tensor] of shape [num_channels, num_channels]
+ """
+ if tensor.ndim == 2: # [num_batch, num_channels]
+ centered_tensor = tensor - tensor.mean(dim=0, keepdim=True) # subtract mean vector
+ covariance = torch.mm(centered_tensor.T, centered_tensor) # sum over batch
+ covariance = covariance / (centered_tensor.shape[0]-1)
+ elif tensor.ndim == 4: # [num_batch, num_channels, elements_h, elements_w]
+ num_batch, num_channels, elements_h, elements_w = tensor.shape
+ flat_map = tensor.view(num_batch, num_channels, elements_h * elements_w)
+ cent_flat_map = flat_map - flat_map.mean(dim=2, keepdim=True) # subtract mean vector
+ covariance = torch.bmm(cent_flat_map, torch.transpose(cent_flat_map, 1, 2)) # sum over space
+ covariance = covariance.mean(dim=0, keepdim=False) # avg cov over batch
+ return covariance
diff --git a/utils/dataset_utils.py b/utils/dataset_utils.py
index cd46541b..8c17df11 100644
--- a/utils/dataset_utils.py
+++ b/utils/dataset_utils.py
@@ -1,17 +1,54 @@
import os
import sys
+from os.path import dirname as up
+ROOT_DIR = up(up(up(os.path.realpath(__file__))))
+if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
+
+from PIL import Image
import numpy as np
import torch
-from torchvision import datasets, transforms
-
-ROOT_DIR = os.path.dirname(os.getcwd())
-if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
+from torchvision import transforms
+import torchvision.datasets
import DeepSparseCoding.utils.data_processing as dp
import DeepSparseCoding.datasets.synthetic as synthetic
+class FastMNIST(torchvision.datasets.MNIST):
+ """
+ The torchvision MNIST dataset has additional overhead that slows it down.
+ This loads the entire dataset onto the specified device at init, resulting in a considerable speedup
+ """
+ def __init__(self, *args, **kwargs):
+ device = kwargs.pop('device', 'cpu')
+ super().__init__(*args, **kwargs)
+ # Scale data to [0,1]
+ self.data = self.data.unsqueeze(-1).float().div(255)
+ self.data = self.data.permute(0, 3, 1, 2) # channels first
+ if self.transform is not None:
+ # doing this so that it is consistent with all other datasets
+ # to return a PIL Image
+ for data_idx in range(self.data.shape[0]):
+ self.data[data_idx, ...] = self.transform(
+ Image.fromarray(
+ self.data[data_idx, ...].numpy().squeeze(), mode='L'))[None, ...]
+ if self.target_transform is not None:
+ self.targets = [self.target_transform(int(target)) for target in self.targets]
+ # Put both data and targets on GPU in advance
+ self.data, self.targets = self.data.to(device), self.targets.to(device)
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+ Returns:
+ tuple: (image, target) where target is index of the target class.
+ """
+ img, target = self.data[index], self.targets[index]
+ return img, target
+
+
class CustomTensorDataset(torch.utils.data.Dataset):
def __init__(self, data_tensor):
self.data_tensor = data_tensor
@@ -24,29 +61,98 @@ def __len__(self):
def load_dataset(params):
- new_params = {}
+ data_stats = {} # dataset statistics
+ extra_outputs = {} # depending on parameters may include dataset_mean, dataset_std, etc
if(params.dataset.lower() == 'mnist'):
preprocessing_pipeline = [
transforms.ToTensor(),
- transforms.Lambda(lambda x: x.permute(1, 2, 0)) # channels last
- ]
+ ]
if params.standardize_data:
preprocessing_pipeline.append(
transforms.Lambda(lambda x: dp.standardize(x, eps=params.eps)[0]))
if params.rescale_data_to_one:
preprocessing_pipeline.append(
transforms.Lambda(lambda x: dp.rescale_data_to_one(x, eps=params.eps, samplewise=True)[0]))
- train_loader = torch.utils.data.DataLoader(
- datasets.MNIST(root=params.data_dir, train=True, download=True,
- transform=transforms.Compose(preprocessing_pipeline)),
- batch_size=params.batch_size, shuffle=params.shuffle_data,
- num_workers=0, pin_memory=False)
- val_loader = None
- test_loader = torch.utils.data.DataLoader(
- datasets.MNIST(root=params.data_dir, train=False, download=True,
- transform=transforms.Compose(preprocessing_pipeline)),
- batch_size=params.batch_size, shuffle=params.shuffle_data,
- num_workers=0, pin_memory=False)
+ kwargs = {
+ 'root':params.data_dir,
+ 'download':False,
+ 'transform':transforms.Compose(preprocessing_pipeline)
+ }
+ if(hasattr(params, 'fast_mnist') and params.fast_mnist):
+ kwargs['device'] = params.device
+ kwargs['train'] = True
+ train_loader = torch.utils.data.DataLoader(
+ FastMNIST(**kwargs), batch_size=params.batch_size,
+ shuffle=params.shuffle_data, num_workers=0, pin_memory=False)
+ kwargs['train'] = False
+ val_loader = None
+ test_loader = torch.utils.data.DataLoader(
+ FastMNIST(**kwargs), batch_size=params.batch_size,
+ shuffle=params.shuffle_data, num_workers=0, pin_memory=False)
+ else:
+ kwargs['train'] = True
+ train_loader = torch.utils.data.DataLoader(
+ torchvision.datasets.MNIST(**kwargs), batch_size=params.batch_size,
+ shuffle=params.shuffle_data, num_workers=0, pin_memory=True)
+ kwargs['train'] = False
+ val_loader = None
+ test_loader = torch.utils.data.DataLoader(
+ torchvision.datasets.MNIST(**kwargs), batch_size=params.batch_size,
+ shuffle=params.shuffle_data, num_workers=0, pin_memory=True)
+
+ elif(params.dataset.lower() == 'cifar10'):
+ preprocessing_pipeline = [
+ transforms.ToTensor(),
+ ]
+ kwargs = {
+ 'root': os.path.join(*[params.data_dir, 'cifar10']),
+ 'download': False,
+ 'train': True,
+ 'transform': transforms.Compose(preprocessing_pipeline)
+ }
+ if params.center_dataset:
+ dataset = torchvision.datasets.CIFAR10(**kwargs)
+ data_loader = torch.utils.data.DataLoader(dataset, batch_size=params.batch_size,
+ shuffle=False, num_workers=0, pin_memory=True)
+ dataset_mean_image = dp.get_mean_from_dataloader(data_loader)
+ preprocessing_pipeline.append(
+ transforms.Lambda(lambda x: x - dataset_mean_image))
+ extra_outputs['dataset_mean_image'] = dataset_mean_image
+ if params.standardize_data:
+ #dataset = torchvision.datasets.CIFAR10(**kwargs)
+ #data_loader = torch.utils.data.DataLoader(dataset, batch_size=params.batch_size,
+ # shuffle=False, num_workers=0, pin_memory=True)
+ #dataset_mean_image = dp.get_mean_from_dataloader(data_loader)
+ #extra_outputs['dataset_mean_image'] = dataset_mean_image
+ #dataset_std_image = dp.get_std_from_dataloader(data_loader, dataset_mean_image)
+ #extra_outputs['dataset_std_image'] = dataset_std_image
+ preprocessing_pipeline.append(
+ transforms.Lambda(
+ lambda x: dp.standardize(x,
+ eps=params.eps,
+ samplewise=True,#False,
+ batch_size=params.batch_size)[0]
+ #sample_mean=dataset_mean_image,
+ #sample_std=dataset_std_image)[0]
+ )
+ )
+ if params.rescale_data_to_one:
+ preprocessing_pipeline.append(
+ transforms.Lambda(lambda x: dp.rescale_data_to_one(x, eps=params.eps, samplewise=True)[0]))
+ kwargs['transform'] = transforms.Compose(preprocessing_pipeline)
+ kwargs['train'] = True
+ dataset = torchvision.datasets.CIFAR10(**kwargs)
+ kwargs['train'] = False
+ testset = torchvision.datasets.CIFAR10(**kwargs)
+ num_train = len(dataset) - params.num_validation
+ trainset, valset = torch.utils.data.random_split(dataset, [num_train, params.num_validation])
+ train_loader = torch.utils.data.DataLoader(trainset, batch_size=params.batch_size,
+ shuffle=params.shuffle_data, num_workers=0, pin_memory=True)
+ val_loader = torch.utils.data.DataLoader(valset, batch_size=params.batch_size,
+ shuffle=False, num_workers=0, pin_memory=True)
+ test_loader = torch.utils.data.DataLoader(testset, batch_size=params.batch_size,
+ shuffle=False, num_workers=0, pin_memory=True)
+
elif(params.dataset.lower() == 'dsprites'):
root = os.path.join(*[params.data_dir])
dsprites_file = os.path.join(*[root, 'dsprites/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz'])
@@ -67,10 +173,11 @@ def load_dataset(params):
pin_memory=False)
val_loader = None
test_loader = None
+
elif(params.dataset.lower() == 'synthetic'):
- preprocessing_pipeline = [transforms.ToTensor(),
- transforms.Lambda(lambda x: x.permute(1, 2, 0)) # channels last
- ]
+ preprocessing_pipeline = [
+ transforms.ToTensor(),
+ ]
train_loader = torch.utils.data.DataLoader(
synthetic.SyntheticImages(params.epoch_size, params.data_edge_size, params.dist_type,
params.rand_state, params.num_classes,
@@ -79,20 +186,20 @@ def load_dataset(params):
num_workers=0, pin_memory=False)
val_loader = None
test_loader = None
- new_params["num_pixels"] = params.data_edge_size**2
+
else:
assert False, (f'Supported datasets are ["mnist", "dsprites", "synthetic"], not {dataset_name}')
- new_params = {}
- new_params['epoch_size'] = len(train_loader.dataset)
+ data_stats['epoch_size'] = len(train_loader.dataset)
if(not hasattr(params, 'num_val_images')):
if val_loader is None:
- new_params['num_val_images'] = 0
+ data_stats['num_val_images'] = 0
else:
- new_params['num_val_images'] = len(val_loader.dataset)
+ data_stats['num_val_images'] = len(val_loader.dataset)
if(not hasattr(params, 'num_test_images')):
if test_loader is None:
- new_params['num_test_images'] = 0
+ data_stats['num_test_images'] = 0
else:
- new_params['num_test_images'] = len(test_loader.dataset)
- new_params['data_shape'] = list(next(iter(train_loader))[0].shape)[1:]
- return (train_loader, val_loader, test_loader, new_params)
+ data_stats['num_test_images'] = len(test_loader.dataset)
+ data_stats['data_shape'] = list(next(iter(train_loader))[0].shape)[1:]
+ data_stats['num_pixels'] = np.prod(data_stats['data_shape'])
+ return (train_loader, val_loader, test_loader, data_stats, extra_outputs)
diff --git a/utils/file_utils.py b/utils/file_utils.py
index 6dd4b7f8..2ca41dc9 100644
--- a/utils/file_utils.py
+++ b/utils/file_utils.py
@@ -3,6 +3,7 @@
import types
import os
from copy import deepcopy
+from collections import OrderedDict
import importlib
import numpy as np
@@ -27,6 +28,16 @@ def js_dumpstring(self, obj):
"""Dump json string with special CustomEncoder"""
return js.dumps(obj, sort_keys=True, indent=2, cls=CustomEncoder)
+ def log_string(self, string):
+ """Log input string"""
+ now = time.localtime(time.time())
+ time_str = time.strftime('%m/%d/%y %H:%M:%S', now)
+ out_str = '\n' + time_str + ' -- ' + str(string)
+ if(self.log_to_file):
+ self.file_obj.write(out_str)
+ else:
+ print(out_str)
+
def log_trainable_variables(self, name_list):
"""
Use logging to write names of trainable variables in model
@@ -34,7 +45,7 @@ def log_trainable_variables(self, name_list):
name_list: list containing variable names
"""
js_str = self.js_dumpstring(name_list)
- self.log_info(''+js_str+'')
+ self.log_string(''+js_str+'')
def log_params(self, params):
"""
@@ -45,7 +56,7 @@ def log_params(self, params):
out_params = deepcopy(params)
if('ensemble_params' in out_params.keys()):
for sub_idx, sub_params in enumerate(out_params['ensemble_params']):
- sub_params.set_params()
+ #sub_params.set_params()
for key, value in sub_params.__dict__.items():
if(key != 'rand_state'):
new_dict_key = f'{sub_idx}_{key}'
@@ -54,17 +65,17 @@ def log_params(self, params):
if('rand_state' in out_params.keys()):
del out_params['rand_state']
js_str = self.js_dumpstring(out_params)
- self.log_info(''+js_str+'')
+ self.log_string(''+js_str+'')
- def log_info(self, string):
- """Log input string"""
- now = time.localtime(time.time())
- time_str = time.strftime('%m/%d/%y %H:%M:%S', now)
- out_str = '\n' + time_str + ' -- ' + str(string)
- if(self.log_to_file):
- self.file_obj.write(out_str)
- else:
- print(out_str)
+ def log_stats(self, stat_dict):
+ """Log dictionary of training / testing statistics"""
+ js_str = self.js_dumpstring(stat_dict)
+ self.log_string(''+js_str+'')
+
+ def log_info(self, info_dict):
+ """Log input dictionary in tags"""
+ js_str = self.js_dumpstring(info_dict)
+ self.log_string(''+js_str+'')
def load_file(self, filename=None):
"""
@@ -168,6 +179,18 @@ def read_stats(self, text):
stats[key] = [js_match[key]]
return stats
+ def read_architecture(self, text):
+ """
+ Generate dictionary of lists that contain stats from log text
+ Outpus:
+ stats: [dict] containing run statistics
+ Inputs:
+ text: [str] containing text to parse, can be obtained by calling load_file()
+ """
+ tokens = ['', '']
+ js_match = self.read_js(tokens, text)
+ return js_match
+
def __del__(self):
if(self.log_to_file and hasattr(self, 'file_obj')):
self.file_obj.close()
@@ -199,3 +222,92 @@ def python_module_from_file(py_module_name, file_name):
py_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(py_module)
return py_module
+
+def summary_string(model, input_size, batch_size=2, device=torch.device('cuda:0'), dtype=torch.FloatTensor):
+ """
+ Returns a string that summarizees the model architecture, including the number of parameters
+ and layer output sizes
+
+ Code is modified from:
+ https://github.com/sksq96/pytorch-summary
+
+ Keyword arguments:
+ model [torch module, module subclass, or EnsembleModel] model to summarize
+ input_size [tuple or list of tuples] must not include the batch dimension; if it is a list
+ of tuples then the architecture will be computed for each option
+ batch_size [positive int] how many images to feed into the model.
+ The default of 2 will ensure that batch norm works.
+ devie [torch.device] which device to run the test on
+ dtype [torch.dtype] for the artificially generated inputs
+ """
+ def register_hook(module):
+ def hook(module, input, output):
+ class_name = str(module.__class__).split('.')[-1].split("'")[0]
+ module_idx = len(summary)
+ m_key = '%s-%i' % (class_name, module_idx + 1)
+ summary[m_key] = OrderedDict()
+ summary[m_key]['input_shape'] = list(input[0].size())
+ summary[m_key]['input_shape'][0] = batch_size
+ if isinstance(output, (list, tuple)):
+ summary[m_key]['output_shape'] = [
+ [-1] + list(o.size())[1:] for o in output
+ ]
+ else:
+ summary[m_key]['output_shape'] = list(output.size())
+ summary[m_key]['output_shape'][0] = batch_size
+ params = 0
+ if hasattr(module, 'weight') and hasattr(module.weight, 'size'):
+ params += torch.prod(torch.LongTensor(list(module.weight.size())))
+ summary[m_key]['trainable'] = module.weight.requires_grad
+ if hasattr(module, 'bias') and hasattr(module.bias, 'size'):
+ params += torch.prod(torch.LongTensor(list(module.bias.size())))
+ summary[m_key]['nb_params'] = params
+ summary[m_key]['gpu_mem'] = round(torch.cuda.memory_allocated(0)/1024**3, 1)
+ if len(list(module.children())) == 0: # only apply hooks at child modules to avoid applying them twice
+ hooks.append(module.register_forward_hook(hook))
+ x = torch.rand(batch_size, *input_size).type(dtype).to(device=device)
+ summary = OrderedDict() # used within hook function to store properties
+ hooks = [] # used within hook function to store resgistered hooks
+ model.apply(register_hook) # recursively apply register_hook function to model and all children
+ model(x) # make a forward pass
+ for h in hooks:
+ h.remove() # remove the hooks so they are not used at run time
+ summary_str = '----------------------------------------------------------------\n'
+ line_new = '{:>20} {:>25} {:>15}'.format('Layer (type)', 'Output Shape', 'Param #')
+ summary_str += line_new + '\n'
+ summary_str += '================================================================\n'
+ total_params = 0
+ total_output = 0
+ trainable_params = 0
+ for layer in summary:
+ line_new = '{:>20} {:>25} {:>15}'.format(
+ layer,
+ str(summary[layer]['output_shape']),
+ '{0:,}'.format(summary[layer]['nb_params']),
+ ) # input_shape, output_shape, trainable, nb_params
+ total_params += summary[layer]['nb_params']
+ total_output += np.prod(summary[layer]['output_shape'])
+ if 'trainable' in summary[layer]:
+ if summary[layer]['trainable'] == True:
+ trainable_params += summary[layer]['nb_params']
+ summary_str += line_new + '\n'
+ # assume 4 bytes/number (float on cuda).
+ total_input_size = abs(np.prod(input_size)) * batch_size * 4. / (1024 ** 2.)
+ total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients
+ total_params_size = abs(total_params * 4. / (1024 ** 2.))
+ total_size = total_params_size + total_output_size + total_input_size
+ summary_str += '================================================================\n'
+ summary_str += f'Total params: {total_params}\n'
+ summary_str += f'Trainable params: {trainable_params}\n'
+ param_diff = total_params - trainable_params
+ summary_str += f'Non-trainable params: {param_diff}\n'
+ summary_str += '----------------------------------------------------------------\n'
+ summary_str += f'Input size (MB): {total_input_size:0.2f}\n'
+ summary_str += f'Forward/backward pass size (MB): {total_output_size:0.2f}\n'
+ summary_str += f'Params size (MB): {total_params_size:0.2f}\n'
+ summary_str += f'Estimated total size (MB): {total_size:0.2f}\n'
+ ## TODO: Update pytorch for this to work
+ #device_memory = torch.cuda.memory_summary(device, abbreviated=True)
+ #summary_str += f'Device memory allocated with batch of inputs (GB): {device_memory}\n'
+ summary_str += '----------------------------------------------------------------\n'
+ return summary_str, (total_params, trainable_params)
diff --git a/utils/loaders.py b/utils/loaders.py
index 9b73a8fb..8b395e2a 100644
--- a/utils/loaders.py
+++ b/utils/loaders.py
@@ -1,11 +1,13 @@
import os
import sys
+from os.path import dirname as up
-ROOT_DIR = os.path.dirname(os.getcwd())
+ROOT_DIR = up(up(up(os.path.realpath(__file__))))
if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)
import DeepSparseCoding.utils.file_utils as file_utils
+
def get_dir_list(target_dir, target_string):
dir_list = [filename.split('.')[0]
for filename in os.listdir(target_dir)
@@ -42,13 +44,16 @@ def load_model_class(model_type):
elif(model_type.lower() == 'lca'):
py_module_name = 'LcaModel'
file_name = os.path.join(*[dsc_dir, 'models', 'lca_model.py'])
+ elif(model_type.lower() == 'pooling'):
+ py_module_name = 'PoolingModel'
+ file_name = os.path.join(*[dsc_dir, 'models', 'pooling_model.py'])
elif(model_type.lower() == 'ensemble'):
py_module_name = 'EnsembleModel'
file_name = os.path.join(*[dsc_dir, 'models', 'ensemble_model.py'])
else:
accepted_names = [''.join(name.split('_')[:-1]) for name in get_module_list(dsc_dir)]
assert False, (
- 'Acceptible model_types are %s, not %s'%(','.join(accepted_names), model_type))
+ 'Acceptible model_types are %s, not %s'%('; '.join(accepted_names), model_type))
py_module = file_utils.python_module_from_file(py_module_name, file_name)
py_module_class = getattr(py_module, py_module_name)
return py_module_class
@@ -66,19 +71,29 @@ def load_module(module_type):
elif(module_type.lower() == 'lca'):
py_module_name = 'LcaModule'
file_name = os.path.join(*[dsc_dir, 'modules', 'lca_module.py'])
+ elif(module_type.lower() == 'pooling'):
+ py_module_name = 'PoolingModule'
+ file_name = os.path.join(*[dsc_dir, 'modules', 'pooling_module.py'])
elif(module_type.lower() == 'ensemble'):
py_module_name = 'EnsembleModule'
file_name = os.path.join(*[dsc_dir, 'modules', 'ensemble_module.py'])
else:
accepted_names = [''.join(name.split('_')[:-1]) for name in get_module_list(dsc_dir)]
assert False, (
- 'Acceptible model_types are %s, not %s'%(','.join(accepted_names), module_type))
+ 'Acceptible model_types are %s, not %s'%('; '.join(accepted_names), module_type))
py_module = file_utils.python_module_from_file(py_module_name, file_name)
py_module_class = getattr(py_module, py_module_name)
return py_module_class()
-def load_params(file_name, key='params'):
+def load_params_from_log(log_file):
+ logger = file_utils.Logger(log_file, overwrite=False)
+ log_text = logger.load_file()
+ params = logger.read_params(log_text)[-1]
+ return params
+
+
+def load_params_file(file_name, key='params'):
params_module = file_utils.python_module_from_file(key, file_name)
params = getattr(params_module, key)()
return params
diff --git a/utils/plot_functions.py b/utils/plot_functions.py
index 150ab1eb..8374db1b 100644
--- a/utils/plot_functions.py
+++ b/utils/plot_functions.py
@@ -94,3 +94,29 @@ def plot_stats(data, x_key, x_label=None, y_keys=None, y_labels=None, start_inde
return None
plot.show()
return fig
+
+def pad_images(images, pad_values=1):
+ """
+ Convert an array of images into a single tiled image with padded border
+
+ Keyword arguments:
+ images: [np.ndarray] of shape [num_samples, im_height, im_width, im_chan]
+ pad_values: [int] specifying what value will be used for padding
+
+ Outputs:
+ padded_images: [np.ndarray] padded version of input
+ """
+ n = int(np.ceil(np.sqrt(images.shape[0])))
+ padding = (((0, n ** 2 - images.shape[0]),
+ (1, 1), (1, 1)) # add some space between filters
+ + ((0, 0),) * (images.ndim - 3)) # don't pad last dimension (if there is one)
+ padded_images = np.pad(images, padding, mode="constant",
+ constant_values=pad_values)
+ # tile the filters into an image
+ padded_images = padded_images.reshape((
+ (n, n) + padded_images.shape[1:])).transpose((
+ (0, 2, 1, 3) + tuple(range(4, padded_images.ndim + 1))))
+ padded_images = padded_images.reshape((n * padded_images.shape[1],
+ n * padded_images.shape[3]) + padded_images.shape[4:])
+ return padded_images
+
diff --git a/utils/run_utils.py b/utils/run_utils.py
index 085fb6ca..20a1f7ea 100644
--- a/utils/run_utils.py
+++ b/utils/run_utils.py
@@ -1,5 +1,25 @@
+import numpy as np
import torch
+import DeepSparseCoding.utils.data_processing as dp
+
+
+def compute_conv_output_shape(in_length, kernel_length, stride, padding=0, dilation=1):
+ out_shape = ((in_length + 2 * padding - dilation * (kernel_length - 1) - 1) / stride) + 1
+ return np.floor(out_shape).astype(int)
+
+
+def compute_deconv_output_shape(in_length, kernel_length, stride, padding=0, output_padding=0, dilation=1):
+ out_shape = (in_length - 1) * stride - 2 * padding + dilation * (kernel_length - 1) + output_padding + 1
+ return np.floor(out_shape).astype(int)
+
+
+def get_module_encodings(module, data, allow_grads=False):
+ if allow_grads:
+ return module.get_encodings(data)
+ else:
+ return module.get_encodings(data).detach()
+
def train_single_model(model, loss):
model.optimizer.zero_grad() # clear gradietns of all optimized variables
@@ -7,7 +27,7 @@ def train_single_model(model, loss):
model.optimizer.step()
if(hasattr(model.params, 'renormalize_weights') and model.params.renormalize_weights):
with torch.no_grad(): # tell autograd to not record this operation
- model.w.div_(torch.norm(model.w, dim=0, keepdim=True))
+ model.weight.div_(dp.get_weights_l2_norm(model.weight))
def train_epoch(epoch, model, loader):
@@ -18,33 +38,34 @@ def train_epoch(epoch, model, loader):
for batch_idx, (data, target) in enumerate(loader):
data, target = data.to(model.params.device), target.to(model.params.device)
inputs = []
- if(model.params.model_type.lower() == 'ensemble'): # TODO: Move this to train_model
+ if(model.params.model_type.lower() == 'ensemble'):
inputs.append(model[0].preprocess_data(data)) # First model preprocesses the input
for submodule_idx, submodule in enumerate(model):
loss = model.get_total_loss((inputs[-1], target), submodule_idx)
train_single_model(submodule, loss)
- # TODO: include optional parameter to allow gradients to propagate through the entire ensemble.
- inputs.append(submodule.get_encodings(inputs[-1]).detach()) # must detach to prevent gradient leaking
+ encodings = get_module_encodings(submodule, inputs[-1],
+ model.params.allow_parent_grads)
+ inputs.append(encodings)
else:
inputs.append(model.preprocess_data(data))
loss = model.get_total_loss((inputs[-1], target))
train_single_model(model, loss)
if model.params.train_logs_per_epoch is not None:
if(batch_idx % int(num_batches/model.params.train_logs_per_epoch) == 0.):
- batch_step = epoch * model.params.batches_per_epoch + batch_idx
+ batch_step = int((epoch - 1) * model.params.batches_per_epoch + batch_idx)
model.print_update(
input_data=inputs[0], input_labels=target, batch_step=batch_step)
if(model.params.model_type.lower() == 'ensemble'):
for submodule in model:
- submodule.scheduler.step(epoch)
+ submodule.scheduler.step()
else:
- model.scheduler.step(epoch)
+ model.scheduler.step()
def test_single_model(model, data, target, epoch):
output = model(data)
#test_loss = torch.nn.functional.nll_loss(output, target, reduction='sum').item()
- test_loss = torch.nn.CorssEntropyLoss()(output, target)
+ test_loss = torch.nn.CrossEntropyLoss()(output, target)
pred = output.max(1, keepdim=True)[1]
correct = pred.eq(target.view_as(pred)).sum().item()
return (test_loss, correct)
@@ -76,13 +97,12 @@ def test_epoch(epoch, model, loader, log_to_file=True):
test_accuracy = 100. * correct / len(loader.dataset)
stat_dict = {
'test_epoch':epoch,
- 'test_loss':test_loss,
+ 'test_loss':test_loss.item(),
'test_correct':correct,
'test_total':len(loader.dataset),
'test_accuracy':test_accuracy}
if log_to_file:
- js_str = model.js_dumpstring(stat_dict)
- model.log_info(''+js_str+'')
+ model.logger.log_stats(stat_dict)
else:
return stat_dict