From e19ad776e0529c605c599f16ecac55868bffaef0 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Mon, 17 Feb 2025 04:31:27 -0800 Subject: [PATCH] Fix test --- legateboost/test/test_estimator.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/legateboost/test/test_estimator.py b/legateboost/test/test_estimator.py index d5c3249b..fd293bc5 100644 --- a/legateboost/test/test_estimator.py +++ b/legateboost/test/test_estimator.py @@ -142,11 +142,8 @@ def test_classifier(num_class, objective, base_models): metric = model._metrics[0] proba = model.predict_proba(X) assert cn.all(proba >= 0) and cn.all(proba <= 1) - if num_class == 2: - assert cn.all(proba > 0.5, model.predict(X)) - else: - assert cn.all(cn.argmax(proba, axis=1) == model.predict(X)) - assert cn.allclose(proba.sum(axis=1), cn.ones(X.shape[0])) + assert cn.all(cn.argmax(proba, axis=1) == model.predict(X)) + assert cn.allclose(proba.sum(axis=1), cn.ones(X.shape[0])) loss = metric.metric(y, proba, cn.ones(y.shape[0])) train_loss = next(iter(eval_result["train"].values()))