-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #60 from mmschlk/development
Development
- Loading branch information
Showing
23 changed files
with
601 additions
and
412 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
__version__ = "0.1.1" | ||
__version__ = "0.1.2" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,12 @@ | ||
""" | ||
This module gathers base Explanation Methods | ||
""" | ||
|
||
# Authors: Maximilian Muschalik <[email protected]> | ||
# Fabian Fumagalli <[email protected]> | ||
# Rohit Jagtani | ||
|
||
import copy | ||
import math | ||
import abc | ||
import typing | ||
from typing import Union, Sequence, Dict, List, Callable, Any, Optional | ||
|
||
from river.metrics.base import Metric | ||
|
||
from ixai.imputer import BaseImputer, MarginalImputer | ||
from ixai.storage import GeometricReservoirStorage, UniformReservoirStorage | ||
|
@@ -35,8 +32,8 @@ class BaseIncrementalExplainer(metaclass=abc.ABCMeta): | |
@abc.abstractmethod | ||
def __init__( | ||
self, | ||
model_function: typing.Callable, | ||
feature_names: list | ||
model_function: Callable[[Any], Any], | ||
feature_names: Sequence[Union[str, int, float]] | ||
): | ||
""" | ||
Args: | ||
|
@@ -62,13 +59,13 @@ class BaseIncrementalFeatureImportance(BaseIncrementalExplainer): | |
@abc.abstractmethod | ||
def __init__( | ||
self, | ||
model_function, | ||
loss_function, | ||
feature_names: list, | ||
storage: typing.Optional[BaseStorage] = None, | ||
imputer: typing.Optional[BaseImputer] = None, | ||
model_function: Callable[[Any], Any], | ||
loss_function: Union[Metric, Callable[[Any, Dict], float]], | ||
feature_names: Sequence[Union[str, int, float]], | ||
storage: Optional[BaseStorage] = None, | ||
imputer: Optional[BaseImputer] = None, | ||
dynamic_setting: bool = False, | ||
smoothing_alpha: typing.Optional[float] = None | ||
smoothing_alpha: Optional[float] = None | ||
): | ||
super().__init__(model_function, feature_names) | ||
self._loss_function = validate_loss_function(loss_function) | ||
|
@@ -157,12 +154,16 @@ def _normalize_importance_values(importance_values: dict, mode: str = 'sum') -> | |
except ZeroDivisionError: | ||
return {feature: 0.0 for feature, importance_value in importance_values.items()} | ||
|
||
def update_storage(self, x_i: dict, y_i: typing.Optional[typing.Any] = None): | ||
"""Manually updates the data storage with the given observation.""" | ||
def update_storage(self, x_i: dict, y_i: Optional[Any] = None): | ||
"""Manually updates the data storage with the given observation. | ||
Args: | ||
x_i (dict): The input features of the current observation. | ||
y_i (Any, optional): Target label of the current observation. Defaults to `None` | ||
""" | ||
self._storage.update(x=x_i, y=y_i) | ||
|
||
|
||
def _get_mean_model_output(model_outputs: typing.List[dict]) -> dict: | ||
def _get_mean_model_output(model_outputs: List[dict]) -> dict: | ||
"""Calculates the mean values of a list of dict model outputs. | ||
Args: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,17 @@ | ||
""" | ||
This module gathers PFI Explanation Methods | ||
""" | ||
from typing import Optional, Union, Callable, Any | ||
|
||
# Authors: Maximilian Muschalik <[email protected]> | ||
# Fabian Fumagalli <[email protected]> | ||
# Rohit Jagtani | ||
|
||
from typing import Optional, Union, Callable, List, Any | ||
from .base import BaseIncrementalFeatureImportance | ||
import numpy as np | ||
from river.metrics.base import Metric | ||
from ..imputer import BaseImputer | ||
from ..storage.base import BaseStorage | ||
|
||
from ixai.imputer import BaseImputer | ||
from ixai.storage.base import BaseStorage | ||
from .base import BaseIncrementalFeatureImportance | ||
|
||
|
||
__all__ = [ | ||
"IncrementalPFI", | ||
] | ||
__all__ = ["IncrementalPFI"] | ||
|
||
|
||
class IncrementalPFI(BaseIncrementalFeatureImportance): | ||
|
@@ -35,7 +30,7 @@ def __init__( | |
feature_names: list, | ||
storage: Optional[BaseStorage] = None, | ||
imputer: Optional[BaseImputer] = None, | ||
n_inner_samples: int = 5, | ||
n_inner_samples: int = 1, | ||
smoothing_alpha: float = 0.001, | ||
dynamic_setting: bool = True | ||
): | ||
|
@@ -66,7 +61,7 @@ def __init__( | |
adaptive explanation) or a static modelling setting `False` (all observations contribute equally to the | ||
final importance) is assumed. Defaults to `True`. | ||
""" | ||
super(IncrementalPFI, self).__init__( | ||
super().__init__( | ||
model_function=model_function, | ||
loss_function=loss_function, | ||
feature_names=feature_names, | ||
|
Oops, something went wrong.