diff --git a/docs/examples/example_survival.ipynb b/docs/examples/example_survival.ipynb index 8c0f5d7..f84a813 100644 --- a/docs/examples/example_survival.ipynb +++ b/docs/examples/example_survival.ipynb @@ -100,7 +100,8 @@ " # X contains \"censored\" column\n", " features = set(X.columns) - {\"censored\"}\n", " dtrain = xgb.DMatrix(X[list(features)], enable_categorical=True)\n", - " if set(np.unique(X[\"censored\"])) != {0, 1}:\n", + "\n", + " if (set(np.unique(X[\"censored\"])) - {0, 1}) != set():\n", " raise ValueError(\"censored column should be binary.\")\n", "\n", " y_upper_bound = np.where(X[\"censored\"], +np.inf, y)\n", @@ -122,7 +123,7 @@ " check_is_fitted(self)\n", " # X contains \"censored\" column\n", " features = set(X.columns) - {\"censored\"}\n", - " if set(np.unique(X[\"censored\"])) != {0, 1}:\n", + " if (set(np.unique(X[\"censored\"])) - {0, 1}) != set():\n", " raise ValueError(\"censored column should be binary.\")\n", " dtest = xgb.DMatrix(X[list(features)], enable_categorical=True)\n", " return self.bst_.predict(dtest)"