Skip to content

Commit

Permalink
Allow callables
Browse files Browse the repository at this point in the history
  • Loading branch information
RAMitchell committed Feb 13, 2025
1 parent 6633a95 commit 45e7c90
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 25 deletions.
22 changes: 16 additions & 6 deletions legateboost/models/tree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
from enum import IntEnum
from typing import Any, List, Sequence, cast
from typing import Any, Callable, List, Sequence, Union, cast

import cupynumeric as cn
from legate.core import TaskTarget, get_legate_runtime, types
Expand Down Expand Up @@ -32,8 +32,10 @@ class Tree(BaseModel):
split_samples : int
The number of data points to sample for each split decision.
Max value is 2048 due to constraints on shared memory in GPU kernels.
feature_fraction : float
The subsampled fraction of features considered in building this model.
feature_fraction :
If float, the subsampled fraction of features considered in building this model.
Users may implement an arbitrary function returning a cupynumeric array of
booleans of shape `(n_features,)` to specify the feature subset.
alpha : float
The L2 regularization parameter.
"""
Expand All @@ -49,7 +51,7 @@ def __init__(
max_depth: int = 8,
split_samples: int = 256,
alpha: float = 1.0,
feature_fraction: float = 1.0,
feature_fraction: Union[float, Callable[..., cn.array]] = 1.0,
) -> None:
self.max_depth = max_depth
if split_samples > 2048:
Expand Down Expand Up @@ -85,7 +87,6 @@ def fit(
task.add_scalar_arg(self.split_samples, types.int32)
task.add_scalar_arg(self.random_state.randint(0, 2**31), types.int32)
task.add_scalar_arg(X.shape[0], types.int64)
task.add_scalar_arg(self.feature_fraction, types.float64)

task.add_input(X_)
task.add_broadcast(X_, 1)
Expand All @@ -95,7 +96,16 @@ def fit(
task.add_alignment(g_, X_)

# sample features
if self.feature_fraction < 1.0:
if isinstance(self.feature_fraction, Callable):
feature_set = self.feature_fraction()
if feature_set.shape != (X.shape[1],) or feature_set.dtype != bool:
raise ValueError(
"feature_fraction must return a boolean array of"
" shape (n_features,)"
)
task.add_input(get_store(feature_set))
task.add_broadcast(get_store(feature_set))
elif self.feature_fraction < 1.0:
cn.random.seed(self.random_state.randint(0, 2**31))
feature_set = cn.random.binomial(
1, self.feature_fraction, size=(X.shape[1],), dtype=bool
Expand Down
58 changes: 50 additions & 8 deletions legateboost/test/models/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,20 @@ def test_alpha():
assert np.isclose(model.predict(X)[0], y.sum() / (y.size + alpha))


def get_feature_distribution(model):
histogram = cn.zeros(model.n_features_in_)
for m in model:
histogram += cn.histogram(
m.feature, bins=model.n_features_in_, range=(0, model.n_features_in_)
)[0]
return histogram / histogram.sum()


def test_feature_sample():
X, y = make_regression(
n_samples=100, n_features=10, n_informative=2, random_state=0
)

def get_feature_distribution(model):
histogram = cn.zeros(X.shape[1])
for m in model:
histogram += cn.histogram(
m.feature, bins=X.shape[1], range=(0, X.shape[1])
)[0]
return histogram / histogram.sum()

# We have a distribution of how often each feature is used in the model
# Hypothesis: the baseline model should use the best features more often and the
# sampled model should use other features more often as it won't always see the
Expand Down Expand Up @@ -116,3 +117,44 @@ def get_feature_distribution(model):
).fit(X, y)
for m in no_features_model:
assert m.num_nodes() == 1


def test_callable_feature_sample():
def feature_fraction():
return cn.array([0, 1, 0, 0, 0, 0, 0, 0, 0, 0], dtype=bool)

rng = np.random.RandomState(0)
X = rng.randn(100, 10)
y = rng.randn(100)
model = lb.LBRegressor(
base_models=(lb.models.Tree(feature_fraction=feature_fraction),),
random_state=0,
).fit(X, y)

assert get_feature_distribution(model)[1] == 1.0

def feature_fraction_int():
return cn.array([0, 1, 0, 0, 0, 0, 0, 0, 0, 0])

with pytest.raises(
ValueError,
match=r"feature_fraction must return a boolean array of shape \(n_features,\)",
):
lb.LBRegressor(
base_models=(lb.models.Tree(feature_fraction=feature_fraction_int),),
random_state=0,
).fit(X, y)

def feature_fraction_wrong_shape():
return cn.array([0, 1, 0, 0, 0, 0, 0, 0, 0], dtype=bool)

with pytest.raises(
ValueError,
match=r"feature_fraction must return a boolean array of shape \(n_features,\)",
):
lb.LBRegressor(
base_models=(
lb.models.Tree(feature_fraction=feature_fraction_wrong_shape),
),
random_state=0,
).fit(X, y)
8 changes: 7 additions & 1 deletion legateboost/test/test_with_hypothesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@ def tree_strategy(draw):
max_depth = draw(st.integers(1, 6))
alpha = draw(st.floats(0.0, 1.0))
split_samples = draw(st.integers(1, 500))
return lb.models.Tree(max_depth=max_depth, alpha=alpha, split_samples=split_samples)
feature_fraction = draw(st.sampled_from(0.5, 1.0))
return lb.models.Tree(
max_depth=max_depth,
alpha=alpha,
split_samples=split_samples,
feature_fraction=feature_fraction,
)


@st.composite
Expand Down
19 changes: 17 additions & 2 deletions src/models/tree/build_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,20 @@ struct TreeBuilder {
}
}
}
void PerformBestSplit(Tree& tree, Histogram<GPair> histogram, double alpha, NodeBatch batch)
void PerformBestSplit(Tree& tree,
Histogram<GPair> histogram,
double alpha,
NodeBatch batch,
std::optional<legate::AccessorRO<bool, 1>> optional_feature_set)
{
for (int node_id = batch.node_idx_begin; node_id < batch.node_idx_end; node_id++) {
double best_gain = 0;
int best_feature = -1;
int best_bin = -1;
for (int feature = 0; feature < num_features; feature++) {
if (optional_feature_set.has_value() && !optional_feature_set.value()[feature]) {
continue;
}
auto [feature_begin, feature_end] = split_proposals.FeatureRange(feature);
for (int bin_idx = feature_begin; bin_idx < feature_end; bin_idx++) {
double gain = 0;
Expand Down Expand Up @@ -477,6 +484,14 @@ struct build_tree_fn {
auto seed = context.scalars().at(4).value<int>();
auto dataset_rows = context.scalars().at(5).value<int64_t>();

// Get feature sample if it exists
std::optional<legate::AccessorRO<bool, 1>> optional_feature_set;
if (context.inputs().size() == 4) {
auto [feature_set, feature_set_shape, feature_set_accessor] =
GetInputStore<bool, 1>(context.input(3).data());
optional_feature_set = feature_set_accessor;
}

Tree tree(max_nodes, narrow<int>(num_outputs));
SparseSplitProposals<T> const split_proposals =
SelectSplitSamples(context, X_accessor, X_shape, split_samples, seed, dataset_rows);
Expand All @@ -494,7 +509,7 @@ struct build_tree_fn {
builder.ComputeHistogram(
histogram, context, tree, X_accessor, X_shape, g_accessor, h_accessor, batch);

builder.PerformBestSplit(tree, histogram, alpha, batch);
builder.PerformBestSplit(tree, histogram, alpha, batch, optional_feature_set);
}
// Update position of entire level
// Don't bother updating positions for the last level
Expand Down
15 changes: 7 additions & 8 deletions src/models/tree/build_tree.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1280,17 +1280,16 @@ struct build_tree_fn {
EXPECT_AXIS_ALIGNED(1, g_shape, h_shape);

// Scalars
auto max_depth = context.scalars().at(0).value<int>();
auto max_nodes = context.scalars().at(1).value<int>();
auto alpha = context.scalars().at(2).value<double>();
auto split_samples = context.scalars().at(3).value<int>();
auto seed = context.scalars().at(4).value<int>();
auto dataset_rows = context.scalars().at(5).value<int64_t>();
auto feature_fraction = context.scalars().at(6).value<double>();
auto max_depth = context.scalars().at(0).value<int>();
auto max_nodes = context.scalars().at(1).value<int>();
auto alpha = context.scalars().at(2).value<double>();
auto split_samples = context.scalars().at(3).value<int>();
auto seed = context.scalars().at(4).value<int>();
auto dataset_rows = context.scalars().at(5).value<int64_t>();

// Get feature sample if it exists
std::optional<legate::AccessorRO<bool, 1>> optional_feature_set;
if (feature_fraction < 1.0) {
if (context.inputs().size() == 4) {
auto [feature_set, feature_set_shape, feature_set_accessor] =
GetInputStore<bool, 1>(context.input(3).data());
optional_feature_set = feature_set_accessor;
Expand Down

0 comments on commit 45e7c90

Please sign in to comment.