Skip to content

Commit

Permalink
Fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
RAMitchell committed Feb 17, 2025
1 parent 31c3229 commit e19ad77
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions legateboost/test/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,8 @@ def test_classifier(num_class, objective, base_models):
metric = model._metrics[0]
proba = model.predict_proba(X)
assert cn.all(proba >= 0) and cn.all(proba <= 1)
if num_class == 2:
assert cn.all(proba > 0.5, model.predict(X))
else:
assert cn.all(cn.argmax(proba, axis=1) == model.predict(X))
assert cn.allclose(proba.sum(axis=1), cn.ones(X.shape[0]))
assert cn.all(cn.argmax(proba, axis=1) == model.predict(X))
assert cn.allclose(proba.sum(axis=1), cn.ones(X.shape[0]))

loss = metric.metric(y, proba, cn.ones(y.shape[0]))
train_loss = next(iter(eval_result["train"].values()))
Expand Down

0 comments on commit e19ad77

Please sign in to comment.