Skip to content

Commit 317e2e5

Browse files
authored
feat: add internal unsqueeze operation in forward of all classifiers (#136)
1 parent 365690a commit 317e2e5

12 files changed

+46
-17
lines changed

CHANGELOG.rst

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Changelog
1818
Ver 0.1.*
1919
---------
2020

21+
* |Feature| Add internal :meth:`unsqueeze` operation in :meth:`forward` of all classifiers | `@xuyxu <https://github.com/xuyxu>`__
2122
* |Feature| |API| Add ``voting_strategy`` parameter for :class:`VotingClassifer`, :class:`NeuralForestClassifier`, and :class:`SnapshotEnsembleClassifier` | `@LukasGardberg <https://github.com/LukasGardberg>`__
2223
* |Fix| Fix the sampling issue in :class:`BaggingClassifier` and :class:`BaggingRegressor` | `@SunHaozhe <https://github.com/SunHaozhe>`__
2324
* |Feature| |API| Add :class:`NeuralForestClassifier` and :class:`NeuralForestRegressor` | `@xuyxu <https://github.com/xuyxu>`__

torchensemble/_base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def get_doc(item):
2929
__doc = {
3030
"model": const.__model_doc,
3131
"seq_model": const.__seq_model_doc,
32-
"tree_ensmeble_model": const.__tree_ensemble_doc,
32+
"tree_ensemble_model": const.__tree_ensemble_doc,
3333
"fit": const.__fit_doc,
3434
"predict": const.__predict_doc,
3535
"set_optimizer": const.__set_optimizer_doc,

torchensemble/adversarial_training.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,8 @@ class AdversarialTrainingClassifier(_BaseAdversarialTraining, BaseClassifier):
226226
def forward(self, *x):
227227
# Take the average over class distributions from all base estimators.
228228
outputs = [
229-
F.softmax(estimator(*x), dim=1) for estimator in self.estimators_
229+
F.softmax(op.unsqueeze_tensor(estimator(*x)), dim=1)
230+
for estimator in self.estimators_
230231
]
231232
proba = op.average(outputs)
232233

torchensemble/bagging.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ class BaggingClassifier(BaseClassifier):
9494
def forward(self, *x):
9595
# Average over class distributions from all base estimators.
9696
outputs = [
97-
F.softmax(estimator(*x), dim=1) for estimator in self.estimators_
97+
F.softmax(op.unsqueeze_tensor(estimator(*x)), dim=1)
98+
for estimator in self.estimators_
9899
]
99100
proba = op.average(outputs)
100101

torchensemble/fast_geometric.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ class FastGeometricClassifier(_BaseFastGeometric, BaseClassifier):
176176
"classifier_forward",
177177
)
178178
def forward(self, *x):
179-
proba = self._forward(*x)
179+
proba = op.unsqueeze_tensor(self._forward(*x))
180180

181181
return F.softmax(proba, dim=1)
182182

torchensemble/fusion.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _forward(self, *x):
3939
"classifier_forward",
4040
)
4141
def forward(self, *x):
42-
output = self._forward(*x)
42+
output = op.unsqueeze_tensor(self._forward(*x))
4343
proba = F.softmax(output, dim=1)
4444

4545
return proba

torchensemble/gradient_boosting.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,10 @@ def fit(
420420
"classifier_forward",
421421
)
422422
def forward(self, *x):
423-
output = [estimator(*x) for estimator in self.estimators_]
423+
output = [
424+
op.unsqueeze_tensor(estimator(*x))
425+
for estimator in self.estimators_
426+
]
424427
output = op.sum_with_multiplicative(output, self.shrinkage_rate)
425428
proba = F.softmax(output, dim=1)
426429

torchensemble/snapshot_ensemble.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,16 @@ class SnapshotEnsembleClassifier(_BaseSnapshotEnsemble, BaseClassifier):
212212
def __init__(self, voting_strategy="soft", **kwargs):
213213
super().__init__(**kwargs)
214214

215+
implemented_strategies = {"soft", "hard"}
216+
if voting_strategy not in implemented_strategies:
217+
msg = (
218+
"Voting strategy {} is not implemented, "
219+
"please choose from {}."
220+
)
221+
raise ValueError(
222+
msg.format(voting_strategy, implemented_strategies)
223+
)
224+
215225
self.voting_strategy = voting_strategy
216226

217227
@torchensemble_model_doc(
@@ -221,13 +231,13 @@ def __init__(self, voting_strategy="soft", **kwargs):
221231
def forward(self, *x):
222232

223233
outputs = [
224-
F.softmax(estimator(*x), dim=1) for estimator in self.estimators_
234+
F.softmax(op.unsqueeze_tensor(estimator(*x)), dim=1)
235+
for estimator in self.estimators_
225236
]
226237

227238
if self.voting_strategy == "soft":
228239
proba = op.average(outputs)
229-
230-
elif self.voting_strategy == "hard":
240+
else:
231241
proba = op.majority_vote(outputs)
232242

233243
return proba

torchensemble/soft_gradient_boosting.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,10 @@ def fit(
406406
"classifier_forward",
407407
)
408408
def forward(self, *x):
409-
output = [estimator(*x) for estimator in self.estimators_]
409+
output = [
410+
op.unsqueeze_tensor(estimator(*x))
411+
for estimator in self.estimators_
412+
]
410413
output = op.sum_with_multiplicative(output, self.shrinkage_rate)
411414
proba = F.softmax(output, dim=1)
412415

torchensemble/tests/test_all_models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(self):
4949
self.linear2 = nn.Linear(2, 2)
5050

5151
def forward(self, X):
52-
X = X.view(X.size()[0], -1)
52+
X = X.view(X.size(0), -1)
5353
output = self.linear1(X)
5454
output = self.linear2(output)
5555
return output

torchensemble/utils/operator.py

+9
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"pseudo_residual_classification",
1414
"pseudo_residual_regression",
1515
"majority_vote",
16+
"unsqueeze_tensor",
1617
]
1718

1819

@@ -73,3 +74,11 @@ def majority_vote(outputs: List[torch.Tensor]) -> torch.Tensor:
7374
majority_one_hots = proba.scatter_(1, votes.view(-1, 1), 1)
7475

7576
return majority_one_hots
77+
78+
79+
def unsqueeze_tensor(tensor: torch.Tensor, dim=1) -> torch.Tensor:
80+
"""Reshape 1-D tensor to 2-D for downstream operations."""
81+
if tensor.ndim == 1:
82+
tensor = torch.unsqueeze(tensor, dim)
83+
84+
return tensor

torchensemble/voting.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -114,13 +114,13 @@ def __init__(self, voting_strategy="soft", **kwargs):
114114
def forward(self, *x):
115115

116116
outputs = [
117-
F.softmax(estimator(*x), dim=1) for estimator in self.estimators_
117+
F.softmax(op.unsqueeze_tensor(estimator(*x)), dim=1)
118+
for estimator in self.estimators_
118119
]
119120

120121
if self.voting_strategy == "soft":
121122
proba = op.average(outputs)
122-
123-
elif self.voting_strategy == "hard":
123+
else:
124124
proba = op.majority_vote(outputs)
125125

126126
return proba
@@ -309,7 +309,7 @@ def predict(self, *x):
309309

310310

311311
@torchensemble_model_doc(
312-
"""Implementation on the NeuralForestClassifier.""", "tree_ensmeble_model"
312+
"""Implementation on the NeuralForestClassifier.""", "tree_ensemble_model"
313313
)
314314
class NeuralForestClassifier(BaseTreeEnsemble, VotingClassifier):
315315
def __init__(self, voting_strategy="soft", **kwargs):
@@ -324,7 +324,8 @@ def __init__(self, voting_strategy="soft", **kwargs):
324324
def forward(self, *x):
325325
# Average over class distributions from all base estimators.
326326
outputs = [
327-
F.softmax(estimator(*x), dim=1) for estimator in self.estimators_
327+
F.softmax(op.unsqueeze_tensor(estimator(*x)), dim=1)
328+
for estimator in self.estimators_
328329
]
329330
proba = op.average(outputs)
330331

@@ -561,7 +562,7 @@ def predict(self, *x):
561562

562563

563564
@torchensemble_model_doc(
564-
"""Implementation on the NeuralForestRegressor.""", "tree_ensmeble_model"
565+
"""Implementation on the NeuralForestRegressor.""", "tree_ensemble_model"
565566
)
566567
class NeuralForestRegressor(BaseTreeEnsemble, VotingRegressor):
567568
@torchensemble_model_doc(

0 commit comments

Comments
 (0)