Skip to content

Commit

Permalink
Merge pull request #67 from delve-team/feature/layer_filter
Browse files Browse the repository at this point in the history
Feature/layer filter
  • Loading branch information
MLRichter authored Mar 29, 2023
2 parents 55bec2c + 79657ee commit 0d413d5
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
2 changes: 1 addition & 1 deletion delve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
import delve.logger

name = "delve"
__version__ = "0.1.49"
__version__ = "0.1.50"
5 changes: 3 additions & 2 deletions delve/pca_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,9 @@ def _compute_autorcorrelation(self) -> torch.Tensor:
return cov_mtx

def _compute_eigenspace(self):
self.eigenvalues.data, self.eigenvectors.data = self._compute_autorcorrelation(
).symeig(True) #.type(self.data_dtype)
self.eigenvalues.data, self.eigenvectors.data = torch.linalg.eigh(
self._compute_autorcorrelation(), UPLO='U'
) #.type(self.data_dtype)
self.eigenvalues.data, idx = self.eigenvalues.sort(descending=True)
# correct numerical error, matrix must be positivly semi-definitie
self.eigenvalues[self.eigenvalues < 0] = 0
Expand Down
13 changes: 10 additions & 3 deletions delve/torchcallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
from torch.nn.functional import interpolate
from torch.nn.modules import LSTM
from torch.nn.modules import LSTM, Module
from torch.nn.modules.conv import Conv2d
from torch.nn.modules.linear import Linear

Expand All @@ -16,6 +16,8 @@
from delve.torch_utils import TorchCovarianceMatrix
from delve.writers import STATMAP, WRITERS, CompositWriter, NPYWriter

from typing import Callable


class SaturationTracker(object):
"""Takes PyTorch module and records layer saturation,
Expand Down Expand Up @@ -49,6 +51,10 @@ class SaturationTracker(object):
Per default, only Conv2D,
Linear and LSTM-Cells
are recorded
layer_filter (func): A filter function that is used to avoid layers from being tracked.
This is function receiving a dictionary as input and returning
it with undesired entries removed. Default: Identity function.
The dictionary contains string keys mapping to torch.nn.Module objects.
writers_args (dict) : contains additional arguments passed over to the
writers. This is only used, when a writer is
initialized through a string-key.
Expand Down Expand Up @@ -168,7 +174,8 @@ class SaturationTracker(object):
def __init__(self,
savefile: str,
save_to: Union[str, delve.writers.AbstractWriter],
modules: torch.nn.Module,
modules: Module,
layer_filter: Callable[[Dict[str, Module]], Dict[str, Module]] = lambda x: x,
writer_args: Optional[Dict[str, Any]] = None,
log_interval=1,
max_samples=None,
Expand All @@ -195,7 +202,7 @@ def __init__(self,

self.timeseries_method = timeseries_method
self.threshold = sat_threshold
self.layers = self.get_layers_recursive(modules)
self.layers = layer_filter(self.get_layers_recursive(modules))
self.max_samples = max_samples
self.log_interval = log_interval
self.reset_covariance = reset_covariance
Expand Down

0 comments on commit 0d413d5

Please sign in to comment.