From 1f00886a9379a7f66a1adbf3e7961af512b0e181 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francesc=20Mart=C3=AD=20Escofet?= Date: Fri, 12 Jul 2024 13:28:50 +0200 Subject: [PATCH] More general censored column check --- docs/examples/example_survival.ipynb | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/examples/example_survival.ipynb b/docs/examples/example_survival.ipynb index 8c0f5d79..f84a8131 100644 --- a/docs/examples/example_survival.ipynb +++ b/docs/examples/example_survival.ipynb @@ -100,7 +100,8 @@ " # X contains \"censored\" column\n", " features = set(X.columns) - {\"censored\"}\n", " dtrain = xgb.DMatrix(X[list(features)], enable_categorical=True)\n", - " if set(np.unique(X[\"censored\"])) != {0, 1}:\n", + "\n", + " if (set(np.unique(X[\"censored\"])) - {0, 1}) != set():\n", " raise ValueError(\"censored column should be binary.\")\n", "\n", " y_upper_bound = np.where(X[\"censored\"], +np.inf, y)\n", @@ -122,7 +123,7 @@ " check_is_fitted(self)\n", " # X contains \"censored\" column\n", " features = set(X.columns) - {\"censored\"}\n", - " if set(np.unique(X[\"censored\"])) != {0, 1}:\n", + " if (set(np.unique(X[\"censored\"])) - {0, 1}) != set():\n", " raise ValueError(\"censored column should be binary.\")\n", " dtest = xgb.DMatrix(X[list(features)], enable_categorical=True)\n", " return self.bst_.predict(dtest)"