Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Smt #45

Open
wants to merge 45 commits into
base: main
Choose a base branch
from
Open

Smt #45

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
3c31803
adds weight normalizer and atleast_kd
dpaiton Sep 8, 2020
15bf4d6
updates weight init to use new normalizer util
dpaiton Sep 8, 2020
793abf5
changes weight norm to use new util; removes deprecated argument to s…
dpaiton Sep 8, 2020
8e64b9f
removes unnecessary imports
dpaiton Sep 8, 2020
3ae4997
adds convolutional lca; training gets NaN after 14 epochs
dpaiton Sep 8, 2020
26bddf3
adds fastMNIST dataset; bugfixes; import fixes
dpaiton Sep 10, 2020
8889ebe
adds proper error message and IPython embed on js load fail
dpaiton Sep 30, 2020
bb17d46
updates plots for final JOV submission
dpaiton Sep 30, 2020
352eeaa
adds scaffolding for SMT model
dpaiton Jan 15, 2021
5d56472
final updates for JOV paper figures
dpaiton Jan 15, 2021
5ce4c9b
updates relative imports
dpaiton Jan 28, 2021
e1b32b7
new util for extracting image patches
dpaiton Feb 4, 2021
15128d9
adds util for converting a batch of images into a single tiled image
dpaiton Feb 4, 2021
db57cd1
renames variables for clarity; docs additions; adds num_validation
dpaiton Feb 8, 2021
b193860
adds normalization options and cifar10 dataset
dpaiton Feb 8, 2021
b45f3f7
working cifar10 conv lca params
dpaiton Feb 8, 2021
cdcecb1
more epochs for training, still not totally converged
dpaiton Feb 10, 2021
0203c3a
no need to specify conv in the params filenames
dpaiton Feb 10, 2021
dcdc5a5
combines conv and fc lca params
dpaiton Feb 10, 2021
2db240e
adds conv mlp; updates checkpointing; new params
dpaiton Feb 11, 2021
75a6a5e
integrates conv_lca model & module into lca
dpaiton Feb 11, 2021
17d9295
simplifies conv -> fc logic
dpaiton Feb 12, 2021
aef8b9f
bugfix in optimizer checkpoint loading for ensemble models
dpaiton Feb 22, 2021
a1515e0
adds manifold pooling layer
dpaiton Feb 22, 2021
079543f
pooling weight orthogonalization is now specified as a loss; bugfixes
dpaiton Feb 22, 2021
8f8260d
Merge branch 'smt' of https://github.com/dpaiton/DeepSparseCoding int…
dpaiton Feb 23, 2021
49259f5
not using smt_model file currently
dpaiton Feb 23, 2021
088d3f7
adds new logging features; bugfixes; pooling tests
dpaiton Feb 26, 2021
4bc2db8
adds tag to architecture logging for easy retrieval
dpaiton Mar 3, 2021
da4bf52
minor linting change
dpaiton Mar 3, 2021
9080984
updated trace loss to have minimum of 0
dpaiton Mar 3, 2021
9a8b0f2
full smt params; minor param bugfixes
dpaiton Mar 3, 2021
d4c7fb1
updates standardization preprocessing
dpaiton Mar 3, 2021
ed00ea1
notebook for visualizing SMT outptus & weights
dpaiton Mar 3, 2021
160ace4
updates to lca module for clarity & consistency
dpaiton Mar 4, 2021
4ee8bc2
adds reconstructions from each layer
dpaiton Mar 10, 2021
893f7ab
more appropriate name
dpaiton Mar 10, 2021
b9d0368
temp fix for learning_rate checkpoint bug
dpaiton Mar 10, 2021
a409210
minor commenting & new train progress outputs
dpaiton Mar 10, 2021
e4564a6
adds function to read architecture information
dpaiton Mar 10, 2021
ceb3577
adds funciton to compute deconvolutional output shape
dpaiton Mar 10, 2021
1342369
integrated hierarchical params into one file
dpaiton Mar 10, 2021
ff9e453
updates so that tests pass with latest pytorch
dpaiton Dec 3, 2021
4362b04
moved tf requirements out of main list
dpaiton Dec 3, 2021
a2dd266
moved tf requirements out of main list
dpaiton Dec 3, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 0 additions & 121 deletions adversarial_analysis.py

This file was deleted.

7 changes: 4 additions & 3 deletions datasets/synthetic.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import os
import sys
from os.path import dirname as up

ROOT_DIR = up(up(up(os.path.realpath(__file__))))
if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)

import numpy as np
from scipy.stats import norm
from PIL import Image
import torch
import torchvision

ROOT_DIR = os.path.dirname(os.getcwd())
if ROOT_DIR not in sys.path: sys.path.append(ROOT_DIR)

import DeepSparseCoding.utils.data_processing as dp

class SyntheticImages(torchvision.datasets.vision.VisionDataset):
Expand Down
126 changes: 105 additions & 21 deletions models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
import os
import subprocess
import pprint

import numpy as np
import torch

from DeepSparseCoding.utils.file_utils import summary_string
from DeepSparseCoding.utils.file_utils import Logger
from DeepSparseCoding.utils.run_utils import compute_conv_output_shape
import DeepSparseCoding.utils.loaders as loaders


class BaseModel(object):
def setup(self, params, logger=None):
"""
Setup required model components
#TODO: log system info, including git commit hash
"""
self.load_params(params)
self.check_params()
self.make_dirs()
if logger is None:
self.init_logging()
self.log_params()
self.logger.log_info(self.get_env_details())
else:
self.logger = logger

Expand Down Expand Up @@ -92,24 +97,106 @@ def log_params(self, params=None):
dump_obj = self.params.__dict__
self.logger.log_params(dump_obj)

def log_info(self, string):
"""Log input string"""
self.logger.log_info(string)
def get_train_stats(self, batch_step=None):
"""
Get default statistics about current training run

Keyword arguments:
batch_step: [int] current batch iteration. The default assumes that training has finished.
"""
if batch_step is None:
batch_step = self.params.num_batches
epoch = batch_step / self.params.batches_per_epoch
stat_dict = {
'epoch':int(epoch),
'batch_step':batch_step,
'train_progress':np.round(batch_step/self.params.num_batches, 3),
}
return stat_dict

def write_checkpoint(self):
"""Write checkpoints"""
torch.save(self.state_dict(), self.params.cp_latest_filename)
self.log_info('Full model saved in file %s'%self.params.cp_latest_filename)
def get_env_details(self):
env = {}
for k in ['SYSTEMROOT', 'PATH']:
v = os.environ.get(k)
if v is not None:
env[k] = v
commit_cmd = ['git', 'rev-parse', 'HEAD']
commit = subprocess.Popen(commit_cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
commit = commit.strip().decode('ascii')
branch_cmd = ['git', 'rev-parse', '--abbrev-ref', 'HEAD']
branch = subprocess.Popen(branch_cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
branch = branch.strip().decode('ascii')
system_details = os.uname()
out_dict = {
'current_branch':branch,
'current_commit_hash':commit,
'sysname':system_details.sysname,
'release':system_details.release,
'machine':system_details.machine
}
if torch.cuda.is_available():
out_dict['gpu_device'] = torch.cuda.get_device_name(0)
return out_dict

def load_checkpoint(self, cp_file=None):
def log_architecture_details(self):
"""
Log model architecture with computed output sizes and number of parameters for each layer
"""
architecture_string = '<architecture>\n'+summary_string(
self,
input_size=tuple(self.params.data_shape),
batch_size=self.params.batch_size,
device=self.params.device,
dtype=torch.FloatTensor
)[0]
architecture_string += '\n</architecture>'
self.logger.log_string(architecture_string)

def write_checkpoint(self, batch_step=None):
"""
Write checkpoints

Keyword arguments:
batch_step: [int] current batch iteration. The default assumes that training has finished.
"""
output_dict = {}
if(self.params.model_type.lower() == 'ensemble'):
for module in self:
module_name = module.params.submodule_name
output_dict[module_name+'_module_state_dict'] = module.state_dict()
output_dict[module_name+'_optimizer_state_dict'] = module.optimizer.state_dict()
else:
output_dict['model_state_dict'] = self.state_dict()
module_state_dict_name = 'optimizer_state_dict'
output_dict[module_state_dict_name] = self.optimizer.state_dict(),
## TODO: Save scheduler state dict as well
training_stats = self.get_train_stats(batch_step)
output_dict.update(training_stats)
torch.save(output_dict, self.params.cp_latest_filename)
self.logger.log_string('Full model saved in file %s'%self.params.cp_latest_filename)

def get_checkpoint_from_log(self, logfile):
model_params = loaders.load_params_from_log(logfile)
checkpoint = torch.load(model_params.cp_latest_filename)
return checkpoint

def load_checkpoint(self, cp_file=None, load_optimizer=False):
"""
Load checkpoint
Inputs:
model_dir: String specifying the path to the checkpoint
Keyword arguments:
model_dir: [str] specifying the path to the checkpoint
"""
if cp_file is None:
cp_file = self.params.cp_latest_filename
return self.load_state_dict(torch.load(cp_file))
checkpoint = torch.load(cp_file)
if load_optimizer:
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.load_state_dict(checkpoint['model_state_dict'])
_ = checkpoint.pop('optimizer_state_dict', None)
_ = checkpoint.pop('model_state_dict', None)
training_status = pprint.pformat(checkpoint, compact=True)#, sort_dicts=True #TODO: Python 3.8 adds the sort_dicts parameter
out_str = f'Loaded checkpoint from {cp_file} with the following stats:\n{training_status}'
return out_str

def get_optimizer(self, optimizer_params, trainable_variables):
optimizer_name = optimizer_params.optimizer.name
Expand All @@ -129,8 +216,8 @@ def get_optimizer(self, optimizer_params, trainable_variables):

def setup_optimizer(self):
self.optimizer = self.get_optimizer(
optimizer_params=self.params,
trainable_variables=self.parameters())
optimizer_params=self.params,
trainable_variables=self.parameters())
self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
self.optimizer,
milestones=self.params.optimizer.milestones,
Expand All @@ -143,21 +230,18 @@ def print_update(self, input_data, input_labels=None, batch_step=0):
input_data: data object containing the current image batch
input_labels: data object containing the current label batch
batch_step: current batch number within the schedule
NOTE: For the analysis code to parse update statistics, the self.js_dumpstring() call
must receive a dict object. Additionally, the self.js_dumpstring() output must be
logged with <stats> </stats> tags.
For example: logging.info('<stats>'+self.js_dumpstring(output_dictionary)+'</stats>')
NOTE: For the analysis code to parse update statistics,
the logger.log_stats() function must be used
"""
update_dict = self.generate_update_dict(input_data, input_labels, batch_step)
js_str = self.js_dumpstring(update_dict)
self.log_info('<stats>'+js_str+'</stats>')
self.logger.log_stats(update_dict)

def generate_update_dict(self, input_data, input_labels=None, batch_step=0, update_dict=None):
"""
Generates a dictionary to be logged in the print_update function
"""
if update_dict is None:
update_dict = dict()
update_dict = self.get_train_stats(batch_step)
for param_name, param_var in self.named_parameters():
grad = param_var.grad
update_dict[param_name+'_grad_max_mean_min'] = [
Expand Down
Loading