Skip to content

Commit

Permalink
change indies to mask in drlearner and update CHANGELOG.rst
Browse files Browse the repository at this point in the history
  • Loading branch information
kyracho committed Sep 4, 2024
1 parent b248794 commit dae28f2
Show file tree
Hide file tree
Showing 17 changed files with 125 additions and 7 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Changelog

**New features**

* Rename ``_treatment_variants_indices`` to ``_treatment_variants_mask``in ``metalearner``, ``xlearner``, ``tlearner``, and ``test_learner``.
* Rename ``_treatment_variants_indices`` to ``_treatment_variants_mask``in ``metalearner``, ``xlearner``, ``tlearner``, ``drlearner``, and ``test_learner``.


0.11.0 (2024-09-xx)
Expand Down
7 changes: 7 additions & 0 deletions docs/api/metalearners.cross_fit_estimator.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
metalearners.cross\_fit\_estimator module
=========================================

.. automodule:: metalearners.cross_fit_estimator
:members:
:undoc-members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/api/metalearners.data_generation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
metalearners.data\_generation module
====================================

.. automodule:: metalearners.data_generation
:members:
:undoc-members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/api/metalearners.drlearner.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
metalearners.drlearner module
=============================

.. automodule:: metalearners.drlearner
:members:
:undoc-members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/api/metalearners.explainer.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
metalearners.explainer module
=============================

.. automodule:: metalearners.explainer
:members:
:undoc-members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/api/metalearners.grid_search.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
metalearners.grid\_search module
================================

.. automodule:: metalearners.grid_search
:members:
:undoc-members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/api/metalearners.metalearner.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
metalearners.metalearner module
===============================

.. automodule:: metalearners.metalearner
:members:
:undoc-members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/api/metalearners.outcome_functions.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
metalearners.outcome\_functions module
======================================

.. automodule:: metalearners.outcome_functions
:members:
:undoc-members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/api/metalearners.rlearner.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
metalearners.rlearner module
============================

.. automodule:: metalearners.rlearner
:members:
:undoc-members:
:show-inheritance:
29 changes: 29 additions & 0 deletions docs/api/metalearners.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
metalearners package
====================

Submodules
----------

.. toctree::
:maxdepth: 4

metalearners.cross_fit_estimator
metalearners.data_generation
metalearners.drlearner
metalearners.explainer
metalearners.grid_search
metalearners.metalearner
metalearners.outcome_functions
metalearners.rlearner
metalearners.slearner
metalearners.tlearner
metalearners.utils
metalearners.xlearner

Module contents
---------------

.. automodule:: metalearners
:members:
:undoc-members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/api/metalearners.slearner.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
metalearners.slearner module
============================

.. automodule:: metalearners.slearner
:members:
:undoc-members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/api/metalearners.tlearner.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
metalearners.tlearner module
============================

.. automodule:: metalearners.tlearner
:members:
:undoc-members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/api/metalearners.utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
metalearners.utils module
=========================

.. automodule:: metalearners.utils
:members:
:undoc-members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/api/metalearners.xlearner.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
metalearners.xlearner module
============================

.. automodule:: metalearners.xlearner
:members:
:undoc-members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/api/modules.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
metalearners
============

.. toctree::
:maxdepth: 4

metalearners
Binary file added docs/examples/model.onnx
Binary file not shown.
10 changes: 4 additions & 6 deletions metalearners/drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,12 @@ def fit_all_nuisance(
self._validate_treatment(w)
self._validate_outcome(y, w)

self._treatment_variants_indices = []
self._treatment_variants_mask = []

qualified_fit_params = self._qualified_fit_params(fit_params)

for treatment_variant in range(self.n_variants):
self._treatment_variants_indices.append(w == treatment_variant)
self._treatment_variants_mask.append(w == treatment_variant)

self._cv_split_indices: SplitIndices | None

Expand All @@ -168,10 +168,8 @@ def fit_all_nuisance(
for treatment_variant in range(self.n_variants):
nuisance_jobs.append(
self._nuisance_joblib_specifications(
X=index_matrix(
X, self._treatment_variants_indices[treatment_variant]
),
y=y[self._treatment_variants_indices[treatment_variant]],
X=index_matrix(X, self._treatment_variants_mask[treatment_variant]),
y=y[self._treatment_variants_mask[treatment_variant]],
model_kind=VARIANT_OUTCOME_MODEL,
model_ord=treatment_variant,
n_jobs_cross_fitting=n_jobs_cross_fitting,
Expand Down

0 comments on commit dae28f2

Please sign in to comment.