Skip to content

Commit

Permalink
Merge branch '21637-add-thresholds-auto-ablation' of github.com:howso…
Browse files Browse the repository at this point in the history
…ai/howso-engine-py into 21637-add-thresholds-auto-ablation
  • Loading branch information
jdbeel committed Oct 17, 2024
2 parents 2241ae6 + 8af0499 commit e716e43
Show file tree
Hide file tree
Showing 18 changed files with 424 additions and 377 deletions.
10 changes: 5 additions & 5 deletions howso/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1848,7 +1848,7 @@ def react( # noqa: C901
If set to True, will scale influence weights by each case's
`weight_feature` weight. If unspecified, case weights
will be used if the Trainee has them.
case_indices : Iterable of Sequence[Union[str, int]], defaults to None
case_indices : Iterable of Sequence[str | int], defaults to None
An Iterable of Sequences, of session id and index, where
index is the original 0-based index of the case as it was trained
into the session. If this case does not exist, discriminative react
Expand Down Expand Up @@ -1956,7 +1956,7 @@ def react( # noqa: C901
action -> pandas.DataFrame
A data frame of action values.
details -> Dict or List
details -> dict or list
An aggregated list of any requested details.
Raises
Expand Down Expand Up @@ -2695,7 +2695,7 @@ def react_series( # noqa: C901
action -> pandas.DataFrame
A data frame of action values.
details -> Dict or List
details -> dict or list
An aggregated list of any requested details.
Raises
Expand Down Expand Up @@ -4062,10 +4062,10 @@ def set_auto_ablation_params(
residual_prediction_features : Optional[List[str]], optional
For each of the features specified, will ablate a case if
abs(prediction - case value) / prediction <= feature residual.
tolerance_prediction_threshold_map : Optional[Dict[str, Tuple[float, float]]], optional
tolerance_prediction_threshold_map : Optional[dict[str, tuple[float, float]]], optional
For each of the features specified, will ablate a case if the prediction >= (case value - MIN)
and the prediction <= (case value + MAX).
relative_prediction_threshold_map : Optional[Dict[str, float]], optional
relative_prediction_threshold_map : Optional[dict[str, float]], optional
For each of the features specified, will ablate a case if
abs(prediction - case value) / prediction <= relative threshold
conviction_lower_threshold : Optional[float], optional
Expand Down
8 changes: 5 additions & 3 deletions howso/client/feature_flags.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import typing as t
import warnings

Expand All @@ -13,9 +15,9 @@ class FeatureFlags:
"""

# Define obsolete flags here to raise a warning when defined
_obsolete_flags: t.Union[t.Set[str], None] = None
_obsolete_flags: set[str] | None = None

def __init__(self, flags: t.Optional[t.Dict[str, t.Any]]):
def __init__(self, flags: t.Optional[dict[str, t.Any]]):
self._store = dict()
if flags is not None:
obsolete = set()
Expand Down Expand Up @@ -55,7 +57,7 @@ def parse_flag(cls, flag: str) -> str:
"""Parse the flag name."""
return flag.replace('-', '_').lower()

def __iter__(self) -> t.Generator[t.Tuple[str, bool], None, None]:
def __iter__(self) -> t.Generator[tuple[str, bool], None, None]:
"""Iterate over flags."""
return ((key, value) for key, value in self._store.items())

Expand Down
6 changes: 3 additions & 3 deletions howso/client/pandas/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from collections.abc import Iterable
from typing import Optional
from collections.abc import Collection
import typing as t

import pandas as pd
from pandas import DataFrame, Index
Expand Down Expand Up @@ -88,7 +88,7 @@ def get_extreme_cases(
trainee_id: str,
num: int,
sort_feature: str,
features: Optional[Iterable[str]] = None
features: t.Optional[Collection[str]] = None
) -> DataFrame:
"""
Base: :func:`howso.client.AbstractHowsoClient.get_extreme_cases`.
Expand Down
47 changes: 26 additions & 21 deletions howso/client/schemas/reaction.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from collections import abc
from functools import singledispatchmethod
from pprint import pformat
Expand Down Expand Up @@ -29,11 +31,11 @@ class Reaction(abc.MutableMapping):
Parameters
----------
action : Union[pandas.DataFrame, list, dict], default None
action : pandas.DataFrame or list or dict, default None
(Optional) A DataFrame with columns representing the requested
features of ``react`` or ``react_series`` cases.
details : List or None
details : list or None
(Optional) The details of results from ``react`` or ``react_series``
when providing a ``details`` parameter.
"""
Expand All @@ -60,8 +62,8 @@ class Reaction(abc.MutableMapping):
}

def __init__(self,
action: t.Optional[t.Union[pd.DataFrame, list, dict]] = None,
details: t.Optional[t.MutableMapping[str, t.Any]] = None
action: t.Optional[pd.DataFrame | list | dict] = None,
details: t.Optional[abc.MutableMapping[str, t.Any]] = None
):
"""Initialize the dictionary with the allowed keys."""
self._data = {
Expand All @@ -79,7 +81,7 @@ def __init__(self,

self._reorganized_details = None

def _validate_key(self, key) -> str:
def _validate_key(self, key: str) -> str:
"""
Raise KeyError if key is not one of the allowed keys.
Expand Down Expand Up @@ -115,18 +117,18 @@ def _validate_key(self, key) -> str:

return key

def __getitem__(self, key):
def __getitem__(self, key: str):
"""Get an item by key if the key is allowed."""
key = self._validate_key(key)
return self._data[key]

def __setitem__(self, key, value):
def __setitem__(self, key: str, value: t.Any):
"""Set an item by key if the key is allowed."""
key = self._validate_key(key)
self._reorganized_details = None
self._data[key] = value

def __delitem__(self, key):
def __delitem__(self, key: str):
"""Delete an item by key if the key is allowed."""
key = self._validate_key(key)
self._reorganized_details = None
Expand All @@ -148,7 +150,7 @@ def __repr__(self) -> str:

@singledispatchmethod
def add_reaction(self, action: pd.DataFrame,
details: t.MutableMapping[str, t.Any]):
details: abc.MutableMapping[str, t.Any]):
"""
Add more data to the instance.
Expand Down Expand Up @@ -201,18 +203,18 @@ def add_reaction(self, action: pd.DataFrame,
self._reorganized_details = None

@add_reaction.register
def _(self, action: dict, details: t.MutableMapping[str, t.Any]):
"""Add Dict[List, Dict] to Reaction."""
def _(self, action: dict, details: abc.MutableMapping[str, t.Any]):
"""Add dict[list, dict] to Reaction."""
action_df = pd.DataFrame.from_dict(action)
return self.add_reaction(action_df, details)

@add_reaction.register
def _(self, action: list, details: t.MutableMapping[str, t.Any]):
"""Add list[Dict] to Reaction."""
def _(self, action: list, details: abc.MutableMapping[str, t.Any]):
"""Add list[dict] to Reaction."""
action_df = pd.DataFrame(action)
return self.add_reaction(action_df, details)

def gen_cases(self) -> t.Generator[t.Dict, None, None]:
def gen_cases(self) -> t.Generator[dict, None, None]:
"""
Yield dict containing DetailedCase items for a single case.
Expand Down Expand Up @@ -240,8 +242,7 @@ def reorganized_details(self):
return self._reorganized_details

@classmethod
def _reorganize_details(cls, details: t.MutableMapping[str, t.List]
) -> t.List[t.Dict]:
def _reorganize_details(cls, details: abc.MutableMapping[str, list]) -> list[dict]:
"""
Re-organize `details` to be a list of dicts. One dict per case.
Expand All @@ -261,11 +262,15 @@ def _reorganize_details(cls, details: t.MutableMapping[str, t.List]
{k1: v1m, k2: v2m, ... kn: vnm}
]
Parameters:
details : Dict of Lists
Parameters
----------
details : dict of list
The reaction details.
Returns:
List of Dicts, one Dict per case
Returns
-------
List of dicts
One dict per case.
"""
if isinstance(details, list):
return details
Expand All @@ -282,7 +287,7 @@ def _reorganize_details(cls, details: t.MutableMapping[str, t.List]
k: v for k, v in details.items()
if k in cls.KNOWN_KEYS and v
}
# Transform Dict[List] -> List[Dict]
# Transform dict[list] -> list[dict]
per_case_details = [
dict(zip([key for key in cleaned_details.keys()], values))
for values in zip(*cleaned_details.values())
Expand Down
7 changes: 5 additions & 2 deletions howso/direct/_utilities.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

from importlib import metadata
from pathlib import Path
import sysconfig
from typing import Union


def get_file_in_distribution(file_path) -> Union[Path, None]:
def get_file_in_distribution(file_path: str) -> Path | None:
"""
Locate the LICENSE.txt file in the distribution of this package.
Expand All @@ -20,6 +21,8 @@ def get_file_in_distribution(file_path) -> Union[Path, None]:
"""
purelib_path = sysconfig.get_path('purelib')
dist = metadata.distribution('howso-engine')
if dist.files is None:
raise AssertionError("The package howso-engine is not installed correctly, please reinstall.")
for fp in dist.files:
if fp.name == file_path:
return Path(purelib_path, fp)
62 changes: 34 additions & 28 deletions howso/engine/trainee.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def needs_analyze(self) -> bool:
return self._needs_analyze

@property
def calculated_matrices(self) -> t.Optional[dict[str, DataFrame]]:
def calculated_matrices(self) -> dict[str, DataFrame] | None:
"""
The calculated matrices.
Expand Down Expand Up @@ -2651,7 +2651,7 @@ def react_group(
use_case_weights: t.Optional[bool] = None,
features: t.Optional[Collection[str]] = None,
weight_feature: t.Optional[str] = None,
) -> DataFrame | dict:
) -> DataFrame:
"""
Computes specified data for a **set** of cases.
Expand Down Expand Up @@ -2702,23 +2702,26 @@ def react_group(
Returns
-------
DataFrame or dict
DataFrame
The conviction of grouped cases.
"""
return self.client.react_group(
trainee_id=self.id,
new_cases=new_cases,
features=features,
familiarity_conviction_addition=familiarity_conviction_addition,
familiarity_conviction_removal=familiarity_conviction_removal,
kl_divergence_addition=kl_divergence_addition,
kl_divergence_removal=kl_divergence_removal,
p_value_of_addition=p_value_of_addition,
p_value_of_removal=p_value_of_removal,
distance_contributions=distance_contributions,
use_case_weights=use_case_weights,
weight_feature=weight_feature,
)
if isinstance(self.client, HowsoPandasClientMixin):
return self.client.react_group(
trainee_id=self.id,
new_cases=new_cases,
features=features,
familiarity_conviction_addition=familiarity_conviction_addition,
familiarity_conviction_removal=familiarity_conviction_removal,
kl_divergence_addition=kl_divergence_addition,
kl_divergence_removal=kl_divergence_removal,
p_value_of_addition=p_value_of_addition,
p_value_of_removal=p_value_of_removal,
distance_contributions=distance_contributions,
use_case_weights=use_case_weights,
weight_feature=weight_feature,
)
else:
raise AssertionError("Client must have the 'react_group' method.")

def get_feature_conviction(
self,
Expand All @@ -2729,7 +2732,7 @@ def get_feature_conviction(
action_features: t.Optional[Collection[str]] = None,
features: t.Optional[Collection[str]] = None,
weight_feature: t.Optional[str] = None,
) -> DataFrame | dict:
) -> DataFrame:
"""
Get familiarity conviction for features in the model.
Expand Down Expand Up @@ -2761,19 +2764,22 @@ def get_feature_conviction(
Returns
-------
DataFrame or dict
DataFrame
A DataFrame containing the familiarity conviction rows to feature
columns.
"""
return self.client.get_feature_conviction(
trainee_id=self.id,
action_features=action_features,
familiarity_conviction_addition=familiarity_conviction_addition,
familiarity_conviction_removal=familiarity_conviction_removal,
features=features,
use_case_weights=use_case_weights,
weight_feature=weight_feature,
)
if isinstance(self.client, HowsoPandasClientMixin):
return self.client.get_feature_conviction(
trainee_id=self.id,
action_features=action_features,
familiarity_conviction_addition=familiarity_conviction_addition,
familiarity_conviction_removal=familiarity_conviction_removal,
features=features,
use_case_weights=use_case_weights,
weight_feature=weight_feature,
)
else:
raise AssertionError("Client must have the 'get_feature_conviction' method.")

def get_marginal_stats(
self, *,
Expand Down
Loading

0 comments on commit e716e43

Please sign in to comment.