Skip to content

Commit

Permalink
Fix type-checking issues
Browse files Browse the repository at this point in the history
  • Loading branch information
schroedk committed May 21, 2024
1 parent 07f80bd commit 3b7289c
Showing 1 changed file with 6 additions and 13 deletions.
19 changes: 6 additions & 13 deletions src/pydvl/influence/torch/influence_function_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
TorchPerSampleAutoGrad,
TorchPerSampleGradientProvider,
)
from .operator.solve import InverseHarmonicMeanOperator, LowRankOperator
from .operator.solve import InverseHarmonicMeanOperator
from .pre_conditioner import PreConditioner
from .util import (
BlockMode,
Expand Down Expand Up @@ -1851,7 +1851,7 @@ def __init__(
block_structure: Union[
BlockMode, OrderedDict[str, OrderedDict[str, torch.nn.Parameter]]
] = BlockMode.FULL,
regularization: Optional[Union[float, Dict[str, float]]] = None,
regularization: Optional[Union[float, Dict[str, Optional[float]]]] = None,
):
if isinstance(block_structure, BlockMode):
self.parameter_dict = ModelParameterDictBuilder(model).build(
Expand All @@ -1864,21 +1864,13 @@ def __init__(

super().__init__(model)

@property
def regularization(self) -> Dict[str, float]:
return self._regularization_dict

@regularization.setter
def regularization(self, value: Union[float, Dict[str, float]]):
self._regularization_dict = self._build_regularization_dict(value)

@property
def block_names(self) -> List[str]:
return list(self.parameter_dict.keys())

@abstractmethod
def with_regularization(
self, regularization: Union[float, Dict[str, float]]
self, regularization: Union[float, Dict[str, Optional[float]]]
) -> TorchComposableInfluence:
pass

Expand Down Expand Up @@ -1948,7 +1940,7 @@ def __init__(
self,
model: torch.nn.Module,
loss: LossType,
regularization: Union[float, Dict[str, float]],
regularization: Union[float, Dict[str, Optional[float]]],
block_structure: Union[
BlockMode, OrderedDict[str, OrderedDict[str, torch.Tensor]]
] = BlockMode.FULL,
Expand All @@ -1974,6 +1966,7 @@ def _create_block(
data: DataLoader,
regularization: Optional[float],
) -> TorchOperatorGradientComposition:
assert regularization is not None
op = InverseHarmonicMeanOperator(
self.model,
self.loss,
Expand All @@ -1988,7 +1981,7 @@ def _create_block(
return TorchOperatorGradientComposition(op, gp)

def with_regularization(
self, regularization: Union[float, Dict[str, float]]
self, regularization: Union[float, Dict[str, Optional[float]]]
) -> TorchComposableInfluence:
self._regularization_dict = self._build_regularization_dict(regularization)
for k, reg in self._regularization_dict.items():
Expand Down

0 comments on commit 3b7289c

Please sign in to comment.