Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

21860: Improved type hinting #309

Merged
merged 2 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions howso/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1830,7 +1830,7 @@ def react( # noqa: C901
If set to True, will scale influence weights by each case's
`weight_feature` weight. If unspecified, case weights
will be used if the Trainee has them.
case_indices : Iterable of Sequence[Union[str, int]], defaults to None
case_indices : Iterable of Sequence[str | int], defaults to None
An Iterable of Sequences, of session id and index, where
index is the original 0-based index of the case as it was trained
into the session. If this case does not exist, discriminative react
Expand Down Expand Up @@ -1938,7 +1938,7 @@ def react( # noqa: C901
action -> pandas.DataFrame
A data frame of action values.

details -> Dict or List
details -> dict or list
An aggregated list of any requested details.

Raises
Expand Down Expand Up @@ -2677,7 +2677,7 @@ def react_series( # noqa: C901
action -> pandas.DataFrame
A data frame of action values.

details -> Dict or List
details -> dict or list
An aggregated list of any requested details.

Raises
Expand Down Expand Up @@ -4041,10 +4041,10 @@ def set_auto_ablation_params(
residual_prediction_features : Optional[List[str]], optional
For each of the features specified, will ablate a case if
abs(prediction - case value) / prediction <= feature residual.
tolerance_prediction_threshold_map : Optional[Dict[str, Tuple[float, float]]], optional
tolerance_prediction_threshold_map : Optional[dict[str, tuple[float, float]]], optional
For each of the features specified, will ablate a case if the prediction >= (case value - MIN)
and the prediction <= (case value + MAX).
relative_prediction_threshold_map : Optional[Dict[str, float]], optional
relative_prediction_threshold_map : Optional[dict[str, float]], optional
For each of the features specified, will ablate a case if
abs(prediction - case value) / prediction <= relative threshold
conviction_lower_threshold : Optional[float], optional
Expand Down
8 changes: 5 additions & 3 deletions howso/client/feature_flags.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import typing as t
import warnings

Expand All @@ -13,9 +15,9 @@ class FeatureFlags:
"""

# Define obsolete flags here to raise a warning when defined
_obsolete_flags: t.Union[t.Set[str], None] = None
_obsolete_flags: set[str] | None = None

def __init__(self, flags: t.Optional[t.Dict[str, t.Any]]):
def __init__(self, flags: t.Optional[dict[str, t.Any]]):
self._store = dict()
if flags is not None:
obsolete = set()
Expand Down Expand Up @@ -55,7 +57,7 @@ def parse_flag(cls, flag: str) -> str:
"""Parse the flag name."""
return flag.replace('-', '_').lower()

def __iter__(self) -> t.Generator[t.Tuple[str, bool], None, None]:
def __iter__(self) -> t.Generator[tuple[str, bool], None, None]:
"""Iterate over flags."""
return ((key, value) for key, value in self._store.items())

Expand Down
6 changes: 3 additions & 3 deletions howso/client/pandas/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from collections.abc import Iterable
from typing import Optional
from collections.abc import Collection
import typing as t

import pandas as pd
from pandas import DataFrame, Index
Expand Down Expand Up @@ -88,7 +88,7 @@ def get_extreme_cases(
trainee_id: str,
num: int,
sort_feature: str,
features: Optional[Iterable[str]] = None
features: t.Optional[Collection[str]] = None
) -> DataFrame:
"""
Base: :func:`howso.client.AbstractHowsoClient.get_extreme_cases`.
Expand Down
47 changes: 26 additions & 21 deletions howso/client/schemas/reaction.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from collections import abc
from functools import singledispatchmethod
from pprint import pformat
Expand Down Expand Up @@ -29,11 +31,11 @@ class Reaction(abc.MutableMapping):

Parameters
----------
action : Union[pandas.DataFrame, list, dict], default None
action : pandas.DataFrame or list or dict, default None
(Optional) A DataFrame with columns representing the requested
features of ``react`` or ``react_series`` cases.

details : List or None
details : list or None
(Optional) The details of results from ``react`` or ``react_series``
when providing a ``details`` parameter.
"""
Expand All @@ -60,8 +62,8 @@ class Reaction(abc.MutableMapping):
}

def __init__(self,
action: t.Optional[t.Union[pd.DataFrame, list, dict]] = None,
details: t.Optional[t.MutableMapping[str, t.Any]] = None
action: t.Optional[pd.DataFrame | list | dict] = None,
details: t.Optional[abc.MutableMapping[str, t.Any]] = None
):
"""Initialize the dictionary with the allowed keys."""
self._data = {
Expand All @@ -79,7 +81,7 @@ def __init__(self,

self._reorganized_details = None

def _validate_key(self, key) -> str:
def _validate_key(self, key: str) -> str:
"""
Raise KeyError if key is not one of the allowed keys.

Expand Down Expand Up @@ -115,18 +117,18 @@ def _validate_key(self, key) -> str:

return key

def __getitem__(self, key):
def __getitem__(self, key: str):
"""Get an item by key if the key is allowed."""
key = self._validate_key(key)
return self._data[key]

def __setitem__(self, key, value):
def __setitem__(self, key: str, value: t.Any):
"""Set an item by key if the key is allowed."""
key = self._validate_key(key)
self._reorganized_details = None
self._data[key] = value

def __delitem__(self, key):
def __delitem__(self, key: str):
"""Delete an item by key if the key is allowed."""
key = self._validate_key(key)
self._reorganized_details = None
Expand All @@ -148,7 +150,7 @@ def __repr__(self) -> str:

@singledispatchmethod
def add_reaction(self, action: pd.DataFrame,
details: t.MutableMapping[str, t.Any]):
details: abc.MutableMapping[str, t.Any]):
"""
Add more data to the instance.

Expand Down Expand Up @@ -201,18 +203,18 @@ def add_reaction(self, action: pd.DataFrame,
self._reorganized_details = None

@add_reaction.register
def _(self, action: dict, details: t.MutableMapping[str, t.Any]):
"""Add Dict[List, Dict] to Reaction."""
def _(self, action: dict, details: abc.MutableMapping[str, t.Any]):
"""Add dict[list, dict] to Reaction."""
action_df = pd.DataFrame.from_dict(action)
return self.add_reaction(action_df, details)

@add_reaction.register
def _(self, action: list, details: t.MutableMapping[str, t.Any]):
"""Add list[Dict] to Reaction."""
def _(self, action: list, details: abc.MutableMapping[str, t.Any]):
"""Add list[dict] to Reaction."""
action_df = pd.DataFrame(action)
return self.add_reaction(action_df, details)

def gen_cases(self) -> t.Generator[t.Dict, None, None]:
def gen_cases(self) -> t.Generator[dict, None, None]:
"""
Yield dict containing DetailedCase items for a single case.

Expand Down Expand Up @@ -240,8 +242,7 @@ def reorganized_details(self):
return self._reorganized_details

@classmethod
def _reorganize_details(cls, details: t.MutableMapping[str, t.List]
) -> t.List[t.Dict]:
def _reorganize_details(cls, details: abc.MutableMapping[str, list]) -> list[dict]:
"""
Re-organize `details` to be a list of dicts. One dict per case.

Expand All @@ -261,11 +262,15 @@ def _reorganize_details(cls, details: t.MutableMapping[str, t.List]
{k1: v1m, k2: v2m, ... kn: vnm}
]

Parameters:
details : Dict of Lists
Parameters
----------
details : dict of list
The reaction details.

Returns:
List of Dicts, one Dict per case
Returns
-------
List of dicts
One dict per case.
"""
if isinstance(details, list):
return details
Expand All @@ -282,7 +287,7 @@ def _reorganize_details(cls, details: t.MutableMapping[str, t.List]
k: v for k, v in details.items()
if k in cls.KNOWN_KEYS and v
}
# Transform Dict[List] -> List[Dict]
# Transform dict[list] -> list[dict]
per_case_details = [
dict(zip([key for key in cleaned_details.keys()], values))
for values in zip(*cleaned_details.values())
Expand Down
7 changes: 5 additions & 2 deletions howso/direct/_utilities.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

from importlib import metadata
from pathlib import Path
import sysconfig
from typing import Union


def get_file_in_distribution(file_path) -> Union[Path, None]:
def get_file_in_distribution(file_path: str) -> Path | None:
"""
Locate the LICENSE.txt file in the distribution of this package.

Expand All @@ -20,6 +21,8 @@ def get_file_in_distribution(file_path) -> Union[Path, None]:
"""
purelib_path = sysconfig.get_path('purelib')
dist = metadata.distribution('howso-engine')
if dist.files is None:
raise AssertionError("The package howso-engine is not installed correctly, please reinstall.")
for fp in dist.files:
if fp.name == file_path:
return Path(purelib_path, fp)
62 changes: 34 additions & 28 deletions howso/engine/trainee.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def needs_analyze(self) -> bool:
return self._needs_analyze

@property
def calculated_matrices(self) -> t.Optional[dict[str, DataFrame]]:
def calculated_matrices(self) -> dict[str, DataFrame] | None:
"""
The calculated matrices.

Expand Down Expand Up @@ -2610,7 +2610,7 @@ def react_group(
use_case_weights: t.Optional[bool] = None,
features: t.Optional[Collection[str]] = None,
weight_feature: t.Optional[str] = None,
) -> DataFrame | dict:
) -> DataFrame:
"""
Computes specified data for a **set** of cases.

Expand Down Expand Up @@ -2661,23 +2661,26 @@ def react_group(

Returns
-------
DataFrame or dict
DataFrame
The conviction of grouped cases.
"""
return self.client.react_group(
trainee_id=self.id,
new_cases=new_cases,
features=features,
familiarity_conviction_addition=familiarity_conviction_addition,
familiarity_conviction_removal=familiarity_conviction_removal,
kl_divergence_addition=kl_divergence_addition,
kl_divergence_removal=kl_divergence_removal,
p_value_of_addition=p_value_of_addition,
p_value_of_removal=p_value_of_removal,
distance_contributions=distance_contributions,
use_case_weights=use_case_weights,
weight_feature=weight_feature,
)
if isinstance(self.client, HowsoPandasClientMixin):
return self.client.react_group(
trainee_id=self.id,
new_cases=new_cases,
features=features,
familiarity_conviction_addition=familiarity_conviction_addition,
familiarity_conviction_removal=familiarity_conviction_removal,
kl_divergence_addition=kl_divergence_addition,
kl_divergence_removal=kl_divergence_removal,
p_value_of_addition=p_value_of_addition,
p_value_of_removal=p_value_of_removal,
distance_contributions=distance_contributions,
use_case_weights=use_case_weights,
weight_feature=weight_feature,
)
else:
raise AssertionError("Client must have the 'react_group' method.")

def get_feature_conviction(
self,
Expand All @@ -2688,7 +2691,7 @@ def get_feature_conviction(
action_features: t.Optional[Collection[str]] = None,
features: t.Optional[Collection[str]] = None,
weight_feature: t.Optional[str] = None,
) -> DataFrame | dict:
) -> DataFrame:
"""
Get familiarity conviction for features in the model.

Expand Down Expand Up @@ -2720,19 +2723,22 @@ def get_feature_conviction(

Returns
-------
DataFrame or dict
DataFrame
A DataFrame containing the familiarity conviction rows to feature
columns.
"""
return self.client.get_feature_conviction(
trainee_id=self.id,
action_features=action_features,
familiarity_conviction_addition=familiarity_conviction_addition,
familiarity_conviction_removal=familiarity_conviction_removal,
features=features,
use_case_weights=use_case_weights,
weight_feature=weight_feature,
)
if isinstance(self.client, HowsoPandasClientMixin):
return self.client.get_feature_conviction(
trainee_id=self.id,
action_features=action_features,
familiarity_conviction_addition=familiarity_conviction_addition,
familiarity_conviction_removal=familiarity_conviction_removal,
features=features,
use_case_weights=use_case_weights,
weight_feature=weight_feature,
)
else:
raise AssertionError("Client must have the 'get_feature_conviction' method.")

def get_marginal_stats(
self, *,
Expand Down
Loading
Loading