Skip to content

Commit

Permalink
Comprehensive API for selection - to be used in tw-experimentation (#30)
Browse files Browse the repository at this point in the history
* Comprehensive API for selection - to be used in tw-experimentation

* fix hsic search

* categorical tests
  • Loading branch information
claudio-tw authored May 12, 2023
1 parent 8009ca0 commit 64ccbde
Show file tree
Hide file tree
Showing 10 changed files with 525 additions and 55 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/run_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ jobs:
poetry run pytest tests/lar_test.py --disable-warnings
poetry run pytest tests/select_test.py --disable-warnings
poetry run pytest tests/hsic_test.py --disable-warnings
poetry run pytest tests/categorical_test.py --disable-warnings
poetry run pytest tests/feature_selection_test.py --disable-warnings
run_trufflehog:
name: "Run trufflehog to catch credential leaks"
Expand Down
2 changes: 1 addition & 1 deletion hisel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from . import kernels, select, hsic # NOQA
from . import kernels, select, permutohedron, hsic, categorical, feature_selection # NOQA
try:
import torch # NOQA
from . import torchkernels # NOQA
Expand Down
211 changes: 211 additions & 0 deletions hisel/categorical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
from typing import Optional, Set, Tuple, Callable, Union, List
import numpy as np
import pandas as pd
from dataclasses import dataclass
from sklearn.metrics import adjusted_mutual_info_score
from joblib import Parallel, delayed


from hisel import permutohedron


def _discretise(
y: np.ndarray,
num_quantiles: int = 10,
) -> np.ndarray:
assert y.ndim < 3
qs = np.linspace(0 + 1. / num_quantiles, 1 - 1. /
num_quantiles, num=num_quantiles)

def _build(cont):
assert cont.ndim == 1
discr = np.zeros(shape=cont.shape, dtype=int)
threshold = np.amin(cont)
for q in qs:
quant = np.quantile(cont, q)
if quant > threshold:
threshold = quant
discr += np.array(cont > threshold, dtype=int)
return discr

res = np.zeros(shape=y.shape)
if y.ndim == 2:
for d in range(y.shape[1]):
res[:, d] = _build(y[:, d])
else:
res = _build(y)
return res


def _preprocess_datatypes(
y: Union[pd.DataFrame, pd.Series],
) -> Union[pd.DataFrame, pd.Series]:
if isinstance(y, pd.DataFrame):
for col in y.columns:
if y[col].dtype == bool:
y[col] = y[col].astype(int)
elif y.dtypes == bool:
y = y.astype(int)
ydtypes = y.dtypes if isinstance(y, pd.DataFrame) else [y.dtypes]
for dtype in ydtypes:
assert dtype == int or dtype == float
return y


@dataclass
class Selection:
indexes: np.ndarray
features: List[str]


def select(
xdf: pd.DataFrame,
ydf: Union[pd.DataFrame, pd.Series],
num_permutations: Optional[int] = None,
im_ratio: float = .05,
max_iter: int = 1,
parallel: bool = False,
random_state: Optional[int] = None,
) -> Selection:
xdf = _preprocess_datatypes(xdf)
x = xdf.values
ydf = _preprocess_datatypes(ydf)
allfeatures: List[np.ndarray] = []
if isinstance(ydf, pd.Series):
if ydf.dtypes == float:
y = _discretise(ydf.values)
else:
y = ydf.values
allfeatures.append(
search(
x, y,
num_permutations=num_permutations,
im_ratio=im_ratio,
max_iter=max_iter,
parallel=parallel,
random_state=random_state,
)
)
else:
for col in ydf.columns:
if ydf[col].dtypes == float:
y = _discretise(ydf[col].values)
else:
y = ydf[col].values
allfeatures.append(
search(
x, y,
num_permutations=num_permutations,
im_ratio=im_ratio,
max_iter=max_iter,
parallel=parallel,
random_state=random_state,
)
)
fs = np.concatenate(allfeatures)
indexes = np.array(list(set(fs)), dtype=int)
features = list(xdf.columns[indexes])
return Selection(indexes=indexes, features=features)


def search(
x: np.ndarray,
y: np.ndarray,
num_permutations: Optional[int] = None,
im_ratio: float = .05,
max_iter: int = 1,
parallel: bool = False,
random_state: Optional[int] = None,
) -> np.ndarray:
assert x.ndim == 2
assert y.ndim == 1
assert x.shape[0] == y.shape[0]
n, d = x.shape
assert x.dtype == int
assert y.dtype == int
if num_permutations is None:
num_permutations = 3 * d
x = x - np.amin(x, axis=0, keepdims=True)
y = y - np.amin(y, axis=0, keepdims=True)
active_set = set(range(d))
sel = np.arange(d, dtype=int)
features = np.array([], dtype=int)
imall = .0
n_iter = 0
while len(active_set) > 1 and n_iter < max_iter:
active = np.array(list(active_set))
num_active = len(active)
num_haar_samples = min(
max(1, num_permutations // num_active),
2**num_active // num_active
)
permutations = permutohedron.haar_sampling(
num_active,
size=num_haar_samples,
random_state=random_state
)
if parallel:
tries = Parallel(n_jobs=-1)([
delayed(_try_permutation)(
ami, x, y, active, list(permutation))
for permutation in permutations
])
else:
tries = [_try_permutation(
ami, x, y, active, list(permutation)) for permutation in permutations]

im = .0
for im_, sel_ in tries:
if im_ > im:
sel = sel_
im = im_
if im < im_ratio * imall:
print('im < im_ratio * imall')
print(f'{im} < {im_ratio} * {imall}')
break
elif im > imall:
imall = im

features = np.concatenate((features, sel))
active_set = active_set.difference(set(features))
n_iter += 1
return features


def ami(
x: np.ndarray,
y: np.ndarray,
) -> np.ndarray:
n, d = x.shape
assert n == y.shape[0]
z = _encode(x)
im = np.empty(shape=(d, ), dtype=float)
for i in range(d):
im[i] = adjusted_mutual_info_score(z[:, i], y)
return im


def _encode(x: np.ndarray) -> np.ndarray:
assert x.ndim == 2
ns = 1 + np.amax(x, axis=0, keepdims=True)
res = np.array(x, copy=True)
ms = np.roll(ns, 1, axis=1)
ms[0, 0] = 1
ms = np.cumprod(ms, axis=1)
res = np.cumsum(ms * x, axis=1)
return res


def _try_permutation(
metric: Callable[[np.ndarray, np.ndarray], np.ndarray],
x: np.ndarray,
y: np.ndarray,
active: np.ndarray,
permutation: Union[List[int], np.ndarray],
) -> Tuple[float, np.ndarray]:
sel = active[permutation]
ims = metric(x[:, sel], y)
s = np.argmax(ims)
im = ims[s]
selection = sel[:s+1]
return im, selection
94 changes: 94 additions & 0 deletions hisel/feature_selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from typing import Optional, Union
import numpy as np
import pandas as pd
from dataclasses import dataclass
from hisel import hsic, select, categorical
from collections.abc import Mapping

LassoSelection = select.Selection


class Parameters(Mapping):
def __iter__(self):
if not hasattr(self, '__dataclass_fields__'):
raise StopIteration
for v in self.__dataclass_fields__:
yield v

def __len__(self):
if not hasattr(self, '__dataclass_fields__'):
return 0
return len(self.__dataclass_fields__)

def __getitem__(self, item):
return getattr(self, item)


@dataclass
class SearchParameters(Parameters):
num_permutations: Optional[int] = None
im_ratio: float = .05
max_iter: int = 2
parallel: bool = True
random_state: Optional[int] = None


@dataclass
class HSICLassoParameters(Parameters):
mi_threshold: float = .0001
hsic_threshold: float = .01
batch_size = 5000
minibatch_size: int = 200
number_of_epochs: int = 4
use_preselection: bool = True
device: Optional[str] = None


continuous_dtypes = [
float,
np.float32,
np.float64,
]

discrete_dtypes = [
bool,
int,
np.int32,
np.int64,
]


def select_features(
xdf: pd.DataFrame,
ydf: Union[pd.DataFrame, pd.Series],
hsiclasso_parameters: Optional[HSICLassoParameters] = None,
categorical_search_parameters: Optional[SearchParameters] = None,
):
n, d = xdf.shape
continuous_features = [
col for col in xdf.columns
if xdf[col].dtype in continuous_dtypes
]
discrete_features = [
col for col in xdf.columns
if xdf[col].dtype in discrete_dtypes
]

if hsiclasso_parameters is None:
hsiclasso_parameters = HSICLassoParameters()
if categorical_search_parameters is None:
categorical_search_parameters = SearchParameters()
continuous_lasso_selection: LassoSelection = select.select(
xdf[continuous_features], ydf, **hsiclasso_parameters)

categorical_search_selection: categorical.Selection = categorical.select(
xdf[discrete_features], ydf, **categorical_search_parameters)

selected_features = categorical_search_selection.features + \
continuous_lasso_selection.features
results = dict(
continuous_lasso_selection=continuous_lasso_selection,
categorical_search_selection=categorical_search_selection,
selected_features=selected_features
)
return results
Loading

0 comments on commit 64ccbde

Please sign in to comment.