Skip to content

Commit

Permalink
Dependence on ensemble - failure of HSIC (#24)
Browse files Browse the repository at this point in the history
* Dependence on ensemble - failure of HSIC

* Update notebook

* Notebook cosmetics
  • Loading branch information
claudio-tw authored May 2, 2023
1 parent dfe1518 commit f78492b
Show file tree
Hide file tree
Showing 2 changed files with 318 additions and 2 deletions.
18 changes: 16 additions & 2 deletions hisel/hsic.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from typing import Optional
from hisel import kernels
from hisel.kernels import KernelType
import numpy as np


def hsic_b(
x: np.ndarray,
y: np.ndarray,
xkerneltype: Optional[KernelType] = None,
ykerneltype: Optional[KernelType] = None,
):
assert x.ndim == 2
assert y.ndim == 2
Expand All @@ -14,12 +18,22 @@ def hsic_b(
dy: int = y.shape[1]
lx: float = np.sqrt(dx)
ly: float = np.sqrt(dy)
if xkerneltype is None:
if x.dtype == int:
xkerneltype = KernelType.DELTA
else:
xkerneltype = KernelType.RBF
if ykerneltype is None:
if y.dtype == int:
ykerneltype = KernelType.DELTA
else:
ykerneltype = KernelType.RBF
xgram: np.ndarray = kernels.multivariate_phi(
x.T, lx
x.T, lx, xkerneltype
)
k = xgram[0, :, :]
ygram: np.ndarray = kernels.multivariate_phi(
y.T, ly
y.T, ly, ykerneltype
)
l = kernels._center_gram(ygram)[0, :, :]
return np.trace(k @ l) / (n*n)
302 changes: 302 additions & 0 deletions notebooks/ensemble-example.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "802e8c73",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import seaborn as sns\n",
"import itertools\n",
"from sklearn.metrics import adjusted_mutual_info_score\n",
"\n",
"\n",
"from hisel import select, hsic\n",
"from hisel.select import FeatureType, HSICSelector as Selector"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "798f7c6d",
"metadata": {},
"outputs": [],
"source": [
"k = 5\n",
"n = 2000\n",
"d = 20"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "50b99be8",
"metadata": {},
"outputs": [],
"source": [
"x0 = np.random.randint(k, size=(n, 1))\n",
"x1 = np.random.randint(k, size=(n, 1))\n",
"ms = np.random.randint(low=2, high=20, size = d-2)\n",
"others = [np.random.choice(m, size=(n, 1)) for m in ms]\n",
"all_ = np.concatenate(\n",
" [x0, x1] + others,\n",
" axis=1\n",
")\n",
"y = np.asarray(x0 == x1, dtype=int) # k + x0 - x1 # np.asarray(x0 == x1, dtype=int)\n",
"permuter = np.random.permutation(np.eye(d, dtype=int).T).T\n",
"x = np.array(all_ @ permuter, dtype=int)\n",
"expected_features = [np.argmax(permuter[0, :]), np.argmax(permuter[1, :])]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e6236e9e",
"metadata": {},
"outputs": [],
"source": [
"assert np.all(x[:, expected_features[0]] == x0[:, 0])\n",
"assert np.all(x[:, expected_features[1]] == x1[:, 0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f83edaef",
"metadata": {},
"outputs": [],
"source": [
"sns.scatterplot(x = x0[:, 0] - x1[:, 0], y = y[:, 0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "140b9f88",
"metadata": {},
"outputs": [],
"source": [
"xdf = pd.DataFrame(x, columns = [f'x{i}' for i in range(d)])\n",
"ydf = pd.Series(y[:, 0], name='y')"
]
},
{
"cell_type": "markdown",
"id": "e37502d7",
"metadata": {},
"source": [
"### Selection with marginal 1D ksg mutual info"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "139b18ff",
"metadata": {},
"outputs": [],
"source": [
"ksgselection, mis = select.ksgmi(xdf, ydf, threshold=0.01)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5ffca204",
"metadata": {},
"outputs": [],
"source": [
"print(f'Expected features: {sorted(expected_features)}')\n",
"print(f'Marginal KSG selection: {sorted(ksgselection)}')"
]
},
{
"cell_type": "markdown",
"id": "c8906000",
"metadata": {},
"source": [
"### Selection with HSIC Lasso"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1487ff0e",
"metadata": {},
"outputs": [],
"source": [
"selector = Selector(x, y, xfeattype=FeatureType.DISCR, yfeattype=FeatureType.DISCR)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "afab6f16",
"metadata": {},
"outputs": [],
"source": [
"batch_size = n // 10\n",
"minibatch_size = 200\n",
"number_of_epochs = 3\n",
"threshold = .0\n",
"device = None # run on CPU"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "01efe57c",
"metadata": {},
"outputs": [],
"source": [
"hsiclasso_selection = selector.select(\n",
" number_of_features=2,\n",
" batch_size=batch_size,\n",
" minibatch_size=minibatch_size,\n",
" number_of_epochs=number_of_epochs,\n",
" device=device\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "97929ada",
"metadata": {},
"outputs": [],
"source": [
"print(f'Expected features: {sorted(expected_features)}')\n",
"print(f'HSIC Lasso selection: {sorted(hsiclasso_selection)}')"
]
},
{
"cell_type": "markdown",
"id": "d88d85c5",
"metadata": {},
"source": [
"### Confirm that HSIC_b correctly assigns highest dependence to the correct selection"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "38056f04",
"metadata": {},
"outputs": [],
"source": [
"correct_dependence = n * n * hsic.hsic_b(\n",
" x[:, list(expected_features)],\n",
" y\n",
")\n",
"nsel = np.random.randint(low=1, high=d)\n",
"random_selection = np.random.choice(list(range(d)), replace=False, size=nsel)\n",
"random_dependence = n * n * hsic.hsic_b(\n",
" x[:, list(random_selection)],\n",
" y\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "92bc809f",
"metadata": {},
"outputs": [],
"source": [
"print(f'HSIC-estimated dependence between correct selection and target: {correct_dependence}')\n",
"print(f'HSIC-estimated dependence between random selection and target: {random_dependence}')"
]
},
{
"cell_type": "markdown",
"id": "beb34ecd",
"metadata": {},
"source": [
"### Selection with 2D discrete mutual information"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3d1459fb",
"metadata": {},
"outputs": [],
"source": [
"def onedimlabel(x):\n",
" assert x.ndim == 2\n",
" ns = np.amax(x, axis=0)\n",
" res = np.array(x[:, 0], copy=True)\n",
" m = 1\n",
" for i in range(1, x.shape[1]):\n",
" m *= max(1, ns[i-1])\n",
" res += (1+m) * x[:, i]\n",
" return res"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "16a8e7f5",
"metadata": {},
"outputs": [],
"source": [
"l = 2\n",
"miscores = {subset: \n",
" adjusted_mutual_info_score(onedimlabel(x[:, list(subset)]), y[:, 0])\n",
" for subset in itertools.combinations(list(range(d)), l)\n",
" \n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "168eb38b",
"metadata": {},
"outputs": [],
"source": [
"s = (0,1)\n",
"mi = 0\n",
"for k, v in miscores.items():\n",
" if v > mi:\n",
" s = k\n",
" mi = v\n",
"twod_mi_selection = s"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a14eb4e9",
"metadata": {},
"outputs": [],
"source": [
"print(f'Expected features: {sorted(expected_features)}')\n",
"print(f'2D discrete MI selection: {sorted(twod_mi_selection)}')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "hiselc",
"language": "python",
"name": "hiselc"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit f78492b

Please sign in to comment.