diff --git a/pyltr/models/lambdamart.py b/pyltr/models/lambdamart.py index 969ff3a..0e7e757 100644 --- a/pyltr/models/lambdamart.py +++ b/pyltr/models/lambdamart.py @@ -315,7 +315,7 @@ def _update_terminal_regions(self, tree, X, y, lambdas, deltas, y_pred, terminal_region = np.where(masked_terminal_regions == leaf) suml = np.sum(lambdas[terminal_region]) sumd = np.sum(deltas[terminal_region]) - tree.value[leaf, 0, 0] = 0.0 if sumd == 0.0 else (suml / sumd) + tree.value[leaf, 0, 0] = 0.0 if abs(sumd) < 1e-300 else (suml / sumd) y_pred += tree.value[terminal_regions, 0, 0] * self.learning_rate