Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve testing #285

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
6 changes: 3 additions & 3 deletions shapiq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@
# dataset functions
from .datasets import load_adult_census, load_bike_sharing, load_california_housing

# exact computer classes
from .exact import ExactComputer

# explainer classes
from .explainer import Explainer, TabularExplainer, TreeExplainer

# exact computer classes
from .game_theory.exact import ExactComputer

# game classes
# imputer classes
from .games import BaselineImputer, ConditionalImputer, Game, MarginalImputer
Expand Down
6 changes: 3 additions & 3 deletions shapiq/approximator/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np

from shapiq.approximator.sampling import CoalitionSampler
from shapiq.indices import (
from shapiq.game_theory.indices import (
AVAILABLE_INDICES_FOR_APPROXIMATION,
get_computation_index,
is_empty_value_the_baseline,
Expand Down Expand Up @@ -318,7 +318,7 @@ def aggregate_interaction_values(
Returns:
The aggregated interaction values.
"""
from ..aggregation import aggregate_interaction_values
from shapiq.game_theory.aggregation import aggregate_interaction_values

if player_set is not None:
raise NotImplementedError(
Expand All @@ -339,6 +339,6 @@ def aggregate_to_one_dimension(
Returns:
tuple[np.ndarray, np.ndarray]: The positive and negative aggregated values.
"""
from ..aggregation import aggregate_to_one_dimension
from shapiq.game_theory.aggregation import aggregate_to_one_dimension

return aggregate_to_one_dimension(interaction_values)
2 changes: 1 addition & 1 deletion shapiq/approximator/montecarlo/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from scipy.special import binom, factorial

from shapiq.approximator._base import Approximator
from shapiq.indices import AVAILABLE_INDICES_MONTE_CARLO
from shapiq.game_theory.indices import AVAILABLE_INDICES_MONTE_CARLO
from shapiq.interaction_values import InteractionValues
from shapiq.utils.sets import powerset

Expand Down
2 changes: 1 addition & 1 deletion shapiq/approximator/regression/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from scipy.special import bernoulli, binom

from shapiq.approximator._base import Approximator
from shapiq.indices import AVAILABLE_INDICES_REGRESSION
from shapiq.game_theory.indices import AVAILABLE_INDICES_REGRESSION
from shapiq.interaction_values import InteractionValues
from shapiq.utils.sets import powerset

Expand Down
2 changes: 2 additions & 0 deletions shapiq/explainer/tree/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def __init__(
self.baseline_value = self._compute_baseline_value()

def explain(self, x: np.ndarray) -> InteractionValues:
if len(x.shape) != 1:
raise TypeError("explain expects a single instance, not a batch.")
# run treeshapiq for all trees
interaction_values: list[InteractionValues] = []
for explainer in self._treeshapiq_explainers:
Expand Down
9 changes: 5 additions & 4 deletions shapiq/explainer/tree/treeshapiq.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import numpy as np
import scipy as sp

from ...aggregation import aggregate_interaction_values
from ...indices import get_computation_index
from ...interaction_values import InteractionValues
from ...utils.sets import generate_interaction_lookup, powerset
from shapiq.game_theory.aggregation import aggregate_interaction_values
from shapiq.game_theory.indices import get_computation_index
from shapiq.interaction_values import InteractionValues
from shapiq.utils.sets import generate_interaction_lookup, powerset

from .base import EdgeTree, TreeModel
from .conversion.edges import create_edge_tree
from .validation import validate_tree_model
Expand Down
1 change: 1 addition & 0 deletions shapiq/explainer/tree/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"lightgbm.sklearn.LGBMRegressor",
"lightgbm.sklearn.LGBMClassifier",
"lightgbm.basic.Booster",
# xboost?
}


Expand Down
30 changes: 30 additions & 0 deletions shapiq/game_theory/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""conversions of interaction values to different indices
"""

from .aggregation import aggregate_interaction_values
from .core import egalitarian_least_core
from .exact import ExactComputer, get_bernoulli_weights
from .indices import (
ALL_AVAILABLE_CONCEPTS,
get_computation_index,
index_generalizes_bv,
index_generalizes_sv,
is_empty_value_the_baseline,
is_index_aggregated,
)
from .moebius_converter import MoebiusConverter

__all__ = [
"ExactComputer",
"aggregate_interaction_values",
"get_bernoulli_weights",
"ALL_AVAILABLE_CONCEPTS",
"index_generalizes_sv",
"index_generalizes_bv",
"get_computation_index",
"is_index_aggregated",
"is_empty_value_the_baseline",
"egalitarian_least_core",
"MoebiusConverter",
]
# todo complete list
4 changes: 2 additions & 2 deletions shapiq/aggregation.py → shapiq/game_theory/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import numpy as np
import scipy as sp

from .interaction_values import InteractionValues
from .utils.sets import powerset
from shapiq.interaction_values import InteractionValues
from shapiq.utils.sets import powerset


def _change_index(index: str) -> str:
Expand Down
File renamed without changes.
59 changes: 45 additions & 14 deletions shapiq/exact.py → shapiq/game_theory/exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,22 @@
like interaction indices or generalized values."""

import copy
from typing import Callable, Union
from typing import Callable, Optional, Union

import numpy as np
from scipy.special import bernoulli, binom

from shapiq.indices import ALL_AVAILABLE_CONCEPTS
from shapiq.interaction_values import InteractionValues
from shapiq.utils import powerset

from .indices import ALL_AVAILABLE_CONCEPTS

__all__ = ["ExactComputer", "get_bernoulli_weights"]


class ExactComputer:
"""Computes exact Shapley Interactions for specified game by evaluating the powerset of all
:math:`2^n` coalitions.
:math:`2^n` coalitions.

The ExactComputer class computes a variety of game theoretic concepts like interaction indices
or generalized values. Currently, the following indices and values are supported:
Expand All @@ -27,6 +28,7 @@ class ExactComputer:
n_players: The number of players in the game.
game_fun: A callable game that takes a binary matrix of shape ``(n_coalitions, n_players)``
and returns a numpy array of shape ``(n_coalitions,)`` containing the game values.
evaluate_game: whether to compute the values at init (if True) or first call (False)

Attributes:
n: The number of players.
Expand All @@ -40,6 +42,7 @@ def __init__(
self,
n_players: int,
game_fun: Callable[[np.ndarray], np.ndarray[float]],
evaluate_game: bool = False,
) -> None:
# set parameter attributes
self.n: int = n_players
Expand All @@ -52,12 +55,15 @@ def __init__(
self._n_interactions: np.ndarray = self.get_n_interactions(self.n)
self._computed: dict[tuple[str, int], InteractionValues] = {} # will store all computations
self._elc_stability_subsidy: float = -1
self._game_is_computed: bool = False

self._baseline_value: Optional[float] = None
self._game_values: Optional[np.ndarray] = None
self._coalition_lookup: Optional[dict[tuple[int], int]] = None

# evaluate the game on the powerset
computed_game = self.compute_game_values(game_fun)
self.baseline_value: float = computed_game[0]
self.game_values: np.ndarray[float] = computed_game[1]
self.coalition_lookup: dict[tuple[int], int] = computed_game[2]
if evaluate_game:
# evaluate the game on the powerset
self._evaluate_game()

# setup callable mapping from index to computation
self._index_mapping: dict[str, Callable[[str, int], InteractionValues]] = {
Expand Down Expand Up @@ -124,9 +130,32 @@ def __call__(self, index: str, order: int = None) -> InteractionValues:
else:
raise ValueError(f"Index {index} not supported.")

def compute_game_values(
self, game_fun: Callable[[np.ndarray], np.ndarray[float]]
) -> tuple[float, np.ndarray[float], dict[tuple[int], int]]:
@property
def baseline_value(self) -> float:
if not self._game_is_computed:
self._evaluate_game()
return self._baseline_value

@property
def coalition_lookup(self) -> dict[tuple[int], int]:
if not self._game_is_computed:
self._evaluate_game()
return self._coalition_lookup

@property
def game_values(self) -> np.ndarray[float]:
if not self._game_is_computed:
self._evaluate_game()
return self._game_values

def _evaluate_game(self):
computed_game = self.compute_game_values()
self._baseline_value = computed_game[0]
self._game_values = computed_game[1]
self._coalition_lookup = computed_game[2]
self._game_is_computed = True

def compute_game_values(self) -> tuple[float, np.ndarray[float], dict[tuple[int], int]]:
"""Evaluates the game on the powerset of all coalitions.

Args:
Expand All @@ -141,8 +170,9 @@ def compute_game_values(
for i, T in enumerate(powerset(self._grand_coalition_set, min_size=0, max_size=self.n)):
coalition_lookup[T] = i # set lookup for the coalition
coalition_matrix[i, T] = True # one-hot-encode the coalition
game_values = game_fun(coalition_matrix) # compute the game values
game_values = self.game_fun(coalition_matrix) # compute the game values
baseline_value = float(game_values[0]) # set the baseline value
coalition_lookup = coalition_lookup
return baseline_value, game_values, coalition_lookup

def moebius_transform(self, *args, **kwargs) -> InteractionValues:
Expand All @@ -158,6 +188,7 @@ def moebius_transform(self, *args, **kwargs) -> InteractionValues:
return self._computed[("Moebius", self.n)]
except KeyError: # if not computed yet, just continue
pass

# compute the Moebius transform
moebius_transform = np.zeros(2**self.n)
coalition_lookup = {}
Expand Down Expand Up @@ -836,7 +867,7 @@ def probabilistic_value(self, index: str, *args, **kwargs) -> InteractionValues:

def compute_egalitarian_least_core(self, *args, **kwargs):

from shapiq.core import egalitarian_least_core
from shapiq.game_theory.core import egalitarian_least_core

order = 1

Expand All @@ -856,7 +887,7 @@ def compute_egalitarian_least_core(self, *args, **kwargs):

def get_bernoulli_weights(order: int) -> np.ndarray:
"""Returns the bernoulli weights in the k-additive approximation via SII, e.g. used in
kADD-SHAP.
kADD-SHAP.

Args:
order: The highest order of interactions
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import numpy as np
from scipy.special import binom

from .interaction_values import InteractionValues
from .utils.sets import powerset
from shapiq.interaction_values import InteractionValues
from shapiq.utils.sets import powerset


class MoebiusConverter:
Expand Down
2 changes: 1 addition & 1 deletion shapiq/games/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def exact_values(self, index: str, order: int) -> InteractionValues:
Returns:
InteractionValues: The exact interaction values.
"""
from ..exact import ExactComputer
from shapiq.game_theory.exact import ExactComputer

# raise warning if the game is not precomputed and n_players > 16
if not self.precomputed and self.n_players > 16:
Expand Down
4 changes: 2 additions & 2 deletions shapiq/games/benchmark/synthetic/soum.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
normalize: bool = False,
verbose: bool = False,
):
from ....moebius_converter import MoebiusConverter
from shapiq.game_theory.moebius_converter import MoebiusConverter

self._rng = np.random.default_rng(random_state)

Expand Down Expand Up @@ -160,7 +160,7 @@ def exact_values(self, index: str, order: int) -> InteractionValues:
Returns:
The exact values for the given index and order.
"""
from ....moebius_converter import MoebiusConverter
from shapiq.game_theory.moebius_converter import MoebiusConverter

if self.converter is None:
self.converter = MoebiusConverter(self.moebius_coefficients)
Expand Down
6 changes: 5 additions & 1 deletion shapiq/interaction_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
import matplotlib.pyplot as plt
import numpy as np

from shapiq.indices import ALL_AVAILABLE_INDICES, index_generalizes_bv, index_generalizes_sv
from shapiq.game_theory.indices import (
ALL_AVAILABLE_INDICES,
index_generalizes_bv,
index_generalizes_sv,
)
from shapiq.utils.sets import count_interactions, generate_interaction_lookup, powerset


Expand Down
2 changes: 1 addition & 1 deletion shapiq/plot/stacked_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def stacked_bar_plot(
This stacked bar plot can be used to visualize the amount of interaction between the features
for a given instance. The n-SII values are plotted as stacked bars with positive and negative
parts stacked on top of each other. The colors represent the order of the n-SII values. For a
detailed explanation of this plot, refer to `Bordt and von Luxburg (2023) <https://doi.org/10.48550/arXiv.2209.0401>`_.
detailed explanation of this plot, refer to `Bordt and von Luxburg (2023) <https://proceedings.mlr.press/v206/bordt23a.html>`_.

An example of the plot is shown below.

Expand Down
Loading
Loading