Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Do not store histogram on SplitInfo #36

Merged
merged 2 commits into from
Nov 5, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 23 additions & 22 deletions pygbm/splitting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# from collections import namedtuple
import numpy as np
from numba import (njit, jitclass, prange, float32, uint8, uint32, typeof,
optional)
from numba import njit, jitclass, prange, float32, uint8, uint32, optional
import numba
from .histogram import _build_histogram
from .histogram import _subtract_histograms
Expand All @@ -21,7 +20,6 @@
('hessian_right', float32),
('n_samples_left', uint32),
('n_samples_right', uint32),
('histogram', typeof(HISTOGRAM_DTYPE)[:]), # array of size n_bins
])
class SplitInfo:
def __init__(self, gain=-1., feature_idx=0, bin_idx=0,
Expand Down Expand Up @@ -266,13 +264,18 @@ def find_node_split(context, sample_indices):
# numba jitclass do not seem to properly support default values for kwargs.
split_infos = [SplitInfo(-1., 0, 0, 0., 0., 0., 0., 0, 0)
for i in range(context.n_features)]
histograms = np.empty(
shape=(np.int64(context.n_features), np.int64(context.n_bins)),
dtype=HISTOGRAM_DTYPE
)
for feature_idx in prange(context.n_features):
split_info = _find_histogram_split(context, feature_idx,
sample_indices)
split_info, histogram = _find_histogram_split(
context, feature_idx, sample_indices)
split_infos[feature_idx] = split_info
histograms[feature_idx, :] = histogram

return _find_best_feature_to_split_helper(context.n_features,
context.n_bins, split_infos)
split_info = _find_best_feature_to_split_helper(split_infos)
return split_info, histograms


@njit(parallel=True)
Expand Down Expand Up @@ -307,31 +310,30 @@ def find_node_split_subtraction(context, sample_indices, parent_histograms,
# Pre-allocate the results datastructure to be able to use prange
split_infos = [SplitInfo(-1., 0, 0, 0., 0., 0., 0., 0, 0)
for i in range(context.n_features)]
histograms = np.empty(
shape=(np.int64(context.n_features), np.int64(context.n_bins)),
dtype=HISTOGRAM_DTYPE
)
for feature_idx in prange(context.n_features):
split_info = _find_histogram_split_subtraction(
context, feature_idx, parent_histograms, sibling_histograms,
n_samples)
split_info, histogram = _find_histogram_split_subtraction(
context, feature_idx, parent_histograms,
sibling_histograms, n_samples)
split_infos[feature_idx] = split_info
histograms[feature_idx, :] = histogram

return _find_best_feature_to_split_helper(
context.n_features, context.n_bins, split_infos)
split_info = _find_best_feature_to_split_helper(split_infos)
return split_info, histograms


@njit
def _find_best_feature_to_split_helper(n_features, n_bins, split_infos):
def _find_best_feature_to_split_helper(split_infos):
best_gain = None
# need to convert to int64, it's a numba bug. See issue #2756
histograms = np.empty(
shape=(np.int64(n_features), np.int64(n_bins)),
dtype=HISTOGRAM_DTYPE
)
for i, split_info in enumerate(split_infos):
histograms[i, :] = split_info.histogram
gain = split_info.gain
if best_gain is None or gain > best_gain:
best_gain = gain
best_split_info = split_info
return best_split_info, histograms
return best_split_info


@njit(fastmath=True)
Expand Down Expand Up @@ -435,8 +437,7 @@ def _find_best_bin_to_split_helper(context, feature_idx, histogram, n_samples):
best_split.hessian_right = hessian_right
best_split.n_samples_right = n_samples_right

best_split.histogram = histogram
return best_split
return best_split, histogram


@njit(fastmath=False)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def test_histogram_split(n_bins):
min_hessian_to_split,
min_samples_leaf, min_gain_to_split)

split_info = _find_histogram_split(context, feature_idx,
sample_indices)
split_info, _ = _find_histogram_split(context, feature_idx,
sample_indices)

assert split_info.bin_idx == true_bin
assert split_info.gain >= 0
Expand Down