Skip to content

Commit

Permalink
fix nan
Browse files Browse the repository at this point in the history
Signed-off-by: xadupre <[email protected]>
  • Loading branch information
xadupre committed Dec 20, 2024
1 parent 51130ec commit 62b24b5
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions skl2onnx/operator_converters/power_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
)


def get_nan():
if hasattr(np, "nan"):
return np.nan
return np.NAN


def convert_powertransformer(
scope: Scope, operator: Operator, container: ModelComponentContainer
):
Expand Down Expand Up @@ -78,7 +84,7 @@ def convert_powertransformer(
y_gr0 = OnnxImputer(
y_gr0,
imputed_value_floats=[0.0],
replaced_value_float=getattr(np, "nan", getattr(np, "NAN")), # noqa: B009
replaced_value_float=get_nan(),
op_version=opv,
)
y_gr0 = OnnxMul(y_gr0, greater_mask, op_version=opv)
Expand All @@ -104,7 +110,7 @@ def convert_powertransformer(
y_le0 = OnnxImputer(
y_le0,
imputed_value_floats=[0.0],
replaced_value_float=getattr(np, "nan", getattr(np, "NAN")), # noqa: B009
replaced_value_float=get_nan(),
op_version=opv,
)
y_le0 = OnnxMul(y_le0, less_mask, op_version=opv)
Expand All @@ -130,7 +136,7 @@ def convert_powertransformer(
y_gr0_l_eq0 = OnnxImputer(
y_gr0_l_eq0,
imputed_value_floats=[0.0],
replaced_value_float=getattr(np, "nan", getattr(np, "NAN")), # noqa: B009
replaced_value_float=get_nan(),
op_version=opv,
)
y_gr0_l_eq0 = OnnxMul(y_gr0_l_eq0, lambda_zero_mask, op_version=opv)
Expand Down

0 comments on commit 62b24b5

Please sign in to comment.