Skip to content

Commit

Permalink
Add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescMartiEscofetQC committed Jul 8, 2024
1 parent b9b89f9 commit 2db627a
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 1 deletion.
5 changes: 5 additions & 0 deletions metalearners/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ def check_spox_installed() -> None:


def infer_dtype_and_shape_onnx(tensor) -> tuple[np.dtype, tuple]:
"""Returns the ``np.dtype`` and shape of an ONNX tensor."""
check_onnx_installed()
import onnx

Expand All @@ -534,6 +535,8 @@ def infer_dtype_and_shape_onnx(tensor) -> tuple[np.dtype, tuple]:


def infer_probabilities_output(model) -> tuple[int, str]:
"""Returns the index and name of the output which contains the probabilities outcome
in a ONNX classifier."""
check_onnx_installed()
for i, output in enumerate(model.graph.output):
if output.name in ["probabilities", "output_probability"]:
Expand All @@ -542,6 +545,8 @@ def infer_probabilities_output(model) -> tuple[int, str]:


def infer_input_dict(model) -> dict:
"""Returns a dict where the keys are the input names of the model and the values are
``spox.Var`` with the corresponding shape and type."""
check_onnx_installed()
check_spox_installed()
from spox import Tensor, Var, argument
Expand Down
3 changes: 3 additions & 0 deletions metalearners/drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
check_onnx_installed,
check_spox_installed,
clip_element_absolute_value_to_epsilon,
copydoc,
get_one,
get_predict,
get_predict_proba,
Expand Down Expand Up @@ -324,7 +325,9 @@ def _pseudo_outcome(

return pseudo_outcome

@copydoc(MetaLearner.build_onnx, sep="")
def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
"""In the DRLearner case, the necessary models are: ``"treatment_model"``."""
check_onnx_installed()
check_spox_installed()
import spox.opset.ai.onnx.v21 as op
Expand Down
19 changes: 18 additions & 1 deletion metalearners/metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,15 @@ def _default_scoring() -> Scoring:
def _validate_onnx_models(
self, models: Mapping[str, Sequence], necessary_models: Set[str]
):
"""Validates that the converted ONNX models are correct.
Specifically it validates the following:
* The ``necessary_models`` are present in the ``models``` dictionary
* The number of models for each model matches the cardinality in the MetaLearner
* All ONNX have the same input format
* The models with ``"predict"`` as ``predict_method`` have only one output
* The models with ``"predict_proba"`` as ``predict_method`` have a probabilities output
"""
if set(models.keys()) != necessary_models:
raise ValueError(f"{necessary_models} should be present in models keys.")
specs_look_up = (
Expand Down Expand Up @@ -1089,7 +1098,15 @@ def _validate_feature_set_all(self):
)

@abstractmethod
def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"): ...
def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
"""Convert the MetaLearner to an ONNX model.
``output_name`` can be used to change the output name of the ONNX model.
``models`` should be a dictionary of sequences with the necessary base models converted to
ONNX.
"""
...


class _ConditionalAverageOutcomeMetaLearner(MetaLearner, ABC):
Expand Down
2 changes: 2 additions & 0 deletions metalearners/rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,9 @@ def _pseudo_outcome_and_weights(

return pseudo_outcomes, weights

@copydoc(MetaLearner.build_onnx, sep="")
def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
"""In the RLearner case, the necessary models are: ``"treatment_model"``."""
check_onnx_installed()
check_spox_installed()
import spox.opset.ai.onnx.v21 as op
Expand Down
6 changes: 6 additions & 0 deletions metalearners/tlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from metalearners._utils import (
check_onnx_installed,
check_spox_installed,
copydoc,
index_matrix,
infer_input_dict,
infer_probabilities_output,
Expand Down Expand Up @@ -136,7 +137,12 @@ def evaluate(
is_treatment_model=False,
)

@copydoc(MetaLearner.build_onnx, sep="")
def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
"""In the TLearner case, the necessary models are:
``"variant_outcome_model"``.
"""
check_onnx_installed()
check_spox_installed()
import spox.opset.ai.onnx.v21 as op
Expand Down
4 changes: 4 additions & 0 deletions metalearners/xlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from metalearners._utils import (
check_onnx_installed,
check_spox_installed,
copydoc,
get_one,
get_predict,
get_predict_proba,
Expand Down Expand Up @@ -406,7 +407,10 @@ def _pseudo_outcome(

return imputed_te_control, imputed_te_treatment

@copydoc(MetaLearner.build_onnx, sep="")
def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
"""In the XLearner case, the necessary models are: ``"propensity_model"``,
``"control_effect_model"`` and ``"treatment_effect_model"``."""
check_onnx_installed()
check_spox_installed()
import spox.opset.ai.onnx.v21 as op
Expand Down

0 comments on commit 2db627a

Please sign in to comment.