-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix handling of the min_samples_leaf hyperparameter #35
Merged
Merged
Changes from 2 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
56d83ca
Collect sample counts in node.split_info
ogrisel 1c6e62b
Add tests
ogrisel d2e2902
merged master + test for min_samples_leaf at histogram level
NicolasHug 0497360
Merge branch 'master' into fix-min_samples_leaf
NicolasHug 7e91c2c
Test for min_gain_to_split at histogram level
NicolasHug b5b6eca
removed plotting from tests/test_predictor.py
NicolasHug f60a344
Make test_predictor_from_grower stricter
ogrisel fb26aa7
Only test one stopping param at a time
ogrisel f83d1e9
Make min_gain_to_split splitting level filtering consistent with grow…
ogrisel e5c0fdc
Better hyperparam for the boston test
ogrisel 1da31d7
Cover edge case when n_samples is too small
ogrisel File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,19 +18,24 @@ | |
('hessian_left', float32), | ||
('gradient_right', float32), | ||
('hessian_right', float32), | ||
('n_samples_left', uint32), | ||
('n_samples_right', uint32), | ||
('histogram', typeof(HISTOGRAM_DTYPE)[:]), # array of size n_bins | ||
]) | ||
class SplitInfo: | ||
def __init__(self, gain=0, feature_idx=0, bin_idx=0, | ||
gradient_left=0., hessian_left=0., | ||
gradient_right=0., hessian_right=0.): | ||
gradient_right=0., hessian_right=0., | ||
n_samples_left=0, n_samples_right=0): | ||
self.gain = gain | ||
self.feature_idx = feature_idx | ||
self.bin_idx = bin_idx | ||
self.gradient_left = gradient_left | ||
self.hessian_left = hessian_left | ||
self.gradient_right = gradient_right | ||
self.hessian_right = hessian_right | ||
self.n_samples_left = n_samples_left | ||
self.n_samples_right = n_samples_right | ||
|
||
|
||
@jitclass([ | ||
|
@@ -236,17 +241,16 @@ def find_node_split_subtraction(context, sample_indices, parent_histograms, | |
# be to compute an average but it's probably not worth it. | ||
sum_gradients = (parent_histograms[0]['sum_gradients'].sum() - | ||
sibling_histograms[0]['sum_gradients'].sum()) | ||
|
||
n_samples = sample_indices.shape[0] | ||
if context.constant_hessian: | ||
n_samples = sample_indices.shape[0] | ||
sum_hessians = context.constant_hessian_value * float32(n_samples) | ||
else: | ||
sum_hessians = (parent_histograms[0]['sum_hessians'].sum() - | ||
sibling_histograms[0]['sum_hessians'].sum()) | ||
|
||
return _parallel_find_split_subtraction( | ||
context, parent_histograms, sibling_histograms, | ||
sum_gradients, sum_hessians) | ||
sum_gradients, sum_hessians, n_samples) | ||
|
||
|
||
@njit | ||
|
@@ -280,7 +284,7 @@ def _parallel_find_split(splitter, sample_indices, ordered_gradients, | |
""" | ||
# Pre-allocate the results datastructure to be able to use prange: | ||
# numba jitclass do not seem to properly support default values for kwargs. | ||
split_infos = [SplitInfo(0, 0, 0, 0., 0., 0., 0.) | ||
split_infos = [SplitInfo(0, 0, 0, 0., 0., 0., 0., 0, 0) | ||
for i in range(splitter.n_features)] | ||
for feature_idx in prange(splitter.n_features): | ||
split_info = _find_histogram_split( | ||
|
@@ -295,7 +299,7 @@ def _parallel_find_split(splitter, sample_indices, ordered_gradients, | |
@njit(parallel=True) | ||
def _parallel_find_split_subtraction(context, parent_histograms, | ||
sibling_histograms, | ||
sum_gradients, sum_hessians): | ||
sum_gradients, sum_hessians, n_samples): | ||
"""For each feature, find the best bin to split by histogram substraction | ||
|
||
This in turn calls _find_histogram_split_subtraction that does not need | ||
|
@@ -307,12 +311,12 @@ def _parallel_find_split_subtraction(context, parent_histograms, | |
histograms by substraction. | ||
""" | ||
# Pre-allocate the results datastructure to be able to use prange | ||
split_infos = [SplitInfo(0, 0, 0, 0., 0., 0., 0.) | ||
split_infos = [SplitInfo(0, 0, 0, 0., 0., 0., 0., 0, 0) | ||
for i in range(context.n_features)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This data structure could probably be also stored as an attribute on the context to avoid reallocating it over and over again. |
||
for feature_idx in prange(context.n_features): | ||
split_info = _find_histogram_split_subtraction( | ||
context, feature_idx, parent_histograms, sibling_histograms, | ||
sum_gradients, sum_hessians) | ||
sum_gradients, sum_hessians, n_samples) | ||
split_infos[feature_idx] = split_info | ||
|
||
return _find_best_feature_to_split_helper( | ||
|
@@ -324,8 +328,9 @@ def _find_histogram_split(context, feature_idx, sample_indices, | |
ordered_gradients, ordered_hessians, | ||
sum_gradients, sum_hessians): | ||
"""Compute the histogram for a given feature and return the best bin.""" | ||
n_samples = sample_indices.shape[0] | ||
binned_feature = context.binned_features.T[feature_idx] | ||
root_node = binned_feature.shape[0] == sample_indices.shape[0] | ||
root_node = binned_feature.shape[0] == n_samples | ||
|
||
if root_node: | ||
if context.constant_hessian: | ||
|
@@ -346,13 +351,14 @@ def _find_histogram_split(context, feature_idx, sample_indices, | |
ordered_gradients, ordered_hessians) | ||
|
||
return _find_best_bin_to_split_helper( | ||
context, feature_idx, histogram, sum_gradients, sum_hessians) | ||
context, feature_idx, histogram, sum_gradients, sum_hessians, | ||
sample_indices.shape[0]) | ||
|
||
|
||
@njit(fastmath=True) | ||
def _find_histogram_split_subtraction(context, feature_idx, | ||
parent_histograms, sibling_histograms, | ||
sum_gradients, sum_hessians): | ||
sum_gradients, sum_hessians, n_samples): | ||
"""Compute the histogram by substraction of parent and sibling | ||
|
||
Uses the identity: hist(parent) = hist(left) + hist(right) | ||
|
@@ -362,26 +368,29 @@ def _find_histogram_split_subtraction(context, feature_idx, | |
sibling_histograms[feature_idx]) | ||
|
||
return _find_best_bin_to_split_helper( | ||
context, feature_idx, histogram, sum_gradients, sum_hessians) | ||
context, feature_idx, histogram, sum_gradients, sum_hessians, | ||
n_samples) | ||
|
||
|
||
@njit(locals={'gradient_left': float32, 'hessian_left': float32}, | ||
@njit(locals={'gradient_left': float32, 'hessian_left': float32, | ||
'n_samples_left': uint32}, | ||
fastmath=True) | ||
def _find_best_bin_to_split_helper(context, feature_idx, histogram, | ||
sum_gradients, sum_hessians): | ||
sum_gradients, sum_hessians, n_samples): | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Find best bin to split on and return the corresponding SplitInfo""" | ||
# Allocate the structure for the best split information. It can be | ||
# returned as such (with a negative gain) if the min_hessian_to_split | ||
# condition is not satisfied. Such invalid splits are later discarded by | ||
# the TreeGrower. | ||
best_split = SplitInfo(-1., 0, 0, 0., 0., 0., 0.) | ||
best_split = SplitInfo(-1., 0, 0, 0., 0., 0., 0., 0, 0) | ||
gradient_left, hessian_left, n_samples_left = 0., 0., 0 | ||
|
||
gradient_left, hessian_left = 0., 0. | ||
for bin_idx in range(context.n_bins): | ||
gradient_left += histogram[bin_idx]['sum_gradients'] | ||
n_samples_left += histogram[bin_idx]['count'] | ||
if context.constant_hessian: | ||
hessian_left += (histogram[bin_idx]['count'] * | ||
context.constant_hessian_value) | ||
hessian_left += (histogram[bin_idx]['count'] | ||
* context.constant_hessian_value) | ||
NicolasHug marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
hessian_left += histogram[bin_idx]['sum_hessians'] | ||
if hessian_left < context.min_hessian_to_split: | ||
|
@@ -400,8 +409,10 @@ def _find_best_bin_to_split_helper(context, feature_idx, histogram, | |
best_split.bin_idx = bin_idx | ||
best_split.gradient_left = gradient_left | ||
best_split.hessian_left = hessian_left | ||
best_split.n_samples_left = n_samples_left | ||
best_split.gradient_right = gradient_right | ||
best_split.hessian_right = hessian_right | ||
best_split.n_samples_right = n_samples - n_samples_left | ||
|
||
best_split.histogram = histogram | ||
return best_split | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be great to add a test for this case in
test_grower.py
.