From 2616e78bf9bd9c95739d383714bc1b21b16ce044 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Thu, 13 Feb 2025 02:49:07 -0800 Subject: [PATCH] Use keyword args only for models. Small change to callback. --- legateboost/legateboost.py | 4 ++-- legateboost/models/krr.py | 1 + legateboost/models/linear.py | 1 + legateboost/models/nn.py | 1 + legateboost/models/tree.py | 1 + 5 files changed, 6 insertions(+), 2 deletions(-) diff --git a/legateboost/legateboost.py b/legateboost/legateboost.py index e8d848d6..cd983e46 100644 --- a/legateboost/legateboost.py +++ b/legateboost/legateboost.py @@ -294,10 +294,10 @@ def _partial_fit( # callbacks after iteration if any( - ( + [ c.after_iteration(self, model_idx, eval_result) for c in self.callbacks - ) + ] ): break diff --git a/legateboost/models/krr.py b/legateboost/models/krr.py index 2a41eb98..65249fd0 100644 --- a/legateboost/models/krr.py +++ b/legateboost/models/krr.py @@ -83,6 +83,7 @@ class KRR(BaseModel): def __init__( self, + *, n_components: int = 100, alpha: Any = "deprecated", l2_regularization: float = 1e-5, diff --git a/legateboost/models/linear.py b/legateboost/models/linear.py index bd1f1595..d660bc23 100644 --- a/legateboost/models/linear.py +++ b/legateboost/models/linear.py @@ -41,6 +41,7 @@ class Linear(BaseModel): def __init__( self, + *, l2_regularization: float = 1e-5, alpha: Any = "deprecated", solver: str = "direct", diff --git a/legateboost/models/nn.py b/legateboost/models/nn.py index 222cada6..eb499b0d 100644 --- a/legateboost/models/nn.py +++ b/legateboost/models/nn.py @@ -33,6 +33,7 @@ class NN(BaseModel): def __init__( self, + *, max_iter: int = 100, hidden_layer_sizes: Tuple[int] = (100,), alpha: Any = "deprecated", diff --git a/legateboost/models/tree.py b/legateboost/models/tree.py index 4e0c98a3..01819830 100644 --- a/legateboost/models/tree.py +++ b/legateboost/models/tree.py @@ -54,6 +54,7 @@ class Tree(BaseModel): def __init__( self, + *, max_depth: int = 8, split_samples: int = 256, l1_regularization: float = 0.0,