Skip to content

Commit

Permalink
Implement onnx RLearner and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescMartiEscofetQC committed Jul 3, 2024
1 parent 75d841d commit 5cdd015
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 1 deletion.
40 changes: 40 additions & 0 deletions metalearners/rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,25 @@
# SPDX-License-Identifier: BSD-3-Clause


from collections.abc import Mapping, Sequence

import numpy as np
from joblib import Parallel, delayed
from sklearn.metrics import root_mean_squared_error
from typing_extensions import Self

from metalearners._typing import Matrix, OosMethod, Scoring, Vector
from metalearners._utils import (
check_onnx_installed,
check_spox_installed,
clip_element_absolute_value_to_epsilon,
copydoc,
function_has_argument,
get_one,
get_predict,
get_predict_proba,
index_matrix,
infer_dtype_and_shape_onnx,
validate_all_vectors_same_index,
validate_valid_treatment_variant_not_control,
)
Expand Down Expand Up @@ -504,3 +509,38 @@ def _pseudo_outcome_and_weights(
weights = np.square(w_residuals)

return pseudo_outcomes, weights

def build_onnx(
self,
models: Mapping[str, Sequence],
input_name: str = "input",
output_name: str = "tau",
):
check_onnx_installed()
check_spox_installed()
import spox.opset.ai.onnx.v21 as op
from onnx.checker import check_model
from spox import Tensor, Var, argument, build, inline

self._validate_onnx_models(models, {TREATMENT_MODEL})
self._validate_feature_set_all()

input_dtype, input_shape = infer_dtype_and_shape_onnx(
models[TREATMENT_MODEL][0].graph.input[0]
)
input_tensor = argument(Tensor(input_dtype, input_shape))

treatment_output_name = models[TREATMENT_MODEL][0].graph.output[0].name

tau_hat: list[Var] = []
for m in models[TREATMENT_MODEL]:
tau_hat_tv = inline(m)(input_tensor)[treatment_output_name]
tau_hat_tv = op.unsqueeze(tau_hat_tv, axes=op.constant(value_int=2))
if self.is_classification:
tau_hat_tv = op.concat([op.neg(tau_hat_tv), tau_hat_tv], axis=-1)
tau_hat.append(tau_hat_tv)

cate = op.concat(tau_hat, axis=1)
final_model = build({input_name: input_tensor}, {output_name: cate})
check_model(final_model, full_check=True)
return final_model
115 changes: 114 additions & 1 deletion tests/test_rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,21 @@
# SPDX-License-Identifier: BSD-3-Clause

import numpy as np
import onnxruntime as rt
import pandas as pd
import pytest
from lightgbm import LGBMClassifier, LGBMRegressor
from onnxmltools import convert_lightgbm, convert_xgboost
from skl2onnx import convert_sklearn
from skl2onnx.common.data_types import FloatTensorType
from sklearn.linear_model import LinearRegression, LogisticRegression
from xgboost import XGBRegressor

from metalearners.rlearner import r_loss
from metalearners._utils import function_has_argument
from metalearners.metalearner import TREATMENT_MODEL
from metalearners.rlearner import RLearner, r_loss

from .conftest import all_sklearn_regressors


@pytest.mark.parametrize("use_pandas", [True, False])
Expand All @@ -26,3 +37,105 @@ def test_r_loss(use_pandas):
propensity_scores=propensity_scores,
)
assert result == pytest.approx(2, abs=1e-4, rel=1e-4)


@pytest.mark.parametrize(
"treatment_model_factory, onnx_converter",
(
list(
zip(
all_sklearn_regressors,
[convert_sklearn] * len(all_sklearn_regressors),
)
)
+ [
(LGBMRegressor, convert_lightgbm),
(XGBRegressor, convert_xgboost),
]
),
)
@pytest.mark.parametrize("is_classification", [True, False])
def test_rlearner_onnx(treatment_model_factory, onnx_converter, is_classification, rng):
if not function_has_argument(treatment_model_factory.fit, "sample_weight"):
pytest.skip()

supports_categoricals = treatment_model_factory in [
LGBMRegressor,
# convert_sklearn does not support categoricals https://github.com/onnx/sklearn-onnx/issues/1051
# HistGradientBoostingRegressor,
# convert_xgboost does not support categoricals https://github.com/onnx/onnxmltools/issues/469#issuecomment-1993880910
# XGBRegressor,
]

n_samples = 300
n_numerical_features = 5
n_variants = 3

X = rng.standard_normal((n_samples, n_numerical_features))
if supports_categoricals:
n_categorical_features = 2
X = pd.DataFrame(X)
X[n_numerical_features] = pd.Series(
rng.integers(10, 13, n_samples), dtype="category"
) # not start at 0
X[n_numerical_features + 1] = pd.Series(
rng.choice([-5, 4, -10, -32], size=n_samples), dtype="category"
) # not consecutive
else:
n_categorical_features = 0

if is_classification:
n_classes = 2
y = rng.integers(0, n_classes, size=n_samples)
nuisance_model_factory = LogisticRegression
else:
y = rng.standard_normal(n_samples)
nuisance_model_factory = LinearRegression
w = rng.integers(0, n_variants, n_samples)

ml = RLearner(
is_classification,
n_variants,
nuisance_model_factory=nuisance_model_factory,
propensity_model_factory=LGBMClassifier,
treatment_model_factory=treatment_model_factory,
propensity_model_params={"n_estimators": 1},
n_folds=2,
)
ml.fit(X, y, w)

onnx_models = []
for tv in range(n_variants - 1):
model = ml._treatment_models[TREATMENT_MODEL][tv]._overall_estimator
onnx_model = onnx_converter(
model,
initial_types=[
(
"X",
FloatTensorType(
[None, n_numerical_features + n_categorical_features]
),
)
],
)
onnx_models.append(onnx_model)

final = ml.build_onnx({TREATMENT_MODEL: onnx_models})
sess = rt.InferenceSession(
final.SerializeToString(), providers=rt.get_available_providers()
)

if supports_categoricals:
onnx_X = X.to_numpy(np.float32)
# This is needed for categoricals as LGBM uses the categorical codes, when
# other implementations support categoricals this may need to be changed
onnx_X[:, n_numerical_features] = X[n_numerical_features].cat.codes
onnx_X[:, n_numerical_features + 1] = X[n_numerical_features + 1].cat.codes
else:
onnx_X = X.astype(np.float32)

pred_onnx = sess.run(
["tau"],
{"input": onnx_X},
)
np.testing.assert_allclose(ml.predict(X, True, "overall"), pred_onnx[0], atol=1e-5)

0 comments on commit 5cdd015

Please sign in to comment.