diff --git a/docs/conf.py b/docs/conf.py index a372a8f..e4a549f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -22,7 +22,7 @@ author = 'Kristian Georgiev' # The full version, including alpha/beta/rc tags -release = '0.1.0' +release = '0.1.1' # -- General configuration --------------------------------------------------- diff --git a/docs/html/.buildinfo b/docs/html/.buildinfo index 7f9c650..a4d6b90 100644 --- a/docs/html/.buildinfo +++ b/docs/html/.buildinfo @@ -1,4 +1,4 @@ # Sphinx build info version 1 # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. -config: def77ad9866373dc2cf63179c7913729 +config: 1a000b88aa8f9cf78e81853015c446f0 tags: 645f666f9bcd5a90fca523b33c5a78b7 diff --git a/docs/html/.doctrees/environment.pickle b/docs/html/.doctrees/environment.pickle index 4321e10..698c812 100644 Binary files a/docs/html/.doctrees/environment.pickle and b/docs/html/.doctrees/environment.pickle differ diff --git a/docs/html/.doctrees/trak.doctree b/docs/html/.doctrees/trak.doctree index 316a952..ef578a5 100644 Binary files a/docs/html/.doctrees/trak.doctree and b/docs/html/.doctrees/trak.doctree differ diff --git a/docs/html/_modules/index.html b/docs/html/_modules/index.html index 1a62708..6acb32f 100644 --- a/docs/html/_modules/index.html +++ b/docs/html/_modules/index.html @@ -5,7 +5,7 @@ -
@abstractmethod
def __init__(self,
save_dir: Union[Path, str],
- metadata: Iterable) -> None:
+ metadata: Iterable,
+ load_from_save_dir: bool) -> None:
""" Creates the save directory if it doesn't already exist.
If the save directory already exists, it validates that the current
TRAKer class has the same hyperparameters (metadata) as the one
@@ -234,14 +235,19 @@ Source code for trak.savers
intermediate values, and metadata
metadata (Iterable): a dictionary containing metadata related to the
TRAKer class
+ load_from_save_dir (bool): If True, the Saver instance will attempt
+ to load existing metadata from save_dir. May lead to I/O issues
+ if multiple Saver instances ran in parallel have this flag set
+ to True. See the SLURM tutorial for more details.
"""
self.metadata = metadata
self.save_dir = Path(save_dir).resolve()
+ self.load_from_save_dir = load_from_save_dir
os.makedirs(self.save_dir, exist_ok=True)
# init TRAKer metadata
self.metadata_file = self.save_dir.joinpath('metadata.json')
- if os.path.exists(self.metadata_file):
+ if os.path.exists(self.metadata_file) and self.load_from_save_dir:
with open(self.metadata_file, 'r') as f:
existsing_metadata = json.load(f)
existing_jl_dim = int(existsing_metadata['JL dimension'])
@@ -254,20 +260,22 @@ Source code for trak.savers
f"In {self.save_dir} there are models using a {existing_matrix_type} JL matrix\
, and this TRAKer instance uses a {self.metadata['JL matrix type']} JL matrix."
- else:
+ elif self.load_from_save_dir:
with open(self.metadata_file, 'w') as f:
json.dump(self.metadata, f)
self.model_ids = {}
- # check if there are existing model ids in the save_dir
- self.model_ids_files = self.save_dir.rglob('id_*.json')
-
- for existing_model_id_file in self.model_ids_files:
- with open(existing_model_id_file, 'r') as f:
- existing_id = json.load(f)
- existing_id = {int(model_id): metadata
- for model_id, metadata in existing_id.items()}
- self.model_ids.update(existing_id)
+ if self.load_from_save_dir:
+ # check if there are existing model ids in the save_dir
+ self.model_ids_files = self.save_dir.rglob('id_*.json')
+
+ for existing_model_id_file in self.model_ids_files:
+ with open(existing_model_id_file, 'r') as f:
+ existing_id = json.load(f)
+ existing_id = {int(model_id): metadata
+ for model_id, metadata in existing_id.items()}
+ self.model_ids.update(existing_id)
+
# wlog set num_targets to those of a random model_id we could raise an
# error here if different model_ids have different num_targets but this
# could be a bit too stringent in some cases
@@ -357,8 +365,10 @@ Source code for trak.savers
into memory.
"""
- def __init__(self, save_dir, metadata, train_set_size, proj_dim) -> None:
- super().__init__(save_dir=save_dir, metadata=metadata)
+ def __init__(self, save_dir, metadata, train_set_size, proj_dim, load_from_save_dir) -> None:
+ super().__init__(save_dir=save_dir,
+ metadata=metadata,
+ load_from_save_dir=load_from_save_dir)
self.train_set_size = train_set_size
self.proj_dim = proj_dim
diff --git a/docs/html/_modules/trak/traker.html b/docs/html/_modules/trak/traker.html
index 85ef668..d20e42c 100644
--- a/docs/html/_modules/trak/traker.html
+++ b/docs/html/_modules/trak/traker.html
@@ -218,6 +218,7 @@ Source code for trak.traker
task: Union[AbstractModelOutput, str],
train_set_size: int,
save_dir: str = './trak_results',
+ load_from_save_dir: bool = True,
device: Union[str, torch.device] = 'cuda',
gradient_computer: AbstractGradientComputer = FunctionalGradientComputer,
projector: Optional[AbstractProjector] = None,
@@ -237,6 +238,10 @@ Source code for trak.traker
train_set_size (int): Size of the train set that TRAK is featurizing
save_dir (str, optional): Directory to save final TRAK scores,
intermediate results, and metadata. Defaults to './trak_results'.
+ load_from_save_dir (bool, optional): If True, the TRAKer instance
+ will attempt to load existing metadata from save_dir. May lead
+ to I/O issues if multiple TRAKer instances ran in parallel have
+ this flag set to True. See the SLURM tutorial for more details.
device (Union[str, torch.device], optional): torch device on which
to do computations. Defaults to 'cuda'.
gradient_computer (AbstractGradientComputer, optional):
@@ -261,6 +266,7 @@ Source code for trak.traker
self.init_projector(projector, proj_dim) # inits self.projector
self.save_dir = Path(save_dir).resolve()
+ self.load_from_save_dir = load_from_save_dir
if type(self.task) is str:
self.modelout_fn = TASK_TO_MODELOUT[(self.task, gradient_computer.is_functional)]
@@ -278,7 +284,8 @@ Source code for trak.traker
self.saver = MmapSaver(save_dir=self.save_dir,
metadata=metadata,
train_set_size=self.train_set_size,
- proj_dim=self.proj_dim)
+ proj_dim=self.proj_dim,
+ load_from_save_dir=self.load_from_save_dir)
[docs] def init_projector(self, projector, proj_dim) -> None:
""" Initialize the projector for a traker class
@@ -511,8 +518,7 @@ Source code for trak.traker
model_ids = self.saver.model_ids
_completed = [False] * len(model_ids)
- _scores = ch.empty(len(model_ids),
- self.train_set_size,
+ _scores = ch.zeros(self.train_set_size,
self.saver.num_targets,
device=self.device)
_avg_out_to_losses = ch.zeros(self.saver.train_set_size, 1, device=self.device)
@@ -530,7 +536,7 @@ Source code for trak.traker
g = ch.as_tensor(self.saver.current_features, device=self.device)
g_target = ch.as_tensor(self.saver.current_target_grads, device=self.device)
- _scores[j] = self.score_computer.get_scores(g, g_target)
+ _scores += self.score_computer.get_scores(g, g_target)
_avg_out_to_losses += ch.as_tensor(self.saver.current_out_to_loss, device=self.device)
_completed[j] = True
@@ -539,10 +545,8 @@ Source code for trak.traker
else:
self.saver.clear_target_grad_count(model_id)
- _scores = _scores[_completed].mean(dim=0)
-
_num_models_used = float(sum(_completed))
- self.scores = _scores * (_avg_out_to_losses / _num_models_used)
+ self.scores = (_scores / _num_models_used) * (_avg_out_to_losses / _num_models_used)
self.saver.save_scores(self.scores.cpu().numpy(), exp_name)
return self.scores
diff --git a/docs/html/_static/documentation_options.js b/docs/html/_static/documentation_options.js
index 87d8188..1f7b0d1 100644
--- a/docs/html/_static/documentation_options.js
+++ b/docs/html/_static/documentation_options.js
@@ -1,6 +1,6 @@
var DOCUMENTATION_OPTIONS = {
URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'),
- VERSION: '0.1.0',
+ VERSION: '0.1.1',
LANGUAGE: 'None',
COLLAPSE_INDEX: false,
BUILDER: 'html',
diff --git a/docs/html/bert.html b/docs/html/bert.html
index 913d52c..3d47df4 100644
--- a/docs/html/bert.html
+++ b/docs/html/bert.html
@@ -6,7 +6,7 @@
- Add a task to TRAKer (subclassing ModelOutput) — BERT-base - TRAK 0.1.0 documentation
+ Add a task to TRAKer (subclassing ModelOutput) — BERT-base - TRAK 0.1.1 documentation
@@ -123,7 +123,7 @@
@@ -146,7 +146,7 @@
@@ -146,7 +146,7 @@
@@ -144,7 +144,7 @@