From c45bbb88ae5c09022a31b80ddcd38da42c9b2179 Mon Sep 17 00:00:00 2001 From: Adam Li Date: Wed, 19 Apr 2023 09:47:07 -0400 Subject: [PATCH] [ENH] Add fisher z test (#7) Towards #5 Changes proposed in this pull request: - Adds partial correlation test - Setsup initial API design - Includes sphinx docs --------- Signed-off-by: Adam Li --- .circleci/config.yml | 6 +- doc/_templates/autosummary/class.rst | 1 - doc/api.rst | 40 +++++- doc/conditional_independence.rst | 51 ++----- doc/conf.py | 21 ++- doc/index.rst | 3 +- doc/installation.md | 22 +-- doc/use.rst | 2 +- doc/whats_new/v0.1.rst | 1 + poetry.lock | 40 +++--- pywhy_stats/__init__.py | 2 + pywhy_stats/fisherz.py | 126 ++++++++++++++++++ pywhy_stats/independence.py | 54 ++++++-- .../{p_value_result.py => pvalue_result.py} | 12 +- tests/test_fisherz_test.py | 44 ++++++ 15 files changed, 330 insertions(+), 95 deletions(-) create mode 100644 pywhy_stats/fisherz.py rename pywhy_stats/{p_value_result.py => pvalue_result.py} (74%) create mode 100644 tests/test_fisherz_test.py diff --git a/.circleci/config.yml b/.circleci/config.yml index c8ad75f..3a7653a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -63,7 +63,7 @@ jobs: sudo apt install libspatialindex-dev xdg-utils - python/install-packages: pkg-manager: poetry - args: "-E graph_func -E viz --with docs" + args: "--with docs" cache-version: "v1" # change to clear cache - run: name: Check poetry package versions @@ -145,12 +145,12 @@ jobs: - run: name: make linkcheck command: | - make -C doc linkcheck + poetry run make -C doc linkcheck - run: name: make linkcheck-grep when: always command: | - make -C doc linkcheck-grep + poetry run make -C doc linkcheck-grep - store_artifacts: path: doc/_build/linkcheck destination: linkcheck diff --git a/doc/_templates/autosummary/class.rst b/doc/_templates/autosummary/class.rst index 6056ea9..9c1db65 100644 --- a/doc/_templates/autosummary/class.rst +++ b/doc/_templates/autosummary/class.rst @@ -4,7 +4,6 @@ .. currentmodule:: {{ module }} .. autoclass:: {{ objname }} - :special-members: __contains__,__getitem__,__iter__,__len__,__add__,__sub__,__mul__,__div__,__neg__,__hash__ :members: .. include:: {{module}}.{{objname}}.examples diff --git a/doc/api.rst b/doc/api.rst index 8015759..66a2fba 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -22,10 +22,44 @@ Pyhy-Stats experimentally provides an interface for conditional independence testing and conditional discrepancy testing (also known as k-sample conditional independence testing). -Conditional Independence Testing -================================ +High-level Independence Testing +=============================== + +The easiest way to run a (conditional) independence test is to use the +:py:func:`independence_test` function. This function takes inputs and +will try to automatically pick the appropriate test based on the input. + +Note: this is only meant for beginnners, and the result should be interpreted +with caution as the ability to choose the optimal test is limited. When +one uses the wrong test for the type of data and assumptions they have, +then typically you will get less statistical power. + +.. currentmodule:: pywhy_stats +.. autosummary:: + :toctree: generated/ + + independence_test + Methods + + +All independence tests return a ``PValueResult`` object, which +contains the p-value and the test statistic and optionally additional information. + +.. currentmodule:: pywhy_stats.pvalue_result +.. autosummary:: + :toctree: generated/ + + PValueResult + +(Conditional) Independence Testing +================================== Testing for conditional independence among variables is a core part of many data analysis procedures. -TBD... \ No newline at end of file +.. currentmodule:: pywhy_stats +.. autosummary:: + :toctree: generated/ + + fisherz + diff --git a/doc/conditional_independence.rst b/doc/conditional_independence.rst index d256f13..6b506f6 100644 --- a/doc/conditional_independence.rst +++ b/doc/conditional_independence.rst @@ -4,7 +4,7 @@ Independence ============ -.. currentmodule:: pywhy_stats.ci +.. currentmodule:: pywhy_stats Probabilistic independence among two random variables is when the realization of one variable does not affect the distribution of the other variable. It is a fundamental notion @@ -42,10 +42,10 @@ with certain assumptions on the underlying data distribution. Conditional Mutual Information ------------------------------ -Conditional mutual information (CMI) is a general formulation of CI, where CMI is defined as -:math:: +Conditional mutual information (CMI) is a general formulation of CI, where CMI is defined as: - \\int log \frac{p(x, y | z)}{p(x | z) p(y | z)} + .. math:: + \int log \frac{p(x, y | z)}{p(x | z) p(y | z)} As we can see, CMI is equal to 0, if and only if :math:`p(x, y | z) = p(x | z) p(y | z)`, which is exactly the definition of CI. CMI is completely non-parametric and thus requires no assumptions @@ -70,24 +70,18 @@ various proposals in the literature for estimating CMI, which we summarize here: one can use variants of Random Forests to generate adaptive nearest-neighbor estimates in high-dimensions or on manifolds, such that the KSG estimator is still powerful. -.. autosummary:: - :toctree: generated/ - - CMITest + - The Classifier Divergence approach estimates CMI using a classification model. -.. autosummary:: - :toctree: generated/ - - ClassifierCMITest + - Direct posterior estimates can be implemented with a classification model by directly estimating :math:`P(y|x)` and :math:`P(y|x,z)`, which can be used as plug-in estimates to the equation for CMI. -Partial (Pearson) Correlation ------------------------------ +:mod:`pywhy_stats.fisherz` Partial (Pearson) Correlation +-------------------------------------------------------- Partial correlation based on the Pearson correlation is equivalent to CMI in the setting of normally distributed data. Computing partial correlation is fast and efficient and thus attractive to use. However, this **relies on the assumption that the variables are Gaussiany**, @@ -96,7 +90,7 @@ which may be unrealistic in certain datasets. .. autosummary:: :toctree: generated/ - FisherZCITest + fisherz Discrete, Categorical and Binary Data ------------------------------------- @@ -105,10 +99,6 @@ class of tests will construct a contingency table based on the number of levels each discrete variable. An exponential amount of data is needed for increasing levels for a discrete variable. -.. autosummary:: - :toctree: generated/ - - GSquareCITest Kernel-Approaches ----------------- @@ -118,10 +108,6 @@ that computes a test statistic from kernels of the data and uses permutation tes generate samples from the null distribution :footcite:`Zhang2011`, which are then used to estimate a pvalue. -.. autosummary:: - :toctree: generated/ - - KernelCITest Classifier-based Approaches --------------------------- @@ -142,16 +128,11 @@ helps maintain dependence between (X, Z) and (Y, Z) (if it exists), but generate conditionally independent dataset. -.. autosummary:: - :toctree: generated/ - - ClassifierCITest - ======================= Conditional Discrepancy ======================= -.. currentmodule:: pywhy_stats.cd +.. currentmodule:: pywhy_stats Conditional discrepancy (CD) is another form of conditional invariances that may be exhibited by data. The general question is whether or not the following two distributions are equal: @@ -181,10 +162,6 @@ that computes a test statistic from kernels of the data and uses a weighted perm based on the estimated propensity scores to generate samples from the null distribution :footcite:`Park2021conditional`, which are then used to estimate a pvalue. -.. autosummary:: - :toctree: generated/ - - KernelCDTest Bregman-Divergences ------------------- @@ -193,7 +170,7 @@ that computes a test statistic from estimated Von-Neumann divergences of the dat weighted permutation testing based on the estimated propensity scores to generate samples from the null distribution :footcite:`Yu2020Bregman`, which are then used to estimate a pvalue. -.. autosummary:: - :toctree: generated/ - - BregmanCDTest +========== +References +========== +.. footbibliography:: diff --git a/doc/conf.py b/doc/conf.py index de6cd2d..b4fab8d 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -3,11 +3,13 @@ # This file only contains a selection of the most common options. For a full # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html +from __future__ import annotations import os import sys from datetime import datetime +import numpy.typing import sphinx_gallery # noqa: F401 from sphinx_gallery.sorting import ExampleTitleSortKey @@ -70,7 +72,11 @@ autosummary_generate = True autodoc_default_options = {"inherited-members": None} -autodoc_typehints = "signature" + +# whether to expand type hints in function/class signatures +autodoc_typehints = "none" + +add_module_names = False # -- numpydoc # Below is needed to prevent errors @@ -109,9 +115,6 @@ "dictionary", "no", "attributes", - # numpy - "ScalarType", - "ArrayLike", # shapes "n_times", "obj", @@ -123,7 +126,6 @@ "n_samples", "n_variables", "n_classes", - "NDArray", "n_samples_X", "n_samples_Y", "n_features_x", @@ -141,11 +143,20 @@ "pgmpy.models.BayesianNetwork": "pgmpy.models.BayesianNetwork", # joblib "joblib.Parallel": "joblib.Parallel", + "PValueResult": "pywhy_stats.pvalue_result.PValueResult", # numpy "NDArray": "numpy.ndarray", + # "ArrayLike": "numpy.typing.ArrayLike", "ArrayLike": ":term:`array_like`", + "fisherz": "pywhy_stats.fisherz", } +autodoc_typehints_format = "short" +# from __future__ import annotations +# autodoc_type_aliases = { +# 'Iterable': 'Iterable', +# 'ArrayLike': 'ArrayLike' +# } default_role = "literal" # Tell myst-parser to assign header anchors for h1-h3. diff --git a/doc/index.rst b/doc/index.rst index 3a9a875..2d2a17f 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -25,7 +25,6 @@ Contents Reference API Simple Examples User Guide - tutorials/index whats_new .. toctree:: @@ -33,7 +32,7 @@ Contents :caption: Development License - Contributing + Contributing Team ---- diff --git a/doc/installation.md b/doc/installation.md index 2cac027..ede7e19 100644 --- a/doc/installation.md +++ b/doc/installation.md @@ -3,26 +3,30 @@ Installation **pywhy-stats** supports Python >= 3.8. -## Installing with ``pip``, or ``poetry``. +Installing with ``pip``, or ``poetry`` +-------------------------------------- **pywhy-stats** is available [on PyPI](https://pypi.org/project/pywhy-stats/). Just run - pip install pywhy-stats + >>> pip install pywhy-stats - # or via poetry (recommended) - poetry add pywhy-stats + >>> # or via poetry (recommended) + >>> poetry add pywhy-stats -## Installing from source +Installing from source +---------------------- To install **pywhy-stats** from source, first clone [the repository](https://github.com/pywhy/pywhy-stats): - git clone https://github.com/py-why/pywhy-stats.git - cd pywhy-stats + + >>> git clone https://github.com/py-why/pywhy-stats.git + >>> cd pywhy-stats Then run installation via poetry (recommended) - poetry install + + >>> poetry install or via pip - pip install -e . + >>> pip install -e . diff --git a/doc/use.rst b/doc/use.rst index 6950947..3cc2c8c 100644 --- a/doc/use.rst +++ b/doc/use.rst @@ -1,7 +1,7 @@ :orphan: Examples and Tutorials using pywhy-stats -======================================= +======================================== To be able to effectively use pywhy-stats, you can look at some of the basic examples here to learn everything you need from concepts to explicit code examples. diff --git a/doc/whats_new/v0.1.rst b/doc/whats_new/v0.1.rst index a3607d1..0baaa9d 100644 --- a/doc/whats_new/v0.1.rst +++ b/doc/whats_new/v0.1.rst @@ -26,6 +26,7 @@ Version 0.1 Changelog --------- +- |Feature| Implement partial correlation test :func:`pywhy_stats.fisherz`, by `Adam Li`_ (:pr:`7`) Code and Documentation Contributors diff --git a/poetry.lock b/poetry.lock index 0c4596b..2eba68e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -430,29 +430,29 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} [[package]] name = "cmake" -version = "3.26.1" +version = "3.26.3" description = "CMake is an open-source, cross-platform family of tools designed to build, test and package software" category = "dev" optional = false python-versions = "*" files = [ - {file = "cmake-3.26.1-py2.py3-none-macosx_10_10_universal2.macosx_10_10_x86_64.macosx_11_0_arm64.macosx_11_0_universal2.whl", hash = "sha256:d8a7e0cc8677677a732aff3e3fd0ad64eeff43cac772614b03c436912247d0d8"}, - {file = "cmake-3.26.1-py2.py3-none-manylinux2010_i686.manylinux_2_12_i686.whl", hash = "sha256:f2f721f5aebe304c281ee4b1d2dfbf7f4a52fca003834b2b4a3ba838aeded63c"}, - {file = "cmake-3.26.1-py2.py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:63a012b72836702eadfe4fba9642aeb17337f26861f4768e837053f40e98cb46"}, - {file = "cmake-3.26.1-py2.py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2b72be88b7bfaa6ae59566cbb9d6a5553f19b2a8d14efa6ac0cf019a29860a1b"}, - {file = "cmake-3.26.1-py2.py3-none-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:1278354f7210e22458aa9137d46a56da1f115a7b76ad2733f0bf6041fb40f1dc"}, - {file = "cmake-3.26.1-py2.py3-none-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:de96a5522917fba0ab0da2d01d9dd9462fa80f365218bf27162d539c2335758f"}, - {file = "cmake-3.26.1-py2.py3-none-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:449928ad7dfcd41e4dcff64c7d44f86557883c70577666a19e79e22d783bbbd0"}, - {file = "cmake-3.26.1-py2.py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:19fa3e457afecf2803265f71652ef17c3f1d317173c330ba46767a0853d38fa0"}, - {file = "cmake-3.26.1-py2.py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:43360650d60d177d979e4ad0a5f31afa286e6d88f5350f7a38c29d94514900eb"}, - {file = "cmake-3.26.1-py2.py3-none-musllinux_1_1_i686.whl", hash = "sha256:16aac10363bc926da5109a59ef8fe46ddcd7e3d421de61f871b35524eef2f1ae"}, - {file = "cmake-3.26.1-py2.py3-none-musllinux_1_1_ppc64le.whl", hash = "sha256:e460ba5070be4dcac9613cb526a46db4e5fa19d8b909a8d8d5244c6cc3c777e1"}, - {file = "cmake-3.26.1-py2.py3-none-musllinux_1_1_s390x.whl", hash = "sha256:fd2ecc0899f7939a014bd906df85e8681bd63ce457de3ab0b5d9e369fa3bdf79"}, - {file = "cmake-3.26.1-py2.py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:22781a23e274ba9bf380b970649654851c1b4b9d83b65fec12ee2e2e03b6ffc4"}, - {file = "cmake-3.26.1-py2.py3-none-win32.whl", hash = "sha256:7b4e81de30ac1fb2f1eb5287063e140b53f376fd9ed7e2060c1c7b5917bd5f83"}, - {file = "cmake-3.26.1-py2.py3-none-win_amd64.whl", hash = "sha256:90845b6c87a25be07e9220f67dd7f6c891c6ec14d764d37335218d97f9ea4520"}, - {file = "cmake-3.26.1-py2.py3-none-win_arm64.whl", hash = "sha256:43bd96327e2631183bb4829ba20cb810e20b4b0c68f852fcd7082fbb5359d57c"}, - {file = "cmake-3.26.1.tar.gz", hash = "sha256:4e0eb3c03dcf2d459f78d96cc85f7482476aeb1ae5ada65150b1db35c0f70cc7"}, + {file = "cmake-3.26.3-py2.py3-none-macosx_10_10_universal2.macosx_10_10_x86_64.macosx_11_0_arm64.macosx_11_0_universal2.whl", hash = "sha256:9d38ea5b4999f8f042a071bea3e213f085bac26d7ab54cb5a4c6a193c4baf132"}, + {file = "cmake-3.26.3-py2.py3-none-manylinux2010_i686.manylinux_2_12_i686.whl", hash = "sha256:6e5fcd1cfaac33d015e2709e0dd1b7ad352a315367012ac359c9adc062cf075b"}, + {file = "cmake-3.26.3-py2.py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:4d3185738a6405aa15801e684f8d589b00570da4cc676cb1b5bbc902e3023e53"}, + {file = "cmake-3.26.3-py2.py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b20f7f7ea316ce7bb158df0e3c3453cfab5048939f1291017d16a8a36ad33ae6"}, + {file = "cmake-3.26.3-py2.py3-none-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:46aa385e19c9e4fc95d7d6ce5ee0bbe0d69bdeac4e9bc95c61f78f3973c2f626"}, + {file = "cmake-3.26.3-py2.py3-none-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:71e1df5587ad860b9829211380c42fc90ef2413363f12805b1fa2d87769bf876"}, + {file = "cmake-3.26.3-py2.py3-none-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:543b6958d1615327f484a07ab041029b1740918a8baa336adc9f5f0cbcd8fbd8"}, + {file = "cmake-3.26.3-py2.py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1bc7b47456256bdcc41069f5c658f232bd6e15bf4796d115f6ec98800793daff"}, + {file = "cmake-3.26.3-py2.py3-none-musllinux_1_1_aarch64.whl", hash = "sha256:2ae3db2c2be50fdaf0c9f3a23b2206e9dcd55ca124f16486a841b939f50b595e"}, + {file = "cmake-3.26.3-py2.py3-none-musllinux_1_1_i686.whl", hash = "sha256:1798547b23b89030518c5668dc55aed0e1d01867cf91d7a94e15d33f62a56fd0"}, + {file = "cmake-3.26.3-py2.py3-none-musllinux_1_1_ppc64le.whl", hash = "sha256:d3017a08e6ba53ec2486d89a7953a81d4c4a068fc9f29d83e209f295dd9c59f3"}, + {file = "cmake-3.26.3-py2.py3-none-musllinux_1_1_s390x.whl", hash = "sha256:a922a6f6c1580d0db17b0b75f82e619441dd43c7f1d6a35f7d27e709db48bdbb"}, + {file = "cmake-3.26.3-py2.py3-none-musllinux_1_1_x86_64.whl", hash = "sha256:e0ed796530641c8a21a423f9bb7882117dbbeee11ec78dbc335402a678d937ae"}, + {file = "cmake-3.26.3-py2.py3-none-win32.whl", hash = "sha256:27a6fa1b97744311a7993d6a1e0ce14bd73696dab9ceb96701f1ec11edbd5053"}, + {file = "cmake-3.26.3-py2.py3-none-win_amd64.whl", hash = "sha256:cf910bbb488659d300c86b1dac77e44eeb0457bde2cf76a42d7e51f691544b21"}, + {file = "cmake-3.26.3-py2.py3-none-win_arm64.whl", hash = "sha256:24741a304ada699b339034958777d9a1472ac8ddb9b6194d74f814287ca091ae"}, + {file = "cmake-3.26.3.tar.gz", hash = "sha256:b54cde1f1c0573321b22382bd2ffaf5d08f65188572d128cd4867fb9669723c5"}, ] [package.extras] @@ -1517,13 +1517,13 @@ six = ">=1.4.1" [[package]] name = "lit" -version = "16.0.0" +version = "16.0.1" description = "A Software Testing Tool" category = "dev" optional = false python-versions = "*" files = [ - {file = "lit-16.0.0.tar.gz", hash = "sha256:3c4ac372122a1de4a88deb277b956f91b7209420a0bef683b1ab2d2b16dabe11"}, + {file = "lit-16.0.1.tar.gz", hash = "sha256:630a47291b714cb115015df23ab04267c24fe59aec7ecd7e637d5c75cdb45c91"}, ] [[package]] diff --git a/pywhy_stats/__init__.py b/pywhy_stats/__init__.py index 3a8d6d5..97e6c4c 100644 --- a/pywhy_stats/__init__.py +++ b/pywhy_stats/__init__.py @@ -1 +1,3 @@ +from . import fisherz from ._version import __version__ # noqa: F401 +from .independence import Methods, independence_test diff --git a/pywhy_stats/fisherz.py b/pywhy_stats/fisherz.py new file mode 100644 index 0000000..eb286e7 --- /dev/null +++ b/pywhy_stats/fisherz.py @@ -0,0 +1,126 @@ +"""Independence test using Fisher-Z's test. + +This test is also known as the partial correlation independence test. +It works on Gaussian random variables. + +When the data is not Gaussian, this test is not valid. In this case, we recommend +using the Kernel independence test at . + +Examples +-------- +>>> import pywhy_stats as ps +>>> res = ps.fisherz.ind([1, 2, 3], [4, 5, 6]) +>>> print(res.pvalue) +>>> 1.0 +""" + +from math import log, sqrt +from typing import Optional + +import numpy as np +from numpy.typing import ArrayLike +from scipy.stats import norm + +from .pvalue_result import PValueResult + + +def ind(X: ArrayLike, Y: ArrayLike, correlation_matrix: Optional[ArrayLike] = None) -> PValueResult: + """Perform an independence test using Fisher-Z's test. + + Works on Gaussian random variables. This test is also known as the + correlation test. + + Parameters + ---------- + X : ArrayLike of shape (n_samples,) + The first node variable. + Y : ArrayLike of shape (n_samples,) + The second node variable. + correlation_matrix : ArrayLike of shape (2, 2), optional + The precomputed correlation matrix between X and Y., by default None. + + Returns + ------- + statistic : float + The test statistic. + pvalue : float + The p-value of the test. + """ + return _fisherz(X, Y, condition_on=None, correlation_matrix=correlation_matrix) + + +def condind( + X: ArrayLike, + Y: ArrayLike, + condition_on: ArrayLike, + correlation_matrix: Optional[ArrayLike] = None, +) -> PValueResult: + """Perform a conditional independence test using Fisher-Z's test. + + Parameters + ---------- + X : ArrayLike of shape (n_samples,) + The first node variable. + Y : ArrayLike of shape (n_samples,) + The second node variable. + condition_on : ArrayLike of shape (n_samples, n_variables) + The conditioning set. + correlation_matrix : ArrayLike of shape (2 + n_variables, 2 + n_variables), optional + The precomputed correlation matrix between X, Y and ``condition_on``, by default None. + + Returns + ------- + statistic : float + The test statistic. + pvalue : float + The p-value of the test. + """ + return _fisherz(X, Y, condition_on=condition_on, correlation_matrix=correlation_matrix) + + +def _fisherz( + X: ArrayLike, + Y: ArrayLike, + condition_on: Optional[ArrayLike] = None, + correlation_matrix: Optional[ArrayLike] = None, +) -> PValueResult: + """Perform an independence test using Fisher-Z's test. + + Parameters + ---------- + X : ArrayLike of shape (n_samples,) + The first node variable. + Y : ArrayLike of shape (n_samples,) + The second node variable. + condition_on : ArrayLike of shape (n_samples, n_variables) + If `None` (default), will run a marginal independence test. + correlation_matrix : np.ndarray of shape (n_variables, n_variables), optional + ``None`` means without the parameter of correlation matrix and + the correlation will be computed from the data., by default None. + + Returns + ------- + statistic : float + The test statistic. + pvalue : float + The p-value of the test. + """ + if condition_on is None: + condition_on = np.empty((X.shape[0], 0)) + + # compute the correlation matrix within the specified data + data = np.hstack((X, Y, condition_on)) + sample_size = data.shape[0] + if correlation_matrix is None: + correlation_matrix = np.corrcoef(data.T) + + inv = np.linalg.pinv(correlation_matrix) + r = -inv[0, 1] / sqrt(inv[0, 0] * inv[1, 1]) + + # apply the Fisher Z-transformation + Z = 0.5 * log((1 + r) / (1 - r)) + + # compute the test statistic + statistic = sqrt(sample_size - condition_on.shape[1] - 3) * abs(Z) + p = 2 * (1 - norm.cdf(abs(statistic))) + return PValueResult(statistic=statistic, pvalue=p) diff --git a/pywhy_stats/independence.py b/pywhy_stats/independence.py index 4d35b4f..0b35b2e 100644 --- a/pywhy_stats/independence.py +++ b/pywhy_stats/independence.py @@ -1,15 +1,24 @@ from enum import Enum +from types import ModuleType from typing import Optional +from warnings import warn -from numpy.testing import ArrayLike +import scipy.stats +from numpy.typing import ArrayLike -from .p_value_result import PValueResult +from pywhy_stats import fisherz + +from .pvalue_result import PValueResult class Methods(Enum): """Methods for independence testing.""" AUTO = 0 + """Choose an automatic method based on the data.""" + + FISHERZ = fisherz + """:py:mod:`~pywhy_stats.fisherz`: Fisher's Z test for independence""" def independence_test( @@ -26,16 +35,16 @@ def independence_test( Parameters ---------- - X : numpy.ndarray, shape (n, d) + X : ArrayLike, shape (n_samples, n_features_x) Data matrix for X. - Y : numpy.ndarray, shape (n, m) + Y : ArrayLike, shape (n_samples, n_features_y) Data matrix for Y. - condition_on : numpy.ndarray or None, shape (n, k), optional + condition_on : ArrayLike or None, shape (n_samples, n_features_z), optional Data matrix for the conditioning variables. If None is given, an unconditional test is performed. method : Methods, optional - Independence test method from the Methods enum. Default is Methods.AUTO, which will - automatically select an appropriate method. + Independence test method from the :class:`pywhy_stats.Methods` enum. Default is + `Methods.AUTO`, which will automatically select an appropriate method. **kwargs : dict or None, optional Additional keyword arguments to be passed to the specific test method @@ -44,5 +53,34 @@ def independence_test( result : PValueResult An instance of the PValueResult data class, containing the p-value, test statistic, and any additional information related to the independence test. + + See Also + -------- + fisherz : Fisher's Z test for independence """ - pass + method_module: ModuleType + if method == Methods.AUTO: + method_module = Methods.FISHERZ + else: + method_module = method + + if method_module == Methods.FISHERZ: + if condition_on is None: + data = [X, Y] + else: + data = [X, Y, condition_on] + for _data in data: + _, pval = scipy.stats.normaltest(_data) + + # XXX: we should add pinguoin as an optional dependency for doing multi-comp stuff + if pval < 0.05: + warn( + "The provided data does not seem to be Gaussian, but the Fisher-Z test " + "assumes that the data follows a Gaussian distribution. The result should " + "be interpreted carefully or consider a different independence test method." + ) + + if condition_on is None: + return method_module.ind(X, Y, method, **kwargs) + else: + return method_module.condind(X, Y, condition_on, method, **kwargs) diff --git a/pywhy_stats/p_value_result.py b/pywhy_stats/pvalue_result.py similarity index 74% rename from pywhy_stats/p_value_result.py rename to pywhy_stats/pvalue_result.py index 0509f47..60b5469 100644 --- a/pywhy_stats/p_value_result.py +++ b/pywhy_stats/pvalue_result.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from typing import Optional, Union -from numpy.testing import ArrayLike +from numpy.typing import ArrayLike @dataclass @@ -10,16 +10,16 @@ class PValueResult: Attributes ---------- - p_value: float + pvalue : float The p-value represents the probability of observing the given test statistic, or more extreme results, under a certain null hypothesis. - test_statistic: float or numpy.ndarray or None + statistic : float or ArrayLike or None The test statistic of the hypothesis test, which might not always be available. - additional_information: object or None + additional_information : object or None Any additional information or metadata relevant to the specific test conducted. These could also be a state of the method to re-use it. """ - p_value: float - test_statistic: Optional[Union[float, ArrayLike]] = None + pvalue: float + statistic: Optional[Union[float, ArrayLike]] = None additional_information: Optional[object] = None diff --git a/tests/test_fisherz_test.py b/tests/test_fisherz_test.py new file mode 100644 index 0000000..cc5e052 --- /dev/null +++ b/tests/test_fisherz_test.py @@ -0,0 +1,44 @@ +import flaky +import numpy as np + +from pywhy_stats import fisherz + + +@flaky.flaky(max_runs=3, min_passes=1) +def test_fisherz_marg_ind(): + """Test FisherZ marginal independence test for Gaussian data.""" + rng = np.random.default_rng() + + # We construct a SCM where X1 -> Y <- X and Y -> Z + # so X1 is independent from X, but conditionally dependent + # given Y or Z + X = rng.standard_normal((300, 1)) + X1 = rng.standard_normal((300, 1)) + Y = X + X1 + 0.5 * rng.standard_normal((300, 1)) + Z = Y + 0.1 * rng.standard_normal((300, 1)) + + res = fisherz.ind(X, X1) + assert res.pvalue > 0.05 + res = fisherz.ind(X, Z) + assert res.pvalue < 0.05 + + +@flaky.flaky(max_runs=3, min_passes=1) +def test_fisherz_cond_ind(): + """Test FisherZ conditional independence test for Gaussian data.""" + rng = np.random.default_rng() + + # We construct a SCM where X1 -> Y <- X and Y -> Z + # so X1 is independent from X, but conditionally dependent + # given Y or Z + X = rng.standard_normal((300, 1)) + X1 = rng.standard_normal((300, 1)) + Y = X + X1 + 0.5 * rng.standard_normal((300, 1)) + Z = Y + 0.1 * rng.standard_normal((300, 1)) + + res = fisherz.condind(X, X1, Z) + assert res.pvalue < 0.05 + res = fisherz.condind(X, X1, Y) + assert res.pvalue < 0.05 + res = fisherz.condind(X, Z, Y) + assert res.pvalue > 0.05