Skip to content

Commit

Permalink
Fix issue in building factors from factors dict
Browse files Browse the repository at this point in the history
  • Loading branch information
schroedk committed Jun 5, 2024
1 parent 8498524 commit 5b7bbb9
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
17 changes: 13 additions & 4 deletions src/pydvl/influence/base_influence_function_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from abc import ABC, abstractmethod
from collections import OrderedDict
from functools import wraps
from typing import Generic, Optional, Type, cast
from typing import Generic, Iterable, Optional, Type, cast

from ..utils.progress import log_duration
from .types import BatchType, BlockMapperType, DataLoaderType, InfluenceMode, TensorType
Expand Down Expand Up @@ -433,10 +433,19 @@ def influences_from_factors_by_block(
)

def _influence_factors(self, x: TensorType, y: TensorType) -> TensorType:
transformed_grads = self.block_mapper.transformed_grads(
self._create_batch(x, y)
transformed_grads = self.influence_factors_by_block(x, y)
transformed_grads = (
self._flatten_trailing_dim(t) for t in transformed_grads.values()
)
return cast(TensorType, sum(transformed_grads.values()))
return cast(TensorType, self._concat(transformed_grads, dim=-1))

@abstractmethod
def _concat(self, tensors: Iterable[TensorType], dim: int):
"""Implement this to concat tensors at a specified dimension"""

@abstractmethod
def _flatten_trailing_dim(self, tensor: TensorType):
"""Implement this to flatten all but the first dimension"""

def _influences(
self,
Expand Down
9 changes: 8 additions & 1 deletion src/pydvl/influence/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from collections import OrderedDict
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, TypeVar, Union, cast
from typing import Dict, Iterable, List, Optional, Tuple, TypeVar, Union, cast

import torch
from torch.func import functional_call
Expand All @@ -17,6 +17,7 @@
GradientProvider,
Operator,
OperatorGradientComposition,
TensorType,
)
from .util import (
BlockMode,
Expand Down Expand Up @@ -635,6 +636,12 @@ def __init__(

super().__init__(model)

def _concat(self, tensors: Iterable[torch.Tensor], dim: int):
return torch.cat(list(tensors), dim=dim)

def _flatten_trailing_dim(self, tensor: torch.Tensor):
return tensor.reshape((tensor.shape[0], -1))

@property
def block_names(self) -> List[str]:
return list(self.parameter_dict.keys())
Expand Down

0 comments on commit 5b7bbb9

Please sign in to comment.