diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..9045e9e --- /dev/null +++ b/.flake8 @@ -0,0 +1,18 @@ +[flake8] +max-line-length = 99 +ignore = + W504, + W503, + W605, # invalid escape sequence '\ ' + E266, + E402, # module level import not at top of file + E226, # missing whitespace around arithmetic operator +exclude = + .git, + __pycache__, + __init__.py, + build, + dist, + docs/* + example/* + scratch/* diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..f5a783b --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,79 @@ +# How to contribute + +If you're interested in contributing to the behavenet package, please contact the project +developer Matt Whiteway at m.whiteway ( at ) columbia.edu. + +If you would like to add a new Pytorch model to the package, you can find more detailed information +[here](behavenet/models/README.md). + +Before submitting a pull request, please follow these steps: + +## Style + +The behavenet package follows the PEP8 style guidelines, and allows for line lengths of up to 99 +characters. To ensure that your code matches these guidelines, please flake your code using the +provided configuration file. You will need to first install flake8 in the behavenet conda +environment: + +```bash +(behavenet) $: pip install flake8 +``` + +Once all code, tests, and documentation are in place, you can run the flaker from from the project +directory: + +```bash +(behavenet) $: flake8 +``` + +## Documentation + +Behavenet uses Sphinx and readthedocs to provide documentation to developers and users. + +* complete all docstrings in new functions using google's style (see source code for examples) +* provide inline comments when necessary; the more the merrier +* add a new user guide if necessary (`docs/source/user_guide.[new_model].rst`) +* update data structure docs if adding to hdf5 (`docs/source/data_structure.rst`) +* add new hyperparams to glossary (`docs/source/glossary.rst`) + +To check the documentation, you can compile it on your local machine first. To do so you will need +to first install sphinx in the behavenet conda environment: + +```bash +(behavenet) $: pip install sphinx==3.2.0 sphinx_rtd_theme==0.4.3 sphinx-automodapi==0.12 +``` + +To compile the documentation, from the behavenet project directory cd to `behavenet/docs` and run +the make file: + +```bash +(behavenet) $: cd docs +(behavenet) $: make html +``` + +## Testing + +Behavenet uses pytest to unit test the package; in addition, there is an integration script +provided to ensure the interlocking pieces play nicely. Please write unit tests for all new +(non-plotting) functions, and if you updated any existing functions please update the corresponding +unit tests. + +To run the unit tests, first install pytest in the behavenet conda environment: + +```bash +(behavenet) $: pip install pytest +``` + +Then, from the project directory, run: + +```bash +(behavenet) $: pytest +``` + +To run the integration script: + +```bash +(behavenet) $: python tests/integration.py +``` + +Running the integration test will take approximately 1 minute with a GPU. diff --git a/README.md b/README.md index 26f0d33..41e886d 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,5 @@ # BehaveNet -NOTE: This is a beta version, we will release the first stable version by early February. - BehaveNet is a probabilistic framework for the analysis of behavioral video and neural activity. This framework provides tools for compression, segmentation, generation, and decoding of behavioral videos. Please see the @@ -12,7 +10,7 @@ for more information about how to install the software and begin fitting models Additionally, we provide an example dataset and several jupyter notebooks that walk you through how to download the dataset, fit models, and analyze the results. The jupyter notebooks can be found -[here](example). +[here](examples). ## Bibtex diff --git a/behavenet/data/__init__.py b/behavenet/data/__init__.py index 6aedd20..fcc2edc 100644 --- a/behavenet/data/__init__.py +++ b/behavenet/data/__init__.py @@ -1 +1 @@ -"""Test string""" +"""Data module""" diff --git a/behavenet/data/data_generator.py b/behavenet/data/data_generator.py index d593894..0cf3e00 100644 --- a/behavenet/data/data_generator.py +++ b/behavenet/data/data_generator.py @@ -192,14 +192,14 @@ def __init__( self.n_trials = None for i, signal in enumerate(signals): if signal == 'images' or signal == 'neural' or signal == 'labels' or \ - signal == 'labels_sc': + signal == 'labels_sc' or signal == 'labels_masks': data_file = paths[i] with h5py.File(data_file, 'r', libver='latest', swmr=True) as f: self.n_trials = len(f[signal]) break elif signal == 'ae_latents': try: - latents = _load_pkl_dict(self.paths[signal], 'latents') #[0] + latents = _load_pkl_dict(self.paths[signal], 'latents') except FileNotFoundError: raise NotImplementedError( ('Could not open %s\nMust create ae latents from model;' + @@ -274,7 +274,8 @@ def __getitem__(self, idx): else: sample[signal] = f[signal][str('trial_%04i' % idx)][()].astype(dtype) - elif signal == 'neural' or signal == 'labels' or signal == 'labels_sc': + elif signal == 'neural' or signal == 'labels' or signal == 'labels_sc' \ + or signal == 'labels_masks': dtype = 'float32' with h5py.File(self.paths[signal], 'r', libver='latest', swmr=True) as f: if idx is None: @@ -286,25 +287,21 @@ def __getitem__(self, idx): else: sample[signal] = [f[signal][str('trial_%04i' % idx)][()].astype(dtype)] - elif signal == 'ae_latents': + elif signal == 'ae_latents' or signal == 'latents': dtype = 'float32' - sample[signal] = self._try_to_load( - signal, key='latents', idx=idx, dtype=dtype) + sample[signal] = self._try_to_load(signal, key='latents', idx=idx, dtype=dtype) elif signal == 'ae_predictions': dtype = 'float32' - sample[signal] = self._try_to_load( - signal, key='predictions', idx=idx, dtype=dtype) + sample[signal] = self._try_to_load(signal, key='predictions', idx=idx, dtype=dtype) elif signal == 'arhmm' or signal == 'arhmm_states': dtype = 'int32' - sample[signal] = self._try_to_load( - signal, key='states', idx=idx, dtype=dtype) + sample[signal] = self._try_to_load(signal, key='states', idx=idx, dtype=dtype) elif signal == 'arhmm_predictions': dtype = 'float32' - sample[signal] = self._try_to_load( - signal, key='predictions', idx=idx, dtype=dtype) + sample[signal] = self._try_to_load(signal, key='predictions', idx=idx, dtype=dtype) else: raise ValueError('"%s" is an invalid signal type' % signal) @@ -626,7 +623,7 @@ def next_batch(self, dtype): if self.as_numpy: for i, signal in enumerate(sample): - if signal is not 'batch_idx': + if signal != 'batch_idx': sample[signal] = [ss.cpu().detach().numpy() for ss in sample[signal]] else: if self.device == 'cuda': diff --git a/behavenet/data/transforms.py b/behavenet/data/transforms.py index 8e6b537..5fcb8cc 100644 --- a/behavenet/data/transforms.py +++ b/behavenet/data/transforms.py @@ -55,161 +55,95 @@ def __repr__(self): raise NotImplementedError -class ClipNormalize(Transform): - """Clip upper level of signal and divide by clip value.""" +class BlockShuffle(Transform): + """Shuffle blocks of contiguous discrete states within each trial.""" - def __init__(self, clip_val): + def __init__(self, rng_seed): """ Parameters ---------- - clip_val : :obj:`float` - signal values above this will be set to this value, then divided by this value so that - signal maximum is 1 + rng_seed : :obj:`int` + to control random number generator """ - if clip_val <= 0: - raise ValueError('clip value must be positive') - self.clip_val = clip_val + self.rng_seed = rng_seed - def __call__(self, signal): + def __call__(self, sample): """ Parameters ---------- - signal : :obj:`np.ndarray` + sample : :obj:`np.ndarray` + dense representation of shape (time) Returns ------- :obj:`np.ndarray` + output shape is (time) """ - signal = np.minimum(signal, self.clip_val) - signal = signal / self.clip_val - return signal - - def __repr__(self): - return str('ClipNormalize(clip_val=%f)' % self.clip_val) - - -# class Resize(Transform): -# """Resize the sample images.""" -# -# def __init__(self, size=(128, 128), order=1): -# """ -# -# Parameters -# ---------- -# size : :obj:`int` or :obj:`tuple` -# desired output size for each image; if type is :obj:`int`, the same value is used for -# both height and width -# order : :obj:`int` -# interpolation order -# -# """ -# assert isinstance(size, (tuple, int)) -# self.order = order -# if isinstance(size, tuple): -# self.x = size[0] -# self.y = size[1] -# else: -# self.x = self.y = size -# -# def __call__(self, sample): -# """ -# -# Parameters -# ---------- -# sample: :obj:`np.ndarray` -# input shape is (trial, time, n_channels) -# -# Returns -# ------- -# :obj:`np.ndarray` -# output shape is (trial, time, n_channels) -# -# """ -# sh = sample.shape -# sample = transform.resize(sample, (sh[0], sh[1], self.y, self.x), order=self.order) -# return sample -# -# def __repr__(self): -# return str('Resize(size=(%i, %i))' % (self.y, self.x)) + np.random.seed(self.rng_seed) + n_time = len(sample) + if not any(np.isnan(sample)): + # mark first time point of state change with a nonzero number + state_change = np.where(np.concatenate([[0], np.diff(sample)], axis=0) != 0)[0] + # collect runs + runs = [] + prev_beg = 0 + for curr_beg in state_change: + runs.append(np.arange(prev_beg, curr_beg)) + prev_beg = curr_beg + runs.append(np.arange(prev_beg, n_time)) + # shuffle runs + rand_perm = np.random.permutation(len(runs)) + runs_shuff = [runs[idx] for idx in rand_perm] + # index back into original labels with shuffled indices + sample_shuff = sample[np.concatenate(runs_shuff)] + else: + sample_shuff = np.full(n_time, fill_value=np.nan) + return sample_shuff -class Threshold(Transform): - """Remove channels of neural activity whose mean value is below a threshold.""" + def __repr__(self): + return str('BlockShuffle(rng_seed=%i)' % self.rng_seed) - def __init__(self, threshold, bin_size): - """ - Parameters - ---------- - threshold : :obj:`float` - threshold in Hz - bin_size : :obj:`float` - bin size of neural activity in ms +class ClipNormalize(Transform): + """Clip upper level of signal and divide by clip value.""" + def __init__(self, clip_val): """ - if bin_size <= 0: - raise ValueError('bin size must be positive') - if threshold < 0: - raise ValueError('threshold must be non-negative') - - self.threshold = threshold - self.bin_size = bin_size - - def __call__(self, sample): - """Calculates firing rate over all time points and thresholds. Parameters ---------- - sample: :obj:`np.ndarray` - input shape is (time, n_channels) - - Returns - ------- - :obj:`np.ndarray` - output shape is (time, n_channels) + clip_val : :obj:`float` + signal values above this will be set to this value, then divided by this value so that + signal maximum is 1 """ - # get firing rates - frs = np.squeeze(np.mean(sample, axis=0)) / (self.bin_size * 1e-3) - fr_mask = frs > self.threshold - # get rid of neurons below fr threshold - sample = sample[:, fr_mask] - return sample.astype(np.float) - - def __repr__(self): - return str('Threshold(threshold=%f, bin_size=%f)' % (self.threshold, self.bin_size)) - - -class ZScore(Transform): - """z-score channel activity.""" - - def __init__(self): - pass + if clip_val <= 0: + raise ValueError('clip value must be positive') + self.clip_val = clip_val - def __call__(self, sample): + def __call__(self, signal): """ Parameters ---------- - sample : :obj:`np.ndarray` - input shape is (time, n_channels) + signal : :obj:`np.ndarray` Returns ------- :obj:`np.ndarray` - output shape is (time, n_channels) """ - sample -= np.mean(sample, axis=0) - sample /= np.std(sample, axis=0) - return sample + signal = np.minimum(signal, self.clip_val) + signal = signal / self.clip_val + return signal def __repr__(self): - return 'ZScore()' + return str('ClipNormalize(clip_val=%f)' % self.clip_val) class MakeOneHot(Transform): @@ -295,36 +229,30 @@ def __call__(self, sample): labels_2d = np.zeros((time, n_labels, self.y_pixels, self.x_pixels)) x_vals = sample[:, :n_labels] + x_vals[np.isnan(x_vals)] = -1 # set nans to 0 x_vals[x_vals > self.x_pixels - 1] = self.x_pixels - 1 x_vals[x_vals < 0] = 0 x_vals = np.round(x_vals).astype(np.int) y_vals = sample[:, n_labels:] + y_vals[np.isnan(y_vals)] = -1 # set nans to 0 y_vals[y_vals > self.y_pixels - 1] = self.y_pixels - 1 y_vals[y_vals < 0] = 0 y_vals = np.round(y_vals).astype(np.int) - for l in range(n_labels): - labels_2d[np.arange(time), l, y_vals[:, l], x_vals[:, l]] = 1 + for n in range(n_labels): + labels_2d[np.arange(time), n, y_vals[:, n], x_vals[:, n]] = 1 return labels_2d def __repr__(self): return str('MakeOneHot2D(y_pixels=%i, x_pixels=%i)' % (self.y_pixels, self.x_pixels)) -class BlockShuffle(Transform): - """Shuffle blocks of contiguous discrete states within each trial.""" - - def __init__(self, rng_seed): - """ - - Parameters - ---------- - rng_seed : :obj:`int` - to control random number generator +class MotionEnergy(Transform): + """Compute motion energy across batch dimension.""" - """ - self.rng_seed = rng_seed + def __init__(self): + pass def __call__(self, sample): """ @@ -332,38 +260,18 @@ def __call__(self, sample): Parameters ---------- sample : :obj:`np.ndarray` - dense representation of shape (time) + input shape is (time, n_channels) Returns ------- :obj:`np.ndarray` - output shape is (time) + output shape is (time, n_channels) """ - - np.random.seed(self.rng_seed) - n_time = len(sample) - if not any(np.isnan(sample)): - # mark first time point of state change with a nonzero number - state_change = np.where(np.concatenate([[0], np.diff(sample)], axis=0) != 0)[0] - # collect runs - runs = [] - prev_beg = 0 - for curr_beg in state_change: - runs.append(np.arange(prev_beg, curr_beg)) - prev_beg = curr_beg - runs.append(np.arange(prev_beg, n_time)) - # shuffle runs - rand_perm = np.random.permutation(len(runs)) - runs_shuff = [runs[idx] for idx in rand_perm] - # index back into original labels with shuffled indices - sample_shuff = sample[np.concatenate(runs_shuff)] - else: - sample_shuff = np.full(n_time, fill_value=np.nan) - return sample_shuff + return np.vstack([np.zeros((1, sample.shape[1])), np.abs(np.diff(sample, axis=0))]) def __repr__(self): - return str('BlockShuffle(rng_seed=%i)' % self.rng_seed) + return 'MotionEnergy()' class SelectIdxs(Transform): @@ -400,3 +308,123 @@ def __call__(self, sample): def __repr__(self): return str('SelectIndxs(idxs=idxs, sample_name=%s)' % self.sample_name) + + +class Threshold(Transform): + """Remove channels of neural activity whose mean value is below a threshold.""" + + def __init__(self, threshold, bin_size): + """ + + Parameters + ---------- + threshold : :obj:`float` + threshold in Hz + bin_size : :obj:`float` + bin size of neural activity in ms + + """ + if bin_size <= 0: + raise ValueError('bin size must be positive') + if threshold < 0: + raise ValueError('threshold must be non-negative') + + self.threshold = threshold + self.bin_size = bin_size + + def __call__(self, sample): + """Calculates firing rate over all time points and thresholds. + + Parameters + ---------- + sample: :obj:`np.ndarray` + input shape is (time, n_channels) + + Returns + ------- + :obj:`np.ndarray` + output shape is (time, n_channels) + + """ + # get firing rates + frs = np.squeeze(np.mean(sample, axis=0)) / (self.bin_size * 1e-3) + fr_mask = frs > self.threshold + # get rid of neurons below fr threshold + sample = sample[:, fr_mask] + return sample.astype(np.float) + + def __repr__(self): + return str('Threshold(threshold=%f, bin_size=%f)' % (self.threshold, self.bin_size)) + + +class ZScore(Transform): + """z-score channel activity.""" + + def __init__(self): + pass + + def __call__(self, sample): + """ + + Parameters + ---------- + sample : :obj:`np.ndarray` + input shape is (time, n_channels) + + Returns + ------- + :obj:`np.ndarray` + output shape is (time, n_channels) + + """ + sample -= np.mean(sample, axis=0) + sample /= np.std(sample, axis=0) + return sample + + def __repr__(self): + return 'ZScore()' + + +# class Resize(Transform): +# """Resize the sample images.""" +# +# def __init__(self, size=(128, 128), order=1): +# """ +# +# Parameters +# ---------- +# size : :obj:`int` or :obj:`tuple` +# desired output size for each image; if type is :obj:`int`, the same value is used for +# both height and width +# order : :obj:`int` +# interpolation order +# +# """ +# assert isinstance(size, (tuple, int)) +# self.order = order +# if isinstance(size, tuple): +# self.x = size[0] +# self.y = size[1] +# else: +# self.x = self.y = size +# +# def __call__(self, sample): +# """ +# +# Parameters +# ---------- +# sample: :obj:`np.ndarray` +# input shape is (trial, time, n_channels) +# +# Returns +# ------- +# :obj:`np.ndarray` +# output shape is (trial, time, n_channels) +# +# """ +# sh = sample.shape +# sample = transform.resize(sample, (sh[0], sh[1], self.y, self.x), order=self.order) +# return sample +# +# def __repr__(self): +# return str('Resize(size=(%i, %i))' % (self.y, self.x)) diff --git a/behavenet/data/utils.py b/behavenet/data/utils.py index cc00248..ef90732 100644 --- a/behavenet/data/utils.py +++ b/behavenet/data/utils.py @@ -4,6 +4,13 @@ import numpy as np import pickle +from behavenet.fitting.utils import export_session_info_to_csv + +# to ignore imports for sphinx-autoapidoc +__all__ = [ + 'get_data_generator_inputs', 'build_data_generator', 'check_same_training_split', + 'get_transforms_paths', 'load_labels_like_latents', 'get_region_list'] + def get_data_generator_inputs(hparams, sess_ids, check_splits=True): """Helper function for generating signals, transforms and paths. @@ -66,7 +73,7 @@ def get_data_generator_inputs(hparams, sess_ids, check_splits=True): elif hparams['model_class'] == 'cond-ae' \ or hparams['model_class'] == 'cond-ae-msp' \ or hparams['model_class'] == 'cond-vae' \ - or hparams['model_class'] == 'sss-vae': + or hparams['model_class'] == 'ps-vae': signals = ['images', 'labels'] transforms = [None, None] @@ -75,6 +82,12 @@ def get_data_generator_inputs(hparams, sess_ids, check_splits=True): signals.append('masks') transforms.append(None) paths.append(os.path.join(data_dir, 'data.hdf5')) + if hparams.get('use_label_mask', False) and ( + hparams['model_class'] == 'cond-ae-msp' + or hparams['model_class'] == 'ps-vae'): + signals.append('labels_masks') + transforms.append(None) + paths.append(os.path.join(data_dir, 'data.hdf5')) if hparams.get('conditional_encoder', False): from behavenet.data.transforms import MakeOneHot2D signals.append('labels_sc') @@ -107,6 +120,23 @@ def get_data_generator_inputs(hparams, sess_ids, check_splits=True): transforms = [neural_transform, ae_transform] paths = [neural_path, ae_path] + elif hparams['model_class'] == 'neural-ae-me': + + hparams['input_signal'] = 'neural' + hparams['output_signal'] = 'ae_latents' + hparams['output_size'] = hparams['n_ae_latents'] + if hparams['model_type'][-2:] == 'mv': + hparams['noise_dist'] = 'gaussian-full' + else: + hparams['noise_dist'] = 'gaussian' + + ae_transform, ae_path = get_transforms_paths( + 'ae_latents_me', hparams, sess_id=sess_id, check_splits=check_splits) + + signals = ['neural', 'ae_latents'] + transforms = [neural_transform, ae_transform] + paths = [neural_path, ae_path] + elif hparams['model_class'] == 'ae-neural': hparams['input_signal'] = 'ae_latents' @@ -127,6 +157,37 @@ def get_data_generator_inputs(hparams, sess_ids, check_splits=True): transforms = [neural_transform, ae_transform] paths = [neural_path, ae_path] + elif hparams['model_class'] == 'neural-labels': + + hparams['input_signal'] = 'neural' + hparams['output_signal'] = 'labels' + hparams['output_size'] = hparams['n_labels'] + if hparams['model_type'][-2:] == 'mv': + hparams['noise_dist'] = 'gaussian-full' + else: + hparams['noise_dist'] = 'gaussian' + + signals = ['neural', 'labels'] + transforms = [neural_transform, None] + paths = [neural_path, os.path.join(data_dir, 'data.hdf5')] + + elif hparams['model_class'] == 'labels-neural': + + hparams['input_signal'] = 'labels' + hparams['output_signal'] = 'neural' + hparams['output_size'] = None # to fill in after data is loaded + if hparams['neural_type'] == 'ca': + if hparams['model_type'][-2:] == 'mv': + hparams['noise_dist'] = 'gaussian-full' + else: + hparams['noise_dist'] = 'gaussian' + elif hparams['neural_type'] == 'spikes': + hparams['noise_dist'] = 'poisson' + + signals = ['neural', 'labels'] + transforms = [neural_transform, None] + paths = [neural_path, os.path.join(data_dir, 'data.hdf5')] + elif hparams['model_class'] == 'neural-arhmm': hparams['input_signal'] = 'neural' @@ -253,6 +314,16 @@ def get_data_generator_inputs(hparams, sess_ids, check_splits=True): signals = [hparams['model_class']] transforms = [None] paths = [os.path.join(data_dir, 'data.hdf5')] + if hparams.get('use_label_mask', False): + signals.append('labels_masks') + transforms.append(None) + paths.append(os.path.join(data_dir, 'data.hdf5')) + + elif hparams['model_class'] == 'labels_masks': + + signals = [hparams['model_class']] + transforms = [None] + paths = [os.path.join(data_dir, 'data.hdf5')] else: raise ValueError('"%s" is an invalid model_class' % hparams['model_class']) @@ -264,16 +335,64 @@ def get_data_generator_inputs(hparams, sess_ids, check_splits=True): return hparams, signals_list, transforms_list, paths_list +def build_data_generator(hparams, sess_ids, export_csv=True): + """Helper function to build data generator from hparams dict. + + Parameters + ---------- + hparams : :obj:`dict` + needs to contain information specifying data inputs to model + sess_ids : :obj:`list` of :obj:`dict` + each entry is a session dict with keys 'lab', 'expt', 'animal', 'session' + export_csv : :obj:`bool`, optional + export csv file containing session info (useful when fitting multi-sessions) + + Returns + ------- + :obj:`ConcatSessionsGenerator` object + data generator + + """ + from behavenet.data.data_generator import ConcatSessionsGenerator + print('using data from following sessions:') + for ids in sess_ids: + print('%s' % os.path.join( + hparams['save_dir'], ids['lab'], ids['expt'], ids['animal'], ids['session'])) + hparams, signals, transforms, paths = get_data_generator_inputs(hparams, sess_ids) + if hparams.get('trial_splits', None) is not None: + # assumes string of form 'train;val;test;gap' + trs = [int(tr) for tr in hparams['trial_splits'].split(';')] + trial_splits = {'train_tr': trs[0], 'val_tr': trs[1], 'test_tr': trs[2], 'gap_tr': trs[3]} + else: + trial_splits = None + print('constructing data generator...', end='') + data_generator = ConcatSessionsGenerator( + hparams['data_dir'], sess_ids, + signals_list=signals, transforms_list=transforms, paths_list=paths, + device=hparams['device'], as_numpy=hparams['as_numpy'], batch_load=hparams['batch_load'], + rng_seed=hparams['rng_seed_data'], trial_splits=trial_splits, + train_frac=hparams['train_frac']) + # csv order will reflect dataset order in data generator + if export_csv: + export_session_info_to_csv(os.path.join( + hparams['expt_dir'], str('version_%i' % hparams['version'])), sess_ids) + print('done') + print(data_generator) + return data_generator + + def check_same_training_split(model_path, hparams): """Ensure data rng seed and trial splits are same for two models.""" import_params_file = os.path.join(os.path.dirname(model_path), 'meta_tags.pkl') import_params = pickle.load(open(import_params_file, 'rb')) - if import_params['rng_seed_data'] != hparams['rng_seed_data']: + if import_params['rng_seed_data'] != hparams['rng_seed_data'] and \ + hparams.get('check_rng_seed_data', True): raise ValueError('Different data random seed from existing models') - if import_params['trial_splits'] != hparams['trial_splits']: + if import_params['trial_splits'] != hparams['trial_splits'] and \ + hparams.get('check_trial_splits', True): raise ValueError('Different trial split from existing models') @@ -287,10 +406,19 @@ def get_transforms_paths(data_type, hparams, sess_id, check_splits=True): 'neural_arhmm_predictions' hparams : :obj:`dict` - required keys for :obj:`data_type=neural`: 'neural_type', 'neural_thresh' - - required keys for :obj:`data_type=ae_latents`: 'ae_experiment_name', 'ae_model_type', 'n_ae_latents', 'ae_version' or 'ae_latents_file'; this last option defines either the specific ae version (as 'best' or an int) or a path to a specific ae latents pickle file. - - required keys for :obj:`data_type=arhmm_states`: 'arhmm_experiment_name', 'n_arhmm_states', 'kappa', 'noise_type', 'n_ae_latents', 'arhmm_version' or 'arhmm_states_file'; this last option defines either the specific arhmm version (as 'best' or an int) or a path to a specific ae latents pickle file. - - required keys for :obj:`data_type=neural_ae_predictions`: 'neural_ae_experiment_name', 'neural_ae_model_type', 'neural_ae_version' or 'ae_predictions_file' plus keys for neural and ae_latents data types. - - required keys for :obj:`data_type=neural_arhmm_predictions`: 'neural_arhmm_experiment_name', 'neural_arhmm_model_type', 'neural_arhmm_version' or 'arhmm_predictions_file', plus keys for neural and arhmm_states data types. + - required keys for :obj:`data_type=ae_latents`: 'ae_experiment_name', 'ae_model_type', + 'n_ae_latents', 'ae_version' or 'ae_latents_file'; this last option defines either the + specific ae version (as 'best' or an int) or a path to a specific ae latents pickle file. + - required keys for :obj:`data_type=arhmm_states`: 'arhmm_experiment_name', + 'n_arhmm_states', 'kappa', 'noise_type', 'n_ae_latents', 'arhmm_version' or + 'arhmm_states_file'; this last option defines either the specific arhmm version (as + 'best' or an int) or a path to a specific ae latents pickle file. + - required keys for :obj:`data_type=neural_ae_predictions`: 'neural_ae_experiment_name', + 'neural_ae_model_type', 'neural_ae_version' or 'ae_predictions_file' plus keys for neural + and ae_latents data types. + - required keys for :obj:`data_type=neural_arhmm_predictions`: + 'neural_arhmm_experiment_name', 'neural_arhmm_model_type', 'neural_arhmm_version' or + 'arhmm_predictions_file', plus keys for neural and arhmm_states data types. sess_id : :obj:`dict` each list entry is a session-specific dict with keys 'lab', 'expt', 'animal', 'session' check_splits : :obj:`bool`, optional @@ -304,11 +432,12 @@ def get_transforms_paths(data_type, hparams, sess_id, check_splits=True): """ + from behavenet.data.transforms import BlockShuffle + from behavenet.data.transforms import Compose + from behavenet.data.transforms import MotionEnergy from behavenet.data.transforms import SelectIdxs from behavenet.data.transforms import Threshold from behavenet.data.transforms import ZScore - from behavenet.data.transforms import BlockShuffle - from behavenet.data.transforms import Compose from behavenet.fitting.utils import get_best_model_version from behavenet.fitting.utils import get_expt_dir @@ -358,6 +487,8 @@ def get_transforms_paths(data_type, hparams, sess_id, check_splits=True): if hparams['model_type'][-6:] != 'neural': # don't zscore if predicting calcium activity transforms_.append(ZScore()) + elif hparams['neural_type'] == 'ca-zscored': + pass else: raise ValueError('"%s" is an invalid neural type' % hparams['neural_type']) @@ -367,18 +498,25 @@ def get_transforms_paths(data_type, hparams, sess_id, check_splits=True): else: transform = Compose(transforms_) - elif data_type == 'ae_latents' or data_type == 'latents': + elif data_type == 'ae_latents' or data_type == 'latents' \ + or data_type == 'ae_latents_me' or data_type == 'latents_me': - transform = None + if data_type == 'ae_latents_me' or data_type == 'latents_me': + transform = MotionEnergy() + else: + transform = None if 'ae_latents_file' in hparams: path = hparams['ae_latents_file'] else: ae_dir = get_expt_dir( - hparams, model_class='ae', + hparams, model_class=hparams['ae_model_class'], expt_name=hparams['ae_experiment_name'], model_type=hparams['ae_model_type']) - if 'ae_version' in hparams and isinstance(hparams['ae_version'], int): + if 'ae_version' in hparams and hparams['ae_version'] != 'best': + # json args read as strings + if isinstance(hparams['ae_version'], str): + hparams['ae_version'] = int(hparams['ae_version']) ae_version = str('version_%i' % hparams['ae_version']) else: ae_version = 'version_%i' % get_best_model_version(ae_dir, 'val_loss')[0] @@ -476,7 +614,6 @@ def load_labels_like_latents(hparams, sess_ids, sess_idx, data_key='labels'): """ import copy - from behavenet.fitting.utils import build_data_generator hparams_new = copy.deepcopy(hparams) hparams_new['model_class'] = data_key diff --git a/behavenet/fitting/ae_grid_search.py b/behavenet/fitting/ae_grid_search.py index c08ce5c..1d321b2 100644 --- a/behavenet/fitting/ae_grid_search.py +++ b/behavenet/fitting/ae_grid_search.py @@ -5,13 +5,13 @@ import torch import math +from behavenet.data.utils import build_data_generator from behavenet.fitting.eval import export_train_plots from behavenet.fitting.hyperparam_utils import get_all_params from behavenet.fitting.hyperparam_utils import get_slurm_params from behavenet.fitting.training import fit from behavenet.fitting.utils import _clean_tt_dir from behavenet.fitting.utils import _print_hparams -from behavenet.fitting.utils import build_data_generator from behavenet.fitting.utils import create_tt_experiment from behavenet.fitting.utils import export_hparams from behavenet.models.aes import load_pretrained_ae @@ -65,8 +65,8 @@ def set_n_labels(data_generator, hparams): from behavenet.models import VAE as Model elif hparams['model_class'] == 'beta-tcvae': from behavenet.models import BetaTCVAE as Model - elif hparams['model_class'] == 'sss-vae': - from behavenet.models import SSSVAE as Model + elif hparams['model_class'] == 'ps-vae': + from behavenet.models import PSVAE as Model set_n_labels(data_generator, hparams) elif hparams['model_class'] == 'cond-vae': from behavenet.models import ConditionalVAE as Model @@ -165,5 +165,3 @@ def set_n_labels(data_generator, hparams): main, nb_trials=hyperparams.tt_n_cpu_trials, nb_workers=hyperparams.tt_n_cpu_workers) - - diff --git a/behavenet/fitting/arhmm_grid_search.py b/behavenet/fitting/arhmm_grid_search.py index 9ce1334..4345e8a 100644 --- a/behavenet/fitting/arhmm_grid_search.py +++ b/behavenet/fitting/arhmm_grid_search.py @@ -5,13 +5,13 @@ import ssm import pickle +from behavenet.data.utils import build_data_generator from behavenet.fitting.eval import export_states from behavenet.fitting.eval import export_train_plots from behavenet.fitting.hyperparam_utils import get_all_params from behavenet.fitting.hyperparam_utils import get_slurm_params from behavenet.fitting.utils import _clean_tt_dir from behavenet.fitting.utils import _print_hparams -from behavenet.fitting.utils import build_data_generator from behavenet.fitting.utils import create_tt_experiment from behavenet.fitting.utils import export_hparams from behavenet.plotting.arhmm_utils import get_latent_arrays_by_dtype @@ -59,7 +59,7 @@ def main(hparams): data_generator, sess_idxs=list(range(n_datasets)), data_key=data_key) obs_dim = latents['train'][0].shape[1] - hparams['total_train_length'] = np.sum([l.shape[0] for l in latents['train']]) + hparams['total_train_length'] = np.sum([z.shape[0] for z in latents['train']]) # get separated by dataset as well latents_sess = {d: None for d in range(n_datasets)} trial_idxs_sess = {d: None for d in range(n_datasets)} @@ -206,7 +206,7 @@ def main(hparams): # save model filepath = os.path.join(hparams['expt_dir'], 'version_%i' % exp.version, 'best_val_model.pt') with open(filepath, 'wb') as f: - pickle.dump(hmm, f) + pickle.dump(hmm, f) # ###################### # ### EVALUATE ARHMM ### diff --git a/behavenet/fitting/decoder_grid_search.py b/behavenet/fitting/decoder_grid_search.py index 1222935..4fb4763 100644 --- a/behavenet/fitting/decoder_grid_search.py +++ b/behavenet/fitting/decoder_grid_search.py @@ -5,12 +5,12 @@ import torch import pickle +from behavenet.data.utils import build_data_generator from behavenet.fitting.hyperparam_utils import get_all_params from behavenet.fitting.hyperparam_utils import get_slurm_params from behavenet.fitting.training import fit from behavenet.fitting.utils import _clean_tt_dir from behavenet.fitting.utils import _print_hparams -from behavenet.fitting.utils import build_data_generator from behavenet.fitting.utils import create_tt_experiment from behavenet.fitting.utils import export_hparams from behavenet.models import Decoder @@ -50,13 +50,23 @@ def main(hparams, *args): elif hparams['model_class'] == 'neural-ae': hparams['input_size'] = data_generator.datasets[0][ex_trial][i_sig].shape[1] hparams['output_size'] = hparams['n_ae_latents'] + elif hparams['model_class'] == 'neural-ae-me': + hparams['input_size'] = data_generator.datasets[0][ex_trial][i_sig].shape[1] + hparams['output_size'] = hparams['n_ae_latents'] elif hparams['model_class'] == 'ae-neural': hparams['input_size'] = hparams['n_ae_latents'] hparams['output_size'] = data_generator.datasets[0][ex_trial][o_sig].shape[1] + elif hparams['model_class'] == 'neural-labels': + hparams['input_size'] = data_generator.datasets[0][ex_trial][i_sig].shape[1] + hparams['output_size'] = hparams['n_labels'] + elif hparams['model_class'] == 'labels-neural': + hparams['input_size'] = hparams['n_labels'] + hparams['output_size'] = data_generator.datasets[0][ex_trial][o_sig].shape[1] else: raise ValueError('%s is an invalid model class' % hparams['model_class']) - if hparams['model_class'] == 'neural-ae' or hparams['model_class'] == 'ae-neural': + if hparams['model_class'] == 'neural-ae' or hparams['model_class'] == 'neural-ae' \ + or hparams['model_class'] == 'ae-neural': hparams['ae_model_path'] = os.path.join( os.path.dirname(data_generator.datasets[0].paths['ae_latents'])) hparams['ae_model_latents_file'] = data_generator.datasets[0].paths['ae_latents'] @@ -66,13 +76,12 @@ def main(hparams, *args): hparams['arhmm_model_states_file'] = data_generator.datasets[0].paths['arhmm_states'] # Store which AE was used for the ARHMM - tags = pickle.load(open(hparams['arhmm_model_path'] + '/meta_tags.pkl', 'rb')) + tags = pickle.load(open(os.path.join(hparams['arhmm_model_path'], 'meta_tags.pkl'), 'rb')) hparams['ae_model_latents_file'] = tags['ae_model_latents_file'] # #################### # ### CREATE MODEL ### # #################### - print(hparams['input_size']) print('constructing model...', end='') torch.manual_seed(hparams['rng_seed_model']) torch_rng_seed = torch.get_rng_state() diff --git a/behavenet/fitting/eval.py b/behavenet/fitting/eval.py index 2999ea4..eae9cbf 100644 --- a/behavenet/fitting/eval.py +++ b/behavenet/fitting/eval.py @@ -68,7 +68,7 @@ def export_latents(data_generator, model, filename=None): else: y_in = y[idx_beg:idx_end] output = model.encoding(y_in, dataset=sess) - if model.hparams['model_class'] == 'sss-vae': + if model.hparams['model_class'] == 'ps-vae': curr_latents = torch.cat([output[0], output[1]], axis=1) else: curr_latents = output[0] @@ -84,7 +84,7 @@ def export_latents(data_generator, model, filename=None): else: y_in = y output = model.encoding(y_in, dataset=sess) - if model.hparams['model_class'] == 'sss-vae': + if model.hparams['model_class'] == 'ps-vae': curr_latents = torch.cat([output[0], output[1]], axis=1) else: curr_latents = output[0] @@ -158,7 +158,7 @@ def export_states(hparams, data_generator, model, filename=None): y = data['labels'][0][0] else: y = data['ae_latents'][0][0] - batch_size = y.shape[0] + # batch_size = y.shape[0] curr_states = model.most_likely_states(y) @@ -300,8 +300,8 @@ def get_reconstruction( labels_2d : :obj:`torch.Tensor` object or :obj:`NoneType`, optional label tensor of shape (batch, n_labels, y_pix, x_pix) apply_inverse_transform : :obj:`bool` - if inputs are latents (and model class is 'cond-ae-msp'), apply inverse transform to put in - original latent space + if inputs are latents (and model class is 'cond-ae-msp' or 'ps-vae'), apply inverse + transform to put in original latent space use_mean : :obj:`bool` if inputs are images (and model class is variational), use mean of approximate posterior without sampling @@ -315,9 +315,9 @@ def get_reconstruction( import torch model.eval() - + if not isinstance(inputs, torch.Tensor): - inputs = torch.Tensor(inputs) + inputs = torch.Tensor(inputs).to(model.hparams['device']) # check to see if inputs are images or latents if len(inputs.shape) == 2: @@ -331,7 +331,7 @@ def get_reconstruction( elif model.hparams['model_class'] == 'vae' \ or model.hparams['model_class'] == 'beta-tcvae': ims_recon, latents, _, _ = model(inputs, dataset=dataset, use_mean=use_mean) - elif model.hparams['model_class'] == 'sss-vae': + elif model.hparams['model_class'] == 'ps-vae': ims_recon, _, latents, _, _ = model(inputs, dataset=dataset, use_mean=use_mean) elif model.hparams['model_class'] == 'cond-ae': ims_recon, latents = model(inputs, dataset=dataset, labels=labels, labels_2d=labels_2d) @@ -346,7 +346,7 @@ def get_reconstruction( inputs = torch.cat((inputs, labels), dim=1) elif model.hparams['model_class'] == 'cond-ae-msp' and apply_inverse_transform: inputs = model.get_inverse_transformed_latents(inputs, as_numpy=False) - elif model.hparams['model_class'] == 'sss-vae' and apply_inverse_transform: + elif model.hparams['model_class'] == 'ps-vae' and apply_inverse_transform: # assume "inputs" are [labels, unsupervised latents] where "labels" need to be # transformed into N(0, 1) latent space inputs = model.get_inverse_transformed_latents(inputs, as_numpy=False) @@ -363,7 +363,9 @@ def get_reconstruction( return ims_recon -def get_test_metric(hparams, model_version, metric='r2', sess_idx=0): +def get_test_metric( + hparams, model_version, metric='r2', dtype='test', multioutput='variance_weighted', + sess_idx=0): """Calculate a single R\ :sup:`2` value across all test batches for a decoder. Parameters @@ -373,7 +375,13 @@ def get_test_metric(hparams, model_version, metric='r2', sess_idx=0): model_version : :obj:`int` or :obj:`str` version from test tube experiment defined in :obj:`hparams` or the string 'best' metric : :obj:`str`, optional - 'r2' | 'fc' + 'r2' | 'fc' | 'mse' + dtype : :obj:`str` + type of trials to use for computing metric + 'train' | 'val' | 'test' + multioutput : :obj:`str` + defines how to aggregate multiple r2 scores; see r2_score documentation in sklearn + 'raw_values' | 'uniform_average' | 'variance_weighted' sess_idx : :obj:`int`, optional session index into data generator @@ -392,17 +400,22 @@ def get_test_metric(hparams, model_version, metric='r2', sess_idx=0): model, data_generator = get_best_model_and_data( hparams, Decoder, load_data=True, version=model_version) - n_test_batches = len(data_generator.datasets[sess_idx].batch_idxs['test']) + n_test_batches = len(data_generator.datasets[sess_idx].batch_idxs[dtype]) max_lags = hparams['n_max_lags'] true = [] pred = [] - data_generator.reset_iterators('test') + data_generator.reset_iterators(dtype) for i in range(n_test_batches): - batch, _ = data_generator.next_batch('test') + batch, _ = data_generator.next_batch(dtype) # get true latents/states - if metric == 'r2': - curr_true = batch['ae_latents'][0].cpu().detach().numpy() + if metric == 'r2' or metric == 'mse': + if 'ae_latents' in batch: + curr_true = batch['ae_latents'][0].cpu().detach().numpy() + elif 'labels' in batch: + curr_true = batch['labels'][0].cpu().detach().numpy() + else: + raise ValueError('no valid key in {}'.format(batch.keys())) elif metric == 'fc': curr_true = batch['arhmm_states'][0].cpu().detach().numpy() else: @@ -416,13 +429,14 @@ def get_test_metric(hparams, model_version, metric='r2', sess_idx=0): if metric == 'r2': metric = r2_score( - np.concatenate(true, axis=0), np.concatenate(pred, axis=0), - multioutput='variance_weighted') + np.concatenate(true, axis=0), np.concatenate(pred, axis=0), multioutput=multioutput) + elif metric == 'mse': + metric = np.mean(np.square(np.concatenate(true, axis=0) - np.concatenate(pred, axis=0))) elif metric == 'fc': metric = accuracy_score( np.concatenate(true, axis=0), np.argmax(np.concatenate(pred, axis=0), axis=1)) - return model.hparams, metric + return model.hparams, metric, true, pred def export_train_plots(hparams, dtype, loss_type='mse', save_file=None, format='png'): @@ -446,9 +460,11 @@ def export_train_plots(hparams, dtype, loss_type='mse', save_file=None, format=' import os import pandas as pd import seaborn as sns + import matplotlib as mpl import matplotlib.pyplot as plt from behavenet.fitting.utils import read_session_info_from_csv + mpl.use('Agg') # deal with display-less machines sns.set_style('white') sns.set_context('talk') diff --git a/behavenet/fitting/hyperparam_utils.py b/behavenet/fitting/hyperparam_utils.py index 7d3e840..7e92ce2 100644 --- a/behavenet/fitting/hyperparam_utils.py +++ b/behavenet/fitting/hyperparam_utils.py @@ -68,7 +68,7 @@ def add_dependent_params(parser, namespace): or namespace.model_class == 'cond-vae' \ or namespace.model_class == 'cond-ae' \ or namespace.model_class == 'cond-ae-msp' \ - or namespace.model_class == 'sss-vae' \ + or namespace.model_class == 'ps-vae' \ or namespace.model_class == 'labels-images': max_latents = 64 @@ -130,7 +130,7 @@ def schedule_experiment(self, trial_params, exp_i): self.slurm_files_log_path, '{}_slurm_cmd.sh'.format(timestamp)) run_cmd = self.__get_run_command( trial_params, slurm_cmd_script_path, timestamp, exp_i, self.on_gpu) - sbatch_params = open(self.master_slurm_file,'r').read() + sbatch_params = open(self.master_slurm_file, 'r').read() slurm_cmd = sbatch_params+run_cmd self._SlurmCluster__save_slurm_cmd(slurm_cmd, slurm_cmd_script_path) diff --git a/behavenet/fitting/label_decoder_grid_search.py b/behavenet/fitting/label_decoder_grid_search.py index 69ddd4f..1fcec7b 100644 --- a/behavenet/fitting/label_decoder_grid_search.py +++ b/behavenet/fitting/label_decoder_grid_search.py @@ -4,13 +4,13 @@ import random import torch +from behavenet.data.utils import build_data_generator from behavenet.fitting.eval import export_train_plots from behavenet.fitting.hyperparam_utils import get_all_params from behavenet.fitting.hyperparam_utils import get_slurm_params from behavenet.fitting.training import fit from behavenet.fitting.utils import _clean_tt_dir from behavenet.fitting.utils import _print_hparams -from behavenet.fitting.utils import build_data_generator from behavenet.fitting.utils import create_tt_experiment from behavenet.fitting.utils import export_hparams from behavenet.models import ConvDecoder diff --git a/behavenet/fitting/losses.py b/behavenet/fitting/losses.py index 0e30126..ba24e3c 100644 --- a/behavenet/fitting/losses.py +++ b/behavenet/fitting/losses.py @@ -389,6 +389,6 @@ def subspace_overlap(A, B): """ C = torch.cat([A, B], dim=0) d = C.shape[0] - I = torch.eye(d, device=C.device) - return torch.mean((torch.matmul(C, torch.transpose(C, 1, 0)) - I).pow(2)) + eye = torch.eye(d, device=C.device) + return torch.mean((torch.matmul(C, torch.transpose(C, 1, 0)) - eye).pow(2)) # return torch.mean(torch.matmul(A, torch.transpose(B, 1, 0)).pow(2)) diff --git a/behavenet/fitting/utils.py b/behavenet/fitting/utils.py index 09882f6..f0eac1a 100644 --- a/behavenet/fitting/utils.py +++ b/behavenet/fitting/utils.py @@ -3,14 +3,13 @@ import os import pickle import numpy as np -from behavenet.data.utils import get_data_generator_inputs # to ignore imports for sphinx-autoapidoc __all__ = [ 'get_subdirs', 'get_session_dir', 'get_expt_dir', 'read_session_info_from_csv', 'export_session_info_to_csv', 'contains_session', 'find_session_dirs', 'experiment_exists', 'get_model_params', 'export_hparams', 'get_lab_example', 'get_region_dir', - 'create_tt_experiment', 'build_data_generator', 'get_best_model_version', + 'create_tt_experiment', 'get_best_model_version', 'get_best_model_and_data'] @@ -29,7 +28,7 @@ def get_subdirs(path): """ if not os.path.exists(path): - raise ValueError('%s is not a path' % path) + raise NotADirectoryError('%s is not a path' % path) try: s = next(os.walk(path))[1] except StopIteration: @@ -69,6 +68,8 @@ def _get_multisession_paths(base_dir, lab='', expt='', animal=''): multi_paths.append(os.path.join(base_dir, lab, expt, animal, sub_dir)) except ValueError: print('warning: did not find requested multisession(s)') + except StopIteration: + print('warning: did not find any sessions') return multi_paths @@ -351,7 +352,7 @@ def get_expt_dir(hparams, model_class=None, model_type=None, expt_name=None): or model_class == 'cond-vae' \ or model_class == 'cond-ae' \ or model_class == 'cond-ae-msp' \ - or model_class == 'sss-vae': + or model_class == 'ps-vae': model_path = os.path.join( model_class, model_type, '%02i_latents' % hparams['n_ae_latents']) if hparams.get('ae_multisession', None) is not None: @@ -365,11 +366,15 @@ def get_expt_dir(hparams, model_class=None, model_type=None, expt_name=None): session_dir, _ = get_session_dir(hparams_) else: session_dir = hparams['session_dir'] - elif model_class == 'neural-ae' or model_class == 'ae-neural': + elif model_class == 'neural-ae' or model_class == 'neural-ae-me' or model_class == 'ae-neural': brain_region = get_region_dir(hparams) model_path = os.path.join( model_class, '%02i_latents' % hparams['n_ae_latents'], model_type, brain_region) session_dir = hparams['session_dir'] + elif model_class == 'neural-labels' or model_class == 'labels-neural': + brain_region = get_region_dir(hparams) + model_path = os.path.join(model_class, model_type, brain_region) + session_dir = hparams['session_dir'] elif model_class == 'neural-arhmm' or model_class == 'arhmm-neural': brain_region = get_region_dir(hparams) model_path = os.path.join( @@ -516,7 +521,7 @@ def find_session_dirs(hparams): expts = get_subdirs(os.path.join(hparams['save_dir'], lab)) # need to grab all multi-sessions as well as the single session session_dirs = [] # full paths - session_ids = [] # dict of lab/expt/animal/session + session_ids = [] # dict of lab/expt/animal/session for expt in expts: if expt[:5] == 'multi': session_dir = os.path.join(hparams['save_dir'], lab, expt) @@ -654,7 +659,7 @@ def get_model_params(hparams): or model_class == 'cond-vae' \ or model_class == 'cond-ae' \ or model_class == 'cond-ae-msp' \ - or model_class == 'sss-vae': + or model_class == 'ps-vae': hparams_less['n_ae_latents'] = hparams['n_ae_latents'] hparams_less['fit_sess_io_layers'] = hparams['fit_sess_io_layers'] hparams_less['learning_rate'] = hparams['learning_rate'] @@ -668,10 +673,10 @@ def get_model_params(hparams): # hparams_less['vae.beta_anneal_epochs'] = hparams['vae.beta_anneal_epochs'] if model_class == 'beta-tcvae': hparams_less['beta_tcvae.beta'] = hparams['beta_tcvae.beta'] - if model_class == 'sss-vae': - hparams_less['sss_vae.alpha'] = hparams['sss_vae.alpha'] - hparams_less['sss_vae.beta'] = hparams['sss_vae.beta'] - hparams_less['sss_vae.gamma'] = hparams['sss_vae.gamma'] + if model_class == 'ps-vae': + hparams_less['ps_vae.alpha'] = hparams['ps_vae.alpha'] + hparams_less['ps_vae.beta'] = hparams['ps_vae.beta'] + hparams_less['ps_vae.gamma'] = hparams['ps_vae.gamma'] elif model_class == 'arhmm' or model_class == 'hmm': hparams_less['n_arhmm_lags'] = hparams['n_arhmm_lags'] hparams_less['noise_type'] = hparams['noise_type'] @@ -680,6 +685,7 @@ def get_model_params(hparams): hparams_less['kappa'] = hparams['kappa'] hparams_less['ae_experiment_name'] = hparams['ae_experiment_name'] hparams_less['ae_version'] = hparams['ae_version'] + hparams_less['ae_model_class'] = hparams['ae_model_class'] hparams_less['ae_model_type'] = hparams['ae_model_type'] hparams_less['n_ae_latents'] = hparams['n_ae_latents'] elif model_class == 'arhmm-labels' or model_class == 'hmm-labels': @@ -688,11 +694,14 @@ def get_model_params(hparams): hparams_less['transitions'] = hparams['transitions'] if hparams['transitions'] == 'sticky': hparams_less['kappa'] = hparams['kappa'] - elif model_class == 'neural-ae' or model_class == 'ae-neural': + elif model_class == 'neural-ae' or model_class == 'neural-ae-me' or model_class == 'ae-neural': hparams_less['ae_experiment_name'] = hparams['ae_experiment_name'] hparams_less['ae_version'] = hparams['ae_version'] + hparams_less['ae_model_class'] = hparams['ae_model_class'] hparams_less['ae_model_type'] = hparams['ae_model_type'] hparams_less['n_ae_latents'] = hparams['n_ae_latents'] + elif model_class == 'neural-labels' or model_class == 'labels-neural': + pass elif model_class == 'neural-arhmm' or model_class == 'arhmm-neural': hparams_less['arhmm_experiment_name'] = hparams['arhmm_experiment_name'] hparams_less['arhmm_version'] = hparams['arhmm_version'] @@ -702,6 +711,7 @@ def get_model_params(hparams): hparams_less['transitions'] = hparams['transitions'] if hparams['transitions'] == 'sticky': hparams_less['kappa'] = hparams['kappa'] + hparams_less['ae_model_class'] = hparams['ae_model_class'] hparams_less['ae_model_type'] = hparams['ae_model_type'] hparams_less['n_ae_latents'] = hparams['n_ae_latents'] elif model_class == 'bayesian-decoding': @@ -714,8 +724,10 @@ def get_model_params(hparams): raise NotImplementedError('"%s" is not a valid model class' % model_class) # decoder arch params - if model_class == 'neural-ae' or model_class == 'ae-neural' \ - or model_class == 'neural-arhmm' or model_class == 'arhmm-neural': + if model_class == 'neural-ae' or model_class == 'neural-ae-me' or model_class == 'ae-neural' \ + or model_class == 'neural-arhmm' or model_class == 'arhmm-neural' \ + or model_class == 'neural-labels' or model_class == 'labels-neural': + hparams_less['learning_rate'] = hparams['learning_rate'] hparams_less['n_lags'] = hparams['n_lags'] hparams_less['l2_reg'] = hparams['l2_reg'] hparams_less['model_type'] = hparams['model_type'] @@ -855,53 +867,6 @@ def create_tt_experiment(hparams): return hparams, sess_ids, exp -def build_data_generator(hparams, sess_ids, export_csv=True): - """Helper function to build data generator from hparams dict. - - Parameters - ---------- - hparams : :obj:`dict` - needs to contain information specifying data inputs to model - sess_ids : :obj:`list` of :obj:`dict` - each entry is a session dict with keys 'lab', 'expt', 'animal', 'session' - export_csv : :obj:`bool`, optional - export csv file containing session info (useful when fitting multi-sessions) - - Returns - ------- - :obj:`ConcatSessionsGenerator` object - data generator - - """ - from behavenet.data.data_generator import ConcatSessionsGenerator - from behavenet.data.utils import get_data_generator_inputs - print('using data from following sessions:') - for ids in sess_ids: - print('%s' % os.path.join( - hparams['save_dir'], ids['lab'], ids['expt'], ids['animal'], ids['session'])) - hparams, signals, transforms, paths = get_data_generator_inputs(hparams, sess_ids) - if hparams.get('trial_splits', None) is not None: - # assumes string of form 'train;val;test;gap' - trs = [int(tr) for tr in hparams['trial_splits'].split(';')] - trial_splits = {'train_tr': trs[0], 'val_tr': trs[1], 'test_tr': trs[2], 'gap_tr': trs[3]} - else: - trial_splits = None - print('constructing data generator...', end='') - data_generator = ConcatSessionsGenerator( - hparams['data_dir'], sess_ids, - signals_list=signals, transforms_list=transforms, paths_list=paths, - device=hparams['device'], as_numpy=hparams['as_numpy'], batch_load=hparams['batch_load'], - rng_seed=hparams['rng_seed_data'], trial_splits=trial_splits, - train_frac=hparams['train_frac']) - # csv order will reflect dataset order in data generator - if export_csv: - export_session_info_to_csv(os.path.join( - hparams['expt_dir'], str('version_%i' % hparams['version'])), sess_ids) - print('done') - print(data_generator) - return data_generator - - def get_best_model_version(expt_dir, measure='val_loss', best_def='min', n_best=1): """Get best model version from a test tube experiment. @@ -967,14 +932,14 @@ def get_best_model_version(expt_dir, measure='val_loss', best_def='min', n_best= return best_versions -def get_best_model_and_data(hparams, Model, load_data=True, version='best', data_kwargs=None): +def get_best_model_and_data(hparams, Model=None, load_data=True, version='best', data_kwargs=None): """Load the best model (and data) defined by hparams out of all available test-tube versions. Parameters ---------- hparams : :obj:`dict` needs to contain enough information to specify both a model and the associated data - Model : :obj:`behavenet.models` object + Model : :obj:`behavenet.models` object, optional model type load_data : :obj:`bool`, optional if `False` then data generator is not returned @@ -993,6 +958,7 @@ def get_best_model_and_data(hparams, Model, load_data=True, version='best', data import torch from behavenet.data.data_generator import ConcatSessionsGenerator + from behavenet.data.utils import get_data_generator_inputs # get session_dir hparams['session_dir'], sess_ids = get_session_dir( @@ -1003,6 +969,10 @@ def get_best_model_and_data(hparams, Model, load_data=True, version='best', data if version == 'best': best_version_int = get_best_model_version(expt_dir)[0] best_version = str('version_{}'.format(best_version_int)) + elif version is None: + # try to match hparams + _, version_hp = experiment_exists(hparams, which_version=True) + best_version = str('version_{}'.format(version_hp)) else: if isinstance(version, str) and version[0] == 'v': # assume we got a string of the form 'version_{%i}' @@ -1038,11 +1008,40 @@ def get_best_model_and_data(hparams, Model, load_data=True, version='best', data signals_list=signals, transforms_list=transforms, paths_list=paths, device=hparams_new['device'], as_numpy=hparams_new['as_numpy'], batch_load=hparams_new['batch_load'], rng_seed=hparams_new['rng_seed_data'], - **data_kwargs) + train_frac=hparams_new['train_frac'], **data_kwargs) else: data_generator = None - # build models + # build model + if Model is None: + if hparams['model_class'] == 'ae': + from behavenet.models import AE as Model + elif hparams['model_class'] == 'vae': + from behavenet.models import VAE as Model + elif hparams['model_class'] == 'cond-ae': + from behavenet.models import ConditionalAE as Model + elif hparams['model_class'] == 'cond-vae': + from behavenet.models import ConditionalVAE as Model + elif hparams['model_class'] == 'cond-ae-msp': + from behavenet.models import AEMSP as Model + elif hparams['model_class'] == 'beta-tcvae': + from behavenet.models import BetaTCVAE as Model + elif hparams['model_class'] == 'ps-vae': + from behavenet.models import PSVAE as Model + elif hparams['model_class'] == 'labels-images': + from behavenet.models import ConvDecoder as Model + elif hparams['model_class'] == 'neural-ae' or hparams['model_class'] == 'neural-ae-me' \ + or hparams['model_class'] == 'neural-arhmm' \ + or hparams['model_class'] == 'neural-labels': + from behavenet.models import Decoder as Model + elif hparams['model_class'] == 'ae-neural' or hparams['model_class'] == 'arhmm-neural' \ + or hparams['model_class'] == 'labels-neural': + from behavenet.models import Decoder as Model + elif hparams['model_class'] == 'arhmm': + raise NotImplementedError('Cannot use get_best_model_and_data() for ssm models') + else: + raise NotImplementedError + model = Model(hparams_new) model.version = int(best_version.split('_')[1]) model.load_state_dict(torch.load(model_file, map_location=lambda storage, loc: storage)) diff --git a/behavenet/models/README.md b/behavenet/models/README.md index fa8c46c..fc99023 100644 --- a/behavenet/models/README.md +++ b/behavenet/models/README.md @@ -12,9 +12,10 @@ Model-related code * `behavenet.data.utils.get_data_generator_inputs` [UPDATE UNIT TEST!] * `behavenet.fitting.utils.get_expt_dir` [UPDATE UNIT TEST!] * `behavenet.fitting.utils.get_model_params` [UPDATE UNIT TEST!] + * `behavenet.fitting.utils.get_best_data_and_model` * `behavenet.fitting.eval.export_xxx` (latents, states, predictions, etc) * potential function updates: - * other `behavenet.fitting.eval` methods (like `get_rconstruction`) + * other `behavenet.fitting.eval` methods (like `get_reconstruction`) * `behavenet.fitting.hyperparam_utils.add_dependent_params` [UPDATE UNIT TEST!] * update relevant jsons (e.g. extra hyperparameters) diff --git a/behavenet/models/__init__.py b/behavenet/models/__init__.py index 8db1cde..72429b7 100644 --- a/behavenet/models/__init__.py +++ b/behavenet/models/__init__.py @@ -1,4 +1,4 @@ from behavenet.models.aes import AE, ConditionalAE, AEMSP from behavenet.models.base import CustomDataParallel from behavenet.models.decoders import Decoder, ConvDecoder -from behavenet.models.vaes import VAE, ConditionalVAE, BetaTCVAE, SSSVAE +from behavenet.models.vaes import VAE, ConditionalVAE, BetaTCVAE, PSVAE diff --git a/behavenet/models/ae_model_architecture_generator.py b/behavenet/models/ae_model_architecture_generator.py index fcf5cf6..9ceafb1 100644 --- a/behavenet/models/ae_model_architecture_generator.py +++ b/behavenet/models/ae_model_architecture_generator.py @@ -106,19 +106,19 @@ def get_possible_arch(input_dim, n_ae_latents, arch_seed=0): arch = {} arch['ae_input_dim'] = input_dim - arch['model_type'] = 'conv' + arch['model_type'] = 'conv' arch['n_ae_latents'] = n_ae_latents arch['ae_decoding_last_FF_layer'] = 0 # arch['ae_decoding_last_FF_layer'] = np.random.choice( # np.asarray([0, 1]), p=np.asarray([1 - opts['FF_layer_prob'], opts['FF_layer_prob']])) - arch['ae_batch_norm'] = 0 + arch['ae_batch_norm'] = 0 arch['ae_batch_norm_momentum'] = None # First decide if strides only or max pooling # network_types = ['strides_only', 'max_pooling'] # arch['ae_network_type'] = network_types[np.random.randint(2)] arch['ae_network_type'] = 'strides_only' - + # Then decide if padding is 0 (0) or same (1) for all layers padding_types = ['valid', 'same'] arch['ae_padding_type'] = padding_types[np.random.randint(2)] @@ -255,7 +255,7 @@ def get_encoding_conv_block(arch, opts): break last_dims = arch['ae_encoding_n_channels'][-1] * arch['ae_encoding_y_dim'][-1] * \ - arch['ae_encoding_x_dim'][-1] + arch['ae_encoding_x_dim'][-1] smallest_pix = min(arch['ae_encoding_y_dim'][-1], arch['ae_encoding_x_dim'][-1]) p = opts['prob_stopping'][global_layer] stop_this_layer = np.random.choice([0, 1], p=[1 - p, p]) @@ -348,7 +348,8 @@ def calculate_output_dim(input_dim, kernel, stride, padding_type, layer_type): """Calculate output dimension of a layer/dimension based on input size, kernel size, etc. Inspired by: - - https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/common_shape_fns.cc#L21 + - https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/ + common_shape_fns.cc#L21 - https://github.com/pytorch/pytorch/issues/3867 Parameters @@ -506,17 +507,17 @@ def get_handcrafted_dims(arch, symmetric=True): """ arch['model_type'] = 'conv' - + arch['ae_encoding_x_dim'] = [] arch['ae_encoding_y_dim'] = [] arch['ae_encoding_x_padding'] = [] arch['ae_encoding_y_padding'] = [] for i_layer in range(len(arch['ae_encoding_n_channels'])): - + kernel_size = arch['ae_encoding_kernel_size'][i_layer] stride_size = arch['ae_encoding_stride_size'][i_layer] - + if i_layer == 0: # use input dimensions input_dim_y = arch['ae_input_dim'][1] input_dim_x = arch['ae_input_dim'][2] @@ -533,8 +534,8 @@ def get_handcrafted_dims(arch, symmetric=True): arch['ae_encoding_x_dim'].append(output_dim_x) arch['ae_encoding_y_dim'].append(output_dim_y) - arch['ae_encoding_x_padding'].append((x_before_pad,x_after_pad)) - arch['ae_encoding_y_padding'].append((y_before_pad,y_after_pad)) + arch['ae_encoding_x_padding'].append((x_before_pad, x_after_pad)) + arch['ae_encoding_y_padding'].append((y_before_pad, y_after_pad)) if symmetric: arch = get_decoding_conv_block(arch) @@ -552,7 +553,7 @@ def get_handcrafted_dims(arch, symmetric=True): for i_layer in range(len(arch['ae_decoding_n_channels'])): kernel_size = arch['ae_decoding_kernel_size'][i_layer] stride_size = arch['ae_decoding_stride_size'][i_layer] - + if i_layer == 0: # use input dimensions input_dim_y = arch['ae_decoding_starting_dim'][1] input_dim_x = arch['ae_decoding_starting_dim'][2] @@ -562,8 +563,9 @@ def get_handcrafted_dims(arch, symmetric=True): # TODO: not correct if arch['ae_padding_type'] == 'valid': - before_pad = 0 - after_pad = 0 + pass + # before_pad = 0 + # after_pad = 0 elif arch['ae_padding_type'] == 'same': # output_dim_x, x_before_pad, y_before_pad = calculate_output_dim( # input_dim_x, kernel_size, stride_size, 'same', 'conv') @@ -703,16 +705,16 @@ def load_handcrafted_arches( def load_default_arch(): - """Load default convolutional AE architecture used in Batty et al 2019""" + """Load default convolutional AE architecture used in Whiteway et al 2021.""" arch = { 'ae_network_type': 'strides_only', 'ae_padding_type': 'same', 'ae_batch_norm': 0, 'ae_batch_norm_momentum': None, 'symmetric_arch': 1, - 'ae_encoding_n_channels': [32, 64, 256, 512], - 'ae_encoding_kernel_size': [5, 5, 5, 5], - 'ae_encoding_stride_size': [2, 2, 2, 2], - 'ae_encoding_layer_type': ['conv', 'conv', 'conv', 'conv'], + 'ae_encoding_n_channels': [32, 64, 128, 256, 512], + 'ae_encoding_kernel_size': [5, 5, 5, 5, 5], + 'ae_encoding_stride_size': [2, 2, 2, 2, 5], + 'ae_encoding_layer_type': ['conv', 'conv', 'conv', 'conv', 'conv'], 'ae_decoding_last_FF_layer': 0} return arch diff --git a/behavenet/models/aes.py b/behavenet/models/aes.py index e6b727c..c10edfc 100644 --- a/behavenet/models/aes.py +++ b/behavenet/models/aes.py @@ -91,7 +91,8 @@ def build_model(self): if self.hparams['ae_batch_norm']: module = nn.BatchNorm2d( self.hparams['ae_encoding_n_channels'][i_layer], - momentum=self.hparams['ae_batch_norm_momentum']) + momentum=self.hparams.get('ae_batch_norm_momentum', 0.1), + track_running_stats=self.hparams.get('track_running_stats', True)) self.encoder.add_module( str('batchnorm%i' % global_layer_num), module) @@ -115,8 +116,8 @@ def build_model(self): # final ff layer to latents last_conv_size = self.hparams['ae_encoding_n_channels'][-1] \ - * self.hparams['ae_encoding_y_dim'][-1] \ - * self.hparams['ae_encoding_x_dim'][-1] + * self.hparams['ae_encoding_y_dim'][-1] \ + * self.hparams['ae_encoding_x_dim'][-1] self.FF = nn.Linear(last_conv_size, self.hparams['n_ae_latents']) # If VAE model, have additional ff layer to latent variances @@ -260,8 +261,8 @@ def build_model(self): # First ff layer (from latents to size of last encoding layer) first_conv_size = self.hparams['ae_decoding_starting_dim'][0] \ - * self.hparams['ae_decoding_starting_dim'][1] \ - * self.hparams['ae_decoding_starting_dim'][2] + * self.hparams['ae_decoding_starting_dim'][1] \ + * self.hparams['ae_decoding_starting_dim'][2] self.FF = nn.Linear(self.hparams['hidden_layer_size'], first_conv_size) self.decoder = nn.ModuleList() @@ -331,7 +332,8 @@ def build_model(self): if self.hparams['ae_batch_norm']: module = nn.BatchNorm2d( self.hparams['ae_decoding_n_channels'][i_layer], - momentum=self.hparams['ae_batch_norm_momentum']) + momentum=self.hparams.get('ae_batch_norm_momentum', 0.1), + track_running_stats=self.hparams.get('track_running_stats', True)) self.decoder.add_module( str('batchnorm%i' % global_layer_num), module) @@ -473,7 +475,7 @@ def forward(self, x, pool_idx=None, target_output_size=None, dataset=None): # (-i does cropping!) x = functional.pad(x, [-i for i in self.conv_t_pads[name]]) elif isinstance(layer, nn.Linear): - x = x.view(x.shape[0],-1) + x = x.view(x.shape[0], -1) x = layer(x) x = x.view( -1, @@ -659,9 +661,9 @@ def __init__(self, hparams): self.hparams = hparams self.model_type = self.hparams['model_type'] self.img_size = ( - self.hparams['n_input_channels'], - self.hparams['y_pixels'], - self.hparams['x_pixels']) + self.hparams['n_input_channels'], + self.hparams['y_pixels'], + self.hparams['x_pixels']) self.encoding = None self.decoding = None self.build_model() diff --git a/behavenet/models/base.py b/behavenet/models/base.py index dacaad2..06a5f17 100644 --- a/behavenet/models/base.py +++ b/behavenet/models/base.py @@ -3,6 +3,9 @@ import math from torch import nn, save, Tensor +# to ignore imports for sphix-autoapidoc +__all__ = ['BaseModule', 'BaseModel', 'DiagLinear', 'CustomDataParallel'] + class BaseModule(nn.Module): """Template for PyTorch modules.""" diff --git a/behavenet/models/decoders.py b/behavenet/models/decoders.py index 50f992a..af54da2 100644 --- a/behavenet/models/decoders.py +++ b/behavenet/models/decoders.py @@ -384,9 +384,9 @@ def __init__(self, hparams): self.hparams = hparams self.model_type = self.hparams['model_type'] self.img_size = ( - self.hparams['n_input_channels'], - self.hparams['y_pixels'], - self.hparams['x_pixels']) + self.hparams['n_input_channels'], + self.hparams['y_pixels'], + self.hparams['x_pixels']) self.decoding = None self.build_model() diff --git a/behavenet/models/hierarchical_decoders.py b/behavenet/models/hierarchical_decoders.py new file mode 100644 index 0000000..268017b --- /dev/null +++ b/behavenet/models/hierarchical_decoders.py @@ -0,0 +1,461 @@ +"""Hierarchical encoding/decoding models implemented in PyTorch.""" + +import numpy as np +from sklearn.metrics import r2_score, accuracy_score +import torch +from torch import nn +import behavenet.fitting.losses as losses +from behavenet.models.base import BaseModule +from behavenet.models.decoders import Decoder + + +class HierarchicalDecoder(Decoder): + """General wrapper class for hierarchical encoding/decoding models.""" + + def __init__(self, hparams): + """ + + Parameters + ---------- + hparams : :obj:`dict` + - model_type (:obj:`str`): 'mlp' | 'mlp-mv' | 'lstm' + - input_size (:obj:`int`) + - output_size (:obj:`int`) + - n_hid_layers (:obj:`int`) + - n_hid_units (:obj:`int`) + - n_lags (:obj:`int`): number of lags in input data to use for temporal convolution + - noise_dist (:obj:`str`): 'gaussian' | 'gaussian-full' | 'poisson' | 'categorical' + - activation (:obj:`str`): 'linear' | 'relu' | 'lrelu' | 'sigmoid' | 'tanh' + + """ + super().__init__(hparams) + self.hparams = hparams + self.model = None + self.build_model() + # choose loss based on noise distribution of the model + if self.hparams['noise_dist'] == 'gaussian': + self._loss = nn.MSELoss() + elif self.hparams['noise_dist'] == 'gaussian-full': + from behavenet.fitting.losses import GaussianNegLogProb + self._loss = GaussianNegLogProb() # model holds precision mat + elif self.hparams['noise_dist'] == 'poisson': + self._loss = nn.PoissonNLLLoss(log_input=False) + elif self.hparams['noise_dist'] == 'categorical': + self._loss = nn.CrossEntropyLoss() + else: + raise ValueError('"%s" is not a valid noise dist' % self.model['noise_dist']) + + def __str__(self): + """Pretty print model architecture.""" + return self.model.__str__() + + def build_model(self): + """Construct the model using hparams.""" + + # TODO + if self.hparams['model_type'] == 'mlp' or self.hparams['model_type'] == 'mlp-mv': + self.model = HierarchicalMLP(self.hparams) + elif self.hparams['model_type'] == 'lstm': + self.model = HierarchicalLSTM(self.hparams) + else: + raise ValueError('"%s" is not a valid model type' % self.hparams['model_type']) + + def forward(self, x,dataset): + """Process input data.""" + return self.model(x,dataset) + + def loss(self, data,dataset, accumulate_grad=True, chunk_size=200, **kwargs): + # TODO + """Calculate negative log-likelihood loss for supervised models. + + The batch is split into chunks if larger than a hard-coded `chunk_size` to keep memory + requirements low; gradients are accumulated across all chunks before a gradient step is + taken. + + Parameters + ---------- + data : :obj:`dict` + signals are of shape (1, time, n_channels) + accumulate_grad : :obj:`bool`, optional + accumulate gradient for training step + chunk_size : :obj:`int`, optional + batch is split into chunks of this size to keep memory requirements low + + Returns + ------- + :obj:`dict` + - 'loss' (:obj:`float`): total loss (negative log-like under specified noise dist) + - 'r2' (:obj:`float`): variance-weighted $R^2$ when noise dist is Gaussian + - 'fc' (:obj:`float`): fraction correct when noise dist is Categorical + + """ + # self.dataset = dataset # it is passed as a kwarg, not sure how else to access this and pass it into forward() + predictors = data[self.hparams['input_signal']][0] + targets = data[self.hparams['output_signal']][0] + + max_lags = self.hparams['n_max_lags'] + + batch_size = targets.shape[0] + n_chunks = int(np.ceil(batch_size / chunk_size)) + + outputs_all = [] + loss_val = 0 + for chunk in range(n_chunks): + + # take chunks of size chunk_size, plus overlap due to max_lags + idx_beg = np.max([chunk * chunk_size - max_lags, 0]) + idx_end = np.min([(chunk + 1) * chunk_size + max_lags, batch_size]) + + outputs, precision = self.forward(predictors[idx_beg:idx_end],dataset) + + # define loss on allowed window of data + if self.hparams['noise_dist'] == 'gaussian-full': + loss = self._loss( + outputs[max_lags:-max_lags], + targets[idx_beg:idx_end][max_lags:-max_lags], + precision[max_lags:-max_lags]) + else: + loss = self._loss( + outputs[max_lags:-max_lags], + targets[idx_beg:idx_end][max_lags:-max_lags]) + + if accumulate_grad: + loss.backward() + + # get loss value (weighted by batch size) + loss_val += loss.item() * outputs[max_lags:-max_lags].shape[0] + + outputs_all.append(outputs[max_lags:-max_lags].cpu().detach().numpy()) + + loss_val /= batch_size + outputs_all = np.concatenate(outputs_all, axis=0) + + if self.hparams['noise_dist'] == 'gaussian' or \ + self.hparams['noise_dist'] == 'gaussian-full': + # use variance-weighted r2s to ignore small-variance latents + r2 = r2_score( + targets[max_lags:-max_lags].cpu().detach().numpy(), + outputs_all, + multioutput='variance_weighted') + fc = 0 + elif self.hparams['noise_dist'] == 'poisson': + raise NotImplementedError + elif self.hparams['noise_dist'] == 'categorical': + r2 = 0 + fc = accuracy_score( + targets[max_lags:-max_lags].cpu().detach().numpy(), + np.argmax(outputs_all, axis=1)) + else: + raise ValueError('"%s" is not a valid noise_dist' % self.hparams['noise_dist']) + + return {'loss': loss_val, 'r2': r2, 'fc': fc} + + +class HierarchicalMLP(BaseModule): + """Feedforward neural network model.""" + + def __init__(self, hparams): + super().__init__() + self.hparams = hparams + self.decoder = None + self.build_model() + + def __str__(self): + """Pretty print model architecture.""" + # TODO + pass + + def build_model(self): + """Construct the model.""" + # TODO + pass + self.decoder = nn.ModuleList() + + global_layer_num = 0 + # Ask if the input size field of the hparams should be populated according to the multiple datasets that are + # present in the datagenerator.datasets somewhere else? Because at the moment hparams just has one single inp dim + out_size = self.hparams['n_hid_units']# fix it to the input size of the global backbone network + # for i,i_layer in enumerate(range(len(sess_ids))): + # in_size = self.hparams['input_size'][i] + # + # # first layer is 1d conv for incorporating past/future neural activity + # # Separate 1d conv for each dataset + # layer = nn.Conv1D(in_channels=in_size, out_channels=out_size, + # kernel_size=self.hparams['n_lags']*2+1, #window around t + # padding=self.hparams['n_lags'])# same output + # name = str('conv1d_layer_%02i'% global_layer_num) + # self.decoder.add_module(name,layer) + # self.final_layer = name + + layer = nn.ModuleList([ + nn.Conv1d(in_channels=in_size, out_channels=out_size, + kernel_size=self.hparams['n_lags']*2+1, + padding=self.hparams['n_lags']) + for in_size in self.hparams['input_size'] + ]) + + name = str('conv1d_layer_%02i' % global_layer_num) + self.decoder.add_module(name, layer) + self.final_layer = name + + # add activation + if self.hparams['n_hid_layers'] == 0: + if self.hparams['noise_dist'] == 'gaussian': + activation = None + elif self.hparams['noise_dist'] == 'gaussian-full': + activation = None + elif self.hparams['noise_dist'] == 'poisson': + activation = nn.Softplus() + elif self.hparams['noise_dist'] == 'categorical': + activation = None + else: + raise ValueError('"%s" is an invalid noise dist'% self.hparams['noise_dist']) + + else: + if self.hparams['activation'] == 'linear': + activation = None + elif self.hparams['activation'] == 'relu': + activation = nn.ReLU() + elif self.hparams['activation'] == 'lrelu': + activation = nn.LeakyReLU(0.05) + elif self.hparams['activation'] == 'sigmoid': + activation = nn.Sigmoid() + elif self.hparams['activation'] == 'tanh': + activation = nn.Tanh() + else: + raise ValueError( + '"%s" is an invalid activation function' % self.hparams['activation']) + + if activation: + name = '%s_%02i' % (self.hparams['activation'], global_layer_num) + self.decoder.add_module(name, activation) + + # add layer for data dependent precision matrix if requires + if self.hparams['n_hid_layers'] == 0 and self.hparams['noise_dist'] == 'gaussian-full': + # build sqrt of precision matrix + self.precision_sqrt = nn.Linear(in_features=in_size, out_features=out_size**2) + else: + self.precision_sqrt = None + + # update layer info + global_layer_num += 1 + in_size = out_size + + # loop over hidden layers + for i_layer in range(self.hparams['n_hid_layers']): + + if i_layer == self.hparams['n_hid_layers'] - 1: + out_size = self.hparams['output_size'] + else: + out_size = self.hparams['n_hid_units'] + + # add layer + layer = nn.Linear(in_features=in_size, out_features=out_size) + name = str('dense_layer_%02i'%global_layer_num) + self.decoder.add_module(name,layer) + self.final_layer = name + + # add activation + if i_layer == self.hparams['n_hid_layers'] - 1: + if self.hparams['noise_dist'] == 'gaussian': + activation = None + elif self.hparams['noise_dist'] == 'gaussian-full': + activation = None + elif self.hparams['noise_dist'] == 'poisson': + activation = nn.Softplus() + elif self.hparams['noise_dist'] == 'categorical': + activation = None + else: + raise ValueError('"%s" is an invalid noise dist' % self.hparams['noise_dist']) + else: + if self.hparams['activation'] == 'linear': + activation = None + elif self.hparams['activation'] == 'relu': + activation = nn.ReLU() + elif self.hparams['activation'] == 'lrelu': + activation = nn.LeakyReLU(0.05) + elif self.hparams['activation'] == 'sigmoid': + activation = nn.Sigmoid() + elif self.hparams['activation'] == 'tanh': + activation = nn.Tanh() + else: + raise ValueError( + '"%s" is an invalid activation function' % self.hparams['activation']) + + if activation: + self.decoder.add_module( + '%s_%02i' % (self.hparams['activation'], global_layer_num), activation) + + # add layer for data-dependent precision matrix if required + if i_layer == self.hparams['n_hid_layers'] - 1 \ + and self.hparams['noise_dist'] == 'gaussian-full': + # build sqrt of precision matrix + self.precision_sqrt = nn.Linear(in_features=in_size, out_features=out_size ** 2) + else: + self.precision_sqrt = None + + # update layer info + global_layer_num += 1 + in_size = out_size + + in_size_list = self.hparams['input_size'] + + # + # if self.hparams['n_hid_layers'] == 0: + # out_size + + def forward(self, x,dataset): + """Process input data. + + Parameters + ---------- + x : :obj:`torch.Tensor` + shape of (time, neurons) + + Returns + ------- + :obj:`tuple` + - x (:obj:`torch.Tensor`): mean prediction of model + - y (:obj:`torch.Tensor`): precision matrix prediction of model (when using 'mlp-mv') + + """ + # sess_id = [s for s in self.hparams['input_size']] + y = None + for name, layer in self.decoder.named_children(): + + if name == 'conv1d_layer_00': + # input is batch x in_channels x time + # output is batch x out_channels x time + x = layer[dataset](x.transpose(1,0).unsqueeze(0)).squeeze().transpose(1,0) + # x = layer(x.transpose(1,0).unsqueeze(0)).squeeze().transpose(1,0) + else: + x = layer(x) + + return x, y + # pass + +class HierarchicalLSTM(BaseModule): + """Feedforward neural network model.""" + + def __init__(self, hparams): + super().__init__() + self.hparams = hparams + self.decoder = None + self.build_model() + self.hidden_cell = (torch.zeros(hparams["stack"], hparams["batch"], hparams["hidden_layer_size"]), + torch.zeros(hparams["stack"], hparams["batch"], hparams["hidden_layer_size"])) + + def __str__(self): + """Pretty print model architecture.""" + # TODO + pass + + def build_model(self): + """Construct the model.""" + # TODO + self.decoder = nn.ModuleList() + + global_layer_num = 0 + + out_size = self.hparams['n_hid_units']# fix it to the input size of the global backbone network + + in_size_1 = self.hparams['input_size'][0] + in_size_2 = self.hparams['input_size'][1] + + + layer = nn.ModuleList( + [ + nn.Linear(in_size_1, self.hparams['lstm_in_size']) + ]) + name = str('InputMLP_layer_%02i' % global_layer_num) + self.decoder.add_module(name, layer) + + # # Add activation + # global_layer_num += 1 + # name = '%s_%02i' % (self.hparams['activation'], global_layer_num) + # activation = nn.ReLU() + # self.decoder.add_module(name, activation) + + # Add a second head of linear and activations + global_layer_num += 1 + layer = nn.ModuleList( + [ + nn.Linear(in_size_2, self.hparams['lstm_in_size']) + ]) + name = str('InputMLP_layer_%02i' % global_layer_num) + self.decoder.add_module(name, layer) + + # # Add activation + # global_layer_num += 1 + # name = '%s_%02i' % (self.hparams['activation'], global_layer_num) + # activation = nn.ReLU() + # self.decoder.add_module(name, activation) + + # update layer info # add lstm layer + global_layer_num += 1 + layer = nn.LSTM(input_size=self.hparams["lstm_in_size"], hidden_size=self.hparams["hidden_layer_size"], num_layers=self.hparams["stack"]) + name = str('lstm_layer_%02i'%global_layer_num) + self.decoder.add_module(name,layer) + + # update layer info + global_layer_num += 1 + in_size = out_size + + # add linear layer + layer = nn.Linear(in_features=self.hparams["hidden_layer_size"],out_features=self.hparams["output_size"]) + name = str('dense_layer_%02i'%global_layer_num) + self.decoder.add_module(name,layer) + self.final_layer = name + + + + + def forward(self, x,dataset): + """Process input data. + + Parameters + ---------- + x : :obj:`torch.Tensor` + shape of (time, neurons) + + Returns + ------- + :obj:`tuple` + - x (:obj:`torch.Tensor`): mean prediction of model + - y (:obj:`torch.Tensor`): precision matrix prediction of model (when using 'mlp-mv') + + """ + # sess_id = [s for s in self.hparams['input_size']] + + + + y = None + for name, layer in self.decoder.named_children(): + + if name == 'InputMLP_layer_00' and dataset==0: + # input is batch x in_channels x time + # output is batch x out_channels x time + x = layer[0](x.unsqueeze(0)).squeeze().transpose(1,0) + + # if name=='relu_01' and dataset==0: + # x = layer(x) + + if name == 'InputMLP_layer_01' and dataset==1: + # input is batch x in_channels x time + # output is batch x out_channels x time + x = layer[0](x.unsqueeze(0)).squeeze().transpose(1,0) + + # if name=='relu_03' and dataset==1: + # x = layer(x) + + if name == 'lstm_layer_02': + x = x.reshape(189,1,-1) + x, _ = layer(x,self.hidden_cell) + + elif name == 'dense_layer_03': + x = layer(x) + + return x.reshape(189,10), y + + + diff --git a/behavenet/models/vaes.py b/behavenet/models/vaes.py index 342876f..893f33c 100644 --- a/behavenet/models/vaes.py +++ b/behavenet/models/vaes.py @@ -9,7 +9,7 @@ from behavenet.models.aes import AE, ConvAEDecoder, ConvAEEncoder # to ignore imports for sphix-autoapidoc -__all__ = ['reparameterize', 'VAE', 'ConditionalVAE', 'BetaTCVAE', 'SSSVAE', 'ConvAESSSEncoder'] +__all__ = ['reparameterize', 'VAE', 'ConditionalVAE', 'BetaTCVAE', 'PSVAE', 'ConvAEPSEncoder'] def reparameterize(mu, logvar): @@ -501,8 +501,8 @@ def loss(self, data, dataset=0, accumulate_grad=True, chunk_size=200): return loss_dict_vals -class SSSVAE(AE): - """Semi-supervised subspace variational autoencoder class. +class PSVAE(AE): + """Partitioned subspace variational autoencoder class. This class constructs a VAE that... @@ -516,16 +516,16 @@ def __init__(self, hparams): hparams : :obj:`dict` in addition to the standard keys, must also contain: - 'n_labels' (:obj:`n_labels`) - - 'sss.alpha' (:obj:`float`) - - 'sss.beta' (:obj:`float`) - - 'sss.gamma' (:obj:`float`) + - 'ps_vae.alpha' (:obj:`float`) + - 'ps_vae.beta' (:obj:`float`) + - 'ps_vae.gamma' (:obj:`float`) """ if hparams['model_type'] == 'linear': raise NotImplementedError if hparams['n_ae_latents'] < hparams['n_labels']: - raise ValueError('SSS-VAE model must contain at least as many latents as labels') + raise ValueError('PS-VAE model must contain at least as many latents as labels') self.n_latents = hparams['n_ae_latents'] self.n_labels = hparams['n_labels'] @@ -534,9 +534,9 @@ def __init__(self, hparams): super().__init__(hparams) # set up beta annealing - anneal_epochs = self.hparams.get('sss_vae.anneal_epochs', 0) + anneal_epochs = self.hparams.get('ps_vae.anneal_epochs', 0) self.curr_epoch = 0 # must be modified by training script - beta = hparams['sss_vae.beta'] + beta = hparams['ps_vae.beta'] # TODO: these values should not be precomputed if anneal_epochs > 0: # annealing for total correlation term @@ -555,7 +555,7 @@ def build_model(self): """Construct the model using hparams.""" self.hparams['hidden_layer_size'] = self.hparams['n_ae_latents'] if self.model_type == 'conv': - self.encoding = ConvAESSSEncoder(self.hparams) + self.encoding = ConvAEPSEncoder(self.hparams) self.decoding = ConvAEDecoder(self.hparams) elif self.model_type == 'linear': raise NotImplementedError @@ -600,7 +600,7 @@ def forward(self, x, dataset=None, use_mean=False, **kwargs): return x_hat, z, mu, logvar, y_hat def loss(self, data, dataset=0, accumulate_grad=True, chunk_size=200): - """Calculate modified ELBO loss for SSSVAE. + """Calculate modified ELBO loss for PSVAE. The batch is split into chunks if larger than a hard-coded `chunk_size` to keep memory requirements low; gradients are accumulated across all chunks before a gradient step is @@ -631,15 +631,16 @@ def loss(self, data, dataset=0, accumulate_grad=True, chunk_size=200): x = data['images'][0] y = data['labels'][0] m = data['masks'][0] if 'masks' in data else None + n = data['labels_masks'][0] if 'labels_masks' in data else None batch_size = x.shape[0] n_chunks = int(np.ceil(batch_size / chunk_size)) n_labels = self.hparams['n_labels'] - n_latents = self.hparams['n_ae_latents'] + # n_latents = self.hparams['n_ae_latents'] # compute hyperparameters - alpha = self.hparams['sss_vae.alpha'] + alpha = self.hparams['ps_vae.alpha'] beta = self.beta_vals[self.curr_epoch] - gamma = self.hparams['sss_vae.gamma'] + gamma = self.hparams['ps_vae.gamma'] kl = self.kl_anneal_vals[self.curr_epoch] loss_strs = [ @@ -659,6 +660,7 @@ def loss(self, data, dataset=0, accumulate_grad=True, chunk_size=200): x_in = x[idx_beg:idx_end] y_in = y[idx_beg:idx_end] m_in = m[idx_beg:idx_end] if m is not None else None + n_in = n[idx_beg:idx_end] if n is not None else None x_hat, sample, mu, logvar, y_hat = self.forward(x_in, dataset=dataset, use_mean=False) # reset losses @@ -669,7 +671,7 @@ def loss(self, data, dataset=0, accumulate_grad=True, chunk_size=200): loss_dict_torch['loss'] -= loss_dict_torch['loss_data_ll'] # label log-likelihood - loss_dict_torch['loss_label_ll'] = losses.gaussian_ll(y_in, y_hat) + loss_dict_torch['loss_label_ll'] = losses.gaussian_ll(y_in, y_hat, n_in) loss_dict_torch['loss'] -= alpha * loss_dict_torch['loss_label_ll'] # supervised latents kl @@ -717,11 +719,17 @@ def loss(self, data, dataset=0, accumulate_grad=True, chunk_size=200): # use variance-weighted r2s to ignore small-variance latents y_hat_all = np.concatenate(y_hat_all, axis=0) - r2 = r2_score(y.cpu().detach().numpy(), y_hat_all, multioutput='variance_weighted') + y_all = y.cpu().detach().numpy() + if n is not None: + n_np = n.cpu().detach().numpy() + r2 = r2_score(y_all[n_np == 1], y_hat_all[n_np == 1], multioutput='variance_weighted') + else: + r2 = r2_score(y_all, y_hat_all, multioutput='variance_weighted') # compile (properly weighted) loss terms for key in loss_dict_vals.keys(): loss_dict_vals[key] /= batch_size + # store hyperparams loss_dict_vals['alpha'] = alpha loss_dict_vals['beta'] = beta @@ -852,7 +860,7 @@ def get_inverse_transformed_latents(self, inputs, dataset=None, as_numpy=True): return latents_tr -class ConvAESSSEncoder(ConvAEEncoder): +class ConvAEPSEncoder(ConvAEEncoder): """Convolutional encoder that separates label-related subspace.""" def __init__(self, hparams): diff --git a/behavenet/plotting/__init__.py b/behavenet/plotting/__init__.py index 451709a..c2dc34a 100644 --- a/behavenet/plotting/__init__.py +++ b/behavenet/plotting/__init__.py @@ -1,9 +1,12 @@ """Utility functions shared across multiple plotting modules.""" +from matplotlib.animation import FFMpegWriter import numpy as np import os +import pickle import pandas as pd +from behavenet import make_dir_if_not_exists from behavenet.fitting.utils import experiment_exists from behavenet.fitting.utils import get_expt_dir from behavenet.fitting.utils import get_session_dir @@ -12,7 +15,7 @@ from behavenet.fitting.utils import read_session_info_from_csv # to ignore imports for sphix-autoapidoc -__all__ = ['load_metrics_csv_as_df'] +__all__ = ['concat', 'get_crop', 'load_latents', 'load_metrics_csv_as_df', 'save_movie'] # TODO: use load_metrics_csv_as_df in ae example notebook @@ -35,6 +38,75 @@ def concat(ims, axis=1): return np.concatenate([ims[0, :, :], ims[1, :, :]], axis=axis) +def get_crop(im, y_0, y_ext, x_0, x_ext): + """Get crop of image, filling in borders with zeros. + + Parameters + ---------- + im : :obj:`np.ndarray` + input image + y_0 : :obj:`int` + y-pixel center value + y_ext : :obj:`int` + y-pixel extent; crop in y-direction will be [y_0 - y_ext, y_0 + y_ext] + x_0 : :obj:`int` + y-pixel center value + x_ext : :obj:`int` + x-pixel extent; crop in x-direction will be [x_0 - x_ext, x_0 + x_ext] + + Returns + ------- + :obj:`np.ndarray` + cropped image + + """ + y_min = y_0 - y_ext + y_max = y_0 + y_ext + y_pix = y_max - y_min + x_min = x_0 - x_ext + x_max = x_0 + x_ext + x_pix = x_max - x_min + im_crop = np.copy(im[y_min:y_max, x_min:x_max]) + y_pix_, x_pix_ = im_crop.shape + im_tmp = np.zeros((y_pix, x_pix)) + im_tmp[:y_pix_, :x_pix_] = im_crop + return im_tmp + + +def load_latents(hparams, version, dtype='val'): + """Load all latents as a single array. + + Parameters + ---------- + hparams : :obj:`dict` + needs to contain enough information to specify both a model and the associated data + version : :obj:`int` + version from test tube experiment defined in :obj:`hparams` + dtype : :obj:`str` + 'train' | 'val' | 'test' + + Returns + ------- + :obj:`np.ndarray` + shape (time, n_latents) + + """ + sess_id = str('%s_%s_%s_%s_latents.pkl' % ( + hparams['lab'], hparams['expt'], hparams['animal'], hparams['session'])) + filename = os.path.join( + hparams['expt_dir'], 'version_%i' % version, sess_id) + if not os.path.exists(filename): + raise FileNotFoundError('latents located at %s do not exist' % filename) + latent_dict = pickle.load(open(filename, 'rb')) + print('loaded latents from %s' % filename) + # get all test latents + latents = [] + for trial in latent_dict['trials'][dtype]: + ls = latent_dict['latents'][trial] + latents.append(ls) + return np.concatenate(latents) + + def load_metrics_csv_as_df(hparams, lab, expt, metrics_list, test=False, version='best'): """Load metrics csv file and return as a pandas dataframe for easy plotting. @@ -118,3 +190,34 @@ def load_metrics_csv_as_df(hparams, lab, expt, metrics_list, test=False, version # tr_dict[metric] = row['tr_%s' % metric] # metrics_df.append(pd.DataFrame(tr_dict, index=[0])) return pd.concat(metrics_df, sort=True) + + +def save_movie(save_file, ani, frame_rate=15): + """Save out matplotlib ArtistAnimation + + Parameters + ---------- + save_file : :obj:`str` + full save file (path and filename) + ani : :obj:`matplotlib.animation.ArtistAnimation` object + animation to save + frame_rate : :obj:`int`, optional + frame rate of saved movie + + """ + + if save_file is not None: + make_dir_if_not_exists(save_file) + if save_file[-3:] == 'gif': + print('saving video to %s...' % save_file, end='') + ani.save(save_file, writer='imagemagick', fps=frame_rate) + print('done') + else: + if save_file[-3:] != 'mp4': + save_file += '.mp4' + writer = FFMpegWriter(fps=frame_rate, bitrate=-1) + print('saving video to %s...' % save_file, end='') + ani.save(save_file, writer=writer) + print('done') + + diff --git a/behavenet/plotting/ae_utils.py b/behavenet/plotting/ae_utils.py index 79b0c67..196fc6c 100644 --- a/behavenet/plotting/ae_utils.py +++ b/behavenet/plotting/ae_utils.py @@ -1,21 +1,14 @@ """Plotting and video making functions for autoencoders.""" -import copy import matplotlib.animation as animation import matplotlib.pyplot as plt from matplotlib.gridspec import GridSpec -from matplotlib.animation import FFMpegWriter -import numpy as np -from behavenet.plotting import concat -from behavenet import make_dir_if_not_exists -from behavenet.fitting.utils import get_best_model_and_data from behavenet.fitting.eval import get_reconstruction +from behavenet.fitting.utils import get_best_model_and_data +from behavenet.plotting import concat, save_movie # to ignore imports for sphix-autoapidoc -__all__ = [ - 'make_ae_reconstruction_movie_wrapper', 'make_reconstruction_movie', - 'make_neural_reconstruction_movie_wrapper', 'make_neural_reconstruction_movie', - 'plot_neural_reconstruction_traces_wrapper', 'plot_neural_reconstruction_traces'] +__all__ = ['make_ae_reconstruction_movie_wrapper', 'make_reconstruction_movie'] def make_reconstruction_movie( @@ -99,18 +92,7 @@ def make_reconstruction_movie( plt.tight_layout(pad=0) ani = animation.ArtistAnimation(fig, ims_ani, blit=True, repeat_delay=1000) - writer = FFMpegWriter(fps=frame_rate, bitrate=-1) - - if save_file is not None: - make_dir_if_not_exists(save_file) - if save_file[-3:] != 'mp4': - save_file += '.mp4' - print('saving video to %s...' % save_file, end='') - ani.save(save_file, writer=writer) - # if save_file[-3:] != 'gif': - # save_file += '.gif' - # ani.save(save_file, writer='imagemagick', fps=15) - print('done') + save_movie(save_file, ani, frame_rate=frame_rate) def make_ae_reconstruction_movie_wrapper( @@ -202,433 +184,3 @@ def make_ae_reconstruction_movie_wrapper( make_reconstruction_movie( ims=ims, titles=titles, n_rows=n_rows, n_cols=n_cols, save_file=save_file, frame_rate=frame_rate) - - -def make_neural_reconstruction_movie_wrapper( - hparams, save_file, trial=None, sess_idx=0, max_frames=400, max_latents=8, frame_rate=15): - """Produce movie with original video, ae reconstructed video, and neural reconstructed video. - - This is a high-level function that loads the model described in the hparams dictionary and - produces the necessary predicted video frames. Latent traces are additionally plotted, as well - as the residual between the ae reconstruction and the neural reconstruction. Currently produces - ae latents and decoder predictions from scratch (rather than saved pickle files). - - Parameters - ---------- - hparams : :obj:`dict` - needs to contain enough information to specify an autoencoder - save_file : :obj:`str` - full save file (path and filename) - trial : :obj:`int`, optional - if :obj:`NoneType`, use first test trial - sess_idx : :obj:`int`, optional - session index into data generator - max_frames : :obj:`int`, optional - maximum number of frames to animate from a trial - max_latents : :obj:`int`, optional - maximum number of ae latents to plot - frame_rate : :obj:`float`, optional - frame rate of saved movie - - """ - - from behavenet.models import Decoder - - ############################### - # build ae model/data generator - ############################### - hparams_ae = copy.copy(hparams) - hparams_ae['experiment_name'] = hparams['ae_experiment_name'] - hparams_ae['model_class'] = hparams['ae_model_class'] - hparams_ae['model_type'] = hparams['ae_model_type'] - if hparams['model_class'] == 'ae': - from behavenet.models import AE as Model - elif hparams['model_class'] == 'cond-ae': - from behavenet.models import ConditionalAE as Model - else: - raise NotImplementedError('"%s" is an invalid model class' % hparams['model_class']) - model_ae, data_generator_ae = get_best_model_and_data( - hparams_ae, Model, version=hparams['ae_version']) - # move model to cpu - model_ae.to('cpu') - - if trial is None: - # choose first test trial - trial = data_generator_ae.batch_idxs[sess_idx]['test'][0] - - # get images from data generator (move to cpu) - batch = data_generator_ae.datasets[sess_idx][trial] - ims_orig_pt = batch['images'][:max_frames].cpu() # 400 - if hparams['model_class'] == 'cond-ae': - labels_pt = batch['labels'][:max_frames] - else: - labels_pt = None - - # push images through ae to get reconstruction - ims_recon_ae = get_reconstruction(model_ae, ims_orig_pt, labels=labels_pt) - # push images through ae to get latents - latents_ae_pt, _, _ = model_ae.encoding(ims_orig_pt) - - # mask images for plotting - if hparams.get('use_output_mask', False): - ims_orig_pt *= batch['masks'][:max_frames] - - ####################################### - # build decoder model/no data generator - ####################################### - hparams_dec = copy.copy(hparams) - hparams_dec['experiment_name'] = hparams['decoder_experiment_name'] - hparams_dec['model_class'] = hparams['decoder_model_class'] - hparams_dec['model_type'] = hparams['decoder_model_type'] - - model_dec, data_generator_dec = get_best_model_and_data( - hparams_dec, Decoder, version=hparams['decoder_version']) - # move model to cpu - model_dec.to('cpu') - - # get neural activity from data generator (move to cpu) - batch = data_generator_dec.datasets[0][trial] # 0 not sess_idx since decoders only have 1 sess - neural_activity_pt = batch['neural'][:max_frames].cpu() - - # push neural activity through decoder to get prediction - latents_dec_pt, _ = model_dec(neural_activity_pt) - # push prediction through ae to get reconstruction - ims_recon_dec = get_reconstruction(model_ae, latents_dec_pt, labels=labels_pt) - - # away - make_neural_reconstruction_movie( - ims_orig=ims_orig_pt.cpu().detach().numpy(), - ims_recon_ae=ims_recon_ae, - ims_recon_neural=ims_recon_dec, - latents_ae=latents_ae_pt.cpu().detach().numpy()[:, :max_latents], - latents_neural=latents_dec_pt.cpu().detach().numpy()[:, :max_latents], - save_file=save_file, - frame_rate=frame_rate) - - -def make_neural_reconstruction_movie( - ims_orig, ims_recon_ae, ims_recon_neural, latents_ae, latents_neural, save_file=None, - frame_rate=15): - """Produce movie with original video, ae reconstructed video, and neural reconstructed video. - - Latent traces are additionally plotted, as well as the residual between the ae reconstruction - and the neural reconstruction. - - Parameters - ---------- - ims_orig : :obj:`np.ndarray` - shape (n_frames, n_channels, y_pix, x_pix) - ims_recon_ae : :obj:`np.ndarray` - shape (n_frames, n_channels, y_pix, x_pix) - ims_recon_neural : :obj:`np.ndarray`, optional - shape (n_frames, n_channels, y_pix, x_pix) - latents_ae : :obj:`np.ndarray`, optional - shape (n_frames, n_latents) - save_file : :obj:`str`, optional - full save file (path and filename) - frame_rate : :obj:`float`, optional - frame rate of saved movie - - """ - - means = np.mean(latents_ae, axis=0) - std = np.std(latents_ae) * 2 - - latents_ae_sc = (latents_ae - means) / std - latents_dec_sc = (latents_neural - means) / std - - n_channels, y_pix, x_pix = ims_orig.shape[1:] - n_time, n_ae_latents = latents_ae.shape - - n_cols = 3 - n_rows = 2 - offset = 2 # 0 if ims_recon_lin is None else 1 - scale_ = 5 - fig_width = scale_ * n_cols * n_channels / 2 - fig_height = y_pix / x_pix * scale_ * n_rows / 2 - fig = plt.figure(figsize=(fig_width, fig_height + offset)) - - gs = GridSpec(n_rows, n_cols, figure=fig) - axs = [] - axs.append(fig.add_subplot(gs[0, 0])) # 0: original frames - axs.append(fig.add_subplot(gs[0, 1])) # 1: ae reconstructed frames - axs.append(fig.add_subplot(gs[0, 2])) # 2: neural reconstructed frames - axs.append(fig.add_subplot(gs[1, 0])) # 3: residual - axs.append(fig.add_subplot(gs[1, 1:3])) # 4: ae and predicted ae latents - for i, ax in enumerate(fig.axes): - ax.set_yticks([]) - if i > 2: - ax.get_xaxis().set_tick_params(labelsize=12, direction='in') - axs[0].set_xticks([]) - axs[1].set_xticks([]) - axs[2].set_xticks([]) - axs[3].set_xticks([]) - - # check that the axes are correct - fontsize = 12 - idx = 0 - axs[idx].set_title('Original', fontsize=fontsize); idx += 1 - axs[idx].set_title('AE reconstructed', fontsize=fontsize); idx += 1 - axs[idx].set_title('Neural reconstructed', fontsize=fontsize); idx += 1 - axs[idx].set_title('Reconstructions residual', fontsize=fontsize); idx += 1 - axs[idx].set_title('AE latent predictions', fontsize=fontsize) - axs[idx].set_xlabel('Time (bins)', fontsize=fontsize) - - time = np.arange(n_time) - - ims_res = ims_recon_ae - ims_recon_neural - - im_kwargs = {'animated': True, 'cmap': 'gray', 'vmin': 0, 'vmax': 1} - tr_kwargs = {'animated': True, 'linewidth': 2} - latents_ae_color = [0.2, 0.2, 0.2] - latents_dec_color = [0, 0, 0] - - # ims is a list of lists, each row is a list of artists to draw in the - # current frame; here we are just animating one artist, the image, in - # each frame - ims = [] - for i in range(n_time): - - ims_curr = [] - idx = 0 - - if i % 100 == 0: - print('processing frame %03i/%03i' % (i, n_time)) - - ################### - # behavioral videos - ################### - # original video - ims_tmp = ims_orig[i, 0] if n_channels == 1 else concat(ims_orig[i]) - im = axs[idx].imshow(ims_tmp, **im_kwargs) - ims_curr.append(im) - idx += 1 - - # ae reconstruction - ims_tmp = ims_recon_ae[i, 0] if n_channels == 1 else concat(ims_recon_ae[i]) - im = axs[idx].imshow(ims_tmp, **im_kwargs) - ims_curr.append(im) - idx += 1 - - # neural reconstruction - ims_tmp = ims_recon_neural[i, 0] if n_channels == 1 else concat(ims_recon_neural[i]) - im = axs[idx].imshow(ims_tmp, **im_kwargs) - ims_curr.append(im) - idx += 1 - - # residual - ims_tmp = ims_res[i, 0] if n_channels == 1 else concat(ims_res[i]) - im = axs[idx].imshow(0.5 + ims_tmp, **im_kwargs) - ims_curr.append(im) - idx += 1 - - ######## - # traces - ######## - # latents over time - for latent in range(n_ae_latents): - # just put labels on last lvs - if latent == n_ae_latents - 1 and i == 0: - label_ae = 'AE latents' - label_dec = 'Predicted AE latents' - else: - label_ae = None - label_dec = None - im = axs[idx].plot( - time[0:i + 1], latent + latents_ae_sc[0:i + 1, latent], - color=latents_ae_color, alpha=0.7, label=label_ae, - **tr_kwargs)[0] - axs[idx].spines['top'].set_visible(False) - axs[idx].spines['right'].set_visible(False) - axs[idx].spines['left'].set_visible(False) - ims_curr.append(im) - im = axs[idx].plot( - time[0:i + 1], latent + latents_dec_sc[0:i + 1, latent], - color=latents_dec_color, label=label_dec, **tr_kwargs)[0] - axs[idx].spines['top'].set_visible(False) - axs[idx].spines['right'].set_visible(False) - axs[idx].spines['left'].set_visible(False) - plt.legend( - loc='lower right', fontsize=fontsize, frameon=True, - framealpha=0.7, edgecolor=[1, 1, 1]) - ims_curr.append(im) - ims.append(ims_curr) - - plt.tight_layout(pad=0) - - ani = animation.ArtistAnimation(fig, ims, blit=True, repeat_delay=1000) - writer = FFMpegWriter(fps=frame_rate, bitrate=-1) - - if save_file is not None: - make_dir_if_not_exists(save_file) - if save_file[-3:] != 'mp4': - save_file += '.mp4' - print('saving video to %s...' % save_file, end='') - ani.save(save_file, writer=writer) - print('done') - - -def plot_neural_reconstruction_traces_wrapper( - hparams, save_file=None, trial=None, xtick_locs=None, frame_rate=None, format='png'): - """Plot ae latents and their neural reconstructions. - - This is a high-level function that loads the model described in the hparams dictionary and - produces the necessary predicted latents. - - Parameters - ---------- - hparams : :obj:`dict` - needs to contain enough information to specify an ae latent decoder - save_file : :obj:`str` - full save file (path and filename) - trial : :obj:`int`, optional - if :obj:`NoneType`, use first test trial - xtick_locs : :obj:`array-like`, optional - tick locations in units of bins - frame_rate : :obj:`float`, optional - frame rate of behavorial video; to properly relabel xticks - format : :obj:`str`, optional - any accepted matplotlib save format, e.g. 'png' | 'pdf' | 'jpeg' - - Returns - ------- - :obj:`matplotlib.figure.Figure` - matplotlib figure handle of plot - - """ - - # find good trials - import copy - from behavenet.data.utils import get_transforms_paths - from behavenet.data.data_generator import ConcatSessionsGenerator - - # ae data - hparams_ae = copy.copy(hparams) - hparams_ae['experiment_name'] = hparams['ae_experiment_name'] - hparams_ae['model_class'] = hparams['ae_model_class'] - hparams_ae['model_type'] = hparams['ae_model_type'] - - ae_transform, ae_path = get_transforms_paths('ae_latents', hparams_ae, None) - - # ae predictions data - hparams_dec = copy.copy(hparams) - hparams_dec['neural_ae_experiment_name'] = hparams['decoder_experiment_name'] - hparams_dec['neural_ae_model_class'] = hparams['decoder_model_class'] - hparams_dec['neural_ae_model_type'] = hparams['decoder_model_type'] - ae_pred_transform, ae_pred_path = get_transforms_paths( - 'neural_ae_predictions', hparams_dec, None) - - signals = ['ae_latents', 'ae_predictions'] - transforms = [ae_transform, ae_pred_transform] - paths = [ae_path, ae_pred_path] - - data_generator = ConcatSessionsGenerator( - hparams['data_dir'], [hparams], - signals_list=[signals], transforms_list=[transforms], paths_list=[paths], - device='cpu', as_numpy=False, batch_load=False, rng_seed=0) - - if trial is None: - # choose first test trial - trial = data_generator.datasets[0].batch_idxs['test'][0] - - batch = data_generator.datasets[0][trial] - traces_ae = batch['ae_latents'].cpu().detach().numpy() - traces_neural = batch['ae_predictions'].cpu().detach().numpy() - - fig = plot_neural_reconstruction_traces( - traces_ae, traces_neural, save_file, xtick_locs, frame_rate, format) - - return fig - - -def plot_neural_reconstruction_traces( - traces_ae, traces_neural, save_file=None, xtick_locs=None, frame_rate=None, format='png', - scale=0.5, max_traces=8, add_r2=True): - """Plot ae latents and their neural reconstructions. - - Parameters - ---------- - traces_ae : :obj:`np.ndarray` - shape (n_frames, n_latents) - traces_neural : :obj:`np.ndarray` - shape (n_frames, n_latents) - save_file : :obj:`str`, optional - full save file (path and filename) - xtick_locs : :obj:`array-like`, optional - tick locations in units of bins - frame_rate : :obj:`float`, optional - frame rate of behavorial video; to properly relabel xticks - format : :obj:`str`, optional - any accepted matplotlib save format, e.g. 'png' | 'pdf' | 'jpeg' - scale : :obj:`int`, optional - scale magnitude of traces - max_traces : :obj:`int`, optional - maximum number of traces to plot, for easier visualization - add_r2 : :obj:`bool`, optional - print R2 value on plot - - Returns - ------- - :obj:`matplotlib.figure.Figure` - matplotlib figure handle - - """ - - import matplotlib.pyplot as plt - import matplotlib.lines as mlines - import seaborn as sns - - sns.set_style('white') - sns.set_context('poster') - - means = np.mean(traces_ae, axis=0) - std = np.std(traces_ae) / scale # scale for better visualization - - traces_ae_sc = (traces_ae - means) / std - traces_neural_sc = (traces_neural - means) / std - - traces_ae_sc = traces_ae_sc[:, :max_traces] - traces_neural_sc = traces_neural_sc[:, :max_traces] - - fig = plt.figure(figsize=(12, 8)) - plt.plot(traces_neural_sc + np.arange(traces_neural_sc.shape[1]), linewidth=3) - plt.plot( - traces_ae_sc + np.arange(traces_ae_sc.shape[1]), color=[0.2, 0.2, 0.2], linewidth=3, - alpha=0.7) - - # add legend - # original latents - gray - orig_line = mlines.Line2D([], [], color=[0.2, 0.2, 0.2], linewidth=3, alpha=0.7) - # predicted latents - cycle through some colors - colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] - dls = [] - for c in range(5): - dls.append(mlines.Line2D( - [], [], linewidth=3, linestyle='--', dashes=(0, 3 * c, 20, 1), color='%s' % colors[c])) - plt.legend( - [orig_line, tuple(dls)], ['Original latents', 'Predicted latents'], - loc='lower right', frameon=True, framealpha=0.7, edgecolor=[1, 1, 1]) - - # add r2 info if desired - if add_r2: - from sklearn.metrics import r2_score - r2 = r2_score(traces_ae, traces_neural, multioutput='variance_weighted') - plt.text( - 0.05, 0.06, '$R^2$=%1.3f' % r2, horizontalalignment='left', verticalalignment='bottom', - transform=plt.gca().transAxes, - bbox=dict(facecolor='white', alpha=0.7, edgecolor=[1, 1, 1])) - - if xtick_locs is not None and frame_rate is not None: - plt.xticks(xtick_locs, (np.asarray(xtick_locs) / frame_rate).astype('int')) - plt.xlabel('Time (s)') - else: - plt.xlabel('Time (bins)') - plt.ylabel('Latent state') - plt.yticks([]) - - if save_file is not None: - make_dir_if_not_exists(save_file) - plt.savefig(save_file + '.' + format, dpi=300, format=format) - - plt.show() - return fig diff --git a/behavenet/plotting/arhmm_utils.py b/behavenet/plotting/arhmm_utils.py index f6fb28f..9fbfd68 100644 --- a/behavenet/plotting/arhmm_utils.py +++ b/behavenet/plotting/arhmm_utils.py @@ -4,17 +4,17 @@ import os import numpy as np import torch -import scipy import matplotlib.pyplot as plt import matplotlib import matplotlib.animation as animation -from matplotlib.animation import FFMpegWriter from behavenet import make_dir_if_not_exists from behavenet.models import AE as AE +from behavenet.plotting import save_movie # to ignore imports for sphix-autoapidoc __all__ = [ - 'get_discrete_chunks', 'get_state_durations', 'get_latent_arrays_by_dtype', 'get_model_latents_states', + 'get_discrete_chunks', 'get_state_durations', 'get_latent_arrays_by_dtype', + 'get_model_latents_states', 'make_syllable_movies_wrapper', 'make_syllable_movies', 'real_vs_sampled_wrapper', 'make_real_vs_sampled_movies', 'plot_real_vs_sampled', 'plot_states_overlaid_with_latents', 'plot_state_transition_matrix', 'plot_dynamics_matrices', @@ -56,9 +56,11 @@ def get_discrete_chunks(states, include_edges=True): which_state = chunk[split_indices[i]+1] if not include_edges: if split_indices[i] != 0 and split_indices[i+1] != (len(chunk)-2): - indexing_list[which_state].append([i_chunk, split_indices[i], split_indices[i+1]]) + indexing_list[which_state].append( + [i_chunk, split_indices[i], split_indices[i+1]]) else: - indexing_list[which_state].append([i_chunk, split_indices[i], split_indices[i + 1]]) + indexing_list[which_state].append( + [i_chunk, split_indices[i], split_indices[i+1]]) # convert lists to numpy arrays indexing_list = [np.asarray(indexing_list[i_state]) for i_state in range(max_state + 1)] @@ -184,7 +186,8 @@ def get_model_latents_states( else: _, version = experiment_exists(hparams, which_version=True) if version is None: - raise FileNotFoundError('Could not find the specified model version in %s' % hparams['expt_dir']) + raise FileNotFoundError( + 'Could not find the specified model version in %s' % hparams['expt_dir']) # load model model_file = os.path.join(hparams['expt_dir'], 'version_%i' % version, 'best_val_model.pt') @@ -374,13 +377,13 @@ def make_syllable_movies( maximum number of frames to animate frame_rate : :obj:`float`, optional frame rate of saved movie - n_buffer : :obj:`int` + n_buffer : :obj:`int`, optional number of blank frames between syllable instances - n_pre_frames : :obj:`int` + n_pre_frames : :obj:`int`, optional number of behavioral frames to precede each syllable instance - n_rows : :obj:`int` or :obj:`NoneType` + n_rows : :obj:`int` or :obj:`NoneType`, optional number of rows in output movie - single_syllable : :obj:`int` or :obj:`NoneType` + single_syllable : :obj:`int` or :obj:`NoneType`, optional choose only a single state for movie """ @@ -492,7 +495,6 @@ def make_syllable_movies( ani = animation.ArtistAnimation( fig, [ims[i] for i in range(len(ims)) if ims[i] != []], interval=20, blit=True, repeat=False) - writer = FFMpegWriter(fps=max(frame_rate, 10), bitrate=-1) print('done') if save_file is not None: @@ -505,10 +507,7 @@ def make_syllable_movies( state_str = '' save_file += state_str save_file += '.mp4' - make_dir_if_not_exists(save_file) - print('saving video to %s...' % save_file, end='') - ani.save(save_file, writer=writer) - print('done') + save_movie(save_file, ani, frame_rate=frame_rate) def real_vs_sampled_wrapper( @@ -636,7 +635,8 @@ def real_vs_sampled_wrapper( fig = plot_real_vs_sampled( latents, latents_samp, states, states_samp, save_file=save_file, xtick_locs=xtick_locs, - frame_rate=frame_rate_beh, format=format) + frame_rate=hparams['frame_rate'] if frame_rate_beh is None else frame_rate_beh, + format=format) if output_type == 'movie': return None @@ -697,15 +697,7 @@ def make_real_vs_sampled_movies( ims.append(ims_curr) ani = animation.ArtistAnimation(fig, ims, blit=True, repeat_delay=1000) - writer = FFMpegWriter(fps=frame_rate, bitrate=-1) - - if save_file is not None: - make_dir_if_not_exists(save_file) - if save_file[-3:] != 'mp4': - save_file += '.mp4' - print('saving video to %s...' % save_file, end='') - ani.save(save_file, writer=writer) - print('done') + save_movie(save_file, ani, frame_rate=frame_rate) def plot_real_vs_sampled( @@ -768,7 +760,8 @@ def plot_real_vs_sampled( def plot_states_overlaid_with_latents( - latents, states, save_file=None, ax=None, xtick_locs=None, frame_rate=None, format='png'): + latents, states, save_file=None, ax=None, xtick_locs=None, frame_rate=None, cmap='tab20b', + format='png'): """Plot states for a single trial overlaid with latents. Parameters @@ -785,6 +778,8 @@ def plot_states_overlaid_with_latents( tick locations in bin values for plot frame_rate : :obj:`float`, optional behavioral video framerate; to properly relabel xticks + cmap : :obj:`str`, optional + matplotlib colormap format : :obj:`str`, optional any accepted matplotlib save format, e.g. 'png' | 'pdf' | 'jpeg' @@ -802,10 +797,10 @@ def plot_states_overlaid_with_latents( spc = 1.1 * abs(latents.max()) n_latents = latents.shape[1] plotting_latents = latents + spc * np.arange(n_latents) - ymin = min(-spc - 1, np.min(plotting_latents)) + ymin = min(-spc, np.min(plotting_latents)) ymax = max(spc * n_latents, np.max(plotting_latents)) ax.imshow( - states[None, :], aspect='auto', extent=(0, len(latents), ymin, ymax), cmap='tab20b', + states[None, :], aspect='auto', extent=(0, len(latents), ymin, ymax), cmap=cmap, alpha=1.0) ax.plot(plotting_latents, '-k', lw=3) ax.set_ylim([ymin, ymax]) @@ -916,8 +911,8 @@ def plot_dynamics_matrices(model, deridge=False): for k in range(K): plt.subplot(n_rows, n_cols, k + 1) im = plt.imshow(mats[k], cmap='RdBu_r', clim=[-clim, clim]) - for l in range(n_lags - 1): - plt.axvline((l + 1) * D - 0.5, ymin=0, ymax=K, color=[0, 0, 0]) + for lag in range(n_lags - 1): + plt.axvline((lag + 1) * D - 0.5, ymin=0, ymax=K, color=[0, 0, 0]) plt.xticks([]) plt.yticks([]) plt.title('State %i' % k) diff --git a/behavenet/plotting/cond_ae_utils.py b/behavenet/plotting/cond_ae_utils.py index 8da6167..8d22ef1 100644 --- a/behavenet/plotting/cond_ae_utils.py +++ b/behavenet/plotting/cond_ae_utils.py @@ -2,58 +2,45 @@ import copy import pickle import numpy as np +import matplotlib.animation as animation import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns import torch +from tqdm import tqdm +from behavenet import get_user_dir from behavenet import make_dir_if_not_exists +from behavenet.data.utils import build_data_generator from behavenet.data.utils import load_labels_like_latents from behavenet.fitting.eval import get_reconstruction +from behavenet.fitting.utils import experiment_exists +from behavenet.fitting.utils import get_best_model_and_data +from behavenet.fitting.utils import get_expt_dir +from behavenet.fitting.utils import get_lab_example from behavenet.fitting.utils import get_session_dir +from behavenet.plotting import concat +from behavenet.plotting import get_crop +from behavenet.plotting import load_latents +from behavenet.plotting import load_metrics_csv_as_df +from behavenet.plotting import save_movie # to ignore imports for sphix-autoapidoc __all__ = [ - 'get_crop', 'get_input_range', 'compute_range', 'get_labels_2d_for_trial', 'get_model_input', - 'interpolate_2d', 'interpolate_1d', 'plot_2d_frame_array', 'plot_1d_frame_array'] + 'get_input_range', 'compute_range', 'get_labels_2d_for_trial', 'get_model_input', + 'interpolate_2d', 'interpolate_1d', 'interpolate_point_path', 'plot_2d_frame_array', + 'plot_1d_frame_array', 'make_interpolated', 'make_interpolated_multipanel', + 'plot_psvae_training_curves', 'plot_hyperparameter_search_results', + 'plot_label_reconstructions', 'plot_latent_traversals', 'make_latent_traversal_movie'] -def get_crop(im, y_0, y_ext, x_0, x_ext): - """Get crop of image, filling in borders with zeros. - - Parameters - ---------- - im : :obj:`np.ndarray` - input image - y_0 : :obj:`int` - y-pixel center value - y_ext : :obj:`int` - y-pixel extent; crop in y-direction will be [y_0 - y_ext, y_0 + y_ext] - x_0 : :obj:`int` - y-pixel center value - x_ext : :obj:`int` - x-pixel extent; crop in x-direction will be [x_0 - x_ext, x_0 + x_ext] - - Returns - ------- - :obj:`np.ndarray` - cropped image - - """ - y_min = y_0 - y_ext - y_max = y_0 + y_ext - y_pix = y_max - y_min - x_min = x_0 - x_ext - x_max = x_0 + x_ext - x_pix = x_max - x_min - im_crop = np.copy(im[y_min:y_max, x_min:x_max]) - y_pix_, x_pix_ = im_crop.shape - im_tmp = np.zeros((y_pix, x_pix)) - im_tmp[:y_pix_, :x_pix_] = im_crop - return im_tmp - +# ---------------------------------------- +# low-level util functions +# ---------------------------------------- def get_input_range( input_type, hparams, sess_ids=None, sess_idx=0, model=None, data_gen=None, version=0, - min_p=5, max_p=95): + min_p=5, max_p=95, apply_label_masks=False): """Helper function to compute input range for a variety of data types. Parameters @@ -77,6 +64,8 @@ def get_input_range( defines lower end of range; percentile in [0, 100] max_p : :obj:`int`, optional defines upper end of range; percentile in [0, 100] + apply_label_masks : :obj:`bool`, optional + `True` to set masked values to NaN in labels Returns ------- @@ -110,6 +99,13 @@ def get_input_range( inputs = labels_sc['latents'] else: raise NotImplementedError + + if apply_label_masks: + masks = load_labels_like_latents( + hparams, sess_ids, sess_idx=sess_idx, data_key='labels_masks') + for i, m in zip(inputs, masks): + i[m == 0] = np.nan + input_range = compute_range(inputs, min_p=min_p, max_p=max_p) return input_range @@ -141,8 +137,8 @@ def compute_range(values_list, min_p=5, max_p=95): else: values = np.vstack(values_list) ranges = { - 'min': np.percentile(values, min_p, axis=0), - 'max': np.percentile(values, max_p, axis=0)} + 'min': np.nanpercentile(values, min_p, axis=0), + 'max': np.nanpercentile(values, max_p, axis=0)} return ranges @@ -181,7 +177,6 @@ def get_labels_2d_for_trial( raise ValueError('only one of "trial" or "trial_idx" can be specified') if data_gen is None: - from behavenet.fitting.utils import build_data_generator hparams_new = copy.deepcopy(hparams) hparams_new['conditional_encoder'] = True # ensure scaled labels are returned hparams_new['device'] = 'cpu' @@ -200,7 +195,7 @@ def get_labels_2d_for_trial( def get_model_input( - data_generator, hparams, model, trial=None, trial_idx=None, sess_idx=0, max_frames=100, + data_generator, hparams, model, trial=None, trial_idx=None, sess_idx=0, max_frames=200, compute_latents=False, compute_2d_labels=True, compute_scaled_labels=False, dtype='test'): """Return images, latents, and labels for a given trial. @@ -247,6 +242,8 @@ def get_model_input( if (trial_idx is not None) and (trial is not None): raise ValueError('only one of "trial" or "trial_idx" can be specified') + if (trial_idx is None) and (trial is None): + raise ValueError('one of "trial" or "trial_idx" must be specified') # get trial if trial is None: @@ -264,7 +261,7 @@ def get_model_input( elif hparams['model_class'] == 'cond-ae' \ or hparams['model_class'] == 'cond-vae' \ or hparams['model_class'] == 'cond-ae-msp' \ - or hparams['model_class'] == 'sss-vae' \ + or hparams['model_class'] == 'ps-vae' \ or hparams['model_class'] == 'labels-images': labels_pt = batch['labels'][:max_frames] labels_np = labels_pt.cpu().detach().numpy() @@ -290,7 +287,7 @@ def get_model_input( # latents if compute_latents: - if hparams['model_class'] == 'cond-ae-msp' or hparams['model_class'] == 'sss-vae': + if hparams['model_class'] == 'cond-ae-msp' or hparams['model_class'] == 'ps-vae': latents_np = model.get_transformed_latents(ims_pt, dataset=sess_idx, as_numpy=True) else: _, latents_np = get_reconstruction( @@ -415,7 +412,7 @@ def interpolate_2d( if model.hparams['model_class'] == 'ae' \ or model.hparams['model_class'] == 'vae' \ or model.hparams['model_class'] == 'beta-tcvae' \ - or model.hparams['model_class'] == 'sss-vae': + or model.hparams['model_class'] == 'ps-vae': labels = None elif model.hparams['model_class'] == 'cond-ae' \ or model.hparams['model_class'] == 'cond-vae': @@ -441,7 +438,7 @@ def interpolate_2d( labels_2d = None if model.hparams['model_class'] == 'cond-ae-msp' \ - or model.hparams['model_class'] == 'sss-vae': + or model.hparams['model_class'] == 'ps-vae': # change latents that correspond to desired labels latents = np.copy(latents_0) latents[0, input_idxs[0]] = inputs[0][i0] @@ -605,7 +602,7 @@ def interpolate_1d( if model.hparams['model_class'] == 'ae' \ or model.hparams['model_class'] == 'vae' \ or model.hparams['model_class'] == 'beta-tcvae' \ - or model.hparams['model_class'] == 'sss-vae': + or model.hparams['model_class'] == 'ps-vae': labels = None elif model.hparams['model_class'] == 'cond-ae' \ or model.hparams['model_class'] == 'cond-vae': @@ -631,7 +628,7 @@ def interpolate_1d( labels_2d = None if model.hparams['model_class'] == 'cond-ae-msp' \ - or model.hparams['model_class'] == 'sss-vae': + or model.hparams['model_class'] == 'ps-vae': # change latents that correspond to desired labels latents = np.copy(latents_0) latents[0, input_idxs[i0]] = inputs[i0][i1] @@ -684,6 +681,118 @@ def interpolate_1d( return ims_list, labels_list, ims_crop_list +def interpolate_point_path( + interp_type, model, ims_0, labels_0, points, n_frames=10, ch=0, crop_kwargs=None, + apply_inverse_transform=True): + """Return reconstructed images created by interpolating through multiple points. + + This function is a simplified version of :func:`interpolate_1d()`; this function computes a + traversal for a single dimension instead of all dimensions; also, this function does not + support conditional encoders, nor does it attempt to compute the interpolated, scaled values + of the labels as :func:`interpolate_1d()` does. This function should supercede + :func:`interpolate_1d()` in a future refactor. Also note that this function is utilized by + the code to make traversal movies, whereas :func:`interpolate_1d()` is utilized by the code to + make traversal plots. + + Parameters + ---------- + interp_type : :obj:`str` + 'latents' | 'labels' + model : :obj:`behavenet.models` object + autoencoder model + ims_0 : :obj:`np.ndarray` + base images for interpolating labels, of shape (1, n_channels, y_pix, x_pix) + labels_0 : :obj:`np.ndarray` + base labels of shape (1, n_labels); these values will be used if + `interp_type='latents'`, and they will be ignored if `inter_type='labels'` + (since `points` will be used) + points : :obj:`list` + one entry for each point in path; each entry is an np.ndarray of shape (n_latents,) + n_frames : :obj:`int` or :obj:`array-like` + number of interpolation points between each point; can be an integer that is used + for all paths, or an array/list of length one less than number of points + ch : :obj:`int`, optional + specify which channel of input images to return (can only be a single value) + crop_kwargs : :obj:`dict`, optional + if crop_type is not None, provides information about the crop (for a fixed crop window) + keys : 'y_0', 'x_0', 'y_ext', 'x_ext'; window is + (y_0 - y_ext, y_0 + y_ext) in vertical direction and + (x_0 - x_ext, x_0 + x_ext) in horizontal direction + apply_inverse_transform : :obj:`bool` + if inputs are latents (and model class is 'cond-ae-msp' or 'ps-vae'), apply inverse + transform to put in original latent space + + Returns + ------- + :obj:`tuple` + - ims_list (:obj:`list` of :obj:`np.ndarray`) interpolated images + - inputs_list (:obj:`list` of :obj:`np.ndarray`) interpolated values + + """ + + if model.hparams.get('conditional_encoder', False): + raise NotImplementedError + + n_points = len(points) + if isinstance(n_frames, int): + n_frames = [n_frames] * (n_points - 1) + assert len(n_frames) == (n_points - 1) + + ims_list = [] + inputs_list = [] + + for p in range(n_points - 1): + + p0 = points[None, p] + p1 = points[None, p + 1] + p_vec = (p1 - p0) / n_frames[p] + + for pn in range(n_frames[p]): + + vec = p0 + pn * p_vec + + if interp_type == 'latents': + + if model.hparams['model_class'] == 'cond-ae' \ + or model.hparams['model_class'] == 'cond-vae': + im_tmp = get_reconstruction( + model, vec, apply_inverse_transform=apply_inverse_transform, + labels=torch.from_numpy(labels_0).float().to(model.hparams['device'])) + else: + im_tmp = get_reconstruction( + model, vec, apply_inverse_transform=apply_inverse_transform) + + elif interp_type == 'labels': + + if model.hparams['model_class'] == 'cond-ae-msp' \ + or model.hparams['model_class'] == 'ps-vae': + im_tmp = get_reconstruction( + model, vec, apply_inverse_transform=True) + else: # cond-ae + im_tmp = get_reconstruction( + model, ims_0, + labels=torch.from_numpy(vec).float().to(model.hparams['device'])) + else: + raise NotImplementedError + + if crop_kwargs is not None: + if not isinstance(ch, int): + raise ValueError('"ch" must be an integer to use crop_kwargs') + ims_list.append(get_crop( + im_tmp[0, ch], + crop_kwargs['y_0'], crop_kwargs['y_ext'], + crop_kwargs['x_0'], crop_kwargs['x_ext'])) + else: + if isinstance(ch, int): + ims_list.append(np.copy(im_tmp[0, ch])) + else: + ims_list.append(np.copy(concat(im_tmp[0]))) + + inputs_list.append(vec) + + return ims_list, inputs_list + + def _get_updated_scaled_labels(labels_og, idxs=None, vals=None): """Helper function for interpolate_xd functions.""" @@ -714,8 +823,13 @@ def _get_updated_scaled_labels(labels_og, idxs=None, vals=None): return labels_sc +# ---------------------------------------- +# mid-level plotting functions +# ---------------------------------------- + def plot_2d_frame_array( - ims_list, markers=None, im_kwargs=None, marker_kwargs=None, figsize=None, save_file=None): + ims_list, markers=None, im_kwargs=None, marker_kwargs=None, figsize=None, save_file=None, + format='pdf'): """Plot list of list of interpolated images output by :func:`interpolate_2d()` in a 2d grid. Parameters @@ -729,10 +843,12 @@ def plot_2d_frame_array( kwargs for `matplotlib.pyplot.imshow()` function (vmin, vmax, cmap, etc) marker_kwargs : :obj:`dict` or NoneType, optional kwargs for `matplotlib.pyplot.plot()` function (markersize, markeredgewidth, etc) - figsize : :obj:`tuple` + figsize : :obj:`tuple`, optional (width, height) in inches save_file : :obj:`str` or NoneType, optional figure saved if not None + format : :obj:`str`, optional + format of saved image; 'pdf' | 'png' | 'jpeg' | ... """ @@ -761,13 +877,13 @@ def plot_2d_frame_array( plt.subplots_adjust(wspace=0, hspace=0, bottom=0, left=0, top=1, right=1) if save_file is not None: make_dir_if_not_exists(save_file) - plt.savefig(save_file, dpi=300, bbox_inches='tight') + plt.savefig(save_file + '.' + format, dpi=300, bbox_inches='tight') plt.show() def plot_1d_frame_array( - ims_list, markers=None, im_kwargs=None, marker_kwargs=None, figsize=None, save_file=None, - plot_ims=True, plot_diffs=True): + ims_list, markers=None, im_kwargs=None, marker_kwargs=None, plot_ims=True, plot_diffs=True, + figsize=None, save_file=None, format='pdf'): """Plot list of list of interpolated images output by :func:`interpolate_1d()` in a 2d grid. Parameters @@ -781,14 +897,16 @@ def plot_1d_frame_array( kwargs for `matplotlib.pyplot.imshow()` function (vmin, vmax, cmap, etc) marker_kwargs : :obj:`dict` or NoneType, optional kwargs for `matplotlib.pyplot.plot()` function (markersize, markeredgewidth, etc) - figsize : :obj:`tuple` + plot_ims : :obj:`bool`, optional + plot images + plot_diffs : :obj:`bool`, optional + plot differences + figsize : :obj:`tuple`, optional (width, height) in inches save_file : :obj:`str` or NoneType, optional figure saved if not None - plot_ims : :obj:`bool` - plot images - plot_diffs : :obj:`bool` - plot differences + format : :obj:`str`, optional + format of saved image; 'pdf' | 'png' | 'jpeg' | ... """ @@ -838,5 +956,1362 @@ def plot_1d_frame_array( plt.subplots_adjust(wspace=0, hspace=0, bottom=0, left=0, top=1, right=1) if save_file is not None: make_dir_if_not_exists(save_file) - plt.savefig(save_file, dpi=300, bbox_inches='tight') + plt.savefig(save_file + '.' + format, dpi=300, bbox_inches='tight') plt.show() + + +def make_interpolated( + ims, save_file, markers=None, text=None, text_title=None, text_color=[1, 1, 1], + frame_rate=20, scale=3, markersize=10, markeredgecolor='w', markeredgewidth=1, ax=None): + """Make a latent space interpolation movie. + + Parameters + ---------- + ims : :obj:`list` of :obj:`np.ndarray` + each list element is an array of shape (y_pix, x_pix) + save_file : :obj:`str` + absolute path of save file; does not need file extension, will automatically be saved as + mp4. To save as a gif, include the '.gif' file extension in `save_file`. The movie will + only be saved if `ax` is `NoneType`; else the list of animated frames is returned + markers : :obj:`array-like`, optional + array of size (n_frames, 2) which specifies the (x, y) coordinates of a marker on each + frame + text : :obj:`array-like`, optional + array of size (n_frames) which specifies text printed in the lower left corner of each + frame + text_title : :obj:`array-like`, optional + array of size (n_frames) which specifies text printed in the upper left corner of each + frame + text_color : :obj:`array-like`, optional + rgb array specifying color of `text` and `text_title`, if applicable + frame_rate : :obj:`float`, optional + frame rate of saved movie + scale : :obj:`float`, optional + width of panel is (scale / 2) inches + markersize : :obj:`float`, optional + size of marker if `markers` is not `NoneType` + markeredgecolor : :obj:`float`, optional + color of marker edge if `markers` is not `NoneType` + markeredgewidth : :obj:`float`, optional + width of marker edge if `markers` is not `NoneType` + ax : :obj:`matplotlib.axes.Axes` object + optional axis in which to plot the frames; if this argument is not `NoneType` the list of + animated frames is returned and the movie is not saved + + Returns + ------- + :obj:`list` + list of list of animated frames if `ax` is True; else save movie + + """ + + y_pix, x_pix = ims[0].shape + + if ax is None: + fig_width = scale / 2 + fig_height = y_pix / x_pix * scale / 2 + fig = plt.figure(figsize=(fig_width, fig_height), dpi=300) + ax = plt.gca() + return_ims = False + else: + return_ims = True + + ax.set_xticks([]) + ax.set_yticks([]) + + default_kwargs = {'animated': True, 'cmap': 'gray', 'vmin': 0, 'vmax': 1} + txt_kwargs = { + 'fontsize': 4, 'color': text_color, 'fontname': 'monospace', + 'horizontalalignment': 'left', 'verticalalignment': 'center', + 'transform': ax.transAxes} + + # ims is a list of lists, each row is a list of artists to draw in the current frame; here we + # are just animating one artist, the image, in each frame + ims_ani = [] + for i, im in enumerate(ims): + im_tmp = [] + im_tmp.append(ax.imshow(im, **default_kwargs)) + # [s.set_visible(False) for s in ax.spines.values()] + if markers is not None: + im_tmp.append(ax.plot( + markers[i, 0], markers[i, 1], '.r', markersize=markersize, + markeredgecolor=markeredgecolor, markeredgewidth=markeredgewidth)[0]) + if text is not None: + im_tmp.append(ax.text(0.02, 0.06, text[i], **txt_kwargs)) + if text_title is not None: + im_tmp.append(ax.text(0.02, 0.92, text_title[i], **txt_kwargs)) + ims_ani.append(im_tmp) + + if return_ims: + return ims_ani + else: + plt.tight_layout(pad=0) + ani = animation.ArtistAnimation(fig, ims_ani, blit=True, repeat_delay=1000) + save_movie(save_file, ani, frame_rate=frame_rate) + + +def make_interpolated_multipanel( + ims, save_file, markers=None, text=None, text_title=None, frame_rate=20, n_cols=3, scale=1, + **kwargs): + """Make a multi-panel latent space interpolation movie. + + Parameters + ---------- + ims : :obj:`list` of :obj:`list` of :obj:`np.ndarray` + each list element is used to for a single panel, and is another list that contains arrays + of shape (y_pix, x_pix) + save_file : :obj:`str` + absolute path of save file; does not need file extension, will automatically be saved as + mp4. To save as a gif, include the '.gif' file extension in `save_file`. + markers : :obj:`list` of :obj:`array-like`, optional + each list element is used for a single panel, and is an array of size (n_frames, 2) + which specifies the (x, y) coordinates of a marker on each frame for that panel + text : :obj:`list` of :obj:`array-like`, optional + each list element is used for a single panel, and is an array of size (n_frames) which + specifies text printed in the lower left corner of each frame for that panel + text_title : :obj:`list` of :obj:`array-like`, optional + each list element is used for a single panel, and is an array of size (n_frames) which + specifies text printed in the upper left corner of each frame for that panel + frame_rate : :obj:`float`, optional + frame rate of saved movie + n_cols : :obj:`int`, optional + movie is `n_cols` panels wide + scale : :obj:`float`, optional + width of panel is (scale / 2) inches + kwargs + arguments are additional arguments to :func:`make_interpolated`, like 'markersize', + 'markeredgewidth', 'markeredgecolor', etc. + + """ + + n_panels = len(ims) + + markers = [None] * n_panels if markers is None else markers + text = [None] * n_panels if text is None else text + + y_pix, x_pix = ims[0][0].shape + n_rows = int(np.ceil(n_panels / n_cols)) + fig_width = scale / 2 * n_cols + fig_height = y_pix / x_pix * scale / 2 * n_rows + fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), dpi=300) + plt.subplots_adjust(wspace=0, hspace=0, left=0, bottom=0, right=1, top=1) + + # fill out empty panels with black frames + while len(ims) < n_rows * n_cols: + ims.append(np.zeros(ims[0].shape)) + markers.append(None) + text.append(None) + + # ims is a list of lists, each row is a list of artists to draw in the current frame; here we + # are just animating one artist, the image, in each frame + ims_ani = [] + for i, (ims_curr, markers_curr, text_curr) in enumerate(zip(ims, markers, text)): + col = i % n_cols + row = int(np.floor(i / n_cols)) + if i == 0: + text_title_str = text_title + else: + text_title_str = None + ims_ani_curr = make_interpolated( + ims=ims_curr, markers=markers_curr, text=text_curr, text_title=text_title_str, + ax=axes[row, col], save_file=None, **kwargs) + ims_ani.append(ims_ani_curr) + + # turn off other axes + i += 1 + while i < n_rows * n_cols: + col = i % n_cols + row = int(np.floor(i / n_cols)) + axes[row, col].set_axis_off() + i += 1 + + # rearrange ims: + # currently a list of length n_panels, each element of which is a list of length n_t + # we need a list of length n_t, each element of which is a list of length n_panels + n_frames = len(ims_ani[0]) + ims_final = [[] for _ in range(n_frames)] + for i in range(n_frames): + for j in range(n_panels): + ims_final[i] += ims_ani[j][i] + + ani = animation.ArtistAnimation(fig, ims_final, blit=True, repeat_delay=1000) + save_movie(save_file, ani, frame_rate=frame_rate) + + +# ---------------------------------------- +# high-level plotting functions +# ---------------------------------------- + +def _get_psvae_hparams(**kwargs): + hparams = { + 'data_dir': get_user_dir('data'), + 'save_dir': get_user_dir('save'), + 'model_class': 'ps-vae', + 'model_type': 'conv', + 'rng_seed_data': 0, + 'trial_splits': '8;1;1;0', + 'train_frac': 1.0, + 'rng_seed_model': 0, + 'fit_sess_io_layers': False, + 'learning_rate': 1e-4, + 'l2_reg': 0, + 'conditional_encoder': False, + 'vae.beta': 1} + # update hparams + for key, val in kwargs.items(): + if key == 'alpha' or key == 'beta' or key == 'gamma': + hparams['ps_vae.%s' % key] = val + else: + hparams[key] = val + return hparams + + +def plot_psvae_training_curves( + lab, expt, animal, session, alphas, betas, gammas, n_ae_latents, rng_seeds_model, + experiment_name, n_labels, dtype='val', save_file=None, format='pdf', **kwargs): + """Create training plots for each term in the ps-vae objective function. + + The `dtype` argument controls which type of trials are plotted ('train' or 'val'). + Additionally, multiple models can be plotted simultaneously by varying one (and only one) of + the following parameters: + + - alpha + - beta + - gamma + - number of unsupervised latents + - random seed used to initialize model weights + + Each of these entries must be an array of length 1 except for one option, which can be an array + of arbitrary length (corresponding to already trained models). This function generates a single + plot with panels for each of the following terms: + + - total loss + - pixel mse + - label R^2 (note the objective function contains the label MSE, but R^2 is easier to parse) + - KL divergence of supervised latents + - index-code mutual information of unsupervised latents + - total correlation of unsupervised latents + - dimension-wise KL of unsupervised latents + - subspace overlap + + Parameters + ---------- + lab : :obj:`str` + lab id + expt : :obj:`str` + expt id + animal : :obj:`str` + animal id + session : :obj:`str` + session id + alphas : :obj:`array-like` + alpha values to plot + betas : :obj:`array-like` + beta values to plot + gammas : :obj:`array-like` + gamma values to plot + n_ae_latents : :obj:`array-like` + unsupervised dimensionalities to plot + rng_seeds_model : :obj:`array-like` + model seeds to plot + experiment_name : :obj:`str` + test-tube experiment name + n_labels : :obj:`int` + dimensionality of supervised latent space + dtype : :obj:`str` + 'train' | 'val' + save_file : :obj:`str`, optional + absolute path of save file; does not need file extension + format : :obj:`str`, optional + format of saved image; 'pdf' | 'png' | 'jpeg' | ... + kwargs + arguments are keys of `hparams`, for example to set `train_frac`, `rng_seed_model`, etc. + + """ + # check for arrays, turn ints into lists + n_arrays = 0 + hue = None + if len(alphas) > 1: + n_arrays += 1 + hue = 'alpha' + if len(betas) > 1: + n_arrays += 1 + hue = 'beta' + if len(gammas) > 1: + n_arrays += 1 + hue = 'gamma' + if len(n_ae_latents) > 1: + n_arrays += 1 + hue = 'n latents' + if len(rng_seeds_model) > 1: + n_arrays += 1 + hue = 'rng seed' + if n_arrays > 1: + raise ValueError( + 'Can only set one of "alphas", "betas", "gammas", "n_ae_latents", or ' + + '"rng_seeds_model" as an array') + + # set model info + hparams = _get_psvae_hparams(experiment_name=experiment_name, **kwargs) + + metrics_list = [ + 'loss', 'loss_data_mse', 'label_r2', + 'loss_zs_kl', 'loss_zu_mi', 'loss_zu_tc', 'loss_zu_dwkl', 'loss_AB_orth'] + + metrics_dfs = [] + i = 0 + for alpha in alphas: + for beta in betas: + for gamma in gammas: + for n_latents in n_ae_latents: + for rng in rng_seeds_model: + + # update hparams + hparams['ps_vae.alpha'] = alpha + hparams['ps_vae.beta'] = beta + hparams['ps_vae.gamma'] = gamma + hparams['n_ae_latents'] = n_latents + n_labels + hparams['rng_seed_model'] = rng + + try: + + get_lab_example(hparams, lab, expt) + hparams['animal'] = animal + hparams['session'] = session + hparams['session_dir'], sess_ids = get_session_dir(hparams) + hparams['expt_dir'] = get_expt_dir(hparams) + _, version = experiment_exists(hparams, which_version=True) + + print( + 'loading results with alpha=%i, beta=%i, gamma=%i (version %i)' % + (alpha, beta, gamma, version)) + + metrics_dfs.append(load_metrics_csv_as_df( + hparams, lab, expt, metrics_list, version=None)) + + metrics_dfs[i]['alpha'] = alpha + metrics_dfs[i]['beta'] = beta + metrics_dfs[i]['gamma'] = gamma + metrics_dfs[i]['n latents'] = hparams['n_ae_latents'] + metrics_dfs[i]['rng seed'] = rng + i += 1 + + except TypeError: + print( + 'could not find model for alpha=%i, beta=%i, gamma=%i' % + (alpha, beta, gamma)) + continue + + metrics_df = pd.concat(metrics_dfs, sort=False) + + sns.set_style('white') + sns.set_context('talk') + data_queried = metrics_df[ + (metrics_df.epoch > 10) & ~pd.isna(metrics_df.val) & (metrics_df.dtype == dtype)] + g = sns.FacetGrid( + data_queried, col='loss', col_wrap=3, hue=hue, sharey=False, height=4) + g = g.map(plt.plot, 'epoch', 'val').add_legend() # , color=".3", fit_reg=False, x_jitter=.1); + + if save_file is not None: + make_dir_if_not_exists(save_file) + g.savefig(save_file + '.' + format, dpi=300, format=format) + + +def plot_hyperparameter_search_results( + lab, expt, animal, session, n_labels, label_names, alpha_weights, alpha_n_ae_latents, + alpha_expt_name, beta_weights, gamma_weights, beta_gamma_n_ae_latents, + beta_gamma_expt_name, alpha, beta, gamma, save_file, batch_size=None, format='pdf', + **kwargs): + """Create a variety of diagnostic plots to assess the ps-vae hyperparameters. + + These diagnostic plots are based on the recommended way to perform a hyperparameter search in + the ps-vae models; first, fix beta=1 and gamma=0, and do a sweep over alpha values and number + of latents (for example alpha=[50, 100, 500, 1000] and n_ae_latents=[2, 4, 8, 16]). The best + alpha value is subjective because it involves a tradeoff between pixel mse and label mse. After + choosing a suitable value, fix alpha and the number of latents and vary beta and gamma. This + function will then plot the following panels: + + - pixel mse as a function of alpha/num latents (for fixed beta/gamma) + - label mse as a function of alpha/num_latents (for fixed beta/gamma) + - pixel mse as a function of beta/gamma (for fixed alpha/n_ae_latents) + - label mse as a function of beta/gamma (for fixed alpha/n_ae_latents) + - index-code mutual information (part of the KL decomposition) as a function of beta/gamma (for + fixed alpha/n_ae_latents) + - total correlation(part of the KL decomposition) as a function of beta/gamma (for fixed + alpha/n_ae_latents) + - dimension-wise KL (part of the KL decomposition) as a function of beta/gamma (for fixed + alpha/n_ae_latents) + - average correlation coefficient across all pairs of unsupervised latent dims as a function of + beta/gamma (for fixed alpha/n_ae_latents) + - subspace overlap computed as ||[A; B] - I||_2^2 for A, B the projections to the supervised + and unsupervised subspaces, respectively, and I the identity - as a function of beta/gamma + (for fixed alpha/n_ae_latents) + - example subspace overlap matrix for gamma=0 and beta=1, with fixed alpha/n_ae_latents + - example subspace overlap matrix for gamma=1000 and beta=1, with fixed alpha/n_ae_latents + + Parameters + ---------- + lab : :obj:`str` + lab id + expt : :obj:`str` + expt id + animal : :obj:`str` + animal id + session : :obj:`str` + session id + n_labels : :obj:`str` + number of label dims + label_names : :obj:`array-like` + names of label dims + alpha_weights : :obj:`array-like` + array of alpha weights for fixed values of beta, gamma + alpha_n_ae_latents : :obj:`array-like` + array of latent dimensionalities for fixed values of beta, gamma using alpha_weights + alpha_expt_name : :obj:`str` + test-tube experiment name of alpha-based hyperparam search + beta_weights : :obj:`array-like` + array of beta weights for a fixed value of alpha + gamma_weights : :obj:`array-like` + array of beta weights for a fixed value of alpha + beta_gamma_n_ae_latents : :obj:`int` + latent dimensionality used for beta-gamma hyperparam search + beta_gamma_expt_name : :obj:`str` + test-tube experiment name of beta-gamma hyperparam search + alpha : :obj:`float` + fixed value of alpha for beta-gamma search + beta : :obj:`float` + fixed value of beta for alpha search + gamma : :obj:`float` + fixed value of gamma for alpha search + save_file : :obj:`str` + absolute path of save file; does not need file extension + batch_size : :obj:`int`, optional + size of batches, used to compute correlation coefficient per batch; if NoneType, the + correlation coefficient is computed across all time points + format : :obj:`str`, optional + format of saved image; 'pdf' | 'png' | 'jpeg' | ... + kwargs + arguments are keys of `hparams`, preceded by either `alpha_` or `beta_gamma_`. For example, + to set the train frac of the alpha models, use `alpha_train_frac`; to set the rng_data_seed + of the beta-gamma models, use `beta_gamma_rng_data_seed`. + + """ + + def apply_masks(data, masks): + return data[masks == 1] + + def get_label_r2(hparams, model, data_generator, version, dtype='val', overwrite=False): + from sklearn.metrics import r2_score + save_file = os.path.join( + hparams['expt_dir'], 'version_%i' % version, 'r2_supervised.csv') + if not os.path.exists(save_file) or overwrite: + if not os.path.exists(save_file): + print('R^2 metrics do not exist; computing from scratch') + else: + print('overwriting metrics at %s' % save_file) + metrics_df = [] + data_generator.reset_iterators(dtype) + for i_test in tqdm(range(data_generator.n_tot_batches[dtype])): + # get next minibatch and put it on the device + data, sess = data_generator.next_batch(dtype) + x = data['images'][0] + y = data['labels'][0].cpu().detach().numpy() + if 'labels_masks' in data: + n = data['labels_masks'][0].cpu().detach().numpy() + else: + n = np.ones_like(y) + z = model.get_transformed_latents(x, dataset=sess) + for i in range(n_labels): + y_true = apply_masks(y[:, i], n[:, i]) + y_pred = apply_masks(z[:, i], n[:, i]) + if len(y_true) > 10: + r2 = r2_score(y_true, y_pred, multioutput='variance_weighted') + mse = np.mean(np.square(y_true - y_pred)) + else: + r2 = np.nan + mse = np.nan + metrics_df.append(pd.DataFrame({ + 'Trial': data['batch_idx'].item(), + 'Label': label_names[i], + 'R2': r2, + 'MSE': mse, + 'Model': 'PS-VAE'}, index=[0])) + + metrics_df = pd.concat(metrics_df) + print('saving results to %s' % save_file) + metrics_df.to_csv(save_file, index=False, header=True) + else: + print('loading results from %s' % save_file) + metrics_df = pd.read_csv(save_file) + return metrics_df + + # ----------------------------------------------------- + # load pixel/label MSE as a function of n_latents/alpha + # ----------------------------------------------------- + + # set model info + hparams = _get_psvae_hparams(experiment_name=alpha_expt_name) + # update hparams + for key, val in kwargs.items(): + # hparam vals should be named 'alpha_[property]', for example 'alpha_train_frac' + if key.split('_')[0] == 'alpha': + prop = key[6:] + hparams[prop] = val + + metrics_list = ['loss_data_mse'] + + metrics_dfs_frame = [] + metrics_dfs_marker = [] + for n_latent in alpha_n_ae_latents: + hparams['n_ae_latents'] = n_latent + n_labels + for alpha_ in alpha_weights: + hparams['ps_vae.alpha'] = alpha_ + hparams['ps_vae.beta'] = beta + hparams['ps_vae.gamma'] = gamma + try: + get_lab_example(hparams, lab, expt) + hparams['animal'] = animal + hparams['session'] = session + hparams['session_dir'], sess_ids = get_session_dir(hparams) + hparams['expt_dir'] = get_expt_dir(hparams) + _, version = experiment_exists(hparams, which_version=True) + print('loading results with alpha=%i, beta=%i, gamma=%i (version %i)' % ( + hparams['ps_vae.alpha'], hparams['ps_vae.beta'], hparams['ps_vae.gamma'], + version)) + # get frame mse + metrics_dfs_frame.append(load_metrics_csv_as_df( + hparams, lab, expt, metrics_list, version=None, test=True)) + metrics_dfs_frame[-1]['alpha'] = alpha_ + metrics_dfs_frame[-1]['n_latents'] = hparams['n_ae_latents'] + # get marker mse + model, data_gen = get_best_model_and_data( + hparams, Model=None, load_data=True, version=version) + metrics_df_ = get_label_r2(hparams, model, data_gen, version, dtype='val') + metrics_df_['alpha'] = alpha_ + metrics_df_['n_latents'] = hparams['n_ae_latents'] + metrics_dfs_marker.append(metrics_df_[metrics_df_.Model == 'PS-VAE']) + except TypeError: + print('could not find model for alpha=%i, beta=%i, gamma=%i' % ( + hparams['ps_vae.alpha'], hparams['ps_vae.beta'], hparams['ps_vae.gamma'])) + continue + metrics_df_frame = pd.concat(metrics_dfs_frame, sort=False) + metrics_df_marker = pd.concat(metrics_dfs_marker, sort=False) + print('done') + + # ----------------------------------------------------- + # load pixel/label MSE as a function of beta/gamma + # ----------------------------------------------------- + # update hparams + hparams['experiment_name'] = beta_gamma_expt_name + for key, val in kwargs.items(): + # hparam vals should be named 'beta_gamma_[property]', for example 'alpha_train_frac' + if key.split('_')[0] == 'beta' and key.split('_')[1] == 'gamma': + prop = key[11:] + hparams[prop] = val + + metrics_list = ['loss_data_mse', 'loss_zu_mi', 'loss_zu_tc', 'loss_zu_dwkl', 'loss_AB_orth'] + + metrics_dfs_frame_bg = [] + metrics_dfs_marker_bg = [] + metrics_dfs_corr_bg = [] + overlaps = {} + for beta in beta_weights: + for gamma in gamma_weights: + hparams['n_ae_latents'] = beta_gamma_n_ae_latents + n_labels + hparams['ps_vae.alpha'] = alpha + hparams['ps_vae.beta'] = beta + hparams['ps_vae.gamma'] = gamma + try: + get_lab_example(hparams, lab, expt) + hparams['animal'] = animal + hparams['session'] = session + hparams['session_dir'], sess_ids = get_session_dir(hparams) + hparams['expt_dir'] = get_expt_dir(hparams) + _, version = experiment_exists(hparams, which_version=True) + print('loading results with alpha=%i, beta=%i, gamma=%i (version %i)' % ( + hparams['ps_vae.alpha'], hparams['ps_vae.beta'], hparams['ps_vae.gamma'], + version)) + # get frame mse + metrics_dfs_frame_bg.append(load_metrics_csv_as_df( + hparams, lab, expt, metrics_list, version=None, test=True)) + metrics_dfs_frame_bg[-1]['beta'] = beta + metrics_dfs_frame_bg[-1]['gamma'] = gamma + # get marker mse + model, data_gen = get_best_model_and_data( + hparams, Model=None, load_data=True, version=version) + metrics_df_ = get_label_r2(hparams, model, data_gen, version, dtype='val') + metrics_df_['beta'] = beta + metrics_df_['gamma'] = gamma + metrics_dfs_marker_bg.append(metrics_df_[metrics_df_.Model == 'PS-VAE']) + # get subspace overlap + A = model.encoding.A.weight.data.cpu().detach().numpy() + B = model.encoding.B.weight.data.cpu().detach().numpy() + C = np.concatenate([A, B], axis=0) + overlap = np.matmul(C, C.T) + overlaps['beta=%i_gamma=%i' % (beta, gamma)] = overlap + # get corr + latents = load_latents(hparams, version, dtype='test') + if batch_size is None: + corr = np.corrcoef(latents[:, n_labels + np.array([0, 1])].T) + metrics_dfs_corr_bg.append(pd.DataFrame({ + 'loss': 'corr', + 'dtype': 'test', + 'val': np.abs(corr[0, 1]), + 'beta': beta, + 'gamma': gamma}, index=[0])) + else: + n_batches = int(np.ceil(latents.shape[0] / batch_size)) + for i in range(n_batches): + corr = np.corrcoef( + latents[i * batch_size:(i + 1) * batch_size, + n_labels + np.array([0, 1])].T) + metrics_dfs_corr_bg.append(pd.DataFrame({ + 'loss': 'corr', + 'dtype': 'test', + 'val': np.abs(corr[0, 1]), + 'beta': beta, + 'gamma': gamma}, index=[0])) + except TypeError: + print('could not find model for alpha=%i, beta=%i, gamma=%i' % ( + hparams['ps_vae.alpha'], hparams['ps_vae.beta'], hparams['ps_vae.gamma'])) + continue + print() + metrics_df_frame_bg = pd.concat(metrics_dfs_frame_bg, sort=False) + metrics_df_marker_bg = pd.concat(metrics_dfs_marker_bg, sort=False) + metrics_df_corr_bg = pd.concat(metrics_dfs_corr_bg, sort=False) + print('done') + + # ----------------------------------------------------- + # ----------------- PLOT DATA ------------------------- + # ----------------------------------------------------- + sns.set_style('white') + sns.set_context('paper', font_scale=1.2) + + alpha_palette = sns.color_palette('Greens') + beta_palette = sns.color_palette('Reds', len(metrics_df_corr_bg.beta.unique())) + gamma_palette = sns.color_palette('Blues', len(metrics_df_corr_bg.gamma.unique())) + + from matplotlib.gridspec import GridSpec + + fig = plt.figure(figsize=(12, 10), dpi=300) + + n_rows = 3 + n_cols = 12 + gs = GridSpec(n_rows, n_cols, figure=fig) + + def despine(ax): + ax.spines['top'].set_visible(False) + ax.spines['right'].set_visible(False) + + sns.set_palette(alpha_palette) + + # -------------------------------------------------- + # MSE per pixel + # -------------------------------------------------- + ax_pixel_mse_alpha = fig.add_subplot(gs[0, 0:3]) + data_queried = metrics_df_frame[(metrics_df_frame.dtype == 'test')] + sns.barplot(x='n_latents', y='val', hue='alpha', data=data_queried, ax=ax_pixel_mse_alpha) + ax_pixel_mse_alpha.legend().set_visible(False) + ax_pixel_mse_alpha.set_xlabel('Latent dimension') + ax_pixel_mse_alpha.set_ylabel('MSE per pixel') + ax_pixel_mse_alpha.ticklabel_format(axis='y', style='sci', scilimits=(-3, 3)) + ax_pixel_mse_alpha.set_title('Beta=1, Gamma=0') + despine(ax_pixel_mse_alpha) + + # -------------------------------------------------- + # MSE per marker + # -------------------------------------------------- + ax_marker_mse_alpha = fig.add_subplot(gs[0, 3:6]) + data_queried = metrics_df_marker + sns.barplot(x='n_latents', y='MSE', hue='alpha', data=data_queried, ax=ax_marker_mse_alpha) + ax_marker_mse_alpha.set_xlabel('Latent dimension') + ax_marker_mse_alpha.set_ylabel('MSE per marker') + ax_marker_mse_alpha.set_title('Beta=1, Gamma=0') + ax_marker_mse_alpha.legend(frameon=True, title='Alpha') + despine(ax_marker_mse_alpha) + + sns.set_palette(gamma_palette) + + # -------------------------------------------------- + # MSE per pixel (beta/gamma) + # -------------------------------------------------- + ax_pixel_mse_bg = fig.add_subplot(gs[0, 6:9]) + data_queried = metrics_df_frame_bg[ + (metrics_df_frame_bg.dtype == 'test') & + (metrics_df_frame_bg.loss == 'loss_data_mse') & + (metrics_df_frame_bg.epoch == 200)] + sns.barplot(x='beta', y='val', hue='gamma', data=data_queried, ax=ax_pixel_mse_bg) + ax_pixel_mse_bg.legend().set_visible(False) + ax_pixel_mse_bg.set_xlabel('Beta') + ax_pixel_mse_bg.set_ylabel('MSE per pixel') + ax_pixel_mse_bg.ticklabel_format(axis='y', style='sci', scilimits=(-3, 3)) + ax_pixel_mse_bg.set_title('Latents=%i, Alpha=1000' % hparams['n_ae_latents']) + despine(ax_pixel_mse_bg) + + # -------------------------------------------------- + # MSE per marker (beta/gamma) + # -------------------------------------------------- + ax_marker_mse_bg = fig.add_subplot(gs[0, 9:12]) + data_queried = metrics_df_marker_bg + sns.barplot(x='beta', y='MSE', hue='gamma', data=data_queried, ax=ax_marker_mse_bg) + ax_marker_mse_bg.set_xlabel('Beta') + ax_marker_mse_bg.set_ylabel('MSE per marker') + ax_marker_mse_bg.set_title('Latents=%i, Alpha=1000' % hparams['n_ae_latents']) + ax_marker_mse_bg.legend(frameon=True, title='Gamma', loc='lower left') + despine(ax_marker_mse_bg) + + # -------------------------------------------------- + # ICMI + # -------------------------------------------------- + ax_icmi = fig.add_subplot(gs[1, 0:4]) + data_queried = metrics_df_frame_bg[ + (metrics_df_frame_bg.dtype == 'test') & + (metrics_df_frame_bg.loss == 'loss_zu_mi') & + (metrics_df_frame_bg.epoch == 200)] + sns.lineplot( + x='beta', y='val', hue='gamma', data=data_queried, ax=ax_icmi, ci=None, + palette=gamma_palette) + ax_icmi.legend().set_visible(False) + ax_icmi.set_xlabel('Beta') + ax_icmi.set_ylabel('Index-code Mutual Information') + ax_icmi.set_title('Latents=%i, Alpha=1000' % hparams['n_ae_latents']) + despine(ax_icmi) + + # -------------------------------------------------- + # TC + # -------------------------------------------------- + ax_tc = fig.add_subplot(gs[1, 4:8]) + data_queried = metrics_df_frame_bg[ + (metrics_df_frame_bg.dtype == 'test') & + (metrics_df_frame_bg.loss == 'loss_zu_tc') & + (metrics_df_frame_bg.epoch == 200)] + sns.lineplot( + x='beta', y='val', hue='gamma', data=data_queried, ax=ax_tc, ci=None, + palette=gamma_palette) + ax_tc.legend().set_visible(False) + ax_tc.set_xlabel('Beta') + ax_tc.set_ylabel('Total Correlation') + ax_tc.set_title('Latents=%i, Alpha=1000' % hparams['n_ae_latents']) + despine(ax_tc) + + # -------------------------------------------------- + # DWKL + # -------------------------------------------------- + ax_dwkl = fig.add_subplot(gs[1, 8:12]) + data_queried = metrics_df_frame_bg[ + (metrics_df_frame_bg.dtype == 'test') & + (metrics_df_frame_bg.loss == 'loss_zu_dwkl') & + (metrics_df_frame_bg.epoch == 200)] + sns.lineplot( + x='beta', y='val', hue='gamma', data=data_queried, ax=ax_dwkl, ci=None, + palette=gamma_palette) + ax_dwkl.legend().set_visible(False) + ax_dwkl.set_xlabel('Beta') + ax_dwkl.set_ylabel('Dimension-wise KL') + ax_dwkl.set_title('Latents=%i, Alpha=1000' % hparams['n_ae_latents']) + despine(ax_dwkl) + + # -------------------------------------------------- + # CC + # -------------------------------------------------- + ax_cc = fig.add_subplot(gs[2, 0:3]) + data_queried = metrics_df_corr_bg + sns.lineplot( + x='beta', y='val', hue='gamma', data=data_queried, ax=ax_cc, ci=None, + palette=gamma_palette) + ax_cc.legend().set_visible(False) + ax_cc.set_xlabel('Beta') + ax_cc.set_ylabel('Correlation Coefficient') + ax_cc.set_title('Latents=%i, Alpha=1000' % hparams['n_ae_latents']) + despine(ax_cc) + + # -------------------------------------------------- + # AB orth + # -------------------------------------------------- + ax_orth = fig.add_subplot(gs[2, 3:6]) + data_queried = metrics_df_frame_bg[ + (metrics_df_frame_bg.dtype == 'test') & + (metrics_df_frame_bg.loss == 'loss_AB_orth') & + (metrics_df_frame_bg.epoch == 200) & + ~metrics_df_frame_bg.val.isna()] + sns.lineplot( + x='gamma', y='val', hue='beta', data=data_queried, ax=ax_orth, ci=None, + palette=beta_palette) + ax_orth.legend(frameon=False, title='Beta') + ax_orth.set_xlabel('Gamma') + ax_orth.set_ylabel('Subspace overlap') + ax_orth.set_title('Latents=%i, Alpha=1000' % hparams['n_ae_latents']) + despine(ax_orth) + + # -------------------------------------------------- + # Gamma = 0 overlap + # -------------------------------------------------- + ax_gamma0 = fig.add_subplot(gs[2, 6:9]) + overlap = overlaps['beta=%i_gamma=%i' % (1, 0)] + im = ax_gamma0.imshow(overlap, cmap='PuOr', vmin=-1, vmax=1) + ax_gamma0.set_xticks(np.arange(overlap.shape[1])) + ax_gamma0.set_yticks(np.arange(overlap.shape[0])) + ax_gamma0.set_title('Subspace overlap\nGamma=0') + fig.colorbar(im, ax=ax_gamma0, orientation='vertical', shrink=0.75) + + # -------------------------------------------------- + # Gamma = 1000 overlap + # -------------------------------------------------- + ax_gamma1 = fig.add_subplot(gs[2, 9:12]) + overlap = overlaps['beta=%i_gamma=%i' % (1, 1000)] + im = ax_gamma1.imshow(overlap, cmap='PuOr', vmin=-1, vmax=1) + ax_gamma1.set_xticks(np.arange(overlap.shape[1])) + ax_gamma1.set_yticks(np.arange(overlap.shape[0])) + ax_gamma1.set_title('Subspace overlap\nGamma=1000') + fig.colorbar(im, ax=ax_gamma1, orientation='vertical', shrink=0.75) + + plt.tight_layout(h_pad=3) # h_pad is fraction of font size + + # reset to default color palette + # sns.set_palette(sns.color_palette(None, 10)) + sns.reset_orig() + + if save_file is not None: + make_dir_if_not_exists(save_file) + plt.savefig(save_file + '.' + format, dpi=300, format=format) + + +def plot_label_reconstructions( + lab, expt, animal, session, n_ae_latents, experiment_name, n_labels, trials, version=None, + plot_scale=0.5, sess_idx=0, save_file=None, format='pdf', **kwargs): + """Plot labels and their reconstructions from an ps-vae. + + Parameters + ---------- + lab : :obj:`str` + lab id + expt : :obj:`str` + expt id + animal : :obj:`str` + animal id + session : :obj:`str` + session id + n_ae_latents : :obj:`str` + dimensionality of unsupervised latent space; n_labels will be added to this + experiment_name : :obj:`str` + test-tube experiment name + n_labels : :obj:`str` + dimensionality of supervised latent space + trials : :obj:`array-like` + array of trials to reconstruct + version : :obj:`str` or :obj:`int`, optional + can be 'best' to load best model, and integer to load a specific model, or NoneType to use + the values in hparams to load a specific model + plot_scale : :obj:`float` + scale the magnitude of reconstructions + sess_idx : :obj:`int`, optional + session index into data generator + save_file : :obj:`str`, optional + absolute path of save file; does not need file extension + format : :obj:`str`, optional + format of saved image; 'pdf' | 'png' | 'jpeg' | ... + kwargs + arguments are keys of `hparams`, for example to set `train_frac`, `rng_seed_model`, etc. + + """ + + from behavenet.plotting.decoder_utils import plot_neural_reconstruction_traces + + # set model info + hparams = _get_psvae_hparams( + experiment_name=experiment_name, n_ae_latents=n_ae_latents + n_labels, **kwargs) + + # programmatically fill out other hparams options + get_lab_example(hparams, lab, expt) + hparams['animal'] = animal + hparams['session'] = session + + model, data_generator = get_best_model_and_data( + hparams, Model=None, load_data=True, version=version, data_kwargs=None) + print(data_generator) + print('alpha: %i' % model.hparams['ps_vae.alpha']) + print('beta: %i' % model.hparams['ps_vae.beta']) + print('gamma: %i' % model.hparams['ps_vae.gamma']) + print('model seed: %i' % model.hparams['rng_seed_model']) + + for trial in trials: + batch = data_generator.datasets[sess_idx][trial] + labels_og = batch['labels'].detach().cpu().numpy() + labels_pred = model.get_predicted_labels(batch['images']).detach().cpu().numpy() + if save_file is not None: + save_file_trial = save_file + '_trial-%i' % trial + else: + save_file_trial = None + plot_neural_reconstruction_traces( + labels_og, labels_pred, scale=plot_scale, save_file=save_file_trial, format=format) + + +def plot_latent_traversals( + lab, expt, animal, session, model_class, alpha, beta, gamma, n_ae_latents, rng_seed_model, + experiment_name, n_labels, label_idxs, label_min_p=5, label_max_p=95, + channel=0, n_frames_zs=4, n_frames_zu=4, trial=None, trial_idx=1, batch_idx=1, + crop_type=None, crop_kwargs=None, sess_idx=0, save_file=None, format='pdf', **kwargs): + """Plot video frames representing the traversal of individual dimensions of the latent space. + + Parameters + ---------- + lab : :obj:`str` + lab id + expt : :obj:`str` + expt id + animal : :obj:`str` + animal id + session : :obj:`str` + session id + model_class : :obj:`str` + model class in which to perform traversal; currently supported models are: + 'ae' | 'vae' | 'cond-ae' | 'cond-vae' | 'beta-tcvae' | 'cond-ae-msp' | 'ps-vae' + note that models with conditional encoders are not currently supported + alpha : :obj:`float` + ps-vae alpha value + beta : :obj:`float` + ps-vae beta value + gamma : :obj:`array-like` + ps-vae gamma value + n_ae_latents : :obj:`int` + dimensionality of unsupervised latents + rng_seed_model : :obj:`int` + model seed + experiment_name : :obj:`str` + test-tube experiment name + n_labels : :obj:`str` + dimensionality of supervised latent space (ignored when using fully unsupervised models) + label_idxs : :obj:`array-like`, optional + set of label indices (dimensions) to individually traverse + label_min_p : :obj:`float`, optional + lower percentile of training data used to compute range of traversal + label_max_p : :obj:`float`, optional + upper percentile of training data used to compute range of traversal + channel : :obj:`int`, optional + image channel to plot + n_frames_zs : :obj:`int`, optional + number of frames (points) to display for traversal through supervised dimensions + n_frames_zu : :obj:`int`, optional + number of frames (points) to display for traversal through unsupervised dimensions + trial : :obj:`int`, optional + trial index into all possible trials (train, val, test); one of `trial` or `trial_idx` + must be specified; `trial` takes precedence over `trial_idx` + trial_idx : :obj:`int`, optional + trial index of base frame used for interpolation + batch_idx : :obj:`int`, optional + batch index of base frame used for interpolation + crop_type : :obj:`str`, optional + cropping method used on interpolated frames + 'fixed' | None + crop_kwargs : :obj:`dict`, optional + if crop_type is not None, provides information about the crop + keys for 'fixed' type: 'y_0', 'x_0', 'y_ext', 'x_ext'; window is + (y_0 - y_ext, y_0 + y_ext) in vertical direction and + (x_0 - x_ext, x_0 + x_ext) in horizontal direction + sess_idx : :obj:`int`, optional + session index into data generator + save_file : :obj:`str`, optional + absolute path of save file; does not need file extension + format : :obj:`str`, optional + format of saved image; 'pdf' | 'png' | 'jpeg' | ... + kwargs + arguments are keys of `hparams`, for example to set `train_frac`, `rng_seed_model`, etc. + + """ + + hparams = _get_psvae_hparams( + model_class=model_class, alpha=alpha, beta=beta, gamma=gamma, n_ae_latents=n_ae_latents, + experiment_name=experiment_name, rng_seed_model=rng_seed_model, **kwargs) + + if model_class == 'cond-ae-msp' or model_class == 'ps-vae': + hparams['n_ae_latents'] += n_labels + + # programmatically fill out other hparams options + get_lab_example(hparams, lab, expt) + hparams['animal'] = animal + hparams['session'] = session + hparams['session_dir'], sess_ids = get_session_dir(hparams) + hparams['expt_dir'] = get_expt_dir(hparams) + _, version = experiment_exists(hparams, which_version=True) + model_ae, data_generator = get_best_model_and_data(hparams, Model=None, version=version) + + # get latent/label info + latent_range = get_input_range( + 'latents', hparams, model=model_ae, data_gen=data_generator, min_p=15, max_p=85, + version=version) + label_range = get_input_range( + 'labels', hparams, sess_ids=sess_ids, sess_idx=sess_idx, + min_p=label_min_p, max_p=label_max_p) + try: + label_sc_range = get_input_range( + 'labels_sc', hparams, sess_ids=sess_ids, sess_idx=sess_idx, + min_p=label_min_p, max_p=label_max_p) + except KeyError: + import copy + label_sc_range = copy.deepcopy(label_range) + + # ---------------------------------------- + # label traversals + # ---------------------------------------- + interp_func_label = interpolate_1d + plot_func_label = plot_1d_frame_array + save_file_new = save_file + '_label-traversals' + + if model_class == 'cond-ae' or model_class == 'cond-ae-msp' or model_class == 'ps-vae' or \ + model_class == 'cond-vae': + + # get model input for this trial + ims_pt, ims_np, latents_np, labels_pt, labels_np, labels_2d_pt, labels_2d_np = \ + get_model_input( + data_generator, hparams, model_ae, trial_idx=trial_idx, trial=trial, + compute_latents=True, compute_scaled_labels=False, compute_2d_labels=False) + + if labels_2d_np is None: + labels_2d_np = np.copy(labels_np) + if crop_type == 'fixed': + crop_kwargs_ = crop_kwargs + else: + crop_kwargs_ = None + + # perform interpolation + ims_label, markers_loc_label, ims_crop_label = interp_func_label( + 'labels', model_ae, ims_pt[None, batch_idx, :], latents_np[None, batch_idx, :], + labels_np[None, batch_idx, :], labels_2d_np[None, batch_idx, :], + mins=label_range['min'], maxes=label_range['max'], + n_frames=n_frames_zs, input_idxs=label_idxs, crop_type=crop_type, + mins_sc=label_sc_range['min'], maxes_sc=label_sc_range['max'], + crop_kwargs=crop_kwargs_, ch=channel) + + # plot interpolation + if crop_type: + marker_kwargs = { + 'markersize': 30, 'markeredgewidth': 8, 'markeredgecolor': [1, 1, 0], + 'fillstyle': 'none'} + plot_func_label( + ims_crop_label, markers=None, marker_kwargs=marker_kwargs, save_file=save_file_new, + format=format) + else: + marker_kwargs = { + 'markersize': 20, 'markeredgewidth': 5, 'markeredgecolor': [1, 1, 0], + 'fillstyle': 'none'} + plot_func_label( + ims_label, markers=None, marker_kwargs=marker_kwargs, save_file=save_file_new, + format=format) + + # ---------------------------------------- + # latent traversals + # ---------------------------------------- + interp_func_latent = interpolate_1d + plot_func_latent = plot_1d_frame_array + save_file_new = save_file + '_latent-traversals' + + if hparams['model_class'] == 'cond-ae-msp' or hparams['model_class'] == 'ps-vae': + latent_idxs = n_labels + np.arange(n_ae_latents) + elif hparams['model_class'] == 'ae' \ + or hparams['model_class'] == 'vae' \ + or hparams['model_class'] == 'cond-vae' \ + or hparams['model_class'] == 'beta-tcvae': + latent_idxs = np.arange(n_ae_latents) + else: + raise NotImplementedError + + # simplify options here + scaled_labels = False + twod_labels = False + crop_type = None + crop_kwargs = None + labels_2d_np_sel = None + + # get model input for this trial + ims_pt, ims_np, latents_np, labels_pt, labels_np, labels_2d_pt, labels_2d_np = \ + get_model_input( + data_generator, hparams, model_ae, trial=trial, trial_idx=trial_idx, + compute_latents=True, compute_scaled_labels=scaled_labels, + compute_2d_labels=twod_labels) + + latents_np[:, n_labels:] = 0 + + if hparams['model_class'] == 'ae' or hparams['model_class'] == 'beta-tcvae': + labels_np_sel = labels_np + else: + labels_np_sel = labels_np[None, batch_idx, :] + + # perform interpolation + ims_latent, markers_loc_latent_, ims_crop_latent = interp_func_latent( + 'latents', model_ae, ims_pt[None, batch_idx, :], latents_np[None, batch_idx, :], + labels_np_sel, labels_2d_np_sel, + mins=latent_range['min'], maxes=latent_range['max'], + n_frames=n_frames_zu, input_idxs=latent_idxs, crop_type=crop_type, + mins_sc=None, maxes_sc=None, crop_kwargs=crop_kwargs, ch=channel) + + # plot interpolation + marker_kwargs = { + 'markersize': 20, 'markeredgewidth': 5, 'markeredgecolor': [1, 1, 0], + 'fillstyle': 'none'} + plot_func_latent( + ims_latent, markers=None, marker_kwargs=marker_kwargs, save_file=save_file_new, + format=format) + + +def make_latent_traversal_movie( + lab, expt, animal, session, model_class, alpha, beta, gamma, n_ae_latents, + rng_seed_model, experiment_name, n_labels, trial_idxs, batch_idxs, trials, + label_min_p=5, label_max_p=95, channel=0, sess_idx=0, n_frames=10, n_buffer_frames=5, + crop_kwargs=None, n_cols=3, movie_kwargs={}, panel_titles=None, order_idxs=None, + save_file=None, **kwargs): + """Create a multi-panel movie with each panel showing traversals of an individual latent dim. + + The traversals will start at a lower bound, increase to an upper bound, then return to a lower + bound; the traversal of each dimension occurs simultaneously. It is also possible to specify + multiple base frames for the traversals; the traversal of each base frame is separated by + several blank frames. Note that support for plotting markers on top of the corresponding + supervised dimensions is not supported by this function. + + Parameters + ---------- + lab : :obj:`str` + lab id + expt : :obj:`str` + expt id + animal : :obj:`str` + animal id + session : :obj:`str` + session id + model_class : :obj:`str` + model class in which to perform traversal; currently supported models are: + 'ae' | 'vae' | 'cond-ae' | 'cond-vae' | 'ps-vae' + note that models with conditional encoders are not currently supported + alpha : :obj:`float` + ps-vae alpha value + beta : :obj:`float` + ps-vae beta value + gamma : :obj:`array-like` + ps-vae gamma value + n_ae_latents : :obj:`int` + dimensionality of unsupervised latents + rng_seed_model : :obj:`int` + model seed + experiment_name : :obj:`str` + test-tube experiment name + n_labels : :obj:`str` + dimensionality of supervised latent space (ignored when using fully unsupervised models) + trial_idxs : :obj:`array-like` of :obj:`int` + trial indices of base frames used for interpolation; if an entry is an integer, the + corresponding entry in `trials` must be `None`. This value is a trial index into all + *test* trials, and is not affected by how the test trials are shuffled. The `trials` + argument (see below) takes precedence over `trial_idxs`. + batch_idxs : :obj:`array-like` of :obj:`int` + batch indices of base frames used for interpolation; correspond to entries in `trial_idxs` + and `trials` + trials : :obj:`array-like` of :obj:`int` + trials of base frame used for interpolation; if an entry is an integer, the + corresponding entry in `trial_idxs` must be `None`. This value is a trial index into all + possible trials (train, val, test), whereas `trial_idxs` is an index only into test trials + label_min_p : :obj:`float`, optional + lower percentile of training data used to compute range of traversal + label_max_p : :obj:`float`, optional + upper percentile of training data used to compute range of traversal + channel : :obj:`int`, optional + image channel to plot + sess_idx : :obj:`int`, optional + session index into data generator + n_frames : :obj:`int`, optional + number of frames (points) to display for traversal across latent dimensions; the movie + will display a traversal of `n_frames` across each dim, then another traversal of + `n_frames` in the opposite direction + n_buffer_frames : :obj:`int`, optional + number of blank frames to insert between base frames + crop_kwargs : :obj:`dict`, optional + if crop_type is not None, provides information about the crop (for a fixed crop window) + keys : 'y_0', 'x_0', 'y_ext', 'x_ext'; window is + (y_0 - y_ext, y_0 + y_ext) in vertical direction and + (x_0 - x_ext, x_0 + x_ext) in horizontal direction + n_cols : :obj:`int`, optional + movie is `n_cols` panels wide + movie_kwargs : :obj:`dict`, optional + additional kwargs for individual panels; possible keys are 'markersize', 'markeredgecolor', + 'markeredgewidth', and 'text_color' + panel_titles : :obj:`list` of :obj:`str`, optional + optional titles for each panel + order_idxs : :obj:`array-like`, optional + used to reorder panels (which are plotted in row-major order) if desired + save_file : :obj:`str`, optional + absolute path of save file; does not need file extension, will automatically be saved as + mp4. To save as a gif, include the '.gif' file extension in `save_file` + kwargs + arguments are keys of `hparams`, for example to set `train_frac`, `rng_seed_model`, etc. + + """ + + panel_titles = [''] * (n_labels + n_ae_latents) if panel_titles is None else panel_titles + + hparams = _get_psvae_hparams( + model_class=model_class, alpha=alpha, beta=beta, gamma=gamma, n_ae_latents=n_ae_latents, + experiment_name=experiment_name, rng_seed_model=rng_seed_model, **kwargs) + + if model_class == 'cond-ae-msp' or model_class == 'ps-vae': + hparams['n_ae_latents'] += n_labels + + # programmatically fill out other hparams options + get_lab_example(hparams, lab, expt) + hparams['animal'] = animal + hparams['session'] = session + hparams['session_dir'], sess_ids = get_session_dir(hparams) + hparams['expt_dir'] = get_expt_dir(hparams) + _, version = experiment_exists(hparams, which_version=True) + model_ae, data_generator = get_best_model_and_data(hparams, Model=None, version=version) + + # get latent/label info + latent_range = get_input_range( + 'latents', hparams, model=model_ae, data_gen=data_generator, min_p=15, max_p=85, + version=version) + label_range = get_input_range( + 'labels', hparams, sess_ids=sess_ids, sess_idx=sess_idx, + min_p=label_min_p, max_p=label_max_p) + + # ---------------------------------------- + # collect frames/latents/labels + # ---------------------------------------- + if hparams['model_class'] == 'vae': + csl = False + c2dl = False + else: + csl = True + c2dl = False + + ims_pt = [] + ims_np = [] + latents_np = [] + labels_pt = [] + labels_np = [] + labels_2d_pt = [] + labels_2d_np = [] + for trial, trial_idx in zip(trials, trial_idxs): + ims_pt_, ims_np_, latents_np_, labels_pt_, labels_np_, labels_2d_pt_, labels_2d_np_ = \ + get_model_input( + data_generator, hparams, model_ae, trial_idx=trial_idx, trial=trial, + compute_latents=True, compute_scaled_labels=csl, compute_2d_labels=c2dl, + max_frames=200) + ims_pt.append(ims_pt_) + ims_np.append(ims_np_) + latents_np.append(latents_np_) + labels_pt.append(labels_pt_) + labels_np.append(labels_np_) + labels_2d_pt.append(labels_2d_pt_) + labels_2d_np.append(labels_2d_np_) + + if hparams['model_class'] == 'ps-vae': + label_idxs = np.arange(n_labels) + latent_idxs = n_labels + np.arange(n_ae_latents) + elif hparams['model_class'] == 'vae': + label_idxs = [] + latent_idxs = np.arange(hparams['n_ae_latents']) + elif hparams['model_class'] == 'cond-vae': + label_idxs = np.arange(n_labels) + latent_idxs = np.arange(hparams['n_ae_latents']) + else: + raise Exception + + # ---------------------------------------- + # label traversals + # ---------------------------------------- + ims_all = [] + txt_strs_all = [] + txt_strs_titles = [] + + for label_idx in label_idxs: + + ims = [] + txt_strs = [] + + for b, batch_idx in enumerate(batch_idxs): + if hparams['model_class'] == 'ps-vae': + points = np.array([latents_np[b][batch_idx, :]] * 3) + elif hparams['model_class'] == 'cond-vae': + points = np.array([labels_np[b][batch_idx, :]] * 3) + else: + raise Exception + points[0, label_idx] = label_range['min'][label_idx] + points[1, label_idx] = label_range['max'][label_idx] + points[2, label_idx] = label_range['min'][label_idx] + ims_curr, inputs = interpolate_point_path( + 'labels', model_ae, ims_pt[b][None, batch_idx, :], + labels_np[b][None, batch_idx, :], points=points, n_frames=n_frames, ch=channel, + crop_kwargs=crop_kwargs) + ims.append(ims_curr) + txt_strs += [panel_titles[label_idx] for _ in range(len(ims_curr))] + + if label_idx == 0: + tmp = trial_idxs[b] if trial_idxs[b] is not None else trials[b] + txt_strs_titles += [ + 'base frame %02i-%02i' % (tmp, batch_idx) for _ in range(len(ims_curr))] + + # add blank frames + y_pix, x_pix = ims_curr[0].shape + ims.append([np.zeros((y_pix, x_pix)) for _ in range(n_buffer_frames)]) + txt_strs += ['' for _ in range(n_buffer_frames)] + if label_idx == 0: + txt_strs_titles += ['' for _ in range(n_buffer_frames)] + + ims_all.append(np.vstack(ims)) + txt_strs_all.append(txt_strs) + + # ---------------------------------------- + # latent traversals + # ---------------------------------------- + crop_kwargs_ = None + for latent_idx in latent_idxs: + + ims = [] + txt_strs = [] + + for b, batch_idx in enumerate(batch_idxs): + + points = np.array([latents_np[b][batch_idx, :]] * 3) + + # points[:, latent_idxs] = 0 + points[0, latent_idx] = latent_range['min'][latent_idx] + points[1, latent_idx] = latent_range['max'][latent_idx] + points[2, latent_idx] = latent_range['min'][latent_idx] + if hparams['model_class'] == 'vae': + labels_curr = None + else: + labels_curr = labels_np[b][None, batch_idx, :] + ims_curr, inputs = interpolate_point_path( + 'latents', model_ae, ims_pt[b][None, batch_idx, :], + labels_curr, points=points, n_frames=n_frames, ch=channel, + crop_kwargs=crop_kwargs_) + ims.append(ims_curr) + if hparams['model_class'] == 'cond-vae': + txt_strs += [panel_titles[latent_idx + n_labels] for _ in range(len(ims_curr))] + else: + txt_strs += [panel_titles[latent_idx] for _ in range(len(ims_curr))] + + if latent_idx == 0 and len(label_idxs) == 0: + # add frame ids here if skipping labels + tmp = trial_idxs[b] if trial_idxs[b] is not None else trials[b] + txt_strs_titles += [ + 'base frame %02i-%02i' % (tmp, batch_idx) for _ in range(len(ims_curr))] + + # add blank frames + y_pix, x_pix = ims_curr[0].shape + ims.append([np.zeros((y_pix, x_pix)) for _ in range(n_buffer_frames)]) + txt_strs += ['' for _ in range(n_buffer_frames)] + if latent_idx == 0 and len(label_idxs) == 0: + txt_strs_titles += ['' for _ in range(n_buffer_frames)] + + ims_all.append(np.vstack(ims)) + txt_strs_all.append(txt_strs) + + # ---------------------------------------- + # make video + # ---------------------------------------- + if order_idxs is None: + # don't change order of latents + order_idxs = np.arange(len(ims_all)) + + make_interpolated_multipanel( + ims=[ims_all[i] for i in order_idxs], + text=[txt_strs_all[i] for i in order_idxs], + text_title=txt_strs_titles, + save_file=save_file, scale=2, n_cols=n_cols, **movie_kwargs) diff --git a/behavenet/plotting/decoder_utils.py b/behavenet/plotting/decoder_utils.py index 3c04042..d776183 100644 --- a/behavenet/plotting/decoder_utils.py +++ b/behavenet/plotting/decoder_utils.py @@ -1,15 +1,28 @@ """Plotting functions for decoders.""" +import copy +import matplotlib.animation as animation +import matplotlib.lines as mlines +import matplotlib.pyplot as plt +from matplotlib.gridspec import GridSpec +import numpy as np import os import pandas as pd import pickle +from behavenet import make_dir_if_not_exists +from behavenet.fitting.eval import get_reconstruction +from behavenet.fitting.utils import get_best_model_and_data from behavenet.data.utils import get_region_list from behavenet.fitting.utils import get_expt_dir from behavenet.fitting.utils import get_session_dir from behavenet.fitting.utils import get_subdirs +from behavenet.plotting import concat, save_movie # to ignore imports for sphix-autoapidoc -__all__ = ['get_r2s_by_trial', 'get_best_models', 'get_r2s_across_trials'] +__all__ = [ + 'get_r2s_by_trial', 'get_best_models', 'get_r2s_across_trials', + 'make_neural_reconstruction_movie_wrapper', 'make_neural_reconstruction_movie', + 'plot_neural_reconstruction_traces_wrapper', 'plot_neural_reconstruction_traces'] def _get_dataset_str(hparams): @@ -63,10 +76,9 @@ def get_r2s_by_trial(hparams, model_types): # read metrics csv file model_dir = os.path.join(expt_dir, version) try: - metric = pd.read_csv( - os.path.join(model_dir, 'metrics.csv')) + metric = pd.read_csv(os.path.join(model_dir, 'metrics.csv')) model_counter += 1 - except: + except FileNotFoundError: continue with open(os.path.join(model_dir, 'meta_tags.pkl'), 'rb') as f: hparams = pickle.load(f) @@ -177,3 +189,527 @@ def get_r2s_across_trials(hparams, best_models_df): 'model_type': hparams['model_type'], 'r2': r2}, index=[0])) return pd.concat(all_test_r2s) + + +def make_neural_reconstruction_movie_wrapper( + hparams, save_file, trials=None, sess_idx=0, max_frames=400, max_latents=8, + zscore_by_dim=False, colored_predictions=False, xtick_locs=None, frame_rate=15): + """Produce movie with original video, ae reconstructed video, and neural reconstructed video. + + This is a high-level function that loads the model described in the hparams dictionary and + produces the necessary predicted video frames. Latent traces are additionally plotted, as well + as the residual between the ae reconstruction and the neural reconstruction. Currently produces + ae latents and decoder predictions from scratch (rather than saved pickle files). + + Parameters + ---------- + hparams : :obj:`dict` + needs to contain enough information to specify an autoencoder + save_file : :obj:`str` + full save file (path and filename) + trials : :obj:`int` or :obj:`list`, optional + if :obj:`NoneType`, use first test trial + sess_idx : :obj:`int`, optional + session index into data generator + max_frames : :obj:`int`, optional + maximum number of frames to animate from a trial + max_latents : :obj:`int`, optional + maximum number of ae latents to plot + zscore_by_dim : :obj:`bool`, optional + True to z-score each dim, False to leave relative scales + colored_predictions : :obj:`bool`, optional + False to plot reconstructions in black, True to plot in different colors + xtick_locs : :obj:`array-like`, optional + tick locations in units of bins + frame_rate : :obj:`float`, optional + frame rate of saved movie + + """ + + from behavenet.models import Decoder + + # define number of frames that separate trials + n_buffer = 5 + + ############################### + # build ae model/data generator + ############################### + hparams_ae = copy.copy(hparams) + hparams_ae['experiment_name'] = hparams['ae_experiment_name'] + hparams_ae['model_class'] = hparams['ae_model_class'] + hparams_ae['model_type'] = hparams['ae_model_type'] + model_ae, data_generator_ae = get_best_model_and_data( + hparams_ae, Model=None, version=hparams['ae_version']) + # move model to cpu + model_ae.to('cpu') + + ####################################### + # build decoder model/no data generator + ####################################### + hparams_dec = copy.copy(hparams) + hparams_dec['experiment_name'] = hparams['decoder_experiment_name'] + hparams_dec['model_class'] = hparams['decoder_model_class'] + hparams_dec['model_type'] = hparams['decoder_model_type'] + + model_dec, data_generator_dec = get_best_model_and_data( + hparams_dec, Decoder, version=hparams['decoder_version']) + # move model to cpu + model_dec.to('cpu') + + if trials is None: + # choose first test trial, put in list + trials = data_generator_ae.batch_idxs[sess_idx]['test'][0] + + if isinstance(trials, int): + trials = [trials] + + # loop over trials, putting black frames/nans in between + ims_orig = [] + ims_recon_ae = [] + ims_recon_neural = [] + latents_ae = [] + latents_neural = [] + for i, trial in enumerate(trials): + + # get images from data generator (move to cpu) + batch = data_generator_ae.datasets[sess_idx][trial] + ims_orig_pt = batch['images'][:max_frames].cpu() # 400 + if hparams_ae['model_class'] == 'cond-ae': + labels_pt = batch['labels'][:max_frames] + else: + labels_pt = None + + # push images through ae to get reconstruction + ims_recon_ae_curr, latents_ae_curr = get_reconstruction( + model_ae, ims_orig_pt, labels=labels_pt, return_latents=True) + + # mask images for plotting + if hparams_ae.get('use_output_mask', False): + ims_orig_pt *= batch['masks'][:max_frames] + + # get neural activity from data generator (move to cpu) + # 0, not sess_idx, since decoders only have 1 sess + batch = data_generator_dec.datasets[0][trial] + neural_activity_pt = batch['neural'][:max_frames].cpu() + + # push neural activity through decoder to get prediction + latents_dec_pt, _ = model_dec(neural_activity_pt) + # push prediction through ae to get reconstruction + ims_recon_dec_curr = get_reconstruction(model_ae, latents_dec_pt, labels=labels_pt) + + # store all relevant quantities + ims_orig.append(ims_orig_pt.cpu().detach().numpy()) + ims_recon_ae.append(ims_recon_ae_curr) + ims_recon_neural.append(ims_recon_dec_curr) + latents_ae.append(latents_ae_curr[:, :max_latents]) + latents_neural.append(latents_dec_pt.cpu().detach().numpy()[:, :max_latents]) + + # add blank frames + if i < len(trials) - 1: + n_channels, y_pix, x_pix = ims_orig[-1].shape[1:] + n = latents_ae[-1].shape[1] + ims_orig.append(np.zeros((n_buffer, n_channels, y_pix, x_pix))) + ims_recon_ae.append(np.zeros((n_buffer, n_channels, y_pix, x_pix))) + ims_recon_neural.append(np.zeros((n_buffer, n_channels, y_pix, x_pix))) + latents_ae.append(np.nan * np.zeros((n_buffer, n))) + latents_neural.append(np.nan * np.zeros((n_buffer, n))) + + latents_ae = np.vstack(latents_ae) + latents_neural = np.vstack(latents_neural) + if zscore_by_dim: + means = np.nanmean(latents_ae, axis=0) + std = np.nanstd(latents_ae, axis=0) + latents_ae = (latents_ae - means) / std + latents_neural = (latents_neural - means) / std + + # away + make_neural_reconstruction_movie( + ims_orig=np.vstack(ims_orig), + ims_recon_ae=np.vstack(ims_recon_ae), + ims_recon_neural=np.vstack(ims_recon_neural), + latents_ae=latents_ae, + latents_neural=latents_neural, + ae_model_class=hparams_ae['model_class'].upper(), + colored_predictions=colored_predictions, + xtick_locs=xtick_locs, + frame_rate_beh=hparams['frame_rate'], + save_file=save_file, + frame_rate=frame_rate) + + +def make_neural_reconstruction_movie( + ims_orig, ims_recon_ae, ims_recon_neural, latents_ae, latents_neural, ae_model_class='AE', + colored_predictions=False, scale=0.5, xtick_locs=None, frame_rate_beh=None, save_file=None, + frame_rate=15): + """Produce movie with original video, ae reconstructed video, and neural reconstructed video. + + Latent traces are additionally plotted, as well as the residual between the ae reconstruction + and the neural reconstruction. + + Parameters + ---------- + ims_orig : :obj:`np.ndarray` + original images; shape (n_frames, n_channels, y_pix, x_pix) + ims_recon_ae : :obj:`np.ndarray` + images reconstructed by AE; shape (n_frames, n_channels, y_pix, x_pix) + ims_recon_neural : :obj:`np.ndarray` + images reconstructed by neural activity; shape (n_frames, n_channels, y_pix, x_pix) + latents_ae : :obj:`np.ndarray` + original AE latents; shape (n_frames, n_latents) + latents_neural : :obj:`np.ndarray` + latents reconstruted by neural activity; shape (n_frames, n_latents) + ae_model_class : :obj:`str`, optional + 'AE', 'VAE', etc. for plot titles + colored_predictions : :obj:`bool`, optional + False to plot reconstructions in black, True to plot in different colors + scale : :obj:`int`, optional + scale magnitude of traces + xtick_locs : :obj:`array-like`, optional + tick locations in units of bins + frame_rate_beh : :obj:`float`, optional + frame rate of behavorial video; to properly relabel xticks + save_file : :obj:`str`, optional + full save file (path and filename) + frame_rate : :obj:`float`, optional + frame rate of saved movie + + """ + + means = np.nanmean(latents_ae, axis=0) + std = np.nanstd(latents_ae) / scale + + latents_ae_sc = (latents_ae - means) / std + latents_dec_sc = (latents_neural - means) / std + + n_channels, y_pix, x_pix = ims_orig.shape[1:] + n_time, n_ae_latents = latents_ae.shape + + n_cols = 3 + n_rows = 2 + offset = 2 # 0 if ims_recon_lin is None else 1 + scale_ = 5 + fig_width = scale_ * n_cols * n_channels / 2 + fig_height = y_pix / x_pix * scale_ * n_rows / 2 + fig = plt.figure(figsize=(fig_width, fig_height + offset)) + + gs = GridSpec(n_rows, n_cols, figure=fig) + axs = [] + axs.append(fig.add_subplot(gs[0, 0])) # 0: original frames + axs.append(fig.add_subplot(gs[0, 1])) # 1: ae reconstructed frames + axs.append(fig.add_subplot(gs[0, 2])) # 2: neural reconstructed frames + axs.append(fig.add_subplot(gs[1, 0])) # 3: residual + axs.append(fig.add_subplot(gs[1, 1:3])) # 4: ae and predicted ae latents + for i, ax in enumerate(fig.axes): + ax.set_yticks([]) + if i > 2: + ax.get_xaxis().set_tick_params(labelsize=12, direction='in') + axs[0].set_xticks([]) + axs[1].set_xticks([]) + axs[2].set_xticks([]) + axs[3].set_xticks([]) + + # check that the axes are correct + fontsize = 12 + idx = 0 + axs[idx].set_title('Original', fontsize=fontsize) + idx += 1 + axs[idx].set_title('%s reconstructed' % ae_model_class, fontsize=fontsize) + idx += 1 + axs[idx].set_title('Neural reconstructed', fontsize=fontsize) + idx += 1 + axs[idx].set_title('Reconstructions residual', fontsize=fontsize) + idx += 1 + axs[idx].set_title('%s latent predictions' % ae_model_class, fontsize=fontsize) + if xtick_locs is not None and frame_rate_beh is not None: + axs[idx].set_xticks(xtick_locs) + axs[idx].set_xticklabels((np.asarray(xtick_locs) / frame_rate_beh).astype('int')) + axs[idx].set_xlabel('Time (s)', fontsize=fontsize) + else: + axs[idx].set_xlabel('Time (bins)', fontsize=fontsize) + + time = np.arange(n_time) + + ims_res = ims_recon_ae - ims_recon_neural + + im_kwargs = {'animated': True, 'cmap': 'gray', 'vmin': 0, 'vmax': 1} + tr_kwargs = {'animated': True, 'linewidth': 2} + latents_ae_color = [0.2, 0.2, 0.2] + + label_ae_base = '%s latents' % ae_model_class + label_dec_base = 'Predicted %s latents' % ae_model_class + + # ims is a list of lists, each row is a list of artists to draw in the + # current frame; here we are just animating one artist, the image, in + # each frame + ims = [] + for i in range(n_time): + + ims_curr = [] + idx = 0 + + if i % 100 == 0: + print('processing frame %03i/%03i' % (i, n_time)) + + ################### + # behavioral videos + ################### + # original video + ims_tmp = ims_orig[i, 0] if n_channels == 1 else concat(ims_orig[i]) + im = axs[idx].imshow(ims_tmp, **im_kwargs) + ims_curr.append(im) + idx += 1 + + # ae reconstruction + ims_tmp = ims_recon_ae[i, 0] if n_channels == 1 else concat(ims_recon_ae[i]) + im = axs[idx].imshow(ims_tmp, **im_kwargs) + ims_curr.append(im) + idx += 1 + + # neural reconstruction + ims_tmp = ims_recon_neural[i, 0] if n_channels == 1 else concat(ims_recon_neural[i]) + im = axs[idx].imshow(ims_tmp, **im_kwargs) + ims_curr.append(im) + idx += 1 + + # residual + ims_tmp = ims_res[i, 0] if n_channels == 1 else concat(ims_res[i]) + im = axs[idx].imshow(0.5 + ims_tmp, **im_kwargs) + ims_curr.append(im) + idx += 1 + + ######## + # traces + ######## + # latents over time + axs[idx].set_prop_cycle(None) # reset colors + for latent in range(n_ae_latents): + if colored_predictions: + latents_dec_color = axs[idx]._get_lines.get_next_color() + else: + latents_dec_color = [0, 0, 0] + # just put labels on last lvs + if latent == n_ae_latents - 1 and i == 0: + label_ae = label_ae_base + label_dec = label_dec_base + else: + label_ae = None + label_dec = None + im = axs[idx].plot( + time[0:i + 1], latent + latents_ae_sc[0:i + 1, latent], + color=latents_ae_color, alpha=0.7, label=label_ae, + **tr_kwargs)[0] + axs[idx].spines['top'].set_visible(False) + axs[idx].spines['right'].set_visible(False) + axs[idx].spines['left'].set_visible(False) + ims_curr.append(im) + im = axs[idx].plot( + time[0:i + 1], latent + latents_dec_sc[0:i + 1, latent], + color=latents_dec_color, label=label_dec, **tr_kwargs)[0] + axs[idx].spines['top'].set_visible(False) + axs[idx].spines['right'].set_visible(False) + axs[idx].spines['left'].set_visible(False) + if colored_predictions: + # original latents - gray + orig_line = mlines.Line2D([], [], color=[0.2, 0.2, 0.2], linewidth=3, alpha=0.7) + # predicted latents - cycle through some colors + colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] + dls = [] + for c in range(5): + dls.append(mlines.Line2D( + [], [], linewidth=3, linestyle='--', dashes=(0, 3 * c, 20, 1), + color='%s' % colors[c])) + plt.legend( + [orig_line, tuple(dls)], [label_ae_base, label_dec_base], + loc='lower right', fontsize=fontsize, frameon=True, framealpha=0.7, + edgecolor=[1, 1, 1]) + else: + plt.legend( + loc='lower right', fontsize=fontsize, frameon=True, + framealpha=0.7, edgecolor=[1, 1, 1]) + ims_curr.append(im) + ims.append(ims_curr) + + plt.tight_layout(pad=0) + + ani = animation.ArtistAnimation(fig, ims, blit=True, repeat_delay=1000) + save_movie(save_file, ani, frame_rate=frame_rate) + + +def plot_neural_reconstruction_traces_wrapper( + hparams, save_file=None, trial=None, xtick_locs=None, frame_rate=None, format='png', + **kwargs): + """Plot ae latents and their neural reconstructions. + + This is a high-level function that loads the model described in the hparams dictionary and + produces the necessary predicted latents. + + Parameters + ---------- + hparams : :obj:`dict` + needs to contain enough information to specify an ae latent decoder + save_file : :obj:`str` + full save file (path and filename) + trial : :obj:`int`, optional + if :obj:`NoneType`, use first test trial + xtick_locs : :obj:`array-like`, optional + tick locations in units of bins + frame_rate : :obj:`float`, optional + frame rate of behavorial video; to properly relabel xticks + format : :obj:`str`, optional + any accepted matplotlib save format, e.g. 'png' | 'pdf' | 'jpeg' + + Returns + ------- + :obj:`matplotlib.figure.Figure` + matplotlib figure handle of plot + + """ + + # find good trials + import copy + from behavenet.data.utils import get_transforms_paths + from behavenet.data.data_generator import ConcatSessionsGenerator + + # ae data + hparams_ae = copy.copy(hparams) + hparams_ae['experiment_name'] = hparams['ae_experiment_name'] + hparams_ae['model_class'] = hparams['ae_model_class'] + hparams_ae['model_type'] = hparams['ae_model_type'] + + ae_transform, ae_path = get_transforms_paths('ae_latents', hparams_ae, None) + + # ae predictions data + hparams_dec = copy.copy(hparams) + hparams_dec['neural_ae_experiment_name'] = hparams['decoder_experiment_name'] + hparams_dec['neural_ae_model_class'] = hparams['decoder_model_class'] + hparams_dec['neural_ae_model_type'] = hparams['decoder_model_type'] + ae_pred_transform, ae_pred_path = get_transforms_paths( + 'neural_ae_predictions', hparams_dec, None) + + signals = ['ae_latents', 'ae_predictions'] + transforms = [ae_transform, ae_pred_transform] + paths = [ae_path, ae_pred_path] + + data_generator = ConcatSessionsGenerator( + hparams['data_dir'], [hparams], + signals_list=[signals], transforms_list=[transforms], paths_list=[paths], + device='cpu', as_numpy=False, batch_load=True, rng_seed=0) + + if trial is None: + # choose first test trial + trial = data_generator.datasets[0].batch_idxs['test'][0] + + batch = data_generator.datasets[0][trial] + traces_ae = batch['ae_latents'].cpu().detach().numpy() + traces_neural = batch['ae_predictions'].cpu().detach().numpy() + + n_max_lags = hparams.get('n_max_lags', 0) # only plot valid segment of data + if n_max_lags > 0: + fig = plot_neural_reconstruction_traces( + traces_ae[n_max_lags:-n_max_lags], traces_neural[n_max_lags:-n_max_lags], + save_file, xtick_locs, frame_rate, format, **kwargs) + else: + fig = plot_neural_reconstruction_traces( + traces_ae, traces_neural, save_file, xtick_locs, frame_rate, format, **kwargs) + return fig + + +def plot_neural_reconstruction_traces( + traces_ae, traces_neural, save_file=None, xtick_locs=None, frame_rate=None, format='png', + scale=0.5, max_traces=8, add_r2=True, add_legend=True, colored_predictions=True): + """Plot ae latents and their neural reconstructions. + + Parameters + ---------- + traces_ae : :obj:`np.ndarray` + shape (n_frames, n_latents) + traces_neural : :obj:`np.ndarray` + shape (n_frames, n_latents) + save_file : :obj:`str`, optional + full save file (path and filename) + xtick_locs : :obj:`array-like`, optional + tick locations in units of bins + frame_rate : :obj:`float`, optional + frame rate of behavorial video; to properly relabel xticks + format : :obj:`str`, optional + any accepted matplotlib save format, e.g. 'png' | 'pdf' | 'jpeg' + scale : :obj:`int`, optional + scale magnitude of traces + max_traces : :obj:`int`, optional + maximum number of traces to plot, for easier visualization + add_r2 : :obj:`bool`, optional + print R2 value on plot + add_legend : :obj:`bool`, optional + print legend on plot + colored_predictions : :obj:`bool`, optional + color predictions using default seaborn colormap; else predictions are black + + + Returns + ------- + :obj:`matplotlib.figure.Figure` + matplotlib figure handle + + """ + + import seaborn as sns + + sns.set_style('white') + sns.set_context('poster') + + means = np.nanmean(traces_ae, axis=0) + std = np.nanstd(traces_ae) / scale # scale for better visualization + + traces_ae_sc = (traces_ae - means) / std + traces_neural_sc = (traces_neural - means) / std + + traces_ae_sc = traces_ae_sc[:, :max_traces] + traces_neural_sc = traces_neural_sc[:, :max_traces] + + fig = plt.figure(figsize=(12, 8)) + if colored_predictions: + plt.plot(traces_neural_sc + np.arange(traces_neural_sc.shape[1]), linewidth=3) + else: + plt.plot(traces_neural_sc + np.arange(traces_neural_sc.shape[1]), linewidth=3, color='k') + plt.plot( + traces_ae_sc + np.arange(traces_ae_sc.shape[1]), color=[0.2, 0.2, 0.2], linewidth=3, + alpha=0.7) + + # add legend if desired + if add_legend: + # original latents - gray + orig_line = mlines.Line2D([], [], color=[0.2, 0.2, 0.2], linewidth=3, alpha=0.7) + # predicted latents - cycle through some colors + colors = plt.rcParams['axes.prop_cycle'].by_key()['color'] + dls = [] + for c in range(5): + dls.append(mlines.Line2D( + [], [], linewidth=3, linestyle='--', dashes=(0, 3 * c, 20, 1), + color='%s' % colors[c])) + plt.legend( + [orig_line, tuple(dls)], ['Original latents', 'Predicted latents'], + loc='lower right', frameon=True, framealpha=0.7, edgecolor=[1, 1, 1]) + + # add r2 info if desired + if add_r2: + from sklearn.metrics import r2_score + r2 = r2_score(traces_ae, traces_neural, multioutput='variance_weighted') + plt.text( + 0.05, 0.06, '$R^2$=%1.3f' % r2, horizontalalignment='left', verticalalignment='bottom', + transform=plt.gca().transAxes, + bbox=dict(facecolor='white', alpha=0.7, edgecolor=[1, 1, 1])) + + if xtick_locs is not None and frame_rate is not None: + plt.xticks(xtick_locs, (np.asarray(xtick_locs) / frame_rate).astype('int')) + plt.xlabel('Time (s)') + else: + plt.xlabel('Time (bins)') + plt.ylabel('Latent state') + plt.yticks([]) + + if save_file is not None: + make_dir_if_not_exists(save_file) + plt.savefig(save_file + '.' + format, dpi=300, format=format) + + plt.show() + return fig diff --git a/configs/ae_jsons/ae_arch_2.json b/configs/ae_jsons/ae_arch_2.json index 4843491..fbfc7be 100644 --- a/configs/ae_jsons/ae_arch_2.json +++ b/configs/ae_jsons/ae_arch_2.json @@ -59,4 +59,4 @@ "ae_decoding_last_FF_layer": 0 # type: int, help: 0 = False, 1 = True -} \ No newline at end of file +} diff --git a/configs/ae_jsons/ae_arch_default.json b/configs/ae_jsons/ae_arch_default.json index c51173b..db81242 100644 --- a/configs/ae_jsons/ae_arch_default.json +++ b/configs/ae_jsons/ae_arch_default.json @@ -59,4 +59,4 @@ "ae_decoding_last_FF_layer": 0 # type: int, help: 0 = False, 1 = True -} \ No newline at end of file +} diff --git a/configs/ae_jsons/ae_model.json b/configs/ae_jsons/ae_model.json index 6fff2f0..baaf763 100644 --- a/configs/ae_jsons/ae_model.json +++ b/configs/ae_jsons/ae_model.json @@ -42,12 +42,12 @@ "beta_tcvae.beta_anneal_epochs": 100, # type: int, help: number of epochs to linearly increase betatcvae beta -"sss_vae.alpha": 1, # type: int, help: weight on label reconstruction term +"ps_vae.alpha": 1, # type: int, help: weight on label reconstruction term -"sss_vae.beta": 1, # type: int, help: weight on total correlation term +"ps_vae.beta": 1, # type: int, help: weight on total correlation term -"sss_vae.gamma": 1, # type: int, help: weight on subspace overlap term +"ps_vae.gamma": 1, # type: int, help: weight on subspace overlap term -"sss_vae.anneal_epochs": 100 # type: int, help: number of epochs to linearly increase sss beta value +"ps_vae.anneal_epochs": 100 # type: int, help: number of epochs to linearly increase sss beta value } diff --git a/configs/ae_jsons/ae_training.json b/configs/ae_jsons/ae_training.json index d80dfb1..bb98c65 100644 --- a/configs/ae_jsons/ae_training.json +++ b/configs/ae_jsons/ae_training.json @@ -44,4 +44,4 @@ "trial_splits": "8;1;1;0" # type: str, help: i;j;k;l correspond to train;val;test;gap' -} \ No newline at end of file +} diff --git a/configs/arhmm_jsons/arhmm_labels_model.json b/configs/arhmm_jsons/arhmm_labels_model.json index de15778..76a59f9 100644 --- a/configs/arhmm_jsons/arhmm_labels_model.json +++ b/configs/arhmm_jsons/arhmm_labels_model.json @@ -32,4 +32,4 @@ "model_type": null -} \ No newline at end of file +} diff --git a/configs/arhmm_jsons/arhmm_model.json b/configs/arhmm_jsons/arhmm_model.json index 1a19d52..80cb22f 100644 --- a/configs/arhmm_jsons/arhmm_model.json +++ b/configs/arhmm_jsons/arhmm_model.json @@ -34,6 +34,8 @@ "ae_version": "best", +"ae_model_class": "ae", # class of AE, ae, vae, etc + "ae_model_type": "conv", # type of AE, linear or conv "n_ae_latents": 9, # type: int diff --git a/configs/data_default.json b/configs/data_default.json index a39d4be..514d65b 100644 --- a/configs/data_default.json +++ b/configs/data_default.json @@ -31,6 +31,8 @@ "use_output_mask": false, # type: boolean +"use_label_mask": false, # type: boolean + ######################## ## Neural data params ## diff --git a/configs/decoding_jsons/decoding_ae_model.json b/configs/decoding_jsons/decoding_ae_model.json index c95a1f3..d0e484b 100644 --- a/configs/decoding_jsons/decoding_ae_model.json +++ b/configs/decoding_jsons/decoding_ae_model.json @@ -27,6 +27,8 @@ "ae_version": "best", +"ae_model_class": "ae", # class of AE, ae, vae, etc + "ae_model_type": "conv", # type of AE, linear or conv "n_ae_latents": 9, # type: int @@ -47,7 +49,3 @@ "activation": "relu" # type: str, could be linear, relu, lrelu, sigmoid, tanh } - - - - diff --git a/configs/decoding_jsons/decoding_arhmm_model.json b/configs/decoding_jsons/decoding_arhmm_model.json index 60cfd7e..a752aa5 100644 --- a/configs/decoding_jsons/decoding_arhmm_model.json +++ b/configs/decoding_jsons/decoding_arhmm_model.json @@ -23,6 +23,8 @@ # specify which ARHMM to use (should match how you trained the AE) +"ae_model_class": "ae", # class of AE, ae, vae, etc + "ae_model_type": "conv", # type of AE, linear or conv "n_ae_latents": 9, # type: int @@ -55,6 +57,3 @@ "activation": "relu" # type: str, could be linear, relu, lrelu, sigmoid, tanh } - - - diff --git a/configs/decoding_jsons/decoding_compute.json b/configs/decoding_jsons/decoding_compute.json index 44bbdb2..46f0cb6 100644 --- a/configs/decoding_jsons/decoding_compute.json +++ b/configs/decoding_jsons/decoding_compute.json @@ -30,5 +30,4 @@ "tt_n_cpu_workers": 3 # type: int - } diff --git a/configs/decoding_jsons/decoding_data.json b/configs/decoding_jsons/decoding_data.json index 24a5f2e..7d8a3d3 100644 --- a/configs/decoding_jsons/decoding_data.json +++ b/configs/decoding_jsons/decoding_data.json @@ -31,6 +31,8 @@ "use_output_mask": false, # type: boolean +"n_labels": null, # type: int + ######################## ## Neural data params ## @@ -57,4 +59,4 @@ "approx_batch_size": 200 # type: int, help: approximate batch size for memory calculation -} \ No newline at end of file +} diff --git a/configs/decoding_jsons/decoding_labels_model.json b/configs/decoding_jsons/decoding_labels_model.json new file mode 100644 index 0000000..3fd1bd2 --- /dev/null +++ b/configs/decoding_jsons/decoding_labels_model.json @@ -0,0 +1,32 @@ +{ + +############################# +## Commonly changed params ## +############################# + +"experiment_name": "grid_search", # type: str, name of this experiment + +"n_lags": [4], # type: int + +"n_max_lags": 8, # type: int, should match largest n_lags value (so all lags are evaluated on exact same data) + +"l2_reg": [1e-3], # type: float + +"rng_seed_model": 0, # type: int, help: control model initialization + +"model_class": "neural-labels", # type: str + + +######################## +## Model Architecture ## +######################## + +"model_type": "mlp", # type: str, currently mlp only option (mlp with 0 hidden layers is linear) + +"n_hid_layers": [1], # type: int + +"n_hid_units": [32], # type: int + +"activation": "relu" # type: str, could be linear, relu, lrelu, sigmoid, tanh + +} diff --git a/configs/decoding_jsons/decoding_training.json b/configs/decoding_jsons/decoding_training.json index ffdeb18..907de18 100644 --- a/configs/decoding_jsons/decoding_training.json +++ b/configs/decoding_jsons/decoding_training.json @@ -38,7 +38,6 @@ "train_frac": 1.0, # type: float, help: fraction of data -"trial_splits": "8;1;1;0" # type: str, help: i;j;k;l correspond to train;val;test;gap' - +"trial_splits": "8;1;1;0" # type: str, help: i;j;k;l correspond to train;val;test;gap } diff --git a/docs/requirements.txt b/docs/requirements.txt index 0daa7a2..ec5ba6f 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,7 +4,7 @@ sphinx-automodapi==0.12 commentjson==0.8.2 h5py==2.9.0 matplotlib==3.0.3 -notebook==6.0.3 +notebook==6.1.5 numpy==1.17.4 requests==2.22.0 scikit-image==0.15.0 diff --git a/docs/source/adv_user_guide.load_model.rst b/docs/source/adv_user_guide.load_model.rst new file mode 100644 index 0000000..5e2e921 --- /dev/null +++ b/docs/source/adv_user_guide.load_model.rst @@ -0,0 +1,139 @@ +.. _load_model: + +Loading a trained model +======================= + +After you've fit one or more models, often you'll want to load these models and their associated +data generator to perform further analyses. BehaveNet provides three methods for doing so: + +* :ref:`Method 1`: load the "best" model from a test-tube experiment +* :ref:`Method 2`: specify the model version in a test-tube experiment +* :ref:`Method 3`: specify the model hyperparameters in a test-tube experiment + +To illustrate these three methods we'll use an autoencoder as an example. Let's assume that we've trained 5 convolutional autoencoders with 10 latents, each with a different random seed for initializing the weights, and these have all been saved in the test-tube experiment ``ae-example``. + +.. _load_best_model: + +Method 1: load best model +------------------------- +The first option is to load the best model from ``ae-example``. The "best" model is defined as the one with the smallest loss value computed on validation data. If you set the parameter ``val_check_interval`` in the ae training json to a nonzero value before fitting, this information has already been computed and saved in a csv file, so this is a relatively fast option. The following code block shows how to load the best model, as well as the associated data generator, from ``ae-example``. + +.. code-block:: python + + # imports + from behavenet import get_user_dir + from behavenet.fitting.utils import get_best_model_and_data + from behavenet.fitting.utils import get_expt_dir + from behavenet.fitting.utils import get_lab_example + from behavenet.fitting.utils import get_session_dir + from behavenet.models import AE as Model + + # define necessary hyperparameters + hparams = { + 'data_dir': get_user_dir('data'), + 'save_dir': get_user_dir('save'), + 'experiment_name': 'ae-example', + 'model_class': 'ae', + 'model_type': 'conv', + 'n_ae_latents': 10, + } + + # programmatically fill out other hparams options + get_lab_example(hparams, 'musall', 'vistrained') + hparams['session_dir'], sess_ids = get_session_dir(hparams) + hparams['expt_dir'] = get_expt_dir(hparams) + + # use helper function to load model and data generator + model, data_generator = get_best_model_and_data(hparams, Model, version='best') + + +.. _specify_version: + +Method 2: specify the model version +----------------------------------- +The next option requires that you know in advance which test-tube version you want to load. In this example, we'll load version 3. All you need to do is replace ``version='best'`` with ``version=3`` in the final line above. + +.. code-block:: python + + # use helper function to load model and data generator + model, data_generator = get_best_model_and_data(hparams, Model, version=3) + + +.. _specify_hparams: + +Method 3: specify model hyperparameters +--------------------------------------- +The final option gives you the most control - you can specify all relevant hyperparameters needed to define the model and the data generator, and load that specific model. + +.. code-block:: python + + # imports + from behavenet import get_user_dir + from behavenet.fitting.utils import experiment_exists + from behavenet.fitting.utils import get_best_model_and_data + from behavenet.fitting.utils import get_expt_dir + from behavenet.fitting.utils import get_lab_example + from behavenet.fitting.utils import get_session_dir + from behavenet.models import AE as Model + + # define necessary hyperparameters + hparams = { + 'data_dir': get_user_dir('data'), + 'save_dir': get_user_dir('save'), + 'experiment_name': 'ae-example', + 'model_class': 'ae', + 'model_type': 'conv', + 'n_ae_latents': 10, + 'rng_seed_data': 0, + 'trial_splits': '8;1;1;0', + 'train_frac': 1, + 'rng_seed_model': 0, + 'fit_sess_io_layers': False, + 'learning_rate': 1e-4, + 'l2_reg': 0, + } + + # programmatically fill out other hparams options + get_lab_example(hparams, 'musall', 'vistrained') + hparams['session_dir'], sess_ids = get_session_dir(hparams) + hparams['expt_dir'] = get_expt_dir(hparams) + + # find the version for these hyperparameters; returns None for version if it doesn't exist + exists, version = experiment_exists(hparams, which_version=True) + + # use helper function to load model and data generator + model, data_generator = get_best_model_and_data(hparams, Model, version=version) + +You will need to specify the following entries in ``hparams`` regardless of the model class: + +* 'rng_seed_data' +* 'trial_splits' +* 'train_frac' +* 'rng_seed_model' +* 'model_class' +* 'model_type' + +For the autencoder, we need to additionally specify ``n_ae_latents``, ``fit_sess_io_layers``, ``learning_rate``, and ``l2_reg``. Check out the source code for :py:func:`behavenet.fitting.utils.get_model_params` to see which entries are required for other model classes. + + +Iterating through the data +-------------------------- + +Below is an example of how to iterate through the data generator and load batches of data: + +.. code-block:: python + + # select data type to load + dtype = 'train' # 'train' | 'val' | 'test' + + # reset data iterator for this data type + data_generator.reset_iterators(dtype) + + # loop through all batches for this data type + for i in range(data_generator.n_tot_batches[dtype]): + + batch, sess = data_generator.next_batch(dtype) + # "batch" is a dict with keys for the relevant signal, e.g. 'images', 'neural', etc. + # "sess" is an integer denoting the dataset this batch comes from + + # ... perform analyses ... diff --git a/docs/source/adv_user_guide.multisession.rst b/docs/source/adv_user_guide.multisession.rst new file mode 100644 index 0000000..9095d24 --- /dev/null +++ b/docs/source/adv_user_guide.multisession.rst @@ -0,0 +1,203 @@ +Training a model with multiple datasets +======================================= + +The statistical models that comprise BehaveNet - autoencoders, ARHMMs, neural network decoders - +often require large amounts of data to avoid overfitting. While the amount of data collected in an +hour long experimental session may suffice, every one of these models will benefit from additional +data. If data is collected from multiple experimental sessions, and these data are similar enough +(e.g. same camera placement/contrast across sessions), then you can train BehaveNet models on all +of this data simultaneously. + +BehaveNet provides two methods for specifying the experimental sessions used to train a model: + +* :ref:`Method 1`: use all sessions from a specified animal, experiment, or lab +* :ref:`Method 2`: specify the sessions in a csv file + +The first method is simpler, while the second method offers greater control. Both of these methods +require modifying the data configuration json before training. We'll use the Musall dataset as an +example; below is the relevant section of the json file located in +``behavenet/configs/data_default.json`` that we will modify below. + +.. code-block:: JSON + + "lab": "musall", # type: str + "expt": "vistrained", # type: str + "animal": "mSM30", # type: str + "session": "10-Oct-2017", # type: str + "sessions_csv": "", # type: str, help: specify multiple sessions + "all_source": "save", # type: str, help: "save" or "data" + +The Musall dataset provided with the repo (see ``behavenet/example/00_data.ipynb``) contains +autoencoders trained on two sessions individually, as well as a single autoencoder trained on both +sessions as an example of this feature. + + +.. _all_keyword: + +Method 1: the "all" keyword +--------------------------- +This method is appropriate if you want to fit a model on all sessions from a specified animal, +experiment, or lab. For example, if we want to fit a model on all sessions from animal +``mSM30``, we would modify the ``session`` parameter value to ``all``: + +.. code-block:: JSON + + "lab": "musall", # type: str + "expt": "vistrained", # type: str + "animal": "mSM30", # type: str + "session": "all", # type: str + "sessions_csv": "", # type: str, help: specify multiple sessions + "all_source": "save", # type: str, help: "save" or "data" + +In this case the resulting models will be stored in the directory +``save_dir/musall/vistrained/mSM30/multisession-xx``, where ``xx`` is selected automatically. +BehaveNet will create a csv file named ``session_info.csv`` inside the multisession directory that +lists the lab, expt, animal, and session for all sessions in that multisession. + + +If we want to fit a model on all sessions from all animals in the ``vistrained`` experiment, we +would modify the ``animal`` parameter value to ``all``: + +.. code-block:: JSON + + "lab": "musall", # type: str + "expt": "vistrained", # type: str + "animal": "all", # type: str + "session": "all", # type: str + "sessions_csv": "", # type: str, help: specify multiple sessions + "all_source": "save", # type: str, help: "save" or "data" + +In this case the resulting models will be stored in the directory +``save_dir/musall/vistrained/multisession-xx``. The string value for ``session`` does not +matter; BehaveNet searches for the ``all`` +keyword starting at the lab level and moves down; once it finds the ``all`` keyword it ignores all +further entries. + +.. note:: + + The ``all_source`` parameter in the json file is included to resolve an ambiguity with the + "all" keyword. For example, let's assume you use ``all`` at the session level for a single + animal. If data for 6 sessions exist for that animal, and BehaveNet models have been fit to 4 + of those 6 sessions, then setting ``"all_source": "data"`` will use all 6 sessions with data. + On the other hand, setting ``"all_source": "save"`` will use all 4 sessions that have been + previously used to fit models. + +.. _sessions_csv: + +Method 2: specify sessions in a csv file +---------------------------------------- +This method is appropriate if you want finer control over which sessions are included; for example, +if you want all sessions from one animal, as well as all but one session from another animal. To +specify these sessions, you can construct a csv file with the four column headers ``lab``, +``expt``, ``animal``, and ``session`` (see below). You can then provide this csv file +(let's say it's called ``data_dir/example_sessions.csv``) as the value for the ``sessions_csv`` +parameter: + +.. code-block:: JSON + + "lab": "musall", # type: str + "expt": "vistrained", # type: str + "animal": "all", # type: str + "session": "all", # type: str + "sessions_csv": "data_dir/example_sessions.csv", # type: str, help: specify multiple sessions + "all_source": "save", # type: str, help: "save" or "data" + +The ``sessions_csv`` parameter takes precedence over any values supplied for ``lab``, ``expt``, +``animal``, ``session``, and ``all_source``. + +Below is an example csv file that includes two sessions from one animal: + +.. code-block:: text + + lab,expt,animal,session + musall,vistrained,mSM36,05-Dec-2017 + musall,vistrained,mSM36,07-Dec-2017 + +Here is another example that include the previous two sessions, as well as a third from a different +animal: + +.. code-block:: text + + lab,expt,animal,session + musall,vistrained,mSM30,12-Oct-2017 + musall,vistrained,mSM36,05-Dec-2017 + musall,vistrained,mSM36,07-Dec-2017 + +Loading a trained multisession model +------------------------------------ + +The approach is almost identical to that laid out in :ref:`Loading a trained model`; +namely, you can either specify the "best" model, the model version, or fully specify all the model +hyperparameters. The one necessary change is to alert BehaveNet that you want to load a +multisession model. As above, you can do this by either using the "all" keyword or a csv file. +The code snippets below illustrate both of these methods when loading the "best" model. + +Method 1: use the "all" keyword to specify all sessions for a particular animal: + +.. code-block:: python + + # imports + from behavenet import get_user_dir + from behavenet.fitting.utils import get_best_model_and_data + from behavenet.fitting.utils import get_expt_dir + from behavenet.fitting.utils import get_lab_example + from behavenet.fitting.utils import get_session_dir + from behavenet.models import AE as Model + + # define necessary hyperparameters + hparams = { + 'data_dir': get_user_dir('data'), + 'save_dir': get_user_dir('save'), + 'lab': 'musall', + 'expt': 'vistrained', + 'animal': 'mSM30', + 'session': 'all', # use all sessions for animal mSM30 + 'experiment_name': 'ae-example', + 'model_class': 'ae', + 'model_type': 'conv', + 'n_ae_latents': 10, + } + + # programmatically fill out other hparams options + hparams['session_dir'], sess_ids = get_session_dir(hparams) + hparams['expt_dir'] = get_expt_dir(hparams) + + # use helper function to load model and data generator + model, data_generator = get_best_model_and_data(hparams, Model, version='best') + +As above, the ``all`` keyword can also be used at the animal or expt level, though not currently at +the lab level. + +Method 2: use a sessions csv file: + +.. code-block:: python + + # imports + from behavenet import get_user_dir + from behavenet.fitting.utils import get_best_model_and_data + from behavenet.fitting.utils import get_expt_dir + from behavenet.fitting.utils import get_lab_example + from behavenet.fitting.utils import get_session_dir + from behavenet.models import AE as Model + + # define necessary hyperparameters + hparams = { + 'data_dir': get_user_dir('data'), + 'save_dir': get_user_dir('save'), + 'sessions_csv': '/path/to/csv/file', + 'experiment_name': 'ae-example', + 'model_class': 'ae', + 'model_type': 'conv', + 'n_ae_latents': 10, + } + + # programmatically fill out other hparams options + hparams['session_dir'], sess_ids = get_session_dir(hparams) + hparams['expt_dir'] = get_expt_dir(hparams) + + # use helper function to load model and data generator + model, data_generator = get_best_model_and_data(hparams, Model, version='best') + +In both cases, iterating through the data proceeds exactly as when using a single session, and the +second return value from ``data_generator.next_batch()`` identifies which session the batch belongs +to. diff --git a/docs/source/adv_user_guide.psvae_hparam_search.rst b/docs/source/adv_user_guide.psvae_hparam_search.rst new file mode 100644 index 0000000..c69af5c --- /dev/null +++ b/docs/source/adv_user_guide.psvae_hparam_search.rst @@ -0,0 +1,157 @@ +.. _psvae_hparams: + +PS-VAE hyperparameter search guide +=================================== + +The PS-VAE objective function :math:`\mathscr{L}_{\text{PS-VAE}}` is comprised of several +different terms: + +.. math:: + + \mathscr{L}_{\text{PS-VAE}} = + \mathscr{L}_{\text{frames}} + + \alpha \mathscr{L}_{\text{labels}} + + \mathscr{L}_{\text{KL-s}} + + \mathscr{L}_{\text{ICMI}} + + \beta \mathscr{L}_{\text{TC}} + + \mathscr{L}_{\text{DWKL}} + + \gamma \mathscr{L}_{\text{orth}} + +where + + * :math:`\mathscr{L}_{\text{frames}}`: log-likelihood of the video frames + * :math:`\mathscr{L}_{\text{labels}}`: log-likelihood of the labels + * :math:`\mathscr{L}_{\text{KL-s}}`: KL divergence of the supervised latents + * :math:`\mathscr{L}_{\text{ICMI}}`: index-code mutual information of the unsupervised latents + * :math:`\mathscr{L}_{\text{TC}}`: total correlation of the unsupervised latents + * :math:`\mathscr{L}_{\text{DWKL}}`: dimension-wise KL of the unsupervised latents + * :math:`\mathscr{L}_{\text{orth}}`: orthogonality of the full latent space (supervised + unsupervised) + +There are three important hyperparameters of the model that we address below: :math:`\alpha`, which +weights the reconstruction of the labels; :math:`\beta`, which weights the factorization of the +unsupervised latent space; and :math:`\gamma`, which weights the orthogonality of the entire latent +space. The purpose of this guide is to propose a series of model fits that efficiently explores +this space of hyperparameters, as well as point out several BehaveNet plotting utilities to assist +in this exploration. + + +How to select :math:`\alpha` +---------------------------- +The hyperparameter :math:`\alpha` controls the strength of the label log-likelihood term, which +needs to be balanced against the frame log-likelihood term. We first recommend z-scoring each +individual label, which removes the scale of the labels as a confound. We then recommend fitting +models with a range of :math:`\alpha` values, while setting the defaults :math:`\beta=1` (no extra +weight on the total correlation term) and :math:`\gamma=0` (no constraint on orthogonality). In our +experience the range :math:`\alpha=[50, 100, 500, 1000]` is a reasonable range to start with. The +"best" value for :math:`\alpha` is subjective because it involves a tradeoff between pixel +log-likelihood (or the related mean square error, MSE) and label log-likelihood (or MSE). +After choosing a suitable value, we will fix :math:`\alpha` and vary :math:`\beta` and +:math:`\gamma`. + + +How to select :math:`\beta` and :math:`\gamma` +---------------------------------------------- +The choice of :math:`\beta` and :math:`\gamma` is more difficult because there does not yet exist +a single robust measure of "disentanglement" that can tell us which models learn a suitable +unsupervised representation. Instead we will fit models with a range of hypeparameters, then use +a quantitative metric to guide a qualitative analysis. + +A reasonable range to start with is :math:`\beta=[1, 5, 10, 20]` and :math:`\gamma=1000`. While it +is possible to extend the range for :math:`\gamma`, we have found :math:`\gamma=1000` to work for +many datasets. How, then, do we choose a good value for :math:`\beta`? Currently our best advice is +to compute the correlation of the training data across all pairs of unsupervised dimensions. The +value of :math:`\beta` that minimizes the average of the pairwise correlations is a good place to +start more qualitative evaluations. + +Ultimately, the choice of the "best" model comes down to a qualitative evaluation, the *latent +traversal*. A latent traversal is the result of changing the value of a latent dimension while +keeping the value of all other latent dimensions fixed. If the model has learned an interpretable +representation then the resulting generated frames should show one single behavioral feature +changing per dimension - an arm, or a jaw, or the chest (see :ref:`below` +for more information on tools +for constructing and visualizing these traversals). In order to choose the "best" model, we perform +these latent traversals for all values of :math:`\beta` and look at the resulting latent traversal +outputs. The model with the (subjectively) most interpretable dimensions is then chosen. + + +A note on model robustness +-------------------------- +We have found the PS-VAE to be somewhat sensitive to initialization of the neural network +parameters. We also recommend choosing the set of hyperparamters with the lowest pairwise +correlations and refitting the model with several random seeds (by changing the ``rng_seed_model`` +parameter of the ``ae_model.json`` file), which may lead to even better results. + +.. _ps_vae_plotting: + +Tools for investigating PS-VAE model fits +------------------------------------------ +The functions listed below are provided in the BehaveNet plotting module ( +:mod:`behavenet.plotting`) to facilitate model checking and comparison at different stages. + +Hyperparameter search visualization +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +The function :func:`behavenet.plotting.cond_ae_utils.plot_hyperparameter_search_results` creates +a variety of diagnostic plots after the user has performed the :math:`\alpha` search and the +:math:`\beta/\gamma` search detailed above: + +- pixel mse as a function of :math:`\alpha`, num latents (for fixed :math:`\beta, \gamma`) +- label mse as a function of :math:`\alpha`, num_latents (for fixed :math:`\beta, \gamma`) +- pixel mse as a function of :math:`\beta, \gamma` (for fixed :math:`\alpha`, n_ae_latents) +- label mse as a function of :math:`\beta, \gamma` (for fixed :math:`\alpha`, n_ae_latents) +- index-code mutual information (part of the KL decomposition) as a function of + :math:`\beta, \gamma` (for fixed :math:`\alpha`, n_ae_latents) +- total correlation(part of the KL decomposition) as a function of :math:`\beta, \gamma` + (for fixed :math:`\alpha`, n_ae_latents) +- dimension-wise KL (part of the KL decomposition) as a function of :math:`\beta, \gamma` + (for fixed :math:`\alpha`, n_ae_latents) +- average correlation coefficient across all pairs of unsupervised latent dims as a function of + :math:`\beta, \gamma` (for fixed :math:`\alpha`, n_ae_latents) +- subspace overlap computed as :math:`||[A; B] - I||_2^2` for :math:`A, B` the projections to the + supervised and unsupervised subspaces, respectively, and :math:`I` the identity - as a function + of :math:`\beta, \gamma` (for fixed :math:`\alpha`, n_ae_latents) +- example subspace overlap matrix for :math:`\gamma=0` and :math:`\beta=1`, with fixed + :math:`\alpha`, n_ae_latents +- example subspace overlap matrix for :math:`\gamma=1000` and :math:`\beta=1`, with fixed + :math:`\alpha`, n_ae_latents + +These plots help with the selection of hyperparameter settings. + +Model training curves +^^^^^^^^^^^^^^^^^^^^^ +The function :func:`behavenet.plotting.cond_ae_utils.plot_psvae_training_curves` creates training +plots for each term in the PS-VAE objective function for a *single* model: + +- total loss +- pixel mse +- label R^2 (note the objective function contains the label MSE, but R^2 is easier to parse) +- KL divergence of supervised latents +- index-code mutual information of unsupervised latents +- total correlation of unsupervised latents +- dimension-wise KL of unsupervised latents +- subspace overlap + +A function argument allows the user to plot either training or validation curves. These plots allow +the user to check whether or not models have trained to completion. + +Label reconstruction +^^^^^^^^^^^^^^^^^^^^ +The function :func:`behavenet.plotting.cond_ae_utils.plot_label_reconstructions` creates a series +of plots that show the true labels and their PS-VAE reconstructions for a given list of batches. +These plots are useful for qualitatively evaluating the supervised subspace of the PS-VAE; +a quantitative evaluation (the label MSE) can be found in the ``metrics.csv`` file created in the +model folder during training. + +Latent traversals: plots +^^^^^^^^^^^^^^^^^^^^^^^^ +The function :func:`behavenet.plotting.cond_ae_utils.plot_latent_traversals` displays video frames +representing the traversal of chosen dimensions in the latent space. This function uses a +single base frame to create all traversals. + +Latent traversals: movies +^^^^^^^^^^^^^^^^^^^^^^^^^ +The function :func:`behavenet.plotting.cond_ae_utils.make_latent_traversal_movie` creates a +multi-panel movie with each panel showing traversals of an individual latent dimension. +The traversals will start at a lower bound, increase to an upper bound, then return to a lower +bound; the traversal of each dimension occurs simultaneously. It is also possible to specify +multiple base frames for the traversals; the traversal of each base frame is separated by +several blank frames. diff --git a/docs/source/adv_user_guide.rst b/docs/source/adv_user_guide.rst index 5c5c408..d90a958 100644 --- a/docs/source/adv_user_guide.rst +++ b/docs/source/adv_user_guide.rst @@ -7,4 +7,6 @@ Advanced user guide :caption: Contents: adv_user_guide.slurm - + adv_user_guide.load_model + adv_user_guide.multisession + adv_user_guide.psvae_hparam_search diff --git a/docs/source/data_structure.rst b/docs/source/data_structure.rst index 0f676a2..c725031 100644 --- a/docs/source/data_structure.rst +++ b/docs/source/data_structure.rst @@ -6,7 +6,6 @@ BehaveNet data structure Introduction ============ - In order to quickly and easily fit many models, BehaveNet uses a standardized data structure. "Raw" experimental data such as behavioral videos and (processed) neural data are stored in the `HDF5 file format `_. This file format can @@ -37,14 +36,12 @@ does not require all trials to be of the same length, but does require that for images and neural activity have the same number of frames. This may require you to interpolate/bin video or neural data differently than the rate at which it was acquired. -**Note 1**: for large experiments having all of this data in memory might be infeasible, and more -sophisticated processing will be required - -**Note 2**: neural data is only required for fitting decoding models; it is still possible to fit -autoencoders and ARHMMs when the HDF5 file only contains images +**Notes**: -**Note 3**: the python package ``h5py`` is required for creating the HDF5 file, and is -automatically installed with the BehaveNet package. +* for large experiments, having all of this (video) data in memory to create the HDF5 file might be infeasible, and more sophisticated processing will be required +* neural data is only required for fitting decoding models; it is still possible to fit autoencoders and ARHMMs when the HDF5 file only contains images +* masks should be the same size as images; a value of 0 excludes the pixel from the loss function, a value of 1 includes it +* the python package ``h5py`` is required for creating the HDF5 file, and is automatically installed with the BehaveNet package. .. code-block:: python @@ -85,7 +82,6 @@ automatically installed with the BehaveNet package. Identifying subsets of neurons ============================== - It is possible that the neural data used for encoding and decoding models will have natural partitions - for example, neurons belonging to different brain regions or cell types. In this case you may be interested in, say, decoding behavior from each brain region individually, as well as all together. BehaveNet provides this capability through the addition of another HDF5 group. This group can have any name, but for illustration purposes we will use the name "regions" (this name will be later be provided in the updated data json file). The "regions" group contains a second level of (again user-defined) groups, which will define different index groupings. As a concrete example, let's say we have neural data with 100 neurons: @@ -153,12 +149,18 @@ This HDF5 file will now have the following addtional datasets: * regions/idxs/AUD * regions/idxs/VIS -Just as the top-level group (here named "regions") can have an arbitrary name (later specified in the data json file), the second-level groups (here named "idxs_lr" and "idxs") can also have arbitrary names, and there can be any number of them, as long as the datasets within them contain valid indices into the neural data. The specific set of indices used for any analyses will be specified in the data json file. See the :ref:`decoding documentation` for an example of how to decode behavior using specified subsets of neurons. +Just as the top-level group (here named "regions") can have an arbitrary name (later specified in +the data json file), the second-level groups (here named "idxs_lr" and "idxs") can also have +arbitrary names, and there can be any number of them, as long as the datasets within them contain +valid indices into the neural data. The specific set of indices used for any analyses will be +specified in the data json file. See the :ref:`decoding documentation` for +an example of how to decode behavior using specified subsets of neurons. +.. _data_structure_labels: + Including labels for ARHMMs and conditional autoencoders ======================================================== - In order to fit :ref:`conditional autoencoder models`, you will need to include additional information about labels in the HDF5 file. These labels can be outputs from pose estimation software, or other behavior-related signals such as pupil diameter or lick times. These @@ -179,5 +181,8 @@ data, you simply need to change the ``model_class`` entry of the arhmm model jso .. note:: - The matrix subspace projection model implemented in BehaveNet learns a linear mapping from the original latent space to the predicted labels that **does not contain a bias term**. Therefore you should center each label before adding them to the HDF5 file. Additionally, normalizing each label by its standard deviation can make searching across msp weights less dependent on the size of the input image. - + The matrix subspace projection model implemented in BehaveNet learns a linear mapping from the + original latent space to the predicted labels that **does not contain a bias term**. Therefore + you should center each label before adding them to the HDF5 file. Additionally, normalizing + each label by its standard deviation can make searching across msp weights less dependent on + the size of the input image. diff --git a/docs/source/glossary.rst b/docs/source/glossary.rst index 97941dc..e81e7a2 100644 --- a/docs/source/glossary.rst +++ b/docs/source/glossary.rst @@ -4,7 +4,11 @@ Hyperparameter glossary ####################### -The BehaveNet code requires a diverse array of hyperparameters (hparams) to specify details about the data, computational resources, training algorithms, and the models themselves. This glossary contains a brief description for each of the hparams options. See the `example json files `_ for reasonable hparams defaults. +The BehaveNet code requires a diverse array of hyperparameters (hparams) to specify details about +the data, computational resources, training algorithms, and the models themselves. This glossary +contains a brief description for each of the hparams options. See the +`example json files `_ for reasonable +hparams defaults. Data ==== @@ -21,11 +25,17 @@ Data * **y_pixels** (*int*): number of behavioral video pixels in y dimension * **x_pixels** (*int*): number of behavioral video pixels in x dimension * **use_output_mask** (*bool*): `True`` to apply frame-wise output masks (must be a key ``masks`` in data HDF5 file) +* **use_label_mask** (*bool*): `True`` to apply frame-wise masks to labels in conditional ae models (must be a key ``labels_masks`` in data HDF5 file) +* **n_labels** (*bool*): specify number of labels when model_class is 'neural-labels' or 'labels-neural' * **neural_bin_size** (*float*): bin size of neural/video data (ms) * **neural_type** (*str*): 'spikes' | 'ca' * **approx_batch_size** (*str*): approximate batch size (number of frames) for gpu memory calculation -For encoders/decoders, additional information can be supplied to control which subsets of neurons are used for encoding/decoding. See the :ref:`data structure documentation` for detailed instructions on how to incorporate this information into your HDF5 data file. The following options must be added to the data json file (an example can be found `here `__): +For encoders/decoders, additional information can be supplied to control which subsets of neurons +are used for encoding/decoding. See the :ref:`data structure documentation` +for detailed instructions on how to incorporate this information into your HDF5 data file. The +following options must be added to the data json file (an example can be found +`here `__): * **subsample_idxs_group_0** (*str*): name of the top-level HDF5 group that contains index groups * **subsample_idxs_group_1** (*str*): name of the second-level HDF5 group that contains index datasets @@ -98,13 +108,17 @@ All models: * 'vae': variational autoencoder * 'beta-tcvae': variational autoencoder with beta tc-vae decomposition of elbo * 'cond-ae': conditional autoencoder + * 'cond-vae': conditional variational autoencoder * 'cond-ae-msp': autoencoder with matrix subspace projection loss + * 'ps-vae': partitioned subspace variational autoencoder * 'hmm': hidden Markov model * 'arhmm': autoregressive hidden Markov model * 'neural-ae': decode AE latents from neural activity + * 'neural-ae-me': decode motion energy of AE latents (absolute value of temporal difference) from neural activity * 'neural-arhmm': decode arhmm states from neural activity * 'ae-neural': predict neural activity from AE latents * 'arhmm-neural': predict neural activity from arhmm states + * 'labels-images': decode images from labels with a convolutional decoder * 'bayesian-decoding': baysian decoding of AE latents and arhmm states from neural activity @@ -152,6 +166,7 @@ ARHMM * **ae_experiment_name** (*str*): name of AE test-tube experiment * **ae_version** (*str* or *int*): 'best' to choose best version in AE experiment, otherwise an integer specifying test-tube version number +* **ae_model_class** (*str*): 'ae' | 'vae' | 'beta-tcvae' | ... * **ae_model_type** (*str*): 'conv' | 'linear' * **n_ae_latents** (*int*): number of autoencoder latents; this will be the observation dimension in the ARHMM * **export_train_plots** ('*bool*): ``True`` to automatically export training/validation log probability as a function of epoch upon completion of training @@ -181,6 +196,7 @@ For the continuous decoder: * **ae_experiment_name** (*str*): name of AE test-tube experiment * **ae_version** (*str* or *int*): 'best' to choose best version in AE experiment, otherwise an integer specifying test-tube version number +* **ae_model_class** (*str*): 'ae' | 'vae' | 'beta-tcvae' | ... * **ae_model_type** (*str*): 'conv' | 'linear' * **n_ae_latents** (*int*): number of autoencoder latents; this will be the dimension of the data predicted by the decoder * **ae_multisession** (*int*): use if loading latents from an AE that was trained on multiple datasets @@ -189,6 +205,7 @@ For the continuous decoder: For the discrete decoder: * **n_ae_latents** (*int*): number of autoencoder latents that the ARHMM was trained on +* **ae_model_class** (*str*): 'ae' | 'vae' | 'beta-tcvae' | ... * **ae_model_type** (*str*): 'conv' | 'linear' * **arhmm_experiment_name** (*str*): name of ARHMM test-tube experiment * **n_arhmm_states** (*int*): number of ARHMM discrete states; this will be the number of classes the decoder is trained on diff --git a/docs/source/user_guide.conditional_autoencoders.rst b/docs/source/user_guide.conditional_autoencoders.rst index cc065f0..9469289 100644 --- a/docs/source/user_guide.conditional_autoencoders.rst +++ b/docs/source/user_guide.conditional_autoencoders.rst @@ -3,19 +3,40 @@ Conditional autoencoders ======================== -One drawback to the use of unsupervised dimensionality reduction (performed by the convolutional autoencoder) is that the resulting latents are generally uninterpretable, because any animal movement in a behavioral video will be represented across many (if not all) of the latents. Thus there is no simple way to find an "arm" dimension that is separate from a "pupil" dimension, distinctions that may be important for downstream analyses. - -Semi-supervised approaches to dimensionality reduction offer a partial resolution to this problem. In this framework, the user first collects a set of markers that track body parts of interest over time. These markers can be, for example, the output of standard pose estimation software such as `DeepLabCut `_, `LEAP `_, or `DeepPoseKit `_. These markers can then be used to augment the latent space (using :ref:`conditional autoencoders`) or regularize the latent space (using the :ref:`matrix subspace projection loss`), both of which are described below. - -In order to fit these models, the data HDF5 needs to be augmented to include a new HDF5 group named ``labels``, which contains an hdf5 dataset for each trial. The labels for each trial must match up with the corresponding video frames; for example, if the image data in ``images/trial_0013`` contains 100 frames (a numpy array of shape [100, n_channels, y_pix, x_pix]), the label data in ``labels/trial_0013`` should contain the corresponding labels (a numpy array of shape [100, n_labels]). See the :ref:`data structure documentation` for more information). +One drawback to the use of unsupervised dimensionality reduction (performed by the convolutional +autoencoder) is that the resulting latents are generally uninterpretable, because any animal +movement in a behavioral video will be represented across many (if not all) of the latents. Thus +there is no simple way to find an "arm" dimension that is separate from a "pupil" dimension, +distinctions that may be important for downstream analyses. + +Semi-supervised approaches to dimensionality reduction offer a partial resolution to this problem. +In this framework, the user first collects a set of markers that track body parts of interest over +time. These markers can be, for example, the output of standard pose estimation software such as +`DeepLabCut `_, `LEAP `_, +or `DeepPoseKit `_. These markers can then be used to +augment the latent space (using :ref:`conditional autoencoders`) or regularize the latent +space (using the :ref:`matrix subspace projection loss`), both of which are described +below. + +In order to fit these models, the data HDF5 needs to be augmented to include a new HDF5 group named +``labels``, which contains an HDF5 dataset for each trial. The labels for each trial must match up +with the corresponding video frames; for example, if the image data in ``images/trial_0013`` +contains 100 frames (a numpy array of shape [100, n_channels, y_pix, x_pix]), the label data in +``labels/trial_0013`` should contain the corresponding labels (a numpy array of shape +[100, n_labels]). See the :ref:`data structure documentation` for more +information). .. _cond_ae: Conditional autoencoders ------------------------ -The `conditional autoencoder `_ implemented in BehaveNet is a simple extension of the convolutional autoencoder. Each frame is pushed through the encoder to produce a set of latents, which are concatenated with the corresponding labels; this augmented vector is then used as input to the decoder. +The `conditional autoencoder `_ +implemented in BehaveNet is a simple extension of the convolutional autoencoder. Each frame is +pushed through the encoder to produce a set of latents, which are concatenated with the +corresponding labels; this augmented vector is then used as input to the decoder. -To fit a single conditional autoencoder with the default CAE BehaveNet architecture, edit the ``model_class`` parameter of the ``ae_model.json`` file: +To fit a single conditional autoencoder with the default CAE BehaveNet architecture, edit the +``model_class`` parameter of the ``ae_model.json`` file: .. code-block:: json @@ -31,20 +52,35 @@ To fit a single conditional autoencoder with the default CAE BehaveNet architect "conditional_encoder": false } -Then to fit the model, use the ``ae_grid_search.py`` function using this updated model json. All other input jsons remain unchanged. +Then to fit the model, use the ``ae_grid_search.py`` function using this updated model json. All +other input jsons remain unchanged. -By concatenating the labels to the latents, we are learning a conditional decoder. We can also condition the latents on the labels by learning a conditional encoder. Turning on this feature requires an additional hdf5 group; documentation coming soon. +By concatenating the labels to the latents, we are learning a conditional decoder. We can also +condition the latents on the labels by learning a conditional encoder. Turning on this feature +requires an additional HDF5 group; documentation coming soon. .. _ae_msp: Matrix subspace projection loss ------------------------------- -An alternative way to obtain a more interpretable latent space is to encourage a subspace to predict the labels themselves, rather than appending them to the latents. With appropriate additions to the loss function, we can ensure that the subspace spanned by the label-predicting latents is orthogonal to the subspace spanned by the remaining unconstrained latents. This is the idea of the `matrix subspace projection loss `_. - -For example, imagine we are tracking 4 body parts, each with their own x-y coordinates for each frame. This gives us 8 dimensions of behavior to predict. If we fit a CAE with 10 latent dimensions, we will use 8 of those dimensions to predict the 8 marker dimensions - one latent dimension for each marker dimension. This leaves 2 unconstrained dimensions to predict remaining variability in the images not captured by the labels. The model is trained by minimizing the mean square error between the true and predicted images, as well as the true and predicted labels. Unlike the conditional autoencoder described above, this new loss function has an additional hyperparameter that governs the tradeoff between image reconstruction and label reconstruction. - -To fit a single autoencoder with the matrix subspace projection loss (and the default CAE BehaveNet architecture), edit the ``model_class`` and ``msp.alpha`` parameters of the ``ae_model.json`` file: +An alternative way to obtain a more interpretable latent space is to encourage a subspace to +predict the labels themselves, rather than appending them to the latents. With appropriate +additions to the loss function, we can ensure that the subspace spanned by the label-predicting +latents is orthogonal to the subspace spanned by the remaining unconstrained latents. This is the +idea of the `matrix subspace projection loss `_. + +For example, imagine we are tracking 4 body parts, each with their own x-y coordinates for each +frame. This gives us 8 dimensions of behavior to predict. If we fit a CAE with 10 latent +dimensions, we will use 8 of those dimensions to predict the 8 marker dimensions - one latent +dimension for each marker dimension. This leaves 2 unconstrained dimensions to predict remaining +variability in the images not captured by the labels. The model is trained by minimizing the mean +square error between the true and predicted images, as well as the true and predicted labels. +Unlike the conditional autoencoder described above, this new loss function has an additional +hyperparameter that governs the tradeoff between image reconstruction and label reconstruction. + +To fit a single autoencoder with the matrix subspace projection loss (and the default CAE BehaveNet +architecture), edit the ``model_class`` and ``msp.alpha`` parameters of the ``ae_model.json`` file: .. code-block:: json @@ -61,10 +97,53 @@ To fit a single autoencoder with the matrix subspace projection loss (and the de "conditional_encoder": false } -The ``msp.alpha`` parameter needs to be tuned for each dataset, but ``msp.alpha=1.0`` is a reasonable starting value if the labels have each been z-scored. +The ``msp.alpha`` parameter needs to be tuned for each dataset, but ``msp.alpha=1.0`` is a +reasonable starting value if the labels have each been z-scored. .. note:: - The matrix subspace projection model implemented in BehaveNet learns a linear mapping from the original latent space to the predicted labels that **does not contain a bias term**. Therefore you should center each label before adding them to the HDF5 file. Additionally, normalizing each label by its standard deviation can make searching across msp weights less dependent on the size of the input image. + The matrix subspace projection model implemented in BehaveNet learns a linear mapping from the + original latent space to the predicted labels that **does not contain a bias term**. Therefore + you should center each label before adding them to the HDF5 file. Additionally, normalizing + each label by its standard deviation can make searching across msp weights less dependent on + the size of the input image. + +Then to fit the model, use the ``ae_grid_search.py`` function using this updated model json. All +other input jsons remain unchanged. + + +.. _ps_vae: + +Partitioned subspace variational autoencoder +-------------------------------------------- +One downside to the MSP model introduced in the previous section is that the representation in the +unsupervised latent space may be difficult to interpret. The partitioned subspace VAE (PS-VAE) +attempts to remedy this situation by encouraging the unsupervised representation to be factorized, +which has shown to help with interpretability (see paper `here `_). + +To fit a single PS-VAE (and the default CAE BehaveNet +architecture), edit the ``model_class``, ``ps_vae.alpha``, ``ps_vae.beta`` and ``ps_vae.gamma`` +parameters of the ``ae_model.json`` file: + +.. code-block:: json + + { + "experiment_name": "ae-example", + "model_type": "conv", + "n_ae_latents": 12, + "l2_reg": 0.0, + "rng_seed_model": 0, + "fit_sess_io_layers": false, + "ae_arch_json": null, + "model_class": "ps-vae", + "ps_vae.alpha": 1000, + "ps_vae.beta": 10, + "ps_vae.gamma": 1000, + "conditional_encoder": false + } + +The ``ps_vae.alpha``, ``ps_vae.beta`` and ``ps_vae.gamma`` parameters need to be tuned for +each dataset. See the guidelines for setting these parameters :ref:`here`. -Then to fit the model, use the ``ae_grid_search.py`` function using this updated model json. All other input jsons remain unchanged. +Then to fit the model, use the ``ae_grid_search.py`` function using this updated model json. All +other input jsons remain unchanged. diff --git a/docs/source/user_guide.decoders.rst b/docs/source/user_guide.decoders.rst index 2f16f07..8dfedf0 100644 --- a/docs/source/user_guide.decoders.rst +++ b/docs/source/user_guide.decoders.rst @@ -1,12 +1,20 @@ Decoders ======== -The next step of the BehaveNet pipeline uses the neural activity to decode (or reconstruct) aspects of behavior. In particular, you may decode either the AE latents or the ARHMM states on a frame-by-frame basis given the surrounding window of neural activity. +The next step of the BehaveNet pipeline uses the neural activity to decode (or reconstruct) aspects +of behavior. In particular, you may decode either the AE latents or the ARHMM states on a +frame-by-frame basis given the surrounding window of neural activity. -The architecture options consist of a linear model or feedforward neural network: exact architecture parameters such as number of layers in the neural network can be specified in ``decoding_ae_model.json`` or ``decoding_arhmm_model.json``. The size of the window of neural activity used to reconstruct each frame of AE latents or ARHMM states is set by ``n_lags``: the neural activity from ``t-n_lags:t+n_lags`` will be used to predict the latents or states at time ``t``. +The architecture options consist of a linear model or feedforward neural network: exact +architecture parameters such as number of layers in the neural network can be specified in +``decoding_ae_model.json`` or ``decoding_arhmm_model.json``. The size of the window of neural +activity used to reconstruct each frame of AE latents or ARHMM states is set by ``n_lags``: the +neural activity from ``t-n_lags:t+n_lags`` will be used to predict the latents or states at time +``t``. - -To begin fitting decoding models, copy the example json files ``decoding_ae_model.json``, ``decoding_arhmm_model.json``, ``decoding_compute.json``, and ``decoding_training.json`` into your ``.behavenet`` directory. ``cd`` to the ``behavenet`` directory in the terminal, and run: +To begin fitting decoding models, copy the example json files ``decoding_ae_model.json``, +``decoding_arhmm_model.json``, ``decoding_compute.json``, and ``decoding_training.json`` into your +``.behavenet`` directory. ``cd`` to the ``behavenet`` directory in the terminal, and run: Decoding ARHMM states: @@ -16,22 +24,24 @@ Decoding ARHMM states: or -Decoding AE states: +Decoding AE latents: .. code-block:: console $: python behavenet/fitting/decoding_grid_search.py --data_config ~/.behavenet/musall_vistrained_params.json --model_config ~/.behavenet/decoding_ae_model.json --training_config ~/.behavenet/decoding_training.json --compute_config ~/.behavenet/decoding_compute.json - - - +It is also possible to decode the motion energy of the AE latents, defined as the absolute value of +the difference between neighboring time points; to do so make the following change in the model +json: ``model_class: 'neural-ae-me'`` .. _decoding_with_subsets: Decoding with subsets of neurons -------------------------------- -Continuing with the toy dataset introduced in the :ref:`data structure` documentation, below are some examples for how to modify the decoding data json file to decode from user-specified groups of neurons: +Continuing with the toy dataset introduced in the :ref:`data structure` +documentation, below are some examples for how to modify the decoding data json file to decode from +user-specified groups of neurons: **Example 0**: @@ -72,15 +82,20 @@ Fit separate decoders for each dataset of indices in the HDF5 group ``regions/id "subsample_method": "single" // subsample, use single regions } -In this toy example, these options will fit 4 decoders, each using a different set of indices: ``AUD_R``, ``AUD_L``, ``VIS_L``, and ``VIS_R``. +In this toy example, these options will fit 4 decoders, each using a different set of indices: +``AUD_R``, ``AUD_L``, ``VIS_L``, and ``VIS_R``. .. note:: - At this time the option ``subsample_idxs_dataset`` can only accept a single string as an argument; therefore you can use ``all`` to fit decoders using all datasets in the specified index group, or you can specify a single dataset (e.g. ``AUD_L`` in this example). You cannot, for example, provide a list of strings. + At this time the option ``subsample_idxs_dataset`` can only accept a single string as an + argument; therefore you can use ``all`` to fit decoders using all datasets in the specified + index group, or you can specify a single dataset (e.g. ``AUD_L`` in this example). You cannot, + for example, provide a list of strings. **Example 3**: -Use all indices *except* those in the HDF5 dataset ``regions/idxs_lr/AUD_L`` ("loo" stands for "leave-one-out"): +Use all indices *except* those in the HDF5 dataset ``regions/idxs_lr/AUD_L`` ("loo" stands for +"leave-one-out"): .. code-block:: javascript @@ -91,11 +106,13 @@ Use all indices *except* those in the HDF5 dataset ``regions/idxs_lr/AUD_L`` ("l "subsample_method": "loo" // subsample, use all but specified region } -In this toy example, the combined neurons from ``AUD_R``, ``VIS_L`` and ``VIS_R`` would be used for decoding (i.e. not the neurons in the specified region ``AUD_L``). +In this toy example, the combined neurons from ``AUD_R``, ``VIS_L`` and ``VIS_R`` would be used for +decoding (i.e. not the neurons in the specified region ``AUD_L``). -**Example 3**: +**Example 4**: -For each dataset in ``regions/indxs_lr``, fit a decoder that uses all indices *except* those in the dataset: +For each dataset in ``regions/indxs_lr``, fit a decoder that uses all indices *except* those in the +dataset: .. code-block:: javascript @@ -106,10 +123,31 @@ For each dataset in ``regions/indxs_lr``, fit a decoder that uses all indices *e "subsample_method": "loo" // subsample, use all but specified region } -Again referring to the toy example, these options will fit 4 decoders, each using a different set of indices: +Again referring to the toy example, these options will fit 4 decoders, each using a different set +of indices: 1. ``AUD_L``, ``VIS_L``, and ``VIS_R`` (not ``AUD_R``) 2. ``AUD_R``, ``VIS_L``, and ``VIS_R`` (not ``AUD_L``) 3. ``AUD_R``, ``AUD_L``, and ``VIS_L`` (not ``VIS_R``) 4. ``AUD_R``, ``AUD_L``, and ``VIS_R`` (not ``VIS_L``) + +.. _decoding_labels: + +Decoding arbitrary covariates +----------------------------- +BehaveNet also uses the above decoding infrastructure to allow users to decode an arbitrary set of +labels from neural activity; these could be markers from pose estimation software, stimulus +information, or other task variables. In order to fit these models, the data HDF5 needs to be +augmented to include a new HDF5 group named ``labels``, which contains an HDF5 dataset for each +trial. See the :ref:`data structure documentation ` for more information. + +Once the labels have been added to the data file, you can decode labels as you would CAE latents +above; the only changes that are necessary is the addition of the field ``n_labels`` in the data +json, and changing the model class in the model json from either ``neural-ae`` or ``neural-arhmm`` +to ``neural-labels``. + +.. note:: + + The current BehaveNet implementation only allows for decoding continuous labels using a + Gaussian noise distribution; support for binary and count data forthcoming. diff --git a/docs/source/user_guide.intro.rst b/docs/source/user_guide.intro.rst index ae56305..838695d 100644 --- a/docs/source/user_guide.intro.rst +++ b/docs/source/user_guide.intro.rst @@ -1,20 +1,24 @@ Introduction ============ -BehaveNet is a software package that provides tools for analyzing behavioral video and neural activity. Currently BehaveNet supports: +BehaveNet is a software package that provides tools for analyzing behavioral video and neural +activity. Currently BehaveNet supports: * Video compression using convolutional autoencoders * Video segmentation (and generation) using autoregressive hidden Markov models * Neural network decoding of videos from neural activity * Bayesian decoding of videos from neural activity -BehaveNet automatically saves models using a well-defined and flexible directory structure, allowing for easy management of many models and multiple datasets. +BehaveNet automatically saves models using a well-defined and flexible directory structure, +allowing for easy management of many models and multiple datasets. The command line interface -------------------------- -Users interact with BehaveNet using a command line interface, so all model fitting is done from the terminal. To simplify this process all necessary parameters are defined in four configuration files that can be manually updated using a text editor: +Users interact with BehaveNet using a command line interface, so all model fitting is done from the +terminal. To simplify this process all necessary parameters are defined in four configuration files +that can be manually updated using a text editor: * **data_config** - dataset ids, video frames sizes, etc. You can automatically generate this configuration file for a new dataset by following the instructions in the following section. * **model_config** - model hyperparameters @@ -31,7 +35,9 @@ For example, the command line call to fit an autoencoder would be (using the def $: cd behavenet $: python fitting/ae_grid_search.py --data_config ../configs/data_default.json --model_config ../configs/ae_model.json --training_config ../configs/ae_training.json --compute_config ../configs/ae_compute.json -We recommend that you copy the default config files in the behavenet repo into a separate directory on your local machine and make edits there. For more information on the different hyperparameters, see the :ref:`hyperparameters glossary`. +We recommend that you copy the default config files in the behavenet repo into a separate directory +on your local machine and make edits there. For more information on the different hyperparameters, +see the :ref:`hyperparameters glossary`. .. _add_dataset: @@ -39,7 +45,9 @@ We recommend that you copy the default config files in the behavenet repo into a Adding a new dataset -------------------- -When using BehaveNet with a new dataset you will need to make a new data config json file, which can be automatically generated using a BehaveNet helper function. You will be asked to enter the following information (examples shown for Musall dataset): +When using BehaveNet with a new dataset you will need to make a new data config json file, which +can be automatically generated using a BehaveNet helper function. You will be asked to enter the +following information (examples shown for Musall dataset): * lab or experimenter name (:code:`musall`) * experiment name (:code:`vistrained`) @@ -59,19 +67,36 @@ To enter this information, launch python from the behavenet environment and type from behavenet import add_dataset add_dataset() -This function will create a json file named ``[lab_id]_[expt_id].json`` in the ``.behavenet`` directory in your user home directory, which you can manually update at any point using a text editor. +This function will create a json file named ``[lab_id]_[expt_id].json`` in the ``.behavenet`` +directory in your user home directory, which you can manually update at any point using a text +editor. Organizing model fits with test-tube ------------------------------------ -BehaveNet uses the `test-tube package `_ to organize model fits into user-defined experiments, log meta and training data, and perform grid searches over model hyperparameters. Most of this occurs behind the scenes, but there are a couple of important pieces of information that will improve your model fitting experience. - -BehaveNet organizes model fits using a combination of hyperparameters and user-defined experiment names. For example, let's say you want to fit 5 different convolutional autoencoder architectures, all with 12 latents, to find the best one. Let's call this experiment "arch_search", which you will set in the ``model_config`` json in the ``experiment_name`` field. The results will then be stored in the directory ``results_dir/lab_id/expt_id/animal_id/session_id/ae/conv/12_latents/arch_search/``. - -Each model will automatically be assigned it's own "version" by test-tube, so the ``arch_search`` directory will have subdirectories ``version_0``, ..., ``version_4``. If an additional CAE model is later fit with 12 latents (and using the "arch_search" experiment name), test-tube will add it to the ``arch_search`` directory as ``version_5``. Different versions may have different architectures, learning rates, regularization values, etc. Each model class (autoencoder, arhmm, decoders) has a set of hyperparameters that are used for directory names, and another set that are used to distinguish test-tube versions within the user-defined experiment. - -Within the ``version_x`` directory, there are various files saved during training. Here are some of the files automatically output when training an autoencoder: +BehaveNet uses the `test-tube package `_ to organize +model fits into user-defined experiments, log meta and training data, and perform grid searches +over model hyperparameters. Most of this occurs behind the scenes, but there are a couple of +important pieces of information that will improve your model fitting experience. + +BehaveNet organizes model fits using a combination of hyperparameters and user-defined experiment +names. For example, let's say you want to fit 5 different convolutional autoencoder architectures, +all with 12 latents, to find the best one. Let's call this experiment "arch_search", which you will +set in the ``model_config`` json in the ``experiment_name`` field. The results will then be stored +in the directory +``results_dir/lab_id/expt_id/animal_id/session_id/ae/conv/12_latents/arch_search/``. + +Each model will automatically be assigned it's own "version" by test-tube, so the ``arch_search`` +directory will have subdirectories ``version_0``, ..., ``version_4``. If an additional CAE model is +later fit with 12 latents (and using the "arch_search" experiment name), test-tube will add it to +the ``arch_search`` directory as ``version_5``. Different versions may have different +architectures, learning rates, regularization values, etc. Each model class (autoencoder, arhmm, +decoders) has a set of hyperparameters that are used for directory names, and another set that are +used to distinguish test-tube versions within the user-defined experiment. + +Within the ``version_x`` directory, there are various files saved during training. Here are some of +the files automatically output when training an autoencoder: * **best_val_model.pt**: the best model (not necessarily from the final training epoch) as determined by computing the loss on validation data * **meta_tags.csv**: hyperparameters associated with data, computational resources, training, and model @@ -93,11 +118,15 @@ and if you set ``export_train_plots`` to ``True`` in the training config file, y Grid searching with test-tube ----------------------------- -Beyond organizing model fits, test-tube is also useful for performing grid searches over model hyperparameters, using multiple cpus or gpus. All you as the user need to do is enter the relevant hyperparameter choices as a list instead of a single value in the associated configuration file. +Beyond organizing model fits, test-tube is also useful for performing grid searches over model +hyperparameters, using multiple cpus or gpus. All you as the user need to do is enter the relevant +hyperparameter choices as a list instead of a single value in the associated configuration file. -Again using the autoencoder as an example, let's say you want to fit a single AE architecture using 4 different numbers of latents, all with the same regularization value. In the model config file, you will set these values as: +Again using the autoencoder as an example, let's say you want to fit a single AE architecture using +4 different numbers of latents, all with the same regularization value. In the model config file, +you will set these values as: -.. code-block:: json +.. code-block:: javascript { ... @@ -106,9 +135,10 @@ Again using the autoencoder as an example, let's say you want to fit a single AE ... } -To specify the computing resources for this job, you will next edit the compute config file, which looks like this: +To specify the computing resources for this job, you will next edit the compute config file, which +looks like this: -.. code-block:: json +.. code-block:: javascript { ... @@ -120,13 +150,21 @@ To specify the computing resources for this job, you will next edit the compute ... } -With the ``device`` field set to ``cuda``, test-tube will use gpus to run this job. The ``gpus_viz`` field can further specify which subset of gpus to use. The ``tt_n_gpu_trials`` defines the maximum number of jobs to run. If this number is larger than the total number of hyperparameter configurations, all configurations are fit; if this number is smaller than the total number (say if ``"tt_n_gpu_trials": 2`` in this example) then this number of configurations is randomly sampled from all possible choices. +With the ``device`` field set to ``cuda``, test-tube will use gpus to run this job. The +``gpus_viz`` field can further specify which subset of gpus to use. The ``tt_n_gpu_trials`` defines +the maximum number of jobs to run. If this number is larger than the total number of hyperparameter +configurations, all configurations are fit; if this number is smaller than the total number (say if +``"tt_n_gpu_trials": 2`` in this example) then this number of configurations is randomly sampled +from all possible choices. -To fit models using the cpu instead, set the ``device`` field to ``cpu``; then ``tt_n_cpu_workers`` defines the total number of cpus to run the job (total number of models fitting at any one time) and ``tt_n_cpu_trials`` is analogous to ``tt_n_gpu_trials``. +To fit models using the cpu instead, set the ``device`` field to ``cpu``; then ``tt_n_cpu_workers`` +defines the total number of cpus to run the job (total number of models fitting at any one time) +and ``tt_n_cpu_trials`` is analogous to ``tt_n_gpu_trials``. -Finally, multiple hyperparameters can be searched over simultaneously; for example, to search over both AE latents and regularization values, set these parameters in the model config file like so: +Finally, multiple hyperparameters can be searched over simultaneously; for example, to search over +both AE latents and regularization values, set these parameters in the model config file like so: -.. code-block:: json +.. code-block:: javascript { ... @@ -136,4 +174,3 @@ Finally, multiple hyperparameters can be searched over simultaneously; for examp } This job would then fit a total of 4 latent values x 3 regularization values = 12 models. - diff --git a/example/05_conditional_ae.ipynb b/example/05_conditional_ae.ipynb deleted file mode 100644 index e4c12a5..0000000 --- a/example/05_conditional_ae.ipynb +++ /dev/null @@ -1,427 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Analyze AEs with matrix subspace projection loss\n", - "This notebook is a template showcasing some ways to analyze autoencoders that have been fit with the matrix subspace projection (MSP) loss.\n", - "\n", - "
\n", - " \n", - "### Contents\n", - "* [Plot loss metrics as a function of epochs](#Plot-loss-metrics-as-a-function-of-epoch)\n", - "* [Plot true vs predicted labels](#Plot-true-vs-predicted-labels)\n", - "* [Evaluate orthogonality of projection matrix](#Evaluate-orthogonality-of-projection-matrix)\n", - "* [Explore label/latent space](#Explore-label/latent-space)\n", - " * [explore label space](#Explore-2D-label-space)\n", - " * [explore latent space](#Explore-2D-latent-space)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import copy\n", - "import os\n", - "import pandas as pd\n", - "import seaborn as sns\n", - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "\n", - "from behavenet import get_user_dir\n", - "from behavenet import make_dir_if_not_exists\n", - "from behavenet.fitting.utils import get_expt_dir\n", - "from behavenet.fitting.utils import get_session_dir\n", - "from behavenet.fitting.utils import get_best_model_version\n", - "from behavenet.fitting.utils import get_lab_example\n", - "\n", - "save_outputs = False # true to save figures/movies to user's figure directory\n", - "format = 'png' # figure format ('png' | 'jpeg' | 'pdf'); movies saved as mp4" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Plot loss metrics as a function of epoch\n", - "\n", - "[Back to contents](#Contents)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from behavenet.plotting import load_metrics_csv_as_df\n", - "\n", - "# set data info\n", - "lab = ?\n", - "expt = ?\n", - "n_labels = ?\n", - "\n", - "# set model info\n", - "n_ae_latents = ? # n_labels will be added to this\n", - "tt_expt_name = ?\n", - "\n", - "hparams = {\n", - " 'data_dir': get_user_dir('data'),\n", - " 'save_dir': get_user_dir('save'),\n", - " 'experiment_name': tt_expt_name,\n", - " 'model_class': 'cond-ae-msp',\n", - " 'model_type': 'conv',\n", - " 'n_ae_latents': n_ae_latents + n_labels}\n", - "\n", - "metrics_list = ['loss', 'loss_mse', 'loss_msp', 'r2']\n", - "metrics_df = load_metrics_csv_as_df(hparams, lab, expt, metrics_list)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# plot data\n", - "sns.set_style('white')\n", - "sns.set_context('talk')\n", - "\n", - "for y in metrics_list:\n", - " \n", - " data_queried = metrics_df[(metrics_df.epoch > 10) & ~pd.isna(metrics_df.loss)]\n", - " splt = sns.relplot(x='epoch', y=y, hue='dtype', kind='line', data=data_queried)\n", - " splt.ax.set_xlabel('Epoch')\n", - " if y == 'loss':\n", - " splt.ax.set_ylabel('Total loss')\n", - " splt.ax.set_yscale('log')\n", - " elif y == 'loss_mse':\n", - " splt.ax.set_ylabel('MSE per pixel')\n", - " splt.ax.set_yscale('log')\n", - " elif y == 'loss_msp':\n", - " splt.ax.set_ylabel('MSE per label')\n", - " splt.ax.set_yscale('log')\n", - " elif y == 'r2':\n", - " splt.ax.set_ylabel('Label $R^2$')\n", - "\n", - " if save_outputs:\n", - " save_file = os.path.join(get_user_dir('fig'), 'ae', 'loss_vs_epoch')\n", - " make_dir_if_not_exists(save_file)\n", - " plt.savefig(save_file + '.' + format, dpi=300, format=format)\n", - "\n", - " plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Plot true vs predicted labels\n", - "\n", - "[Back to contents](#Contents)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from behavenet.fitting.utils import get_best_model_and_data\n", - "from behavenet.models import AEMSP\n", - "\n", - "# set model info\n", - "version = 0 # 'best' # test-tube version; 'best' finds the version with the lowest mse\n", - "sess_idx = 0 # when using a multisession, this determines which session is used\n", - "hparams = {\n", - " 'data_dir': get_user_dir('data'),\n", - " 'save_dir': get_user_dir('save'),\n", - " 'experiment_name': tt_expt_name,\n", - " 'model_class': 'cond-ae-msp',\n", - " 'model_type': 'conv',\n", - " 'n_ae_latents': n_ae_latents + n_labels}\n", - "\n", - "trial_idxs = [1, 2, 3] # test trials to plot\n", - "\n", - "# programmatically fill out other hparams options\n", - "get_lab_example(hparams, lab, expt) \n", - "\n", - "model, data_generator = get_best_model_and_data(\n", - " hparams, AEMSP, load_data=True, version=version, data_kwargs=None)\n", - "n_labels = model.n_labels\n", - "print(data_generator)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from behavenet.plotting.ae_utils import plot_neural_reconstruction_traces\n", - "\n", - "for trial_idx in trial_idxs:\n", - " trial = data_generator.datasets[sess_idx].batch_idxs['test'][trial_idx]\n", - " batch = data_generator.datasets[sess_idx][trial]\n", - " labels_og = batch['labels'].detach().cpu().numpy()\n", - " labels_pred = model.get_transformed_latents(batch['images'])[:, :n_labels]\n", - " plot = plot_neural_reconstruction_traces(labels_og, labels_pred, scale=2)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Evaluate orthogonality of projection matrix\n", - "\n", - "[Back to contents](#Contents)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "U = model.U.weight.data.cpu().detach().numpy()\n", - "\n", - "plt.figure(figsize=(6, 6))\n", - "overlap = np.matmul(U, U.T)\n", - "m = np.max(np.abs(overlap))\n", - "plt.imshow(overlap, cmap='RdBu', vmin=-m, vmax=m)\n", - "plt.colorbar()\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Explore label/latent space\n", - "\n", - "[Back to contents](#Contents)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "\n", - "from behavenet.data.utils import get_data_generator_inputs\n", - "\n", - "from behavenet.fitting.utils import get_best_model_and_data\n", - "from behavenet.fitting.eval import get_reconstruction\n", - "\n", - "from behavenet.plotting.cond_ae_utils import get_crop\n", - "from behavenet.plotting.cond_ae_utils import get_input_range\n", - "from behavenet.plotting.cond_ae_utils import get_labels_2d_for_trial\n", - "from behavenet.plotting.cond_ae_utils import get_model_input\n", - "from behavenet.plotting.cond_ae_utils import interpolate_2d\n", - "from behavenet.plotting.cond_ae_utils import plot_2d_frame_array" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### setup - define model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from behavenet.models import AEMSP as Model\n", - "\n", - "# dataset info\n", - "n_ae_latents = 2 # not including label-related latents\n", - "label_min_p = 15 # minimum percentile for latent/label space interpolation\n", - "label_max_p = 85 # maximum percentile for latent/label space interpolation\n", - "n_frames = 3 # number of frames to plot along each manipulated dim\n", - "trial_idx = 0 # index into trials for base frame\n", - "batch_idx = 0 # index into batch for base frame\n", - "label_idxs = [5, 1] # indices of labels to manipulate; y label first, then x\n", - "latent_idxs = np.array([0, 1]) # indices of latents to manipulate\n", - " \n", - "show_markers = True\n", - " \n", - "# set model info\n", - "version = 0 # test-tube version; 'best' finds the version with the lowest mse\n", - "sess_idx = 0 # when using a multisession, this determines which session is used\n", - "hparams = {\n", - " 'data_dir': get_user_dir('data'),\n", - " 'save_dir': get_user_dir('save'),\n", - " 'experiment_name': tt_expt_name,\n", - " 'model_class': 'cond-ae-msp',\n", - " 'model_type': 'conv',\n", - " 'n_ae_latents': n_ae_latents + n_labels,\n", - " 'rng_seed_data': 0,\n", - " 'trial_splits': '8;1;1;0',\n", - " 'train_frac': 1.0,\n", - " 'rng_seed_model': 0,\n", - " 'conditional_encoder': False,\n", - "}\n", - "\n", - "# programmatically fill out other hparams options\n", - "get_lab_example(hparams, lab, expt)\n", - "hparams['session_dir'], sess_ids = get_session_dir(hparams)\n", - "hparams['expt_dir'] = get_expt_dir(hparams)\n", - "\n", - "# build model\n", - "model_ae, data_generator = get_best_model_and_data(hparams, Model, version=version)\n", - "\n", - "latent_range = get_input_range(\n", - " 'latents', hparams, model=model_ae, data_gen=data_generator)\n", - "label_range = get_input_range(\n", - " 'labels', hparams, sess_ids=sess_ids, sess_idx=sess_idx, \n", - " min_p=label_min_p, max_p=label_max_p)\n", - "label_sc_range = get_input_range(\n", - " 'labels_sc', hparams, sess_ids=sess_ids, sess_idx=sess_idx,\n", - " min_p=label_min_p, max_p=label_max_p)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Explore 2D label space\n", - "\n", - "[Back to contents](#Contents)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ims_pt, ims_np, latents_np, labels_pt, labels_np, labels_2d_pt, labels_2d_np = \\\n", - " get_model_input(\n", - " data_generator, hparams, model_ae, trial_idx=trial_idx, compute_latents=True, \n", - " compute_scaled_labels=False, compute_2d_labels=True)\n", - "\n", - "ims_label, markers_loc_label, ims_crop_label = interpolate_2d(\n", - " 'labels', model_ae, ims_pt[None, batch_idx, :], latents_np[None, batch_idx, :], \n", - " labels_np[None, batch_idx, :], labels_2d_np[None, batch_idx, :], \n", - " mins=[label_range['min'][label_idxs[0]], label_range['min'][label_idxs[1]]], \n", - " maxes=[label_range['max'][label_idxs[0]], label_range['max'][label_idxs[1]]], \n", - " n_frames=n_frames, input_idxs=label_idxs, crop_type=None, \n", - " mins_sc=[label_sc_range['min'][label_idxs[0]], label_sc_range['min'][label_idxs[1]]], \n", - " maxes_sc=[label_sc_range['max'][label_idxs[0]], label_sc_range['max'][label_idxs[1]]], \n", - " crop_kwargs=None, ch=0)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "marker_kwargs = {\n", - " 'markersize': 20, 'markeredgewidth': 3, 'markeredgecolor': [1, 1, 0],\n", - " 'fillstyle': 'none'}\n", - "\n", - "if save_outputs:\n", - " save_file = os.path.join(\n", - " get_user_dir('fig'), \n", - " 'ae', 'D=%02i_label-manipulation_%s_%s-crop.png' % \n", - " (hparams['n_ae_latents'], hparams['session'], crop_type))\n", - "else:\n", - " save_file = None\n", - "\n", - "plot_2d_frame_array(\n", - " ims_label, markers=markers_loc_label, marker_kwargs=marker_kwargs, save_file=None,\n", - " figsize=(15, 15))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Explore 2D latent space\n", - "\n", - "[Back to contents](#Contents)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ims_pt, ims_np, latents_np, labels_pt, labels_np, labels_2d_pt, labels_2d_np = \\\n", - " get_model_input(data_generator, hparams, model_ae, trial=None, trial_idx=trial_idx,\n", - " compute_latents=True, compute_scaled_labels=False, compute_2d_labels=True)\n", - "\n", - "latent_idxs += n_labels # first `n_labels` dims are used to reconstruct labels\n", - "\n", - "ims_latent, markers_loc_latent_, ims_crop_latent = interpolate_2d(\n", - " 'latents', model_ae, ims_pt[None, batch_idx, :], latents_np[None, batch_idx, :], \n", - " labels_np[None, batch_idx, :], labels_2d_np[None, batch_idx, :], \n", - " mins=[latent_range['min'][latent_idxs[0]], latent_range['min'][latent_idxs[1]]], \n", - " maxes=[latent_range['max'][latent_idxs[0]], latent_range['max'][latent_idxs[1]]], \n", - " n_frames=n_frames, input_idxs=latent_idxs, crop_type=None, \n", - " mins_sc=None, maxes_sc=None, crop_kwargs=None, marker_idxs=label_idxs, ch=0)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "marker_kwargs = {\n", - " 'markersize': 20, 'markeredgewidth': 5, 'markeredgecolor': [1, 1, 0],\n", - " 'fillstyle': 'none'}\n", - "\n", - "if save_outputs:\n", - " save_file = os.path.join(\n", - " get_user_dir('fig'), \n", - " 'ae', 'D=%02i_latent-manipulation_%s_%s-crop.png' % \n", - " (hparams['n_ae_latents'], hparams['session'], crop_type))\n", - "else:\n", - " save_file = None\n", - "\n", - "plot_2d_frame_array(\n", - " ims_latent, markers=markers_loc_latent_, marker_kwargs=marker_kwargs, \n", - " save_file=None, figsize=(15, 15))" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "behavenet", - "language": "python", - "name": "behavenet" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.2" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/example/00_data.ipynb b/examples/00_data.ipynb similarity index 100% rename from example/00_data.ipynb rename to examples/00_data.ipynb diff --git a/example/01_ae.ipynb b/examples/01_ae.ipynb similarity index 100% rename from example/01_ae.ipynb rename to examples/01_ae.ipynb diff --git a/example/02_arhmm.ipynb b/examples/02_arhmm.ipynb similarity index 100% rename from example/02_arhmm.ipynb rename to examples/02_arhmm.ipynb diff --git a/example/03_decoder.ipynb b/examples/03_decoder.ipynb similarity index 100% rename from example/03_decoder.ipynb rename to examples/03_decoder.ipynb diff --git a/example/04_bayesian_decoder.ipynb b/examples/04_bayesian_decoder.ipynb similarity index 100% rename from example/04_bayesian_decoder.ipynb rename to examples/04_bayesian_decoder.ipynb diff --git a/examples/ps-vae/00_data.ipynb b/examples/ps-vae/00_data.ipynb new file mode 100644 index 0000000..fb4f3d2 --- /dev/null +++ b/examples/ps-vae/00_data.ipynb @@ -0,0 +1,163 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Fitting the PS-VAE to an example dataset\n", + "\n", + "This notebook will walk you through how to download an example dataset, including some already trained models; the next notebook shows how to evaluate those models.\n", + "\n", + "Before beginning, first make sure that you have properly installed the BehaveNet package and environment by following the instructions [here](https://behavenet.readthedocs.io/en/latest/source/installation.html). Specifically, (1) set up the Anaconda virtual environment; and (2) install the `BehaveNet` package. You do not need to set user paths at this time (this will be covered below).\n", + "\n", + "To illustrate the use of BehaveNet we will use an example dataset from the [International Brain Lab](https://www.biorxiv.org/content/10.1101/2020.01.17.909838v5).\n", + "\n", + "Briefly, a head-fixed mouse performed a visual decision-making task. Behavioral data was recorded using a single camera at 60 Hz frame rate. Grayscale video frames were downsampled to 192x192 pixels. We labeled the forepaw positions using [Deep Graph Pose](https://papers.nips.cc/paper/2020/file/4379cf00e1a95a97a33dac10ce454ca4-Paper.pdf). Data consists of batches of 100 contiguous frames and their accompanying labels.\n", + "\n", + "The data are stored on the IBL data repository; you will download this data after setting some user paths.\n", + "\n", + "**Note**: make sure that you are running the `behavenet` ipython kernel - you should see the current ipython kernel name in the upper right hand corner of this notebook. If it is not `behavenet` (for example it might be `Python 3`) then change it using the dropdown menus above: `Kernel > Change kernel > behavenet`. If you do not see `behavenet` as an option see [here](https://behavenet.readthedocs.io/en/latest/source/installation.html#environment-setup).\n", + "\n", + "
\n", + "\n", + "### Contents\n", + "* [Set user paths](#0.-Set-user-paths)\n", + "* [Download the data](#1.-Download-the-data)\n", + "* [Add dataset hyperparameters](#2.-Add-dataset-hyperparameters)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 0. Set user paths\n", + "First set the paths to the directories where data, results, and figures will be stored on your local machine. Note that the data is ~3GB, so make sure that your data directory has enough space.\n", + "\n", + "A note about the BehaveNet path structure: every dataset is uniquely identified by a lab id, experiment id, animal id, and session id. Paths to data and results contain directories for each of these id types. For example, a sample data path will look like `/home/user/data/lab_id/expt_id/animal_id/session_id/data.hdf5`. In this case the base data directory is `/home/user/data/`.\n", + "\n", + "The downloaded zip file will automatically be saved as `data_dir/ibl/angelakilab/IBL-T4/2019-04-23-001/data.hdf5`\n", + "\n", + "Additionally, the zip file contains already trained VAE and PS-VAE models, which will automatically be saved in the directories:\n", + "* `results_dir/ibl/angelakilab/IBL-T4/2019-04-23-001/vae/conv/06_latents/demo-run/`\n", + "* `results_dir/ibl/angelakilab/IBL-T4/2019-04-23-001/ps-vae/conv/06_latents/demo-run/`\n", + "\n", + "To set the user paths, run the cell below.\n", + "\n", + "[Back to contents](#Contents)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from behavenet import setup\n", + "setup()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The directory file is stored in your user home directory; this is a json file that can be updated in a text editor at any time." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. Download the data\n", + "Run the cell below; data and results will be stored in the directories provided in the previous step.\n", + "\n", + "[Back to contents](#Contents)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import io\n", + "import shutil\n", + "import requests\n", + "import zipfile as zf\n", + "from behavenet import get_user_dir\n", + "\n", + "url = 'https://ibl.flatironinstitute.org/public/ps-vae_demo_head-fixed.zip'\n", + "\n", + "print('Downloading data - this may take several minutes')\n", + "\n", + "# fetch data from IBL data repository\n", + "print('fetching data from url...', end='')\n", + "r = requests.get(url, stream=True)\n", + "z = zf.ZipFile(io.BytesIO(r.content))\n", + "print('done')\n", + "\n", + "# extract data\n", + "data_dir = get_user_dir('data')\n", + "if not os.path.exists(data_dir):\n", + " os.makedirs(data_dir)\n", + "print('extracting data to %s...' % data_dir, end='')\n", + "for file in z.namelist():\n", + " if file.startswith('ps-vae_demo_head-fixed/data/'):\n", + " z.extract(file, data_dir)\n", + "# clean up paths\n", + "shutil.move(os.path.join(data_dir, 'ps-vae_demo_head-fixed', 'data', 'ibl'), data_dir)\n", + "shutil.rmtree(os.path.join(data_dir, 'ps-vae_demo_head-fixed'))\n", + "print('done')\n", + "\n", + "# extract results\n", + "results_dir = get_user_dir('save')\n", + "if not os.path.exists(results_dir):\n", + " os.makedirs(results_dir)\n", + "print('extracting results to %s...' % results_dir, end='')\n", + "for file in z.namelist():\n", + " if file.startswith('ps-vae_demo_head-fixed/results/'):\n", + " z.extract(file, results_dir)\n", + "# clean up paths\n", + "shutil.move(os.path.join(results_dir, 'ps-vae_demo_head-fixed', 'results', 'ibl'), results_dir)\n", + "shutil.rmtree(os.path.join(results_dir, 'ps-vae_demo_head-fixed'))\n", + "print('done')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Add dataset hyperparameters\n", + "The last step is to save some of the dataset hyperparameters in their own json file. This is used to simplify command line arguments to model fitting functions. This json file has already been provided in the data directory, where the `data.hdf5` file is stored - you should see a file named `ibl_angelakilab_params.json`. Copy and paste this file into the `.behavenet` directory in your home directory:\n", + "\n", + "* In Linux, `~/.behavenet`\n", + "* In MacOS, `/Users/CurrentUser/.behavenet`\n", + "\n", + "The next notebook will now walk you through how to evaluate the downloaded models/data.\n", + "\n", + "[Back to contents](#Contents)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "behavenet", + "language": "python", + "name": "behavenet" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/ps-vae/01_ps-vae.ipynb b/examples/ps-vae/01_ps-vae.ipynb new file mode 100644 index 0000000..f98bbf8 --- /dev/null +++ b/examples/ps-vae/01_ps-vae.ipynb @@ -0,0 +1,616 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Analyze a PS-VAE model\n", + "Because the PS-VAEs currently require significant computation time (generally ~5 hours on a GPU) the data downloaded in the previous notebook also contains already trained PS-VAEs, which we will analyze here.\n", + "\n", + "There are a variety of files that are automatically saved during the fitting of a PS-VAE, which can be used for later analyses such as those below. Some of these files (many of which are common to all BehaveNet models, not just the PS-VAE):\n", + "* `best_val_model.pt`: the best PS-VAE (not necessarily from the final training epoch) as determined by computing the loss on validation data\n", + "* `meta_tags.csv`: hyperparameters associated with data, computational resources, and model\n", + "* `metrics.csv`: metrics computed on dataset as a function of epochs; the default is that metrics are computed on training and validation data every epoch (and reported as a mean over all batches) while metrics are computed on test data only at the end of training using the best model (and reported per batch).\n", + "* `[lab_id]_[expt_id]_[animal_id]_[session_id]_latents.pkl`: list of np.ndarrays of PS-VAE latents (both supervised and unsupervised) computed using the best model\n", + "* `session_info.csv`: sessions used to fit the model\n", + "\n", + "To fit your own PS-VAEs, see additional documentation [here](https://behavenet.readthedocs.io/en/latest/source/user_guide.html).\n", + "\n", + "
\n", + "\n", + "### Contents\n", + "* [Plot validation losses as a function of epochs](#Plot-losses-as-a-function-of-epochs)\n", + "* [Plot label reconstructions](#Plot-label-reconstructions)\n", + "* [Plot latent traversals](#Plot-latent-traversals)\n", + "* [Make latent traversal movie](#Make-latent-traversal-movie)\n", + "* [Make frame reconstruction movie](#Make-reconstruction-movies)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "from behavenet import get_user_dir\n", + "from behavenet.plotting.cond_ae_utils import plot_psvae_training_curves\n", + "from behavenet.plotting.cond_ae_utils import plot_label_reconstructions\n", + "from behavenet.plotting.cond_ae_utils import plot_latent_traversals\n", + "from behavenet.plotting.cond_ae_utils import make_latent_traversal_movie\n", + "\n", + "dataset = 'head-fixed'\n", + "# 'head-fixed': IBL data\n", + "# 'face': dipoppa data\n", + "# 'two-view': musall data\n", + "\n", + "save_outputs = True # true to save figures/movies to user's figure directory\n", + "file_ext = 'pdf' # figure format ('png' | 'jpeg' | 'pdf'); movies saved as mp4" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### define dataset parameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# parameters common to all datasets\n", + "n_latents = 2 # number of unsupervised latents\n", + "train_frac = 0.5 # all models trained with 50% of training data to speed up fitting\n", + "experiment_name = 'demo-run' # test-tube exp name\n", + "\n", + "# set dataset-specific parameters\n", + "if dataset == 'head-fixed':\n", + " \n", + " lab = 'ibl'\n", + " expt = 'angelakilab'\n", + " animal = 'IBL-T4'\n", + " session = '2019-04-23-001'\n", + " n_labels = 4\n", + " label_names = ['L paw (x)', 'R paw (x)', 'L paw (y)', 'R paw (y)']\n", + "\n", + " # define \"best\" model\n", + " best_alpha = 1000\n", + " best_beta = 5\n", + " best_gamma = 500\n", + " best_rng = 0\n", + "\n", + " # label reconstructions\n", + " label_recon_trials= [229, 289, 419] # good validation trials; also used for frame recon\n", + " xtick_locs= [0, 30, 60, 90]\n", + " frame_rate= 60\n", + " scale= 0.4\n", + " \n", + " # latent traversal params\n", + " label_min_p = 35 # lower bound of label traversals\n", + " label_max_p = 85 # upper bound of label traversals\n", + " ch = 0 # video channel to display\n", + " n_frames_zs = 4 # n frames for supervised static traversals\n", + " n_frames_zu = 4 # n frames for unsupervised static traversals\n", + " label_idxs = [1, 0] # horizontally move left/right paws\n", + " crop_type = None # no image cropping\n", + " crop_kwargs = None # no image cropping\n", + " # select base frames for traversals\n", + " trial_idxs = [11, 4, 0, None, None, None, None] # trial index wrt to all test trials\n", + " trials = [None, None, None, 169, 129, 429, 339] # trial index wrt to *all* trials\n", + " batch_idxs = [99, 99, 99, 16, 46, 11, 79] # batch index within trial\n", + " n_cols = 3 # width of traversal movie\n", + " text_color = [1, 1, 1] # text color for labels\n", + " \n", + "elif dataset == 'face':\n", + " \n", + " lab = 'dipoppa'\n", + " expt = 'pupil'\n", + " animal = 'MD0ST5'\n", + " session = 'session-3'\n", + " n_labels = 3\n", + " label_names = ['Pupil area', 'Pupil (y)', 'Pupil (x)']\n", + "\n", + " # define \"best\" model\n", + " best_alpha = 1000\n", + " best_beta = 20\n", + " best_gamma = 1000\n", + " best_rng = 0\n", + "\n", + " # label reconstructions\n", + " label_recon_trials= [43, 83, 73] # good validation trials; also used for frame recon\n", + " xtick_locs= [0, 30, 60, 90, 120, 150]\n", + " frame_rate= 30\n", + " scale= 0.45\n", + " \n", + " # latent traversal params\n", + " label_min_p = 5 # lower bound of label traversals\n", + " label_max_p = 95 # upper bound of label traversals\n", + " ch = 0 # video channel to display\n", + " n_frames_zs = 4 # n frames for supervised static traversals\n", + " n_frames_zu = 4 # n frames for unsupervised static traversals\n", + " label_idxs = [1, 2] # pupil location\n", + " crop_type = 'fixed' # crop around eye\n", + " crop_kwargs = {'y_0': 48, 'y_ext': 48, 'x_0': 192, 'x_ext': 64}\n", + " # select base frames for traversals\n", + " trial_idxs = [11, None, 21] # trial index wrt to all test trials\n", + " trials = [None, 393, None] # trial index wrt to *all* trials\n", + " batch_idxs = [60, 27, 99] # batch index within trial\n", + " n_cols = 3 # width of traversal movie\n", + " text_color = [0, 0, 0] # text color for labels\n", + " \n", + "elif dataset == 'two-view':\n", + " \n", + " lab = 'musall'\n", + " expt = 'vistrained'\n", + " animal = 'mSM36'\n", + " session = '05-Dec-2017-wpaw'\n", + " n_labels = 5\n", + " label_names = ['Levers', 'L Spout', 'R Spout', 'R paw (x)', 'R paw (y)']\n", + "\n", + " # define \"best\" model\n", + " best_alpha = 1000\n", + " best_beta = 1\n", + " best_gamma = 1000\n", + " best_rng = 1\n", + "\n", + " # label reconstructions\n", + " label_recon_trials= [9, 19, 29] # good validation trials; also used for frame recon\n", + " xtick_locs= [0, 60, 120, 180]\n", + " frame_rate= 30\n", + " scale= 0.25\n", + "\n", + " # latent traversal params\n", + " label_min_p = 5 # lower bound of label traversals\n", + " label_max_p = 95 # upper bound of label traversals\n", + " ch = 1 # video channel to display\n", + " n_frames_zs = 3 # n frames for supervised static traversals\n", + " n_frames_zu = 3 # n frames for unsupervised static traversals\n", + " label_idxs = [3, 4] # move right paw\n", + " crop_type = None # no image cropping\n", + " crop_kwargs = None # no image cropping\n", + " # select base frames for traversals\n", + " trial_idxs = [11, 11, 11, 5] # trial index wrt to all test trials\n", + " trials = [None, None, None, None] # trial index wrt to *all* trials\n", + " batch_idxs = [99, 0, 50, 180] # batch index within trial\n", + " n_cols = 2 # width of traversal movie\n", + " text_color = [1, 1, 1] # text color for labels\n", + "\n", + "else:\n", + " raise ValueError('Invalid dataset; must choose \"head-fixed\", \"face\", or \"two-view\"')\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Plot losses as a function of epochs\n", + "The PS-VAE loss function contains many individual terms; this function plots each term separately (as well as the overall loss) to better understand model performance. Note that this function can also be used to plot training curves for multiple models simultaneously; see function documentation. \n", + "\n", + "Panel info (see paper for mathematical descriptions):\n", + "* loss=loss: total PS-VAE loss\n", + "* loss=loss_data_mse: mean square error on frames (actual loss function uses log-likelihood, a scaled version of the MSE)\n", + "* loss=label_r2: $R^2$ (per trial) of the label reconstructions (actual loss function uses log-likelihood)\n", + "* loss=loss_zs_kl: Kullback-Leibler (KL) divergence of supervised latents\n", + "* loss=loss_zu_mi: index-code mutual information of unuspervised latents\n", + "* loss=loss_zu_tc: total correlation of unuspervised latents\n", + "* loss=loss_zu_dwkl: dimension-wise KL of unuspervised latents\n", + "* loss=loss_AB_orth: orthogonality between supervised/unsupervised subspaces\n", + "\n", + "[Back to contents](#Contents)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "save_file = os.path.join(\n", + " get_user_dir('fig'), lab, expt, animal, session, 'ps-vae', 'training_curves')\n", + "\n", + "save_file_new = save_file + '_alpha={}_beta={}_gamma={}_rng={}_latents={}'.format(\n", + " best_alpha, best_beta, best_gamma, best_rng, n_latents)\n", + "plot_psvae_training_curves(\n", + " lab=lab, expt=expt, animal=animal, session=session, alphas=[best_alpha], \n", + " betas=[best_beta], gammas=[best_gamma], n_ae_latents=[n_latents], \n", + " rng_seeds_model=[best_rng], experiment_name=experiment_name,\n", + " n_labels=n_labels, train_frac=train_frac,\n", + " save_file=save_file_new, format=file_ext)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Plot label reconstructions\n", + "Plot the original labels and their reconstructions from the supervised subspace of the PS-VAE.\n", + "\n", + "[Back to contents](#Contents)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "save_file = os.path.join(\n", + " get_user_dir('fig'), lab, expt, animal, session, 'ps-vae', 'label_recon')\n", + "\n", + "plot_label_reconstructions(\n", + " lab=lab, expt=expt, animal=animal, session=session, n_ae_latents=n_latents, \n", + " experiment_name=experiment_name,\n", + " n_labels=n_labels, trials=label_recon_trials, version=None,\n", + " alpha=best_alpha, beta=best_beta, gamma=best_gamma, rng_seed_model=best_rng, \n", + " train_frac=train_frac, save_file=save_file, format=file_ext)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Plot latent traversals\n", + "Latent traversals provide a qualitative way to assess the quality of the learned PS-VAE representation. We generate these traversals by changing the latent representation one dimension at a time and visually compare the outputs. If the representation is sufficiently interpretable we should be able to easily assign semantic meaning to each latent dimension.\n", + "\n", + "[Back to contents](#Contents)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "n_latents = 2\n", + "\n", + "# for trial, trial_idx, batch_idx in zip(trials, trial_idxs, batch_idxs):\n", + "# just plot traversals for single base frame\n", + "trial = trials[0]\n", + "trial_idx = trial_idxs[0]\n", + "batch_idx = batch_idxs[0]\n", + "\n", + "if trial is not None:\n", + " trial_str = 'trial-%i-%i' % (trial, batch_idx)\n", + "else:\n", + " trial_str = 'trial-idx-%i-%i' % (trial_idx, batch_idx)\n", + "\n", + "save_file = os.path.join(\n", + " get_user_dir('fig'), lab, expt, animal, session, 'ps-vae', \n", + " 'traversals_alpha={}_beta={}_gamma={}_rng={}_latents={}_{}'.format(\n", + " best_alpha, best_beta, best_gamma, best_rng, n_latents, trial_str))\n", + "\n", + "plot_latent_traversals(\n", + " lab=lab, expt=expt, animal=animal, session=session, model_class='ps-vae', \n", + " alpha=best_alpha, beta=best_beta, gamma=best_gamma, n_ae_latents=2, \n", + " rng_seed_model=best_rng, experiment_name=experiment_name, \n", + " n_labels=n_labels, label_idxs=label_idxs,\n", + " label_min_p=label_min_p, label_max_p=label_max_p, channel=ch, \n", + " n_frames_zs=n_frames_zs, n_frames_zu=n_frames_zu, trial_idx=trial_idx, \n", + " trial=trial, batch_idx=batch_idx, crop_type=crop_type, crop_kwargs=crop_kwargs,\n", + " train_frac=train_frac, save_file=save_file, format='png')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Make latent traversal movie\n", + "A dynamic version of the traversals above; these typically provide a richer look at the traversal results.\n", + "\n", + "[Back to contents](#Contents)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "n_frames = 10 # number of sample frames per dimension\n", + "model_class = 'ps-vae' # 'sss-vae' | 'vae'\n", + "\n", + "# NOTE: below I hand label each dimension; semantic labels for unsupervised dims are chosen\n", + "# by looking at the latent traversals above, and are indicated with quotes to distinguish\n", + "# them from the supervised dims\n", + "\n", + "if dataset == 'head-fixed':\n", + " if model_class == 'ps-vae':\n", + " panel_titles = [\n", + " 'L paw (x)', 'R paw (x)', 'L paw (y)', 'R paw (y)', '\"Jaw\"', '\"L paw config\"']\n", + " order_idxs = [0, 1, 4, 2, 3, 5] # reorder nicely\n", + " elif model_class == 'vae':\n", + " panel_titles = [\n", + " 'Latent 0', 'Latent 1', 'Latent 2', 'Latent 3', 'Latent 4', 'Latent 5']\n", + " order_idxs = [0, 1, 2, 3, 4, 5]\n", + " else:\n", + " raise NotImplementedError\n", + "\n", + "elif dataset == 'face':\n", + " crop_kwargs = None\n", + " if model_class == 'ps-vae':\n", + " panel_titles = [\n", + " 'Pupil area', 'Pupil (y)', 'Pupil (x)', '\"Whisker pad\"', '\"Eyelid\"']\n", + " order_idxs = [2, 1, 0, 3, 4]\n", + " elif model_class == 'vae':\n", + " panel_titles = [\n", + " 'Latent 0', 'Latent 1', 'Latent 2', 'Latent 3', 'Latent 4']\n", + " order_idxs = [0, 1, 2, 3, 4]\n", + " else:\n", + " raise NotImplementedError\n", + "\n", + "elif dataset == 'two-view':\n", + "# crop_kwargs_ = None\n", + "# show_markers = True \n", + " if model_class == 'ps-vae':\n", + " panel_titles = [\n", + " 'Lever', 'R spout', 'L spout', 'R paw (x)', 'R paw (y)', '\"Chest\"', \n", + " '\"Jaw\"']\n", + " order_idxs = [1, 2, 3, 4, 0, 5, 6]\n", + " elif model_class == 'vae':\n", + " panel_titles = [\n", + " 'Latent 0', 'Latent 1', 'Latent 2', 'Latent 3', 'Latent 4', 'Latent 5', \n", + " 'Latent 6']\n", + " order_idxs = [0, 1, 2, 3, 4, 5, 6]\n", + " else:\n", + " raise NotImplementedError\n", + "\n", + "else:\n", + " raise NotImplementedError\n", + "\n", + "save_file = os.path.join(\n", + " get_user_dir('fig'), lab, expt, animal, session, model_class, \n", + " 'traversals_alpha={}_beta={}_gamma={}_rng={}_latents={}'.format(\n", + " best_alpha, best_beta, best_gamma, best_rng, n_latents))\n", + "\n", + "make_latent_traversal_movie(\n", + " lab=lab, expt=expt, animal=animal, session=session, model_class=model_class, \n", + " alpha=best_alpha, beta=best_beta, gamma=best_gamma, n_ae_latents=n_latents, \n", + " rng_seed_model=best_rng, experiment_name=experiment_name, \n", + " n_labels=n_labels, trial_idxs=trial_idxs, batch_idxs=batch_idxs, trials=trials, \n", + " panel_titles=panel_titles, label_min_p=label_min_p, \n", + " label_max_p=label_max_p, channel=ch, n_frames=n_frames, crop_kwargs=crop_kwargs, \n", + " n_cols=n_cols, movie_kwargs={'text_color': text_color}, order_idxs=order_idxs,\n", + " train_frac=train_frac, save_file=save_file)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Make reconstruction movies\n", + "Compare original frames to VAE and PS-VAE reconstructions.\n", + "\n", + "[Back to contents](#Contents)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### helper function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "from behavenet.plotting.ae_utils import make_reconstruction_movie\n", + "from behavenet.plotting.cond_ae_utils import get_model_input\n", + "from behavenet.fitting.eval import get_reconstruction\n", + "from behavenet.fitting.utils import get_best_model_and_data, get_lab_example\n", + "from behavenet.plotting import concat, save_movie\n", + "\n", + "def make_reconstruction_movie_wrapper(\n", + " hparams, save_file, model_info, trial_idxs=None, trials=None, sess_idx=0, \n", + " max_frames=400, frame_rate=15, layout_pattern=None):\n", + " \"\"\"Produce movie with original video and reconstructed videos.\n", + "\n", + " This is a high-level function that loads the model described in the hparams dictionary \n", + " and produces the necessary predicted video frames.\n", + "\n", + " Parameters\n", + " ----------\n", + " hparams : :obj:`dict`\n", + " needs to contain enough information to specify an autoencoder\n", + " save_file : :obj:`str`\n", + " full save file (path and filename)\n", + " model_info : :obj:`list`\n", + " each entry is a dict that contains model-specific parameters; must include\n", + " 'title', 'model_class'\n", + " trial_idxs : :obj:`list`, optional\n", + " list of test trials to construct videos from; each element is index into \n", + " test trials only; one of `trial_idxs` or `trials` must be \n", + " specified; `trials` takes precedence over `trial_idxs`\n", + " trials : :obj:`list`, optional\n", + " list of test trials to construct videos from; each element is index into all \n", + " possible trials (train, val, test); one of `trials` or `trial_idxs` must be \n", + " specified; `trials` takes precedence over `trial_idxs`\n", + " sess_idx : :obj:`int`, optional\n", + " session index into data generator\n", + " max_frames : :obj:`int`, optional\n", + " maximum number of frames to animate from a trial\n", + " frame_rate : :obj:`float`, optional\n", + " frame rate of saved movie\n", + " layout_pattern : :obj:`array-like`, optional\n", + " boolean entries specify which panels are used to display frames\n", + " \n", + " \"\"\"\n", + "\n", + " n_labels = hparams['n_labels']\n", + " n_latents = hparams['n_ae_latents']\n", + " expt_name = hparams['experiment_name']\n", + "\n", + " # set up models to fit\n", + " titles = ['Original']\n", + " for model in model_info:\n", + " titles.append(model['title'])\n", + " \n", + " # insert original video at front\n", + " model_info.insert(0, {'model_class': None})\n", + "\n", + " ims_recon = [[] for _ in titles]\n", + " latents = [[] for _ in titles]\n", + " \n", + " if trial_idxs is None:\n", + " trial_idxs = [None] * len(trials)\n", + " if trials is None:\n", + " trials = [None] * len(trial_idxs)\n", + "\n", + " for i, model in enumerate(model_info):\n", + "\n", + " if i == 0:\n", + " continue\n", + " \n", + " # further specify model\n", + " version = model.get('version', 'best')\n", + " hparams['experiment_name'] = model.get('experiment_name', expt_name)\n", + " hparams['model_class'] = model['model_class']\n", + " model_ae, data_generator = get_best_model_and_data(hparams, None, version=version)\n", + "\n", + " # get images\n", + " for trial_idx, trial in zip(trial_idxs, trials):\n", + "\n", + " # get model inputs\n", + " ims_orig_pt, ims_orig_np, _, labels_pt, _, labels_2d_pt, _ = get_model_input(\n", + " data_generator, hparams, model_ae, trial_idx=trial_idx, trial=trial,\n", + " sess_idx=sess_idx, max_frames=max_frames, compute_latents=False, \n", + " compute_2d_labels=False)\n", + " \n", + " # get model outputs\n", + " ims_recon_tmp, latents_tmp = get_reconstruction(\n", + " model_ae, ims_orig_pt, labels=labels_pt, labels_2d=labels_2d_pt,\n", + " return_latents=True)\n", + " ims_recon[i].append(ims_recon_tmp)\n", + " latents[i].append(latents_tmp)\n", + " \n", + " # add a couple black frames to separate trials\n", + " final_trial = True\n", + " if (trial_idx is not None and (trial_idx != trial_idxs[-1])) or \\\n", + " (trial is not None and (trial != trials[-1])):\n", + " final_trial = False\n", + "\n", + " n_buffer = 5\n", + " if not final_trial:\n", + " _, n, y_p, x_p = ims_recon[i][-1].shape\n", + " ims_recon[i].append(np.zeros((n_buffer, n, y_p, x_p)))\n", + " latents[i].append(np.nan * np.zeros((n_buffer, n_latents)))\n", + "\n", + " if i == 1: # deal with original frames only once\n", + " ims_recon[0].append(ims_orig_np)\n", + " latents[0].append([])\n", + " # add a couple black frames to separate trials\n", + " if not final_trial:\n", + " _, n, y_p, x_p = ims_recon[0][-1].shape\n", + " ims_recon[0].append(np.zeros((n_buffer, n, y_p, x_p)))\n", + " \n", + " for i, (ims, zs) in enumerate(zip(ims_recon, latents)):\n", + " ims_recon[i] = np.concatenate(ims, axis=0)\n", + " latents[i] = np.concatenate(zs, axis=0)\n", + " \n", + " if layout_pattern is None:\n", + " if len(titles) < 4:\n", + " n_rows, n_cols = 1, len(titles)\n", + " elif len(titles) == 4:\n", + " n_rows, n_cols = 2, 2\n", + " elif len(titles) > 4:\n", + " n_rows, n_cols = 2, 3\n", + " else:\n", + " raise ValueError('too many models')\n", + " else:\n", + " assert np.sum(layout_pattern) == len(ims_recon)\n", + " n_rows, n_cols = layout_pattern.shape\n", + " count = 0\n", + " for pos_r in layout_pattern:\n", + " for pos_c in pos_r:\n", + " if not pos_c:\n", + " ims_recon.insert(count, [])\n", + " titles.insert(count, [])\n", + " count += 1\n", + "\n", + " make_reconstruction_movie(\n", + " ims=ims_recon, titles=titles, n_rows=n_rows, n_cols=n_cols, \n", + " save_file=save_file, frame_rate=frame_rate)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# set model info\n", + "hparams = {\n", + " 'data_dir': get_user_dir('data'),\n", + " 'save_dir': get_user_dir('save'),\n", + " 'n_labels': n_labels,\n", + " 'n_ae_latents': n_latents + n_labels,\n", + " 'experiment_name': None,\n", + " 'model_type': 'conv',\n", + " 'conditional_encoder': False,\n", + "}\n", + "\n", + "# programmatically fill out other hparams options\n", + "get_lab_example(hparams, lab, expt)\n", + "\n", + "# compare vae/ps-vae reconstructions\n", + "model_info = [\n", + " {\n", + " 'model_class': 'ps-vae',\n", + " 'experiment_name': 'demo-run',\n", + " 'title': 'PS-VAE (%i latents)' % n_latents,\n", + " 'version': 0},\n", + " {\n", + " 'model_class': 'vae',\n", + " 'experiment_name': 'demo-run',\n", + " 'title': 'VAE (%i latents)' % n_latents,\n", + " 'version': 0},\n", + "]\n", + "\n", + "save_file = os.path.join(\n", + " get_user_dir('fig'), lab, expt, animal, session, model_class, \n", + " 'reconstructions_alpha={}_beta={}_gamma={}_rng={}_latents={}'.format(\n", + " best_alpha, best_beta, best_gamma, best_rng, n_latents))\n", + "\n", + "make_reconstruction_movie_wrapper(\n", + " hparams, save_file=save_file, trial_idxs=None, trials=label_recon_trials, \n", + " model_info=model_info, frame_rate=15)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "behavenet", + "language": "python", + "name": "behavenet" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/requirements.txt b/requirements.txt index c35d451..4ae695a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ commentjson==0.8.2 h5py==2.9.0 ipykernel==5.1.0 matplotlib==3.0.3 -notebook==6.0.3 +notebook==6.1.5 numpy==1.17.4 requests==2.22.0 scikit-image==0.15.0 diff --git a/test b/test new file mode 100644 index 0000000..d25c715 --- /dev/null +++ b/test @@ -0,0 +1 @@ +“some test file” diff --git a/tests/integration.py b/tests/integration.py index efd7082..edf6ca5 100644 --- a/tests/integration.py +++ b/tests/integration.py @@ -44,17 +44,19 @@ SESSIONS = ['sess-0', 'sess-1'] MODELS_TO_FIT = [ # ['model_file']_grid_search - {'model_class': 'ae', 'model_file': 'ae', 'sessions': SESSIONS[0]}, - {'model_class': 'arhmm', 'model_file': 'arhmm', 'sessions': SESSIONS[0]}, - {'model_class': 'neural-ae', 'model_file': 'decoder', 'sessions': SESSIONS[0]}, - {'model_class': 'neural-arhmm', 'model_file': 'decoder', 'sessions': SESSIONS[0]}, - {'model_class': 'ae', 'model_file': 'ae', 'sessions': 'all'}, - {'model_class': 'vae', 'model_file': 'ae', 'sessions': SESSIONS[0]}, - {'model_class': 'beta-tcvae', 'model_file': 'ae', 'sessions': SESSIONS[0]}, - {'model_class': 'cond-ae-msp', 'model_file': 'ae', 'sessions': SESSIONS[0]}, - {'model_class': 'cond-vae', 'model_file': 'ae', 'sessions': SESSIONS[0]}, - {'model_class': 'sss-vae', 'model_file': 'ae', 'sessions': SESSIONS[0]}, - {'model_class': 'labels-images', 'model_file': 'label_decoder', 'sessions': SESSIONS[0]}, + {'model_class': 'ae', 'model_file': 'ae', 'sessions': SESSIONS[0]}, + {'model_class': 'arhmm', 'model_file': 'arhmm', 'sessions': SESSIONS[0]}, + {'model_class': 'neural-ae', 'model_file': 'decoder', 'sessions': SESSIONS[0]}, + {'model_class': 'neural-ae-me', 'model_file': 'decoder', 'sessions': SESSIONS[0]}, + {'model_class': 'neural-labels', 'model_file': 'decoder', 'sessions': SESSIONS[0]}, + {'model_class': 'neural-arhmm', 'model_file': 'decoder', 'sessions': SESSIONS[0]}, + {'model_class': 'ae', 'model_file': 'ae', 'sessions': 'all'}, + {'model_class': 'vae', 'model_file': 'ae', 'sessions': SESSIONS[0]}, + {'model_class': 'beta-tcvae', 'model_file': 'ae', 'sessions': SESSIONS[0]}, + {'model_class': 'cond-ae-msp', 'model_file': 'ae', 'sessions': SESSIONS[0]}, + {'model_class': 'cond-vae', 'model_file': 'ae', 'sessions': SESSIONS[0]}, + {'model_class': 'ps-vae', 'model_file': 'ae', 'sessions': SESSIONS[0]}, + {'model_class': 'labels-images', 'model_file': 'label_decoder', 'sessions': SESSIONS[0]}, ] """ @@ -120,7 +122,7 @@ def get_model_config_files(model, json_dir): or model == 'cond-vae' \ or model == 'beta-tcvae' \ or model == 'cond-ae-msp' \ - or model == 'sss-vae' \ + or model == 'ps-vae' \ or model == 'labels-images' \ or model == 'arhmm': if model != 'arhmm': @@ -131,9 +133,10 @@ def get_model_config_files(model, json_dir): 'model': os.path.join(model_json_dir, '%s_model.json' % model), 'training': os.path.join(model_json_dir, '%s_training.json' % model), 'compute': os.path.join(model_json_dir, '%s_compute.json' % model)} - elif model == 'neural-ae' or model == 'neural-arhmm': + elif model == 'neural-ae' or model == 'neural-ae-me' or model == 'neural-arhmm' \ + or model == 'neural-labels': m = 'decoding' - s = model.split('-')[-1] + s = model.split('-')[1] # take string after "neural" model_json_dir = os.path.join(json_dir, '%s_jsons' % m) base_config_files = { 'data': os.path.join(model_json_dir, '%s_data.json' % m), @@ -148,19 +151,35 @@ def get_model_config_files(model, json_dir): def define_new_config_values(model, session='sess-0'): # data vals - data_dict = {'session': session, 'all_source': 'data', **DATA_DICT} + data_dict = { + 'session': session, 'all_source': 'data', 'n_labels': TEMP_DATA['n_labels'], **DATA_DICT} # training vals train_frac = 0.5 trial_splits = '8;1;1;1' + training_dict = { + 'export_train_plots': False, + 'export_latents': True, + 'export_predictions': True, + 'min_n_epochs': 1, + 'max_n_epochs': 1, + 'enable_early_stop': False, + 'train_frac': train_frac, + 'trial_splits': trial_splits + } + # compute vals gpu_id = 0 + compute_dict = {'gpus_viz': str(gpu_id), 'tt_n_cpu_workers': 2} + # model vals: ae ae_expt_name = 'ae-expt' + ae_model_class = 'ae' ae_model_type = 'conv' n_ae_latents = 6 + l2_reg = 0.0 # model vals: arhmm arhmm_expt_name = 'arhmm-expt' @@ -169,7 +188,7 @@ def define_new_config_values(model, session='sess-0'): transitions = 'stationary' noise_type = 'gaussian' - if model == 'ae' or model == 'vae' or model == 'beta-tcvae' or model == 'sss-vae': + if model == 'ae' or model == 'vae' or model == 'beta-tcvae' or model == 'ps-vae': new_values = { 'data': data_dict, 'model': { @@ -177,37 +196,21 @@ def define_new_config_values(model, session='sess-0'): 'model_class': model, 'model_type': ae_model_type, 'n_ae_latents': n_ae_latents, - 'l2_reg': 0.0}, - 'training': { - 'export_train_plots': False, - 'export_latents': True, - 'min_n_epochs': 1, - 'max_n_epochs': 1, - 'enable_early_stop': False, - 'train_frac': train_frac, - 'trial_splits': trial_splits}, - 'compute': { - 'gpus_viz': str(gpu_id)}} + 'l2_reg': l2_reg}, + 'training': training_dict, + 'compute': compute_dict} elif model == 'cond-ae-msp': new_values = { 'data': data_dict, 'model': { 'experiment_name': ae_expt_name, - 'model_class': 'cond-ae-msp', + 'model_class': model, 'model_type': ae_model_type, 'n_ae_latents': n_ae_latents + TEMP_DATA['n_labels'], - 'l2_reg': 0.0, + 'l2_reg': l2_reg, 'msp.alpha': 1e-5}, - 'training': { - 'export_train_plots': False, - 'export_latents': True, - 'min_n_epochs': 1, - 'max_n_epochs': 1, - 'enable_early_stop': False, - 'train_frac': train_frac, - 'trial_splits': trial_splits}, - 'compute': { - 'gpus_viz': str(gpu_id)}} + 'training': training_dict, + 'compute': compute_dict} elif model == 'cond-vae': new_values = { 'data': data_dict, @@ -216,18 +219,10 @@ def define_new_config_values(model, session='sess-0'): 'model_class': model, 'model_type': ae_model_type, 'n_ae_latents': n_ae_latents, - 'l2_reg': 0.0, + 'l2_reg': l2_reg, 'conditional_encoder': False}, - 'training': { - 'export_train_plots': False, - 'export_latents': True, - 'min_n_epochs': 1, - 'max_n_epochs': 1, - 'enable_early_stop': False, - 'train_frac': train_frac, - 'trial_splits': trial_splits}, - 'compute': { - 'gpus_viz': str(gpu_id)}} + 'training': training_dict, + 'compute': compute_dict} elif model == 'arhmm': new_values = { 'data': data_dict, @@ -238,6 +233,7 @@ def define_new_config_values(model, session='sess-0'): 'transitions': transitions, 'noise_type': noise_type, 'ae_experiment_name': ae_expt_name, + 'ae_model_class': ae_model_class, 'ae_model_type': ae_model_type, 'n_ae_latents': n_ae_latents}, 'training': { @@ -246,33 +242,57 @@ def define_new_config_values(model, session='sess-0'): 'n_iters': 2, 'train_frac': train_frac, 'trial_splits': trial_splits}, - 'compute': { - 'gpus_viz': str(gpu_id), - 'tt_n_cpu_workers': 2}} + 'compute': compute_dict} elif model == 'neural-ae': new_values = { 'data': data_dict, 'model': { + 'model_class': model, 'n_lags': 4, 'n_max_lags': 8, 'l2_reg': 1e-3, 'ae_experiment_name': ae_expt_name, + 'ae_model_class': ae_model_class, 'ae_model_type': ae_model_type, 'n_ae_latents': n_ae_latents, 'model_type': 'mlp', 'n_hid_layers': 1, 'n_hid_units': 16, 'activation': 'relu'}, - 'training': { - 'export_predictions': True, - 'min_n_epochs': 1, - 'max_n_epochs': 1, - 'enable_early_stop': False, - 'train_frac': train_frac, - 'trial_splits': trial_splits}, - 'compute': { - 'gpus_viz': str(gpu_id), - 'tt_n_cpu_workers': 2}} + 'training': training_dict, + 'compute': compute_dict} + elif model == 'neural-ae-me': + new_values = { + 'data': data_dict, + 'model': { + 'model_class': model, + 'n_lags': 4, + 'n_max_lags': 8, + 'l2_reg': 1e-3, + 'ae_experiment_name': ae_expt_name, + 'ae_model_class': ae_model_class, + 'ae_model_type': ae_model_type, + 'n_ae_latents': n_ae_latents, + 'model_type': 'mlp', + 'n_hid_layers': 1, + 'n_hid_units': 16, + 'activation': 'relu'}, + 'training': training_dict, + 'compute': compute_dict} + elif model == 'neural-labels': + new_values = { + 'data': data_dict, + 'model': { + 'model_class': model, + 'n_lags': 3, + 'n_max_lags': 5, + 'l2_reg': 1e-4, + 'model_type': 'mlp', + 'n_hid_layers': 1, + 'n_hid_units': 16, + 'activation': 'relu'}, + 'training': training_dict, + 'compute': compute_dict} elif model == 'neural-arhmm': new_values = { 'data': data_dict, @@ -280,6 +300,7 @@ def define_new_config_values(model, session='sess-0'): 'n_lags': 2, 'n_max_lags': 8, 'l2_reg': 1e-3, + 'ae_model_class': ae_model_class, 'ae_model_type': ae_model_type, 'n_ae_latents': n_ae_latents, 'arhmm_experiment_name': arhmm_expt_name, @@ -291,16 +312,8 @@ def define_new_config_values(model, session='sess-0'): 'n_hid_layers': 1, 'n_hid_units': [8, 16], 'activation': 'relu'}, - 'training': { - 'export_predictions': True, - 'min_n_epochs': 1, - 'max_n_epochs': 1, - 'enable_early_stop': False, - 'train_frac': train_frac, - 'trial_splits': trial_splits}, - 'compute': { - 'gpus_viz': str(gpu_id), - 'tt_n_cpu_workers': 2}} + 'training': training_dict, + 'compute': compute_dict} elif model == 'labels-images': new_values = { 'data': data_dict, @@ -309,17 +322,16 @@ def define_new_config_values(model, session='sess-0'): 'model_class': 'labels-images', 'model_type': ae_model_type, 'n_ae_latents': 0, - 'l2_reg': 0.0}, + 'l2_reg': l2_reg}, 'training': { 'export_train_plots': False, - 'export_latents': False, + 'export_predictions': False, 'min_n_epochs': 1, 'max_n_epochs': 1, 'enable_early_stop': False, 'train_frac': train_frac, 'trial_splits': trial_splits}, - 'compute': { - 'gpus_viz': str(gpu_id)}} + 'compute': compute_dict} else: raise NotImplementedError diff --git a/tests/test_data/test_transforms.py b/tests/test_data/test_transforms.py index bc5393a..c13d712 100644 --- a/tests/test_data/test_transforms.py +++ b/tests/test_data/test_transforms.py @@ -16,6 +16,43 @@ def test_compose(): assert np.allclose(np.std(s, axis=0), [1, 1], atol=1e-3) +def test_blockshuffle(): + + def get_runs(sample): + + vals = np.unique(sample) + n_time = len(sample) + + # mark first time point of state change with a nonzero number + change = np.where(np.concatenate([[0], np.diff(sample)], axis=0) != 0)[0] + # collect runs + runs = {val: [] for val in vals} + prev_beg = 0 + for curr_beg in change: + runs[sample[prev_beg]].append(curr_beg - prev_beg) + prev_beg = curr_beg + runs[sample[-1]].append(n_time - prev_beg) + return runs + + t = transforms.BlockShuffle(0) + + # signal has changed + signal = np.array([0, 0, 0, 1, 1, 1, 2, 2, 0, 0, 1, 1]) + s = t(signal) + assert not np.all(signal == s) + + # frequency of values unchanged + n_ex_og = np.array([len(np.argwhere(signal == i)) for i in range(3)]) + n_ex_sh = np.array([len(np.argwhere(s == i)) for i in range(3)]) + assert np.all(n_ex_og == n_ex_sh) + + # distribution of runs unchanged + runs_og = get_runs(signal) + runs_sh = get_runs(s) + for key in runs_og.keys(): + assert np.all(np.sort(np.array(runs_og[key])) == np.sort(np.array(runs_sh[key]))) + + def test_clipnormalize(): # raise exception when clip value <= 0 @@ -35,40 +72,6 @@ def test_clipnormalize(): assert np.max(s) == 1 -def test_threshold(): - - # raise exception when bin size <= 0 - with pytest.raises(ValueError): - transforms.Threshold(1, 0) - - # raise exception when threshold < 0 - with pytest.raises(ValueError): - transforms.Threshold(-1, 1) - - # no thresholding with 0 threshold - t = transforms.Threshold(0, 1) - signal = np.random.uniform(0, 4, (5, 4)) - s = t(signal) - assert s.shape == (5, 4) - - # correct thresholding - t = transforms.Threshold(1, 1e3) - signal = np.random.uniform(2, 4, (5, 4)) - signal[:, 0] = 0 - s = t(signal) - assert s.shape == (5, 3) - - -def test_zscore(): - - t = transforms.ZScore() - signal = 10 + 0.3 * np.random.randn(100, 3) - s = t(signal) - assert s.shape == (100, 3) - assert np.allclose(np.mean(s, axis=0), [0, 0, 0], atol=1e-3) - assert np.allclose(np.std(s, axis=0), [1, 1, 1], atol=1e-3) - - def test_makeonehot(): t = transforms.MakeOneHot() @@ -119,42 +122,31 @@ def test_makeonehot2d(): s = t(signal) assert np.all(s == sp) + # correct one-hotting with nans in signal + t = transforms.MakeOneHot2D(4, 4) + signal = np.array([[1, 2, 0, np.nan], [0, 2, 1, 1], [3, 0, np.nan, 2]]) + sp = np.zeros((3, 2, 4, 4)) + sp[0, 0, 0, 1] = 1 + sp[0, 1, 0, 2] = 1 + sp[1, 0, 1, 0] = 1 + sp[1, 1, 1, 2] = 1 + sp[2, 0, 0, 3] = 1 + sp[2, 1, 2, 0] = 1 + s = t(signal) + assert np.all(s == sp) -def test_blockshuffle(): - - def get_runs(sample): - - vals = np.unique(sample) - n_time = len(sample) - - # mark first time point of state change with a nonzero number - change = np.where(np.concatenate([[0], np.diff(sample)], axis=0) != 0)[0] - # collect runs - runs = {val: [] for val in vals} - prev_beg = 0 - for curr_beg in change: - runs[sample[prev_beg]].append(curr_beg - prev_beg) - prev_beg = curr_beg - runs[sample[-1]].append(n_time - prev_beg) - return runs - t = transforms.BlockShuffle(0) +def test_motionenergy(): - # signal has changed - signal = np.array([0, 0, 0, 1, 1, 1, 2, 2, 0, 0, 1, 1]) + T = 100 + D = 4 + t = transforms.MotionEnergy() + signal = np.random.randn(T, D) s = t(signal) - assert not np.all(signal == s) - - # frequency of values unchanged - n_ex_og = np.array([len(np.argwhere(signal == i)) for i in range(3)]) - n_ex_sh = np.array([len(np.argwhere(s == i)) for i in range(3)]) - assert np.all(n_ex_og == n_ex_sh) - - # distribution of runs unchanged - runs_og = get_runs(signal) - runs_sh = get_runs(s) - for key in runs_og.keys(): - assert np.all(np.sort(np.array(runs_og[key])) == np.sort(np.array(runs_sh[key]))) + me = np.vstack([np.zeros((1, signal.shape[1])), np.abs(np.diff(signal, axis=0))]) + assert s.shape == (T, D) + assert np.allclose(s, me, atol=1e-3) + assert np.all(me >= 0) def test_selectindxs(): @@ -166,3 +158,37 @@ def test_selectindxs(): s = t(signal) assert s.shape == (5, 2) assert np.all(signal[:, idxs] == s) + + +def test_threshold(): + + # raise exception when bin size <= 0 + with pytest.raises(ValueError): + transforms.Threshold(1, 0) + + # raise exception when threshold < 0 + with pytest.raises(ValueError): + transforms.Threshold(-1, 1) + + # no thresholding with 0 threshold + t = transforms.Threshold(0, 1) + signal = np.random.uniform(0, 4, (5, 4)) + s = t(signal) + assert s.shape == (5, 4) + + # correct thresholding + t = transforms.Threshold(1, 1e3) + signal = np.random.uniform(2, 4, (5, 4)) + signal[:, 0] = 0 + s = t(signal) + assert s.shape == (5, 3) + + +def test_zscore(): + + t = transforms.ZScore() + signal = 10 + 0.3 * np.random.randn(100, 3) + s = t(signal) + assert s.shape == (100, 3) + assert np.allclose(np.mean(s, axis=0), [0, 0, 0], atol=1e-3) + assert np.allclose(np.std(s, axis=0), [1, 1, 1], atol=1e-3) diff --git a/tests/test_data/test_utils_data.py b/tests/test_data/test_utils_data.py index 2509325..a2947d4 100644 --- a/tests/test_data/test_utils_data.py +++ b/tests/test_data/test_utils_data.py @@ -76,16 +76,16 @@ def test_get_data_generator_inputs(): hparams['use_output_mask'] = False # ----------------- - # sss-vae + # ps-vae # ----------------- - hparams['model_class'] = 'sss-vae' + hparams['model_class'] = 'ps-vae' hparams_, signals, transforms, paths = utils.get_data_generator_inputs( hparams, sess_ids, check_splits=False) assert signals[0] == ['images', 'labels'] assert transforms[0] == [None, None] assert paths[0] == [hdf5_path, hdf5_path] - hparams['model_class'] = 'sss-vae' + hparams['model_class'] = 'ps-vae' hparams['use_output_mask'] = True hparams_, signals, transforms, paths = utils.get_data_generator_inputs( hparams, sess_ids, check_splits=False) @@ -94,6 +94,15 @@ def test_get_data_generator_inputs(): assert paths[0] == [hdf5_path, hdf5_path, hdf5_path] hparams['use_output_mask'] = False + hparams['model_class'] = 'ps-vae' + hparams['use_label_mask'] = True + hparams_, signals, transforms, paths = utils.get_data_generator_inputs( + hparams, sess_ids, check_splits=False) + assert signals[0] == ['images', 'labels', 'labels_masks'] + assert transforms[0] == [None, None, None] + assert paths[0] == [hdf5_path, hdf5_path, hdf5_path] + hparams['use_label_mask'] = False + # ----------------- # cond-vae # ----------------- @@ -114,7 +123,7 @@ def test_get_data_generator_inputs(): hparams['use_output_mask'] = False # ----------------- - # cond-ae [-msp] + # cond-ae # ----------------- hparams['model_class'] = 'cond-ae' hparams_, signals, transforms, paths = utils.get_data_generator_inputs( @@ -145,6 +154,9 @@ def test_get_data_generator_inputs(): assert paths[0] == [hdf5_path, hdf5_path, hdf5_path] hparams['conditional_encoder'] = False + # ----------------- + # cond-ae-msp + # ----------------- hparams['model_class'] = 'cond-ae-msp' hparams_, signals, transforms, paths = utils.get_data_generator_inputs( hparams, sess_ids, check_splits=False) @@ -152,11 +164,21 @@ def test_get_data_generator_inputs(): assert transforms[0] == [None, None] assert paths[0] == [hdf5_path, hdf5_path] + hparams['model_class'] = 'cond-ae-msp' + hparams['use_label_mask'] = True + hparams_, signals, transforms, paths = utils.get_data_generator_inputs( + hparams, sess_ids, check_splits=False) + assert signals[0] == ['images', 'labels', 'labels_masks'] + assert transforms[0] == [None, None, None] + assert paths[0] == [hdf5_path, hdf5_path, hdf5_path] + hparams['use_label_mask'] = False + # ----------------- # ae_latents # ----------------- hparams['model_class'] = 'ae_latents' hparams['session_dir'] = session_dir + hparams['ae_model_class'] = 'ae' hparams['ae_model_type'] = 'conv' hparams['n_ae_latents'] = 8 hparams['ae_experiment_name'] = 'tt_expt_ae' @@ -187,6 +209,29 @@ def test_get_data_generator_inputs(): hparams, sess_ids, check_splits=False) assert hparams_['noise_dist'] == 'gaussian-full' + # ----------------- + # neural-ae-me + # ----------------- + hparams['model_class'] = 'neural-ae-me' + hparams['model_type'] = 'linear' + hparams['session_dir'] = session_dir + hparams['neural_type'] = 'spikes' + hparams['neural_thresh'] = 0 + hparams_, signals, transforms, paths = utils.get_data_generator_inputs( + hparams, sess_ids, check_splits=False) + assert signals[0] == ['neural', 'ae_latents'] + assert transforms[0][0] is None + assert transforms[0][1].__repr__().find('MotionEnergy') > -1 + assert hparams_['input_signal'] == 'neural' + assert hparams_['output_signal'] == 'ae_latents' + assert hparams_['output_size'] == hparams['n_ae_latents'] + assert hparams_['noise_dist'] == 'gaussian' + + hparams['model_type'] = 'linear-mv' + hparams_, signals, transforms, paths = utils.get_data_generator_inputs( + hparams, sess_ids, check_splits=False) + assert hparams_['noise_dist'] == 'gaussian-full' + # ----------------- # ae-neural # ----------------- @@ -215,6 +260,57 @@ def test_get_data_generator_inputs(): hparams, sess_ids, check_splits=False) assert hparams_['noise_dist'] == 'gaussian-full' + # ----------------- + # neural-labels + # ----------------- + hparams['model_class'] = 'neural-labels' + hparams['model_type'] = 'linear' + hparams['n_labels'] = 4 + hparams['session_dir'] = session_dir + hparams['neural_type'] = 'spikes' + hparams['neural_thresh'] = 0 + hparams_, signals, transforms, paths = utils.get_data_generator_inputs( + hparams, sess_ids, check_splits=False) + assert signals[0] == ['neural', 'labels'] + assert hparams_['input_signal'] == 'neural' + assert hparams_['output_signal'] == 'labels' + assert hparams_['output_size'] == hparams['n_labels'] + assert hparams_['noise_dist'] == 'gaussian' + + hparams['model_type'] = 'linear-mv' + hparams_, signals, transforms, paths = utils.get_data_generator_inputs( + hparams, sess_ids, check_splits=False) + assert hparams_['noise_dist'] == 'gaussian-full' + + # ----------------- + # labels-neural + # ----------------- + hparams['model_class'] = 'labels-neural' + hparams['model_type'] = 'linear' + hparams['n_labels'] = 4 + hparams['session_dir'] = session_dir + hparams['neural_type'] = 'spikes' + hparams['neural_thresh'] = 0 + hparams_, signals, transforms, paths = utils.get_data_generator_inputs( + hparams, sess_ids, check_splits=False) + assert signals[0] == ['neural', 'labels'] + assert hparams_['input_signal'] == 'labels' + assert hparams_['output_signal'] == 'neural' + assert hparams_['output_size'] is None + assert hparams_['noise_dist'] == 'poisson' + + hparams['model_type'] = 'linear' + hparams['neural_type'] = 'ca' + hparams_, signals, transforms, paths = utils.get_data_generator_inputs( + hparams, sess_ids, check_splits=False) + assert hparams_['noise_dist'] == 'gaussian' + + hparams['model_type'] = 'linear-mv' + hparams['neural_type'] = 'ca' + hparams_, signals, transforms, paths = utils.get_data_generator_inputs( + hparams, sess_ids, check_splits=False) + assert hparams_['noise_dist'] == 'gaussian-full' + # ----------------- # arhmm # ----------------- @@ -377,6 +473,24 @@ def test_get_data_generator_inputs(): assert transforms[0] == [None] assert paths[0] == [hdf5_path] + hparams['use_label_mask'] = True + hparams_, signals, transforms, paths = utils.get_data_generator_inputs( + hparams, sess_ids, check_splits=False) + assert signals[0] == ['labels', 'labels_masks'] + assert transforms[0] == [None, None] + assert paths[0] == [hdf5_path, hdf5_path] + hparams['use_label_mask'] = False + + # ----------------- + # labels_masks + # ----------------- + hparams['model_class'] = 'labels_masks' + hparams_, signals, transforms, paths = utils.get_data_generator_inputs( + hparams, sess_ids, check_splits=False) + assert signals[0] == ['labels_masks'] + assert transforms[0] == [None] + assert paths[0] == [hdf5_path] + # ----------------- # other # ----------------- @@ -444,6 +558,7 @@ def test_get_transforms_paths(): # ae latents # ------------------------ hparams['session_dir'] = session_dir + hparams['ae_model_class'] = 'ae' hparams['ae_model_type'] = 'conv' hparams['n_ae_latents'] = 8 hparams['ae_experiment_name'] = 'tt_expt_ae' @@ -451,7 +566,7 @@ def test_get_transforms_paths(): ae_path = os.path.join( hparams['data_dir'], hparams['lab'], hparams['expt'], hparams['animal'], - hparams['session'], 'ae', hparams['ae_model_type'], + hparams['session'], hparams['ae_model_class'], hparams['ae_model_type'], '%02i_latents' % hparams['n_ae_latents'], hparams['ae_experiment_name']) # user-defined latent path @@ -469,6 +584,13 @@ def test_get_transforms_paths(): ae_path, 'version_%i' % hparams['ae_version'], '%slatents.pkl' % sess_id_str) assert transform is None + # get correct transform + transform, path = utils.get_transforms_paths( + 'ae_latents_me', hparams, sess_id=None, check_splits=False) + assert path == os.path.join( + ae_path, 'version_%i' % hparams['ae_version'], '%slatents.pkl' % sess_id_str) + assert transform.__repr__().find('MotionEnergy') > -1 + # TODO: use get_best_model_version() # ------------------------ @@ -600,7 +722,7 @@ def test_get_region_list(tmpdir): 'i2': np.array([6, 7, 8])} with h5py.File(path, 'w') as f: group0 = f.create_group('group0') - groupa = f.create_group('groupa') + # groupa = f.create_group('groupa') group1 = group0.create_group('group1') group1.create_dataset('i0', data=idx_data['i0']) group1.create_dataset('i1', data=idx_data['i1']) diff --git a/tests/test_fitting/test_hyperparam_utils.py b/tests/test_fitting/test_hyperparam_utils.py index 0c6ca59..7ae3872 100644 --- a/tests/test_fitting/test_hyperparam_utils.py +++ b/tests/test_fitting/test_hyperparam_utils.py @@ -37,7 +37,7 @@ def test_get_all_params(): training_config = os.path.join( os.getcwd(), 'configs', 'arhmm_jsons', 'arhmm_training.json') compute_config = os.path.join( - os.getcwd(), 'configs', 'arhmm_jsons', 'arhmm_compute.json') + os.getcwd(), 'configs', 'arhmm_jsons', 'arhmm_compute.json') args = [ '--data_config', data_config, '--model_config', model_config, @@ -133,7 +133,7 @@ def test_add_dependent_params(tmpdir): 'i2': np.array([6, 7, 8])} with h5py.File(path, 'w') as f: group0 = f.create_group('regions') - groupa = f.create_group('neural') + # groupa = f.create_group('neural') group1 = group0.create_group('indxs') group1.create_dataset('i0', data=idx_data['i0']) group1.create_dataset('i1', data=idx_data['i1']) diff --git a/tests/test_fitting/test_utils_fitting.py b/tests/test_fitting/test_utils_fitting.py index 0653093..4cbbd8e 100644 --- a/tests/test_fitting/test_utils_fitting.py +++ b/tests/test_fitting/test_utils_fitting.py @@ -189,7 +189,7 @@ def test_get_subdirs(self): assert sorted(subdirs) == ['expt0', 'expt1', 'multisession-00'] # raise exception when not a path - with pytest.raises(ValueError): + with pytest.raises(NotADirectoryError): utils.get_subdirs('/ZzZtestingZzZ') def test_get_multisession_paths(self): @@ -528,9 +528,9 @@ def test_get_expt_dir(self): assert expt_dir == model_path # ------------------------- - # sss-vae + # ps-vae # ------------------------- - hparams['model_class'] = 'sss-vae' + hparams['model_class'] = 'ps-vae' hparams['model_type'] = 'conv' hparams['n_ae_latents'] = 10 hparams['experiment_name'] = 'tt_expt' @@ -547,7 +547,7 @@ def test_get_expt_dir(self): assert expt_dir == model_path # ------------------------- - # neural-ae/ae-neural + # neural-ae/neural-ae-me/ae-neural # ------------------------- hparams['model_class'] = 'neural-ae' hparams['model_type'] = 'mlp' @@ -562,6 +562,16 @@ def test_get_expt_dir(self): expt_name=hparams['experiment_name']) assert expt_dir == model_path + hparams['model_class'] = 'neural-ae-me' + model_path = os.path.join( + session_dir, hparams['model_class'], '%02i_latents' % hparams['n_ae_latents'], + hparams['model_type'], 'all', hparams['experiment_name']) + + expt_dir = utils.get_expt_dir( + hparams, model_class=hparams['model_class'], model_type=hparams['model_type'], + expt_name=hparams['experiment_name']) + assert expt_dir == model_path + hparams['model_class'] = 'ae-neural' model_path = os.path.join( session_dir, hparams['model_class'], '%02i_latents' % hparams['n_ae_latents'], @@ -571,6 +581,30 @@ def test_get_expt_dir(self): expt_name=hparams['experiment_name']) assert expt_dir == model_path + # ------------------------- + # neural-labels/labels-neural + # ------------------------- + hparams['model_class'] = 'neural-labels' + hparams['model_type'] = 'mlp' + hparams['experiment_name'] = 'tt_expt' + model_path = os.path.join( + session_dir, hparams['model_class'], hparams['model_type'], 'all', + hparams['experiment_name']) + + expt_dir = utils.get_expt_dir( + hparams, model_class=hparams['model_class'], model_type=hparams['model_type'], + expt_name=hparams['experiment_name']) + assert expt_dir == model_path + + hparams['model_class'] = 'labels-neural' + model_path = os.path.join( + session_dir, hparams['model_class'], hparams['model_type'], 'all', + hparams['experiment_name']) + expt_dir = utils.get_expt_dir( + hparams, model_class=hparams['model_class'], model_type=hparams['model_type'], + expt_name=hparams['experiment_name']) + assert expt_dir == model_path + # ------------------------- # neural-arhmm/arhmm-neural # ------------------------- @@ -891,17 +925,17 @@ def test_get_model_params(self): ret_hparams = utils.get_model_params({**misc_hparams, **base_hparams, **model_hparams}) assert ret_hparams == {**base_hparams, **model_hparams} - # sss-vae + # ps-vae model_hparams = { - 'model_class': 'sss-vae', + 'model_class': 'ps-vae', 'model_type': 'conv', 'n_ae_latents': 6, 'fit_sess_io_layers': False, 'learning_rate': 1e-4, 'l2_reg': 1e-2, - 'sss_vae.alpha': 1, - 'sss_vae.beta': 2, - 'sss_vae.gamma': 3, + 'ps_vae.alpha': 1, + 'ps_vae.beta': 2, + 'ps_vae.gamma': 3, # 'beta_tcvae.beta_anneal_epochs': 100 } ret_hparams = utils.get_model_params({**misc_hparams, **base_hparams, **model_hparams}) @@ -918,6 +952,7 @@ def test_get_model_params(self): 'transitions': 'stationary', 'ae_experiment_name': 'ae_expt', 'ae_version': 4, + 'ae_model_class': 'ae', 'ae_model_type': 'conv', 'n_ae_latents': 5} ret_hparams = utils.get_model_params({**misc_hparams, **base_hparams, **model_hparams}) @@ -932,6 +967,7 @@ def test_get_model_params(self): 'kappa': 100, 'ae_experiment_name': 'ae_expt', 'ae_version': 4, + 'ae_model_class': 'ae', 'ae_model_type': 'conv', 'n_ae_latents': 5} ret_hparams = utils.get_model_params({**misc_hparams, **base_hparams, **model_hparams}) @@ -950,19 +986,44 @@ def test_get_model_params(self): assert ret_hparams == {**base_hparams, **model_hparams} # ----------------- - # neural-ae/ae-neural + # neural-ae/neural-ae-me/ae-neural # ----------------- model_hparams = { 'model_class': 'neural-ae', 'model_type': 'mlp', 'ae_experiment_name': 'ae_expt', 'ae_version': 4, + 'ae_model_class': 'ae', 'ae_model_type': 'conv', 'n_ae_latents': 5, 'n_lags': 3, 'l2_reg': 1, 'n_hid_layers': 0, 'activation': 'relu', + 'learning_rate': 1e-3, + 'subsample_method': 'none'} + ret_hparams = utils.get_model_params({**misc_hparams, **base_hparams, **model_hparams}) + assert ret_hparams == {**base_hparams, **model_hparams} + + model_hparams['model_class'] = 'neural-ae-me' + ret_hparams = utils.get_model_params({**misc_hparams, **base_hparams, **model_hparams}) + assert ret_hparams == {**base_hparams, **model_hparams} + + model_hparams['model_class'] = 'ae-neural' + ret_hparams = utils.get_model_params({**misc_hparams, **base_hparams, **model_hparams}) + assert ret_hparams == {**base_hparams, **model_hparams} + + # ----------------- + # neural-labels/labels-neural + # ----------------- + model_hparams = { + 'model_class': 'neural-labels', + 'model_type': 'mlp', + 'n_lags': 3, + 'l2_reg': 1, + 'n_hid_layers': 0, + 'activation': 'relu', + 'learning_rate': 1e-3, 'subsample_method': 'none'} ret_hparams = utils.get_model_params({**misc_hparams, **base_hparams, **model_hparams}) assert ret_hparams == {**base_hparams, **model_hparams} @@ -980,6 +1041,7 @@ def test_get_model_params(self): 'noise_type': 'gaussian', 'transitions': 'sticky', 'kappa': 10, + 'ae_model_class': 'ae', 'ae_model_type': 'conv', 'n_ae_latents': 5, 'n_lags': 3, @@ -987,6 +1049,7 @@ def test_get_model_params(self): 'n_hid_layers': 2, 'n_hid_units': 10, 'activation': 'relu', + 'learning_rate': 1e-3, 'subsample_method': 'single', 'subsample_idxs_name': 'a', 'subsample_idxs_group_0': 'b', diff --git a/tests/test_models/test_ae_model_architecture_generator.py b/tests/test_models/test_ae_model_architecture_generator.py index 7bc7172..7119549 100644 --- a/tests/test_models/test_ae_model_architecture_generator.py +++ b/tests/test_models/test_ae_model_architecture_generator.py @@ -131,9 +131,9 @@ def test_get_decoding_conv_block(): assert arch['ae_decoding_n_channels'][-1] == input_dim[0] for i in range(len(arch['ae_decoding_n_channels']) - 1): assert arch['ae_decoding_layer_type'][i] in ['convtranspose'] - assert arch['ae_decoding_n_channels'][i] == arch['ae_encoding_n_channels'][-2-i] - assert arch['ae_decoding_kernel_size'][i] == arch['ae_encoding_kernel_size'][-1-i] - assert arch['ae_decoding_stride_size'][i] == arch['ae_encoding_stride_size'][-1-i] + assert arch['ae_decoding_n_channels'][i] == arch['ae_encoding_n_channels'][-2 - i] + assert arch['ae_decoding_kernel_size'][i] == arch['ae_encoding_kernel_size'][-1 - i] + assert arch['ae_decoding_stride_size'][i] == arch['ae_encoding_stride_size'][-1 - i] # using correct options (with maxpool) np.random.seed(16) @@ -143,9 +143,9 @@ def test_get_decoding_conv_block(): print(arch) for i in range(len(arch['ae_decoding_n_channels']) - 1): assert arch['ae_decoding_layer_type'][i] in ['convtranspose', 'unpool'] - assert arch['ae_decoding_n_channels'][i] == arch['ae_encoding_n_channels'][-2-i] - assert arch['ae_decoding_kernel_size'][i] == arch['ae_encoding_kernel_size'][-1-i] - assert arch['ae_decoding_stride_size'][i] == arch['ae_encoding_stride_size'][-1-i] + assert arch['ae_decoding_n_channels'][i] == arch['ae_encoding_n_channels'][-2 - i] + assert arch['ae_decoding_kernel_size'][i] == arch['ae_encoding_kernel_size'][-1 - i] + assert arch['ae_decoding_stride_size'][i] == arch['ae_encoding_stride_size'][-1 - i] # using correct options (with final ff layer) arch['ae_decoding_last_FF_layer'] = True @@ -377,14 +377,14 @@ def test_get_handcrafted_dims(): arch0 = utils.load_default_arch() arch0['ae_input_dim'] = [2, 128, 128] arch0 = utils.get_handcrafted_dims(arch0, symmetric=True) - assert arch0['ae_encoding_x_dim'] == [64, 32, 16, 8] - assert arch0['ae_encoding_y_dim'] == [64, 32, 16, 8] - assert arch0['ae_encoding_x_padding'] == [(1, 2), (1, 2), (1, 2), (1, 2)] - assert arch0['ae_encoding_y_padding'] == [(1, 2), (1, 2), (1, 2), (1, 2)] - assert arch0['ae_decoding_x_dim'] == [16, 32, 64, 128] - assert arch0['ae_decoding_y_dim'] == [16, 32, 64, 128] - assert arch0['ae_decoding_x_padding'] == [(1, 2), (1, 2), (1, 2), (1, 2)] - assert arch0['ae_decoding_y_padding'] == [(1, 2), (1, 2), (1, 2), (1, 2)] + assert arch0['ae_encoding_x_dim'] == [64, 32, 16, 8, 2] + assert arch0['ae_encoding_y_dim'] == [64, 32, 16, 8, 2] + assert arch0['ae_encoding_x_padding'] == [(1, 2), (1, 2), (1, 2), (1, 2), (1, 1)] + assert arch0['ae_encoding_y_padding'] == [(1, 2), (1, 2), (1, 2), (1, 2), (1, 1)] + assert arch0['ae_decoding_x_dim'] == [8, 16, 32, 64, 128] + assert arch0['ae_decoding_y_dim'] == [8, 16, 32, 64, 128] + assert arch0['ae_decoding_x_padding'] == [(1, 1), (1, 2), (1, 2), (1, 2), (1, 2)] + assert arch0['ae_decoding_y_padding'] == [(1, 1), (1, 2), (1, 2), (1, 2), (1, 2)] # asymmetric arch (TODO: source code not updated) arch1 = utils.load_default_arch() @@ -395,10 +395,10 @@ def test_get_handcrafted_dims(): arch1['ae_decoding_layer_type'] = ['conv', 'conv', 'conv'] arch1['ae_decoding_starting_dim'] = [1, 8, 8] arch1 = utils.get_handcrafted_dims(arch1, symmetric=False) - assert arch1['ae_encoding_x_dim'] == [64, 32, 16, 8] - assert arch1['ae_encoding_y_dim'] == [64, 32, 16, 8] - assert arch1['ae_encoding_x_padding'] == [(1, 2), (1, 2), (1, 2), (1, 2)] - assert arch1['ae_encoding_y_padding'] == [(1, 2), (1, 2), (1, 2), (1, 2)] + assert arch1['ae_encoding_x_dim'] == [64, 32, 16, 8, 2] + assert arch1['ae_encoding_y_dim'] == [64, 32, 16, 8, 2] + assert arch1['ae_encoding_x_padding'] == [(1, 2), (1, 2), (1, 2), (1, 2), (1, 1)] + assert arch1['ae_encoding_y_padding'] == [(1, 2), (1, 2), (1, 2), (1, 2), (1, 1)] assert arch1['ae_decoding_x_dim'] == [15, 29, 57] assert arch1['ae_decoding_y_dim'] == [15, 29, 57] assert arch1['ae_decoding_x_padding'] == [(2, 2), (2, 2), (2, 2)] @@ -425,7 +425,7 @@ def test_load_handcrafted_arch(): assert arch['x_pixels'] == input_dim[2] assert arch['ae_input_dim'] == input_dim assert arch['n_ae_latents'] == n_ae_latents - assert arch['ae_encoding_n_channels'] == [32, 64, 256, 512] + assert arch['ae_encoding_n_channels'] == [32, 64, 128, 256, 512] # load arch from json ae_arch_json = os.path.join( @@ -447,7 +447,7 @@ def test_load_handcrafted_arch(): assert arch['x_pixels'] == input_dim[2] assert arch['ae_input_dim'] == input_dim assert arch['n_ae_latents'] == n_ae_latents - assert arch['ae_encoding_n_channels'] == [32, 64, 256, 512] + assert arch['ae_encoding_n_channels'] == [32, 64, 128, 256, 512] # check memory runs ae_arch_json = None @@ -458,7 +458,7 @@ def test_load_handcrafted_arch(): assert arch['x_pixels'] == input_dim[2] assert arch['ae_input_dim'] == input_dim assert arch['n_ae_latents'] == n_ae_latents - assert arch['ae_encoding_n_channels'] == [32, 64, 256, 512] + assert arch['ae_encoding_n_channels'] == [32, 64, 128, 256, 512] # raise exception when not enough gpu memory ae_arch_json = None