Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement basic feature sampling in tree model #212

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions legateboost/legateboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ def __len__(self) -> int:
Returns:
int: The number of models in the `models_` attribute.
"""
check_is_fitted(self, "is_fitted_")
return len(self.models_)

def __getitem__(self, i: int) -> BaseModel:
Expand All @@ -466,6 +467,7 @@ def __getitem__(self, i: int) -> BaseModel:
Returns:
BaseModel: The model at the specified index.
"""
check_is_fitted(self, "is_fitted_")
return self.models_[i]

def __iter__(self) -> Any:
Expand All @@ -474,6 +476,7 @@ def __iter__(self) -> Any:
Yields:
Any: An iterator over the models in the `models_` attribute.
"""
check_is_fitted(self, "is_fitted_")
return iter(self.models_)

def __mul__(self, scalar: Any) -> Self:
Expand All @@ -496,6 +499,7 @@ def __mul__(self, scalar: Any) -> Self:
ValueError
If the provided scalar is not a numeric value.
"""
check_is_fitted(self, "is_fitted_")

if not np.isscalar(scalar):
raise ValueError("Can only multiply by scalar")
Expand Down
29 changes: 28 additions & 1 deletion legateboost/models/tree.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import warnings
from enum import IntEnum
from typing import Any, List, Sequence, cast
from typing import Any, Callable, List, Sequence, Union, cast

import cupynumeric as cn
from legate.core import TaskTarget, get_legate_runtime, types
Expand Down Expand Up @@ -33,6 +33,10 @@ class Tree(BaseModel):
split_samples : int
The number of data points to sample for each split decision.
Max value is 2048 due to constraints on shared memory in GPU kernels.
feature_fraction :
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like you wanted to put this. (I'll note that since we do have type hints in think it is also completely fine to just put nothing/remove the : even.)

If float, the subsampled fraction of features considered in building this model.
Users may implement an arbitrary function returning a cupynumeric array of
booleans of shape `(n_features,)` to specify the feature subset.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
booleans of shape `(n_features,)` to specify the feature subset.
booleans of shape ``(n_features,)`` to specify the feature subset.

Should we say how this is rounded?

l1_regularization : float
The L1 regularization parameter applied to leaf weights.
l2_regularization : float
Expand All @@ -57,6 +61,7 @@ def __init__(
*,
max_depth: int = 8,
split_samples: int = 256,
feature_fraction: Union[float, Callable[..., cn.array]] = 1.0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
feature_fraction: Union[float, Callable[..., cn.array]] = 1.0,
feature_fraction: Union[float, Callable[[], cn.array]] = 1.0,

l1_regularization: float = 0.0,
l2_regularization: float = 1.0,
min_split_gain: float = 0.0,
Expand All @@ -67,6 +72,7 @@ def __init__(
raise ValueError("split_samples must be <= 2048")
self.split_samples = split_samples
self.alpha = alpha
self.feature_fraction = feature_fraction
self.l1_regularization = l1_regularization
self.l2_regularization = l2_regularization
self.min_split_gain = min_split_gain
Expand All @@ -79,6 +85,9 @@ def __init__(
)
self.l2_regularization = alpha

def num_nodes(self) -> int:
return int(cn.sum(self.hessian > 0.0))

def fit(
self,
X: cn.ndarray,
Expand Down Expand Up @@ -113,6 +122,24 @@ def fit(
task.add_alignment(g_, h_)
task.add_alignment(g_, X_)

# sample features
if callable(self.feature_fraction):
feature_set = self.feature_fraction()
if feature_set.shape != (X.shape[1],) or feature_set.dtype != bool:
raise ValueError(
"feature_fraction must return a boolean array of"
" shape (n_features,)"
)
task.add_input(get_store(feature_set))
task.add_broadcast(get_store(feature_set))
elif self.feature_fraction < 1.0:
cn.random.seed(self.random_state.randint(0, 2**31))
feature_set = cn.random.binomial(
1, self.feature_fraction, size=(X.shape[1],), dtype=bool
)
task.add_input(get_store(feature_set))
task.add_broadcast(get_store(feature_set))

# outputs
leaf_value = get_legate_runtime().create_store(
types.float64, (max_nodes, num_outputs)
Expand Down
86 changes: 86 additions & 0 deletions legateboost/test/models/test_tree.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np
import pytest
import scipy.stats as stats
from sklearn.datasets import make_regression

import cupynumeric as cn
import legateboost as lb
Expand Down Expand Up @@ -76,6 +78,90 @@ def test_l2_regularization():
assert np.isclose(model.predict(X)[0], y.sum() / (y.size + l2_regularization))


def get_feature_distribution(model):
histogram = cn.zeros(model.n_features_in_)
for m in model:
histogram += cn.histogram(
m.feature, bins=model.n_features_in_, range=(0, model.n_features_in_)
)[0]
return histogram / histogram.sum()


def test_feature_sample():
X, y = make_regression(
n_samples=100, n_features=10, n_informative=2, random_state=0
)

# We have a distribution of how often each feature is used in the model
# Hypothesis: the baseline model should use the best features more often and the
# sampled model should use other features more often as it won't always see the
# best features. So we expect the entropy of the baseline model feature
# disribution to be lower than the sampled model
# i.e. the sampled model should be closer to uniform distribution
baseline_samples = []
sampled_samples = []
for trial in range(5):
baseline_model = lb.LBRegressor(
base_models=(lb.models.Tree(feature_fraction=1.0),), random_state=trial
).fit(X, y)
sampled_model = lb.LBRegressor(
base_models=(lb.models.Tree(feature_fraction=0.5),), random_state=trial
).fit(X, y)
baseline_samples.append(stats.entropy(get_feature_distribution(baseline_model)))
sampled_samples.append(stats.entropy(get_feature_distribution(sampled_model)))

_, p = stats.mannwhitneyu(baseline_samples, sampled_samples, alternative="less")
assert p < 0.05

# the no features model contains only the bias term - no splits
no_features_model = lb.LBRegressor(
base_models=(lb.models.Tree(feature_fraction=0.0),), random_state=0
).fit(X, y)
for m in no_features_model:
assert m.num_nodes() == 1


def test_callable_feature_sample():
def feature_fraction():
return cn.array([0, 1, 0, 0, 0, 0, 0, 0, 0, 0], dtype=bool)

rng = np.random.RandomState(0)
X = rng.randn(100, 10)
y = rng.randn(100)
model = lb.LBRegressor(
base_models=(lb.models.Tree(feature_fraction=feature_fraction),),
random_state=0,
).fit(X, y)

assert get_feature_distribution(model)[1] == 1.0

def feature_fraction_int():
return cn.array([0, 1, 0, 0, 0, 0, 0, 0, 0, 0])

with pytest.raises(
ValueError,
match=r"feature_fraction must return a boolean array of shape \(n_features,\)",
):
lb.LBRegressor(
base_models=(lb.models.Tree(feature_fraction=feature_fraction_int),),
random_state=0,
).fit(X, y)

def feature_fraction_wrong_shape():
return cn.array([0, 1, 0, 0, 0, 0, 0, 0, 0], dtype=bool)

with pytest.raises(
ValueError,
match=r"feature_fraction must return a boolean array of shape \(n_features,\)",
):
lb.LBRegressor(
base_models=(
lb.models.Tree(feature_fraction=feature_fraction_wrong_shape),
),
random_state=0,
).fit(X, y)


def test_l1_regularization():
X = cn.array([[0.0], [0.0]])
y = cn.array([500.0, 500.0])
Expand Down
8 changes: 7 additions & 1 deletion legateboost/test/test_with_hypothesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@ def tree_strategy(draw):
max_depth = draw(st.integers(1, 6))
alpha = draw(st.floats(0.0, 1.0))
split_samples = draw(st.integers(1, 500))
return lb.models.Tree(max_depth=max_depth, alpha=alpha, split_samples=split_samples)
feature_fraction = draw(st.sampled_from([0.5, 1.0]))
return lb.models.Tree(
max_depth=max_depth,
alpha=alpha,
split_samples=split_samples,
feature_fraction=feature_fraction,
)


@st.composite
Expand Down
23 changes: 20 additions & 3 deletions src/models/tree/build_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,13 +303,17 @@ struct TreeBuilder {
NodeBatch batch,
double l1_regularization,
double l2_regularization,
double min_split_gain)
double min_split_gain,
std::optional<legate::AccessorRO<bool, 1>> optional_feature_set)
{
for (int node_id = batch.node_idx_begin; node_id < batch.node_idx_end; node_id++) {
double best_gain = 0;
int best_feature = -1;
int best_bin = -1;
for (int feature = 0; feature < num_features; feature++) {
if (optional_feature_set.has_value() && !optional_feature_set.value()[feature]) {
continue;
}
auto [feature_begin, feature_end] = split_proposals.FeatureRange(feature);
for (int bin_idx = feature_begin; bin_idx < feature_end; bin_idx++) {
double gain = 0;
Expand Down Expand Up @@ -483,6 +487,14 @@ struct build_tree_fn {
auto l2_regularization = context.scalars().at(6).value<double>();
auto min_split_gain = context.scalars().at(7).value<double>();

// Get feature sample if it exists
std::optional<legate::AccessorRO<bool, 1>> optional_feature_set;
if (context.inputs().size() == 4) {
auto [feature_set, feature_set_shape, feature_set_accessor] =
GetInputStore<bool, 1>(context.input(3).data());
optional_feature_set = feature_set_accessor;
}

Tree tree(max_nodes, narrow<int>(num_outputs));
SparseSplitProposals<T> const split_proposals =
SelectSplitSamples(context, X_accessor, X_shape, split_samples, seed, dataset_rows);
Expand All @@ -501,8 +513,13 @@ struct build_tree_fn {
builder.ComputeHistogram(
histogram, context, tree, X_accessor, X_shape, g_accessor, h_accessor, batch);

builder.PerformBestSplit(
tree, histogram, batch, l1_regularization, l2_regularization, min_split_gain);
builder.PerformBestSplit(tree,
histogram,
batch,
l1_regularization,
l2_regularization,
min_split_gain,
optional_feature_set);
}
// Update position of entire level
// Don't bother updating positions for the last level
Expand Down
32 changes: 27 additions & 5 deletions src/models/tree/build_tree.cu
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,8 @@ __global__ void __launch_bounds__(BLOCK_THREADS)
legate::Buffer<double, 1> tree_split_value,
legate::Buffer<double, 1> tree_gain,
NodeBatch batch,
GradientQuantiser quantiser)
GradientQuantiser quantiser,
std::optional<legate::AccessorRO<bool, 1>> optional_feature_set)
{
// using one block per (level) node to have blockwise reductions
int const node_id = narrow<int>(batch.node_idx_begin + blockIdx.x);
Expand All @@ -708,6 +709,12 @@ __global__ void __launch_bounds__(BLOCK_THREADS)

for (int bin_idx = narrow_cast<int>(threadIdx.x); bin_idx < split_proposals.histogram_size;
bin_idx += BLOCK_THREADS) {
// Check if this feature is in the feature set
if (optional_feature_set.has_value() &&
!optional_feature_set.value()[split_proposals.FindFeature(bin_idx)]) {
continue;
}

double gain = 0;
for (int output = 0; output < n_outputs; ++output) {
auto node_sum = vectorised_load(&node_sums[{node_id, output}]);
Expand Down Expand Up @@ -1111,7 +1118,8 @@ struct TreeBuilder {
NodeBatch batch,
double l1_regularization,
double l2_regularization,
double min_split_gain)
double min_split_gain,
const std::optional<legate::AccessorRO<bool, 1>>& optional_feature_set)
{
const int kBlockThreads = 512;
perform_best_split<T, kBlockThreads>
Expand All @@ -1127,7 +1135,8 @@ struct TreeBuilder {
tree.split_value,
tree.gain,
batch,
quantiser);
quantiser,
optional_feature_set);
CHECK_CUDA_STREAM(stream);
}
void InitialiseRoot(legate::TaskContext context,
Expand Down Expand Up @@ -1281,6 +1290,14 @@ struct build_tree_fn {
auto l2_regularization = context.scalars().at(6).value<double>();
auto min_split_gain = context.scalars().at(7).value<double>();

// Get feature sample if it exists
std::optional<legate::AccessorRO<bool, 1>> optional_feature_set;
if (context.inputs().size() == 4) {
auto [feature_set, feature_set_shape, feature_set_accessor] =
GetInputStore<bool, 1>(context.input(3).data());
optional_feature_set = feature_set_accessor;
}

auto* stream = context.get_task_stream();
auto thrust_alloc = ThrustAllocator(legate::Memory::GPU_FB_MEM);
auto thrust_exec_policy = DEFAULT_POLICY(thrust_alloc).on(stream);
Expand Down Expand Up @@ -1313,8 +1330,13 @@ struct build_tree_fn {
builder.ComputeHistogram(
histogram, context, tree, X_accessor, X_shape, g_accessor, h_accessor, batch, seed);

builder.PerformBestSplit(
tree, histogram, batch, l1_regularization, l2_regularization, min_split_gain);
builder.PerformBestSplit(tree,
histogram,
batch,
l1_regularization,
l2_regularization,
min_split_gain,
optional_feature_set);
}
// Update position of entire level
// Don't bother updating positions for the last level
Expand Down