diff --git a/hisel/hsic.py b/hisel/hsic.py index afabb8e..9454a43 100644 --- a/hisel/hsic.py +++ b/hisel/hsic.py @@ -1,10 +1,14 @@ +from typing import Optional from hisel import kernels +from hisel.kernels import KernelType import numpy as np def hsic_b( x: np.ndarray, y: np.ndarray, + xkerneltype: Optional[KernelType] = None, + ykerneltype: Optional[KernelType] = None, ): assert x.ndim == 2 assert y.ndim == 2 @@ -14,12 +18,22 @@ def hsic_b( dy: int = y.shape[1] lx: float = np.sqrt(dx) ly: float = np.sqrt(dy) + if xkerneltype is None: + if x.dtype == int: + xkerneltype = KernelType.DELTA + else: + xkerneltype = KernelType.RBF + if ykerneltype is None: + if y.dtype == int: + ykerneltype = KernelType.DELTA + else: + ykerneltype = KernelType.RBF xgram: np.ndarray = kernels.multivariate_phi( - x.T, lx + x.T, lx, xkerneltype ) k = xgram[0, :, :] ygram: np.ndarray = kernels.multivariate_phi( - y.T, ly + y.T, ly, ykerneltype ) l = kernels._center_gram(ygram)[0, :, :] return np.trace(k @ l) / (n*n) diff --git a/notebooks/ensemble-example.ipynb b/notebooks/ensemble-example.ipynb new file mode 100644 index 0000000..237743c --- /dev/null +++ b/notebooks/ensemble-example.ipynb @@ -0,0 +1,302 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "802e8c73", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "import itertools\n", + "from sklearn.metrics import adjusted_mutual_info_score\n", + "\n", + "\n", + "from hisel import select, hsic\n", + "from hisel.select import FeatureType, HSICSelector as Selector" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "798f7c6d", + "metadata": {}, + "outputs": [], + "source": [ + "k = 5\n", + "n = 2000\n", + "d = 20" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50b99be8", + "metadata": {}, + "outputs": [], + "source": [ + "x0 = np.random.randint(k, size=(n, 1))\n", + "x1 = np.random.randint(k, size=(n, 1))\n", + "ms = np.random.randint(low=2, high=20, size = d-2)\n", + "others = [np.random.choice(m, size=(n, 1)) for m in ms]\n", + "all_ = np.concatenate(\n", + " [x0, x1] + others,\n", + " axis=1\n", + ")\n", + "y = np.asarray(x0 == x1, dtype=int) # k + x0 - x1 # np.asarray(x0 == x1, dtype=int)\n", + "permuter = np.random.permutation(np.eye(d, dtype=int).T).T\n", + "x = np.array(all_ @ permuter, dtype=int)\n", + "expected_features = [np.argmax(permuter[0, :]), np.argmax(permuter[1, :])]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e6236e9e", + "metadata": {}, + "outputs": [], + "source": [ + "assert np.all(x[:, expected_features[0]] == x0[:, 0])\n", + "assert np.all(x[:, expected_features[1]] == x1[:, 0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f83edaef", + "metadata": {}, + "outputs": [], + "source": [ + "sns.scatterplot(x = x0[:, 0] - x1[:, 0], y = y[:, 0])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "140b9f88", + "metadata": {}, + "outputs": [], + "source": [ + "xdf = pd.DataFrame(x, columns = [f'x{i}' for i in range(d)])\n", + "ydf = pd.Series(y[:, 0], name='y')" + ] + }, + { + "cell_type": "markdown", + "id": "e37502d7", + "metadata": {}, + "source": [ + "### Selection with marginal 1D ksg mutual info" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "139b18ff", + "metadata": {}, + "outputs": [], + "source": [ + "ksgselection, mis = select.ksgmi(xdf, ydf, threshold=0.01)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5ffca204", + "metadata": {}, + "outputs": [], + "source": [ + "print(f'Expected features: {sorted(expected_features)}')\n", + "print(f'Marginal KSG selection: {sorted(ksgselection)}')" + ] + }, + { + "cell_type": "markdown", + "id": "c8906000", + "metadata": {}, + "source": [ + "### Selection with HSIC Lasso" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1487ff0e", + "metadata": {}, + "outputs": [], + "source": [ + "selector = Selector(x, y, xfeattype=FeatureType.DISCR, yfeattype=FeatureType.DISCR)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "afab6f16", + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = n // 10\n", + "minibatch_size = 200\n", + "number_of_epochs = 3\n", + "threshold = .0\n", + "device = None # run on CPU" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "01efe57c", + "metadata": {}, + "outputs": [], + "source": [ + "hsiclasso_selection = selector.select(\n", + " number_of_features=2,\n", + " batch_size=batch_size,\n", + " minibatch_size=minibatch_size,\n", + " number_of_epochs=number_of_epochs,\n", + " device=device\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "97929ada", + "metadata": {}, + "outputs": [], + "source": [ + "print(f'Expected features: {sorted(expected_features)}')\n", + "print(f'HSIC Lasso selection: {sorted(hsiclasso_selection)}')" + ] + }, + { + "cell_type": "markdown", + "id": "d88d85c5", + "metadata": {}, + "source": [ + "### Confirm that HSIC_b correctly assigns highest dependence to the correct selection" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38056f04", + "metadata": {}, + "outputs": [], + "source": [ + "correct_dependence = n * n * hsic.hsic_b(\n", + " x[:, list(expected_features)],\n", + " y\n", + ")\n", + "nsel = np.random.randint(low=1, high=d)\n", + "random_selection = np.random.choice(list(range(d)), replace=False, size=nsel)\n", + "random_dependence = n * n * hsic.hsic_b(\n", + " x[:, list(random_selection)],\n", + " y\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "92bc809f", + "metadata": {}, + "outputs": [], + "source": [ + "print(f'HSIC-estimated dependence between correct selection and target: {correct_dependence}')\n", + "print(f'HSIC-estimated dependence between random selection and target: {random_dependence}')" + ] + }, + { + "cell_type": "markdown", + "id": "beb34ecd", + "metadata": {}, + "source": [ + "### Selection with 2D discrete mutual information" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3d1459fb", + "metadata": {}, + "outputs": [], + "source": [ + "def onedimlabel(x):\n", + " assert x.ndim == 2\n", + " ns = np.amax(x, axis=0)\n", + " res = np.array(x[:, 0], copy=True)\n", + " m = 1\n", + " for i in range(1, x.shape[1]):\n", + " m *= max(1, ns[i-1])\n", + " res += (1+m) * x[:, i]\n", + " return res" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16a8e7f5", + "metadata": {}, + "outputs": [], + "source": [ + "l = 2\n", + "miscores = {subset: \n", + " adjusted_mutual_info_score(onedimlabel(x[:, list(subset)]), y[:, 0])\n", + " for subset in itertools.combinations(list(range(d)), l)\n", + " \n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "168eb38b", + "metadata": {}, + "outputs": [], + "source": [ + "s = (0,1)\n", + "mi = 0\n", + "for k, v in miscores.items():\n", + " if v > mi:\n", + " s = k\n", + " mi = v\n", + "twod_mi_selection = s" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a14eb4e9", + "metadata": {}, + "outputs": [], + "source": [ + "print(f'Expected features: {sorted(expected_features)}')\n", + "print(f'2D discrete MI selection: {sorted(twod_mi_selection)}')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "hiselc", + "language": "python", + "name": "hiselc" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}