diff --git a/docs/conf.py b/docs/conf.py index a372a8f..e4a549f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -22,7 +22,7 @@ author = 'Kristian Georgiev' # The full version, including alpha/beta/rc tags -release = '0.1.0' +release = '0.1.1' # -- General configuration --------------------------------------------------- diff --git a/docs/html/.buildinfo b/docs/html/.buildinfo index 7f9c650..a4d6b90 100644 --- a/docs/html/.buildinfo +++ b/docs/html/.buildinfo @@ -1,4 +1,4 @@ # Sphinx build info version 1 # This file hashes the configuration used when building these files. When it is not found, a full rebuild will be done. -config: def77ad9866373dc2cf63179c7913729 +config: 1a000b88aa8f9cf78e81853015c446f0 tags: 645f666f9bcd5a90fca523b33c5a78b7 diff --git a/docs/html/.doctrees/environment.pickle b/docs/html/.doctrees/environment.pickle index 4321e10..698c812 100644 Binary files a/docs/html/.doctrees/environment.pickle and b/docs/html/.doctrees/environment.pickle differ diff --git a/docs/html/.doctrees/trak.doctree b/docs/html/.doctrees/trak.doctree index 316a952..ef578a5 100644 Binary files a/docs/html/.doctrees/trak.doctree and b/docs/html/.doctrees/trak.doctree differ diff --git a/docs/html/_modules/index.html b/docs/html/_modules/index.html index 1a62708..6acb32f 100644 --- a/docs/html/_modules/index.html +++ b/docs/html/_modules/index.html @@ -5,7 +5,7 @@ - Overview: module code - TRAK 0.1.0 documentation + Overview: module code - TRAK 0.1.1 documentation @@ -122,7 +122,7 @@
-
TRAK 0.1.0 documentation
+
TRAK 0.1.1 documentation
@@ -145,7 +145,7 @@ +
[docs]class TextClassificationModelOutput(AbstractModelOutput): + """ + Margin for text classification models. This assumes that the model takes in + input_ids, token_type_ids, and attention_mask. + .. math:: + \text{logit}[\text{correct}] - \log\left(\sum_{i \neq \text{correct}} + \exp(\text{logit}[i])\right) + Version of margin proposed in 'Understanding Influence Functions + and Datamodels via Harmonic Analysis' + """ + + def __init__(self, temperature=1.) -> None: + super().__init__() + self.softmax = ch.nn.Softmax(-1) + self.loss_temperature = temperature + +
[docs] @staticmethod + def get_output(func_model, + weights: Iterable[Tensor], + buffers: Iterable[Tensor], + input_id: Tensor, + token_type_id: Tensor, + attention_mask: Tensor, + label: Tensor, + ) -> Tensor: + logits = func_model(weights, buffers, input_id.unsqueeze(0), + token_type_id.unsqueeze(0), + attention_mask.unsqueeze(0)) + bindex = ch.arange(logits.shape[0]).to(logits.device, non_blocking=False) + logits_correct = logits[bindex, label.unsqueeze(0)] + + cloned_logits = logits.clone() + cloned_logits[bindex, label.unsqueeze(0)] = ch.tensor(-ch.inf).to(logits.device) + + margins = logits_correct - cloned_logits.logsumexp(dim=-1) + return margins.sum()
+ +
[docs] def forward(self, model: Module, batch: Iterable[Tensor]) -> Tensor: + input_ids, token_type_ids, attention_mask, _ = batch + return model(input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask)
+ +
[docs] def get_out_to_loss_grad(self, func_model, weights, buffers, batch: Iterable[Tensor]) -> Tensor: + input_ids, token_type_ids, attention_mask, labels = batch + logits = func_model(weights, buffers, input_ids, token_type_ids, attention_mask) + ps = self.softmax(logits / self.loss_temperature)[ch.arange(logits.size(0)), labels] + return (1 - ps).clone().detach().unsqueeze(-1)
+ + TASK_TO_MODELOUT = { ('image_classification', True): ImageClassificationModelOutput, ('image_classification', False): IterImageClassificationModelOutput, + ('text_classification', True): TextClassificationModelOutput, ('clip', True): CLIPModelOutput, }
diff --git a/docs/html/_modules/trak/savers.html b/docs/html/_modules/trak/savers.html index 6f0fb9c..f450540 100644 --- a/docs/html/_modules/trak/savers.html +++ b/docs/html/_modules/trak/savers.html @@ -220,7 +220,8 @@

Source code for trak.savers

     @abstractmethod
     def __init__(self,
                  save_dir: Union[Path, str],
-                 metadata: Iterable) -> None:
+                 metadata: Iterable,
+                 load_from_save_dir: bool) -> None:
         """ Creates the save directory if it doesn't already exist.
         If the save directory already exists, it validates that the current
         TRAKer class has the same hyperparameters (metadata) as the one
@@ -234,14 +235,19 @@ 

Source code for trak.savers

                 intermediate values, and metadata
             metadata (Iterable): a dictionary containing metadata related to the
                 TRAKer class
+            load_from_save_dir (bool): If True, the Saver instance will attempt
+                to load existing metadata from save_dir. May lead to I/O issues
+                if multiple Saver instances ran in parallel have this flag set
+                to True. See the SLURM tutorial for more details.
         """
         self.metadata = metadata
         self.save_dir = Path(save_dir).resolve()
+        self.load_from_save_dir = load_from_save_dir
         os.makedirs(self.save_dir, exist_ok=True)
 
         # init TRAKer metadata
         self.metadata_file = self.save_dir.joinpath('metadata.json')
-        if os.path.exists(self.metadata_file):
+        if os.path.exists(self.metadata_file) and self.load_from_save_dir:
             with open(self.metadata_file, 'r') as f:
                 existsing_metadata = json.load(f)
             existing_jl_dim = int(existsing_metadata['JL dimension'])
@@ -254,20 +260,22 @@ 

Source code for trak.savers

                    f"In {self.save_dir} there are models using a {existing_matrix_type} JL matrix\
                    , and this TRAKer instance uses a {self.metadata['JL matrix type']} JL matrix."
 
-        else:
+        elif self.load_from_save_dir:
             with open(self.metadata_file, 'w') as f:
                 json.dump(self.metadata, f)
 
         self.model_ids = {}
-        # check if there are existing model ids in the save_dir
-        self.model_ids_files = self.save_dir.rglob('id_*.json')
-
-        for existing_model_id_file in self.model_ids_files:
-            with open(existing_model_id_file, 'r') as f:
-                existing_id = json.load(f)
-                existing_id = {int(model_id): metadata
-                               for model_id, metadata in existing_id.items()}
-            self.model_ids.update(existing_id)
+        if self.load_from_save_dir:
+            # check if there are existing model ids in the save_dir
+            self.model_ids_files = self.save_dir.rglob('id_*.json')
+
+            for existing_model_id_file in self.model_ids_files:
+                with open(existing_model_id_file, 'r') as f:
+                    existing_id = json.load(f)
+                    existing_id = {int(model_id): metadata
+                                   for model_id, metadata in existing_id.items()}
+                self.model_ids.update(existing_id)
+
         # wlog set num_targets to those of a random model_id we could raise an
         # error here if different model_ids have different num_targets but this
         # could be a bit too stringent in some cases
@@ -357,8 +365,10 @@ 

Source code for trak.savers

     into memory.
 
     """
-    def __init__(self, save_dir, metadata, train_set_size, proj_dim) -> None:
-        super().__init__(save_dir=save_dir, metadata=metadata)
+    def __init__(self, save_dir, metadata, train_set_size, proj_dim, load_from_save_dir) -> None:
+        super().__init__(save_dir=save_dir,
+                         metadata=metadata,
+                         load_from_save_dir=load_from_save_dir)
         self.train_set_size = train_set_size
         self.proj_dim = proj_dim
 
diff --git a/docs/html/_modules/trak/traker.html b/docs/html/_modules/trak/traker.html
index 85ef668..d20e42c 100644
--- a/docs/html/_modules/trak/traker.html
+++ b/docs/html/_modules/trak/traker.html
@@ -218,6 +218,7 @@ 

Source code for trak.traker

                  task: Union[AbstractModelOutput, str],
                  train_set_size: int,
                  save_dir: str = './trak_results',
+                 load_from_save_dir: bool = True,
                  device: Union[str, torch.device] = 'cuda',
                  gradient_computer: AbstractGradientComputer = FunctionalGradientComputer,
                  projector: Optional[AbstractProjector] = None,
@@ -237,6 +238,10 @@ 

Source code for trak.traker

             train_set_size (int): Size of the train set that TRAK is featurizing
             save_dir (str, optional): Directory to save final TRAK scores,
                 intermediate results, and metadata. Defaults to './trak_results'.
+            load_from_save_dir (bool, optional): If True, the TRAKer instance
+                will attempt to load existing metadata from save_dir. May lead
+                to I/O issues if multiple TRAKer instances ran in parallel have
+                this flag set to True. See the SLURM tutorial for more details.
             device (Union[str, torch.device], optional): torch device on which
                 to do computations. Defaults to 'cuda'.
             gradient_computer (AbstractGradientComputer, optional):
@@ -261,6 +266,7 @@ 

Source code for trak.traker

         self.init_projector(projector, proj_dim)  # inits self.projector
 
         self.save_dir = Path(save_dir).resolve()
+        self.load_from_save_dir = load_from_save_dir
 
         if type(self.task) is str:
             self.modelout_fn = TASK_TO_MODELOUT[(self.task, gradient_computer.is_functional)]
@@ -278,7 +284,8 @@ 

Source code for trak.traker

         self.saver = MmapSaver(save_dir=self.save_dir,
                                metadata=metadata,
                                train_set_size=self.train_set_size,
-                               proj_dim=self.proj_dim)
+                               proj_dim=self.proj_dim,
+                               load_from_save_dir=self.load_from_save_dir)
 
 
[docs] def init_projector(self, projector, proj_dim) -> None: """ Initialize the projector for a traker class @@ -511,8 +518,7 @@

Source code for trak.traker

             model_ids = self.saver.model_ids
 
         _completed = [False] * len(model_ids)
-        _scores = ch.empty(len(model_ids),
-                           self.train_set_size,
+        _scores = ch.zeros(self.train_set_size,
                            self.saver.num_targets,
                            device=self.device)
         _avg_out_to_losses = ch.zeros(self.saver.train_set_size, 1, device=self.device)
@@ -530,7 +536,7 @@ 

Source code for trak.traker

             g = ch.as_tensor(self.saver.current_features, device=self.device)
             g_target = ch.as_tensor(self.saver.current_target_grads, device=self.device)
 
-            _scores[j] = self.score_computer.get_scores(g, g_target)
+            _scores += self.score_computer.get_scores(g, g_target)
             _avg_out_to_losses += ch.as_tensor(self.saver.current_out_to_loss, device=self.device)
             _completed[j] = True
 
@@ -539,10 +545,8 @@ 

Source code for trak.traker

             else:
                 self.saver.clear_target_grad_count(model_id)
 
-        _scores = _scores[_completed].mean(dim=0)
-
         _num_models_used = float(sum(_completed))
-        self.scores = _scores * (_avg_out_to_losses / _num_models_used)
+        self.scores = (_scores / _num_models_used) * (_avg_out_to_losses / _num_models_used)
         self.saver.save_scores(self.scores.cpu().numpy(), exp_name)
 
         return self.scores
diff --git a/docs/html/_static/documentation_options.js b/docs/html/_static/documentation_options.js index 87d8188..1f7b0d1 100644 --- a/docs/html/_static/documentation_options.js +++ b/docs/html/_static/documentation_options.js @@ -1,6 +1,6 @@ var DOCUMENTATION_OPTIONS = { URL_ROOT: document.getElementById("documentation_options").getAttribute('data-url_root'), - VERSION: '0.1.0', + VERSION: '0.1.1', LANGUAGE: 'None', COLLAPSE_INDEX: false, BUILDER: 'html', diff --git a/docs/html/bert.html b/docs/html/bert.html index 913d52c..3d47df4 100644 --- a/docs/html/bert.html +++ b/docs/html/bert.html @@ -6,7 +6,7 @@ - Add a task to TRAKer (subclassing ModelOutput) — BERT-base - TRAK 0.1.0 documentation + Add a task to TRAKer (subclassing ModelOutput) — BERT-base - TRAK 0.1.1 documentation @@ -123,7 +123,7 @@
@@ -146,7 +146,7 @@
@@ -146,7 +146,7 @@
@@ -144,7 +144,7 @@