Skip to content

Commit

Permalink
Merge pull request #60 from mmschlk/development
Browse files Browse the repository at this point in the history
Development
  • Loading branch information
mmschlk authored Jan 3, 2023
2 parents 287d09a + ea1c178 commit d43d7b6
Show file tree
Hide file tree
Showing 23 changed files with 601 additions and 412 deletions.
4 changes: 2 additions & 2 deletions examples/agrawal_accuracy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

if __name__ == "__main__":

# Get Data ---------------------------------------------------------------------------------------------------------
# Get Data -------------------------------------------------------------------------------------
stream = Agrawal(classification_function=1, seed=42)
feature_names = list([x_0 for x_0, _ in stream.take(1)][0].keys())

Expand All @@ -24,7 +24,7 @@
model = AdaptiveRandomForestClassifier(n_models=15, max_depth=10, leaf_prediction='mc')
model_function = RiverWrapper(model.predict_one)

# Get imputer and explainers ---------------------------------------------------------------------------------------
# Get imputer and explainers -------------------------------------------------------------------
storage = GeometricReservoirStorage(
size=200,
store_targets=False
Expand Down
4 changes: 2 additions & 2 deletions examples/agrawal_cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

if __name__ == "__main__":

# Get Data ---------------------------------------------------------------------------------------------------------
# Get Data -------------------------------------------------------------------------------------
stream = Agrawal(classification_function=1, seed=42)
feature_names = list([x_0 for x_0, _ in stream.take(1)][0].keys())

Expand All @@ -24,7 +24,7 @@
model = AdaptiveRandomForestClassifier(n_models=15, max_depth=10, leaf_prediction='mc')
model_function = RiverWrapper(model.predict_proba_one)

# Get imputer and explainers ---------------------------------------------------------------------------------------
# Get imputer and explainers -------------------------------------------------------------------
storage = GeometricReservoirStorage(
size=200,
store_targets=False
Expand Down
2 changes: 1 addition & 1 deletion ixai/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.1"
__version__ = "0.1.2"
35 changes: 18 additions & 17 deletions ixai/explainer/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
"""
This module gathers base Explanation Methods
"""

# Authors: Maximilian Muschalik <[email protected]>
# Fabian Fumagalli <[email protected]>
# Rohit Jagtani

import copy
import math
import abc
import typing
from typing import Union, Sequence, Dict, List, Callable, Any, Optional

from river.metrics.base import Metric

from ixai.imputer import BaseImputer, MarginalImputer
from ixai.storage import GeometricReservoirStorage, UniformReservoirStorage
Expand All @@ -35,8 +32,8 @@ class BaseIncrementalExplainer(metaclass=abc.ABCMeta):
@abc.abstractmethod
def __init__(
self,
model_function: typing.Callable,
feature_names: list
model_function: Callable[[Any], Any],
feature_names: Sequence[Union[str, int, float]]
):
"""
Args:
Expand All @@ -62,13 +59,13 @@ class BaseIncrementalFeatureImportance(BaseIncrementalExplainer):
@abc.abstractmethod
def __init__(
self,
model_function,
loss_function,
feature_names: list,
storage: typing.Optional[BaseStorage] = None,
imputer: typing.Optional[BaseImputer] = None,
model_function: Callable[[Any], Any],
loss_function: Union[Metric, Callable[[Any, Dict], float]],
feature_names: Sequence[Union[str, int, float]],
storage: Optional[BaseStorage] = None,
imputer: Optional[BaseImputer] = None,
dynamic_setting: bool = False,
smoothing_alpha: typing.Optional[float] = None
smoothing_alpha: Optional[float] = None
):
super().__init__(model_function, feature_names)
self._loss_function = validate_loss_function(loss_function)
Expand Down Expand Up @@ -157,12 +154,16 @@ def _normalize_importance_values(importance_values: dict, mode: str = 'sum') ->
except ZeroDivisionError:
return {feature: 0.0 for feature, importance_value in importance_values.items()}

def update_storage(self, x_i: dict, y_i: typing.Optional[typing.Any] = None):
"""Manually updates the data storage with the given observation."""
def update_storage(self, x_i: dict, y_i: Optional[Any] = None):
"""Manually updates the data storage with the given observation.
Args:
x_i (dict): The input features of the current observation.
y_i (Any, optional): Target label of the current observation. Defaults to `None`
"""
self._storage.update(x=x_i, y=y_i)


def _get_mean_model_output(model_outputs: typing.List[dict]) -> dict:
def _get_mean_model_output(model_outputs: List[dict]) -> dict:
"""Calculates the mean values of a list of dict model outputs.
Args:
Expand Down
21 changes: 8 additions & 13 deletions ixai/explainer/pfi.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
"""
This module gathers PFI Explanation Methods
"""
from typing import Optional, Union, Callable, Any

# Authors: Maximilian Muschalik <[email protected]>
# Fabian Fumagalli <[email protected]>
# Rohit Jagtani

from typing import Optional, Union, Callable, List, Any
from .base import BaseIncrementalFeatureImportance
import numpy as np
from river.metrics.base import Metric
from ..imputer import BaseImputer
from ..storage.base import BaseStorage

from ixai.imputer import BaseImputer
from ixai.storage.base import BaseStorage
from .base import BaseIncrementalFeatureImportance


__all__ = [
"IncrementalPFI",
]
__all__ = ["IncrementalPFI"]


class IncrementalPFI(BaseIncrementalFeatureImportance):
Expand All @@ -35,7 +30,7 @@ def __init__(
feature_names: list,
storage: Optional[BaseStorage] = None,
imputer: Optional[BaseImputer] = None,
n_inner_samples: int = 5,
n_inner_samples: int = 1,
smoothing_alpha: float = 0.001,
dynamic_setting: bool = True
):
Expand Down Expand Up @@ -66,7 +61,7 @@ def __init__(
adaptive explanation) or a static modelling setting `False` (all observations contribute equally to the
final importance) is assumed. Defaults to `True`.
"""
super(IncrementalPFI, self).__init__(
super().__init__(
model_function=model_function,
loss_function=loss_function,
feature_names=feature_names,
Expand Down
Loading

0 comments on commit d43d7b6

Please sign in to comment.