Skip to content

Commit

Permalink
Make methods private
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescMartiEscofetQC committed Jul 25, 2024
1 parent 8f1e9e0 commit d32d27f
Show file tree
Hide file tree
Showing 13 changed files with 35 additions and 35 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ Changelog

**New features**

* Add ``build_onnx`` to :class:`metalearners.MetaLearner` abstract class and implement it
* Add ``_build_onnx`` to :class:`metalearners.MetaLearner` abstract class and implement it
for :class:`metalearners.TLearner`, :class:`metalearners.XLearner`, :class:`metalearners.RLearner`
and :class:`metalearners.DRLearner`.

* Add ``necessary_onnx_models`` to :class:`metalearners.MetaLearner`.
* Add ``_necessary_onnx_models`` to :class:`metalearners.MetaLearner`.

0.8.0 (2024-07-22)
------------------
Expand Down
6 changes: 3 additions & 3 deletions docs/examples/example_onnx.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@
"\n",
"Before being able to convert the MetaLearner to ONXX we need to manually convert the necessary\n",
"base models for the prediction. To get a list of the necessary base models that need to be\n",
"converted we can use :meth:`~metalearners.MetaLearner.necessary_onnx_models`."
"converted we can use :meth:`~metalearners.MetaLearner._necessary_onnx_models`."
]
},
{
Expand All @@ -156,7 +156,7 @@
"metadata": {},
"outputs": [],
"source": [
"xlearner.necessary_onnx_models()"
"xlearner._necessary_onnx_models()"
]
},
{
Expand Down Expand Up @@ -213,7 +213,7 @@
" )\n",
" onnx_models[\"treatment_effect_model\"].append(onnx_model)\n",
"\n",
"onnx_model = xlearner.build_onnx(onnx_models)"
"onnx_model = xlearner._build_onnx(onnx_models)"
]
},
{
Expand Down
10 changes: 5 additions & 5 deletions metalearners/drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,23 +406,23 @@ def _pseudo_outcome(
return pseudo_outcome

@classmethod
def necessary_onnx_models(cls) -> set[str]:
def _necessary_onnx_models(cls) -> set[str]:
return {TREATMENT_MODEL}

@copydoc(MetaLearner.build_onnx, sep="")
def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
@copydoc(MetaLearner._build_onnx, sep="")
def _build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
"""In the DRLearner case, the necessary models are:
* ``"treatment_model"``
"""
warning_experimental_feature("build_onnx")
warning_experimental_feature("_build_onnx")
check_spox_installed()
import spox.opset.ai.onnx.v21 as op
from onnx.checker import check_model
from spox import Var, build, inline

self._validate_feature_set_none()
self._validate_onnx_models(models, self.necessary_onnx_models())
self._validate_onnx_models(models, self._necessary_onnx_models())

input_dict = infer_input_dict(models[TREATMENT_MODEL][0])

Expand Down
4 changes: 2 additions & 2 deletions metalearners/metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,12 +1185,12 @@ def _validate_feature_set_none(self):

@classmethod
@abstractmethod
def necessary_onnx_models(cls) -> set[str]:
def _necessary_onnx_models(cls) -> set[str]:
"""Return a set with the necessary models to convert the MetaLearner to ONNX."""
...

Check warning on line 1190 in metalearners/metalearner.py

View check run for this annotation

Codecov / codecov/patch

metalearners/metalearner.py#L1190

Added line #L1190 was not covered by tests

@abstractmethod
def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
def _build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
"""Convert the MetaLearner to an ONNX model.
.. warning::
Expand Down
10 changes: 5 additions & 5 deletions metalearners/rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,23 +526,23 @@ def _pseudo_outcome_and_weights(
return pseudo_outcomes, weights

@classmethod
def necessary_onnx_models(cls) -> set[str]:
def _necessary_onnx_models(cls) -> set[str]:
return {TREATMENT_MODEL}

@copydoc(MetaLearner.build_onnx, sep="")
def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
@copydoc(MetaLearner._build_onnx, sep="")
def _build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
"""In the RLearner case, the necessary models are:
* ``"treatment_model"``
"""
warning_experimental_feature("build_onnx")
warning_experimental_feature("_build_onnx")
check_spox_installed()
import spox.opset.ai.onnx.v21 as op
from onnx.checker import check_model
from spox import Var, build, inline

self._validate_feature_set_none()
self._validate_onnx_models(models, self.necessary_onnx_models())
self._validate_onnx_models(models, self._necessary_onnx_models())

input_dict = infer_input_dict(models[TREATMENT_MODEL][0])

Expand Down
4 changes: 2 additions & 2 deletions metalearners/slearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,12 +300,12 @@ def predict_conditional_average_outcomes(
)

@classmethod
def necessary_onnx_models(cls) -> set[str]:
def _necessary_onnx_models(cls) -> set[str]:
raise ValueError(

Check warning on line 304 in metalearners/slearner.py

View check run for this annotation

Codecov / codecov/patch

metalearners/slearner.py#L304

Added line #L304 was not covered by tests
"The SLearner does not implement this method. Please refer to comment in the tutorial."
)

def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
def _build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
raise ValueError(

Check warning on line 309 in metalearners/slearner.py

View check run for this annotation

Codecov / codecov/patch

metalearners/slearner.py#L309

Added line #L309 was not covered by tests
"The SLearner does not implement this method. Please refer to comment in the tutorial."
)
10 changes: 5 additions & 5 deletions metalearners/tlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,23 +151,23 @@ def evaluate(
)

@classmethod
def necessary_onnx_models(cls) -> set[str]:
def _necessary_onnx_models(cls) -> set[str]:
return {VARIANT_OUTCOME_MODEL}

@copydoc(MetaLearner.build_onnx, sep="")
def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
@copydoc(MetaLearner._build_onnx, sep="")
def _build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
"""In the TLearner case, the necessary models are:
* ``"variant_outcome_model"``
"""
warning_experimental_feature("build_onnx")
warning_experimental_feature("_build_onnx")
check_spox_installed()
import spox.opset.ai.onnx.v21 as op
from onnx.checker import check_model
from spox import build, inline

self._validate_feature_set_none()
self._validate_onnx_models(models, self.necessary_onnx_models())
self._validate_onnx_models(models, self._necessary_onnx_models())

input_dict = infer_input_dict(models[VARIANT_OUTCOME_MODEL][0])

Expand Down
10 changes: 5 additions & 5 deletions metalearners/xlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,25 +439,25 @@ def _pseudo_outcome(
return imputed_te_control, imputed_te_treatment

@classmethod
def necessary_onnx_models(cls) -> set[str]:
def _necessary_onnx_models(cls) -> set[str]:
return {PROPENSITY_MODEL, CONTROL_EFFECT_MODEL, TREATMENT_EFFECT_MODEL}

@copydoc(MetaLearner.build_onnx, sep="")
def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
@copydoc(MetaLearner._build_onnx, sep="")
def _build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"):
"""In the XLearner case, the necessary models are:
* ``"propensity_model"``
* ``"control_effect_model"``
* ``"treatment_effect_model"``
"""
warning_experimental_feature("build_onnx")
warning_experimental_feature("_build_onnx")
check_spox_installed()
import spox.opset.ai.onnx.v21 as op
from onnx.checker import check_model
from spox import Var, build, inline

self._validate_feature_set_none()
self._validate_onnx_models(models, self.necessary_onnx_models())
self._validate_onnx_models(models, self._necessary_onnx_models())

input_dict = infer_input_dict(models[PROPENSITY_MODEL][0])

Expand Down
2 changes: 1 addition & 1 deletion tests/test_drlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_drlearner_onnx(
)
onnx_models.append(onnx_model)

final = ml.build_onnx({TREATMENT_MODEL: onnx_models})
final = ml._build_onnx({TREATMENT_MODEL: onnx_models})
sess = rt.InferenceSession(
final.SerializeToString(), providers=rt.get_available_providers()
)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,10 @@ def evaluate(self, X, y, w, is_oos, oos_method=None):
def predict_conditional_average_outcomes(self, X, is_oos, oos_method=None):
return np.zeros((len(X), 2, 1))

def build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"): ...
def _build_onnx(self, models: Mapping[str, Sequence], output_name: str = "tau"): ...

@classmethod
def necessary_onnx_models(cls) -> set[str]:
def _necessary_onnx_models(cls) -> set[str]:
return set()


Expand Down
2 changes: 1 addition & 1 deletion tests/test_rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_rlearner_onnx(
)
onnx_models.append(onnx_model)

final = ml.build_onnx({TREATMENT_MODEL: onnx_models})
final = ml._build_onnx({TREATMENT_MODEL: onnx_models})
sess = rt.InferenceSession(
final.SerializeToString(), providers=rt.get_available_providers()
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def test_tlearner_onnx(
)
onnx_models.append(onnx_model)

final = ml.build_onnx({VARIANT_OUTCOME_MODEL: onnx_models})
final = ml._build_onnx({VARIANT_OUTCOME_MODEL: onnx_models})
sess = rt.InferenceSession(
final.SerializeToString(), providers=rt.get_available_providers()
)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_xlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def test_xlearner_onnx(
)
onnx_models[PROPENSITY_MODEL].append(onnx_model)

final = ml.build_onnx(onnx_models)
final = ml._build_onnx(onnx_models)

sess = rt.InferenceSession(
final.SerializeToString(), providers=rt.get_available_providers()
Expand Down

0 comments on commit d32d27f

Please sign in to comment.