Skip to content

Commit

Permalink
Merge branch 'main' into fix_slearner_feature_set
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescMartiEscofetQC committed Jun 25, 2024
2 parents e871ada + ed92851 commit 1f0b995
Show file tree
Hide file tree
Showing 41 changed files with 23,176 additions and 553 deletions.
20 changes: 7 additions & 13 deletions .github/workflows/benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,14 @@ jobs:
name: Run benchmarks
runs-on: ubuntu-latest
steps:
- name: Checkout
- name: Checkout branch
uses: actions/checkout@v4
with:
ref: main
fetch-depth: 0
- name: Setup conda
uses: mamba-org/setup-micromamba@v1
with:
condarc-file: .github/assets/.condarc
environment-file: benchmarks/environment.yml
- name: Install package
run: python -m pip install -e .
- name: Run benchmark.py
run: python benchmarks/benchmark.py
- name: Set up pixi
uses: prefix-dev/setup-pixi@v0.8.1
- name: Install repository
run: |
pixi run -e default postinstall
pixi run -e default python benchmarks/benchmark.py
- name: Update readme.md
run: |
line_number=`grep "| T-learner" benchmarks/readme.md -n | cut -f1 -d:`
Expand Down
67 changes: 20 additions & 47 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,68 +18,41 @@ jobs:
steps:
- name: Checkout branch
uses: actions/checkout@v4
- name: Set up pixi
uses: prefix-dev/setup-pixi@v0.8.1
with:
ref: ${{ github.head_ref }}
# needed for 'pre-commit-mirrors-insert-license'
fetch-depth: 0
- name: Run pre-commit-conda
uses: quantco/pre-commit-conda@v1

mypy-type-checks:
name: Mypy Type Checks
runs-on: ubuntu-latest
steps:
- name: Checkout branch
uses: actions/checkout@v4
with:
ref: ${{ github.head_ref }}
fetch-depth: 0
- name: Set up Conda env
uses: mamba-org/setup-micromamba@422500192359a097648154e8db4e39bdb6c6eed7
with:
condarc-file: .github/assets/.condarc
environment-file: environment.yml
cache-environment: true
- name: Install repository
run: python -m pip install --no-build-isolation --no-deps --disable-pip-version-check -e .
- name: Run mypy
run: mypy .
environments: default lint
- name: pre-commit
run: pixi run pre-commit-run --color=always --show-diff-on-failure

unit-tests:
name: Unit Tests - ${{ matrix.os == 'ubuntu-latest' && 'Linux' || 'Windows' }} - Python ${{ matrix.python-version }}
name: Unit Tests
timeout-minutes: 30
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python-version: ["3.10", "3.11", "3.12"]
os: [ubuntu-latest, macos-latest]
env: ["py310", "py311", "py312"]
steps:
- name: Checkout branch
uses: actions/checkout@v4
- name: Set up pixi
uses: prefix-dev/setup-pixi@v0.8.1
with:
ref: ${{ github.head_ref }}
fetch-depth: 0
- name: Set up Conda env
uses: mamba-org/setup-micromamba@422500192359a097648154e8db4e39bdb6c6eed7
with:
condarc-file: .github/assets/.condarc
environment-file: environment.yml
cache-environment: true
create-args: >-
python=${{ matrix.python-version }}
pytest-md
pytest-emoji
environments: ${{ matrix.env }}
- name: Install repository
run: python -m pip install --no-build-isolation --no-deps --disable-pip-version-check -e .
- name: Run unittests
uses: quantco/pytest-action@v2
run: |
pixi run -e ${{ matrix.env }} postinstall
pixi run -e ${{ matrix.env }} coverage
- name: Generate code coverage report
if: matrix.python-version == '3.12'
uses: codecov/codecov-action@v3.1.3
with:
report-title: Unit Tests - ${{ matrix.os == 'ubuntu-latest' && 'Linux' || 'Windows' }} - Python ${{ matrix.python-version }}
custom-arguments: --cov=metalearners --cov-report=xml --cov-report term-missing --color=yes
file: ./coverage.xml
- name: Upload coverage reports to Codecov
if: matrix.python-version == '3.12'
uses: codecov/codecov-action@125fc84a9a348dbcf27191600683ec096ec9021c
if: matrix.env == 'py312'
uses: codecov/codecov-action@v3.1.3
env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
file: ./coverage.xml
102 changes: 69 additions & 33 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,45 +1,81 @@
exclude: ^\.pixi
repos:
- repo: https://github.com/Quantco/pre-commit-mirrors-insert-license
rev: 1.3.0
- repo: local
hooks:
# ensure pixi environments are up to date
# workaround for https://github.com/prefix-dev/pixi/issues/1482
- id: pixi-install
name: pixi-install
entry: pixi install -e default -e lint
language: system
always_run: true
require_serial: true
pass_filenames: false
- id: insert-license
name: insert-license
entry: pixi run -e lint insert-license
types: [python]
language: system
args:
- --license-base64
- IyBDb3B5cmlnaHQgKGMpIFF1YW50Q28gMjAyNC0yMDI0CiMgU1BEWC1MaWNlbnNlLUlkZW50aWZpZXI6IEJTRC0zLUNsYXVzZQo=
- Q29weXJpZ2h0IChjKSBRdWFudENvIDIwMjQtMjAyNApTUERYLUxpY2Vuc2UtSWRlbnRpZmllcjogQlNELTMtQ2xhdXNl
- --dynamic-years
- --comment-style
- "#"
- repo: https://github.com/Quantco/pre-commit-mirrors-docformatter
rev: 1.7.5
hooks:
- id: docformatter-conda
- repo: https://github.com/Quantco/pre-commit-mirrors-ruff
rev: 0.4.7
hooks:
- id: ruff-conda
- repo: https://github.com/Quantco/pre-commit-mirrors-black
rev: 24.4.2
hooks:
- id: docformatter
name: docformatter
entry: pixi run -e lint docformatter
args: [-i]
types: [python]
language: system
- id: ruff
name: ruff
entry: pixi run -e lint ruff check --fix --exit-non-zero-on-fix --force-exclude
language: system
types_or: [python, pyi]
require_serial: true
- id: black-conda
- repo: https://github.com/Quantco/pre-commit-mirrors-mypy
rev: 1.10.0
hooks:
- id: mypy-conda
additional_dependencies: [-c, conda-forge, types-setuptools]
- repo: https://github.com/Quantco/pre-commit-mirrors-prettier
rev: 3.2.5
hooks:
- id: prettier-conda
name: black-conda
entry: pixi run -e lint black
language: system
require_serial: true
types: [python]
- id: mypy
name: mypy
entry: pixi run -e default mypy
language: system
types: [python]
args: ["--ignore-missing-imports", "--scripts-are-modules"]
require_serial: true
- id: prettier
name: prettier
entry: pixi run -e lint prettier
language: system
files: \.(md|yml|yaml)$
- repo: https://github.com/Quantco/pre-commit-mirrors-pre-commit-hooks
rev: 4.6.0
hooks:
- id: trailing-whitespace-conda
- id: end-of-file-fixer-conda
- id: check-merge-conflict-conda
types: [text]
args: ["--write", "--list-different", "--ignore-unknown"]
- id: trailing-whitespace
name: trim trailing whitespace
language: system
entry: pixi run -e lint trailing-whitespace-fixer
types: [text]
stages: [commit, push, manual]
- id: end-of-file-fixer
name: fix end of files
language: system
entry: pixi run -e lint end-of-file-fixer
types: [text]
stages: [commit, push, manual]
- id: check-merge-conflict
name: check for merge conflicts
language: system
entry: pixi run -e lint check-merge-conflict
types: [text]
args: ["--assume-in-merge"]
- repo: https://github.com/Quantco/pre-commit-mirrors-typos
rev: 1.21.0
hooks:
- id: typos-conda
- id: typos
name: typos
entry: pixi run -e lint typos --force-exclude
language: system
types: [text]
exclude: "\\.csv$"
require_serial: true
16 changes: 6 additions & 10 deletions .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,13 @@ version: 2
build:
os: ubuntu-20.04
tools:
python: mambaforge-4.10
python: mambaforge-latest
commands:
- mamba install -c conda-forge -c nodefaults pixi
- pixi run -e docs postinstall
- pixi run -e docs docs
- pixi run -e docs readthedocs
sphinx:
configuration: docs/conf.py
python:
install:
- method: pip
path: .
extra_requirements:
- doc
- test
conda:
environment: environment.yml
formats:
- pdf
4 changes: 2 additions & 2 deletions benchmarks/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# # Copyright (c) QuantCo 2024-2024
# # SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) QuantCo 2024-2024
# SPDX-License-Identifier: BSD-3-Clause

import json
from pathlib import Path
Expand Down
4 changes: 2 additions & 2 deletions data/download_twins.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# # Copyright (c) QuantCo 2024-2024
# # SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) QuantCo 2024-2024
# SPDX-License-Identifier: BSD-3-Clause

import pandas as pd
from git_root import git_root
Expand Down
24 changes: 15 additions & 9 deletions docs/background.rst
Original file line number Diff line number Diff line change
Expand Up @@ -358,15 +358,15 @@ It is an extension of the T-Learner and consists of three stages:
\widetilde{D}_1^i &:= Y^i_1 - \hat{\mu}_0(X^i_1) \\
\widetilde{D}_0^i &:= \hat{\mu}_1(X^i_0) - Y^i_0
Then estimate :math:`\tau_1(x) := \mathbb{E}[\widetilde{D}^i_1 | X]` and
:math:`\tau_0(x) := \mathbb{E}[\widetilde{D}^i_0 | X]` using the observations in the
Then estimate :math:`\tau_1(x) := \mathbb{E}[\widetilde{D}^i_1 | X=x]` and
:math:`\tau_0(x) := \mathbb{E}[\widetilde{D}^i_0 | X=x]` using the observations in the
treatment group and the ones in the control group respectively.
#. Define the CATE estimate by a weighted average of the two estimates in stage 2:

.. math::
\hat{\tau}^X(x) := g(x)\hat{\tau}_0(x) + (1-g(x))\hat{\tau}_1(x)
where :math:`g(x) \in [0,1]`. We take :math:`g(x) := \mathbb{E}[W = 1 | X]` to be
where :math:`g(x) \in [0,1]`. We take :math:`g(x) := \mathbb{E}[W = 1 | X=x]` to be
the propensity score.

More than binary treatment
Expand All @@ -388,8 +388,8 @@ In the case of multiple discrete treatments the stages are similar to the binary
\widetilde{D}_k^i &:= Y^i_k - \hat{\mu}_0(X^i_k) \\
\widetilde{D}_{0,k}^i &:= \hat{\mu}_k(X^i_0) - Y^i_0
Then :math:`\tau_k(x) := \mathbb{E}[\widetilde{D}^i_k | X]` is estimated using the
observations which received treatment :math:`k` and :math:`\tau_{0,k}(x) := \mathbb{E}[\widetilde{D}^i_{0,k} | X]`
Then :math:`\tau_k(x) := \mathbb{E}[\widetilde{D}^i_k | X=x]` is estimated using the
observations which received treatment :math:`k` and :math:`\tau_{0,k}(x) := \mathbb{E}[\widetilde{D}^i_{0,k} | X=x]`
using the observations in the control group.

#. Finally the CATE for each variant is estimated as a weighted average:
Expand Down Expand Up @@ -419,9 +419,15 @@ It consists of two stages:

.. math::
\DeclareMathOperator*{\argmin}{arg\,min}
\hat{\tau}^R (x) &:= \argmin_{\tau}\Bigg\{\mathbb{E}\Bigg[\bigg(\left\{Y^i - \hat{m}(X^i)\right\} - \left\{W^i - \hat{e}(X^i)\right\}\tau(X^i)\bigg)^2\Bigg]\Bigg\} \\
\hat{\tau}^R (\cdot) &:= \argmin_{\tau}\Bigg\{\mathbb{E}\Bigg[\bigg(\left\{Y^i - \hat{m}(X^i)\right\} - \left\{W^i - \hat{e}(X^i)\right\}\tau(X^i)\bigg)^2\Bigg]\Bigg\} \\
&=\argmin_{\tau}\left\{\mathbb{E}\left[\left\{W^i - \hat{e}(X^i)\right\}^2\bigg(\frac{\left\{Y^i - \hat{m}(X^i)\right\}}{\left\{W^i - \hat{e}(X^i)\right\}} - \tau(X^i)\bigg)^2\right]\right\} \\
&= \argmin_{\tau}\left\{\mathbb{E}\left[{\tilde{W}^i}^2\bigg(\frac{\tilde{Y}^i}{\tilde{W}^i} - \tau(X^i)\bigg)^2\right]\right\}
&= \argmin_{\tau}\left\{\mathbb{E}\left[{\widetilde{W}^i}^2\bigg(\frac{\widetilde{Y}^i}{\widetilde{W}^i} - \tau(X^i)\bigg)^2\right]\right\}
Where

.. math::
\widetilde{W}^i &= W^i - \hat{e}(X^i) \\
\widetilde{Y}^i &= Y^i - \hat{m}(X^i)
And therefore any ML model which supports weighting each observation differently can be used for the final model.

Expand Down Expand Up @@ -484,7 +490,7 @@ It consists of two stages:
#. Estimate the CATE by regressing :math:`\varphi` on :math:`X`:

.. math::
\hat{\tau}^{DR}(x) := \mathbb{E}[\varphi(X^i, W^i, Y^i) | X^i]
\hat{\tau}^{DR}(x) := \mathbb{E}[\varphi(X^i, W^i, Y^i) | X^i=x]
More than binary treatment
**************************
Expand All @@ -508,4 +514,4 @@ In the case of multiple discrete treatments the stages are similar to the binary
treatment variant, :math:`\forall k \in \{1,\dots, K-1\}`:

.. math::
\hat{\tau}_k^{DR}(x) := \mathbb{E}[\varphi_k(X^i, W^i, Y^i) | X^i]
\hat{\tau}_k^{DR}(x) := \mathbb{E}[\varphi_k(X^i, W^i, Y^i) | X^i=x]
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# # Copyright (c) QuantCo 2024-2024
# # SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) QuantCo 2024-2024
# SPDX-License-Identifier: BSD-3-Clause

# Configuration file for the Sphinx documentation builder.
#
Expand Down
23 changes: 20 additions & 3 deletions docs/examples/example_basic.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@
"* We need to specify the observed treatment assignment ``w`` in the call to the\n",
" ``fit`` method.\n",
"* We need to specify whether we want in-sample or out-of-sample\n",
" estimates in the {meth}`~metalearners.TLearner.predict` call via ``is_oos``."
" CATE estimates in the {meth}`~metalearners.TLearner.predict` call via ``is_oos``. In the\n",
" case of in-sample predictions, the data passed to {meth}`~metalearners.TLearner.predict`\n",
" must be exactly the same as the data that was used to call {meth}`~metalearners.TLearner.fit`."
]
},
{
Expand Down Expand Up @@ -176,7 +178,7 @@
"Using a MetaLearner with two stages\n",
"-----------------------------------\n",
"\n",
"Instead of using a T-Learner, we can of course also some other\n",
"Instead of using a T-Learner, we can of course also use some other\n",
"MetaLearner, such as the {class}`~metalearners.RLearner`.\n",
"The R-Learner's documentation tells us that two more instantiation\n",
"parameters are necessary: ``propensity_model_factory`` and\n",
Expand Down Expand Up @@ -209,7 +211,22 @@
"metadata": {},
"source": [
"where we choose a classifier class to serve as a blueprint for our\n",
"eventual propensity model.\n",
"eventual propensity model. It is important to notice that although we consider the propensity\n",
"model a nuisance model, the initialization parameters for it are separated from the other\n",
"nuisance parameters to allow a more understandable user interface, see the next code prompt.\n",
"\n",
"In general, when initializing a MetaLearner, the ``nuisance_model_factory`` parameter will\n",
"be used to create all the nuisance models which are not a propensity model, the\n",
"``propensity_model_factory`` will be used for the propensity model if the MetaLearner\n",
"contains one, and the ``treatment_model_factory`` will be used for the models predicting\n",
"the CATE. To see the models present in each MetaLearner type see\n",
"{meth}`~metalearners.metalearner.MetaLearner.nuisance_model_specifications` and\n",
"{meth}`~metalearners.metalearner.MetaLearner.treatment_model_specifications`.\n",
"\n",
"In the {class}`~metalearners.RLearner` case, the ``nuisance_model_factory`` parameter will\n",
"be used to create the outcome model, the ``propensity_model_factory`` will be used for the\n",
"propensity model and the ``treatment_model_factory`` will be used for the model predicting\n",
"the CATE.\n",
"\n",
"If we want to make sure these models are initialized in a specific\n",
"way, e.g. with a specific value for the hyperparameter ``n_estimators``, we can do that\n",
Expand Down
Loading

0 comments on commit 1f0b995

Please sign in to comment.