Skip to content

Commit

Permalink
Use pandas' check for int.
Browse files Browse the repository at this point in the history
  • Loading branch information
kklein committed Jun 26, 2024
1 parent db3d29d commit 87b3179
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions metalearners/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,10 @@ def convert_treatment(treatment: Vector) -> np.ndarray:
new_treatment = treatment.to_numpy()
if new_treatment.dtype == bool:
return new_treatment.astype(int)
elif new_treatment.dtype == float and all(x.is_integer() for x in new_treatment):
if new_treatment.dtype == float and all(x.is_integer() for x in new_treatment):
return new_treatment.astype(int)
elif new_treatment.dtype != int:

if not pd.api.types.is_integer_dtype(new_treatment):
raise TypeError(
"Treatment must be boolean, integer or float with integer values."
)
Expand Down

0 comments on commit 87b3179

Please sign in to comment.