Skip to content

Commit

Permalink
More general censored column check
Browse files Browse the repository at this point in the history
  • Loading branch information
FrancescMartiEscofetQC committed Jul 12, 2024
1 parent 2a3cd76 commit 1f00886
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions docs/examples/example_survival.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@
" # X contains \"censored\" column\n",
" features = set(X.columns) - {\"censored\"}\n",
" dtrain = xgb.DMatrix(X[list(features)], enable_categorical=True)\n",
" if set(np.unique(X[\"censored\"])) != {0, 1}:\n",
"\n",
" if (set(np.unique(X[\"censored\"])) - {0, 1}) != set():\n",
" raise ValueError(\"censored column should be binary.\")\n",
"\n",
" y_upper_bound = np.where(X[\"censored\"], +np.inf, y)\n",
Expand All @@ -122,7 +123,7 @@
" check_is_fitted(self)\n",
" # X contains \"censored\" column\n",
" features = set(X.columns) - {\"censored\"}\n",
" if set(np.unique(X[\"censored\"])) != {0, 1}:\n",
" if (set(np.unique(X[\"censored\"])) - {0, 1}) != set():\n",
" raise ValueError(\"censored column should be binary.\")\n",
" dtest = xgb.DMatrix(X[list(features)], enable_categorical=True)\n",
" return self.bst_.predict(dtest)"
Expand Down

0 comments on commit 1f00886

Please sign in to comment.