Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescMartiEscofetQC committed Jul 3, 2024
1 parent 6107b87 commit fdffa3d
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 35 deletions.
8 changes: 8 additions & 0 deletions metalearners/metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,14 @@ def _validate_onnx_models(
"with name 'probabilities' or 'output_probability'."
)

def _validate_feature_set_all(self):
for feature_set in self.feature_set.values():
if feature_set is not None:
raise ValueError(

Check warning on line 1083 in metalearners/metalearner.py

View check run for this annotation

Codecov / codecov/patch

metalearners/metalearner.py#L1083

Added line #L1083 was not covered by tests
"ONNX conversion can only be used if all base models use all the "
"features."
)


class _ConditionalAverageOutcomeMetaLearner(MetaLearner, ABC):

Expand Down
36 changes: 16 additions & 20 deletions metalearners/tlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,53 +137,49 @@ def evaluate(
)

# TODO: Fix typing without importing onnx
def build_onnx(self, models: Mapping[str, Sequence]):
def build_onnx(
self,
models: Mapping[str, Sequence],
input_name: str = "input",
output_name: str = "tau",
):
check_onnx_installed()
check_spox_installed()
import spox.opset.ai.onnx.v21 as op
from onnx.checker import check_model
from spox import Tensor, argument, build, inline

self._validate_onnx_models(models, {VARIANT_OUTCOME_MODEL})
self._validate_feature_set_all()

input_dtype, input_shape = infer_dtype_and_shape_onnx(
models[VARIANT_OUTCOME_MODEL][0].graph.input[0]
)

if not self.is_classification:
output_index = 0
output_name = models[VARIANT_OUTCOME_MODEL][0].graph.output[0].name
model_output_name = models[VARIANT_OUTCOME_MODEL][0].graph.output[0].name
else:
output_index, output_name = infer_probabilities_output(
_, model_output_name = infer_probabilities_output(
models[VARIANT_OUTCOME_MODEL][0]
)

output_dtype, output_shape = infer_dtype_and_shape_onnx(
models[VARIANT_OUTCOME_MODEL][0].graph.output[output_index]
)

input_tensor = argument(Tensor(input_dtype, input_shape))

a = argument(Tensor(output_dtype, output_shape))
b = argument(Tensor(output_dtype, output_shape))
subtraction = op.sub(a, b)
sub_model = build({"a": a, "b": b}, {"subtraction": subtraction})

output_0 = inline(models[VARIANT_OUTCOME_MODEL][0])(input_tensor)
output_0 = inline(models[VARIANT_OUTCOME_MODEL][0])(input_tensor)[
model_output_name
]

variant_cates = []
for m in models[VARIANT_OUTCOME_MODEL][1:]:
variant_output = inline(m)(input_tensor)
variant_output = inline(m)(input_tensor)[model_output_name]
variant_cate = op.sub(variant_output, output_0)
variant_cates.append(
op.unsqueeze(
inline(sub_model)(
variant_output[output_name],
output_0[output_name],
)["subtraction"],
variant_cate,
axes=op.constant(value_int=1),
)
)
cate = op.concat(variant_cates, axis=1)
final_model = build({"input": input_tensor}, {"tau": cate})
final_model = build({input_name: input_tensor}, {output_name: cate})
check_model(final_model, full_check=True)
return final_model
26 changes: 12 additions & 14 deletions metalearners/xlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,12 @@ def _pseudo_outcome(

return imputed_te_control, imputed_te_treatment

def build_onnx(self, models: Mapping[str, Sequence]):
def build_onnx(
self,
models: Mapping[str, Sequence],
input_name: str = "input",
output_name: str = "tau",
):
check_onnx_installed()
check_spox_installed()
import spox.opset.ai.onnx.v21 as op
Expand All @@ -417,31 +422,24 @@ def build_onnx(self, models: Mapping[str, Sequence]):
self._validate_onnx_models(
models, {PROPENSITY_MODEL, CONTROL_EFFECT_MODEL, TREATMENT_EFFECT_MODEL}
)

# TODO: move this validation to metalearner level as it will be common for all
for model_kind, feature_set in self.feature_set.items():
if feature_set is not None:
raise ValueError(
"ONNX conversion can only be used if all base models use all the "
"features."
)
self._validate_feature_set_all()

# All models should have the same input dtype and shape
input_dtype, input_shape = infer_dtype_and_shape_onnx(
models[PROPENSITY_MODEL][0].graph.input[0]
)
input_tensor = argument(Tensor(input_dtype, input_shape))

output_name = models[CONTROL_EFFECT_MODEL][0].graph.output[0].name
treatment_output_name = models[CONTROL_EFFECT_MODEL][0].graph.output[0].name

tau_hat_control: list[Var] = []
for m in models[CONTROL_EFFECT_MODEL]:
tau_hat_control.append(inline(m)(input_tensor)[output_name])
tau_hat_control.append(inline(m)(input_tensor)[treatment_output_name])
tau_hat_effect: list[Var] = []
for m in models[TREATMENT_EFFECT_MODEL]:
tau_hat_effect.append(inline(m)(input_tensor)[output_name])
tau_hat_effect.append(inline(m)(input_tensor)[treatment_output_name])

propensity_output_index, propensity_output_name = infer_probabilities_output(
_, propensity_output_name = infer_probabilities_output(
models[PROPENSITY_MODEL][0]
)

Expand Down Expand Up @@ -479,6 +477,6 @@ def build_onnx(self, models: Mapping[str, Sequence]):
tau_hat.append(tau_hat_tv)

cate = op.concat(tau_hat, axis=1)
final_model = build({"input": input_tensor}, {"tau": cate})
final_model = build({input_name: input_tensor}, {output_name: cate})
check_model(final_model, full_check=True)
return final_model
2 changes: 1 addition & 1 deletion tests/test_xlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,4 @@ def test_xlearner_onnx(
["tau", "Div_1_C"],
{"input": X.astype(np.float32)},
)
np.testing.assert_allclose(ml.predict(X, True, "overall"), pred_onnx[0], atol=1e-6)
np.testing.assert_allclose(ml.predict(X, True, "overall"), pred_onnx[0], atol=1e-5)

0 comments on commit fdffa3d

Please sign in to comment.