diff --git a/AVCDecisionTreeForest.m b/AVCDecisionTreeForest.m index 7f8e4c1c..3535b09a 100644 --- a/AVCDecisionTreeForest.m +++ b/AVCDecisionTreeForest.m @@ -495,25 +495,34 @@ Mathematica is (C) Copyright 1988-2012 Wolfram Research, Inc. ]; (* Classify by forest *) - Clear[DecisionForestClassify] -DecisionForestClassify[forest_, record_] := - Block[{res}, - res = DecisionTreeClassify[#, record] & /@ forest; +Options[DecisionForestClassify] = {"Weighted" -> False}; +DecisionForestClassify[forest_, record_, opts : OptionsPattern[]] := + Block[{res, weightedQ = OptionValue[DecisionForestClassify, "Weighted"]}, + res = TreeClassify[#, record] & /@ forest; res = Flatten[res, 1]; - res = GatherBy[res, #[[2]] &]; - res = Map[{Total[#[[All, 1]]], #[[1, 2]]} &, res]; - SortBy[res, -#[[1]] &] + If[TrueQ[weightedQ], + res = GatherBy[res, #[[2]] &]; + res = Map[{Total[#[[All, 1]]], #[[1, 2]]} &, res]; + SortBy[res, -#[[1]] &], + (*ELSE*) + SortBy[Reverse /@ Tally[res[[All, 2]]], -#[[1]] &] + ] ]; Clear[ParallelDecisionForestClassify] -ParallelDecisionForestClassify[forest_, record_] := - Block[{res}, - res = ParallelMap[DecisionTreeClassify[#, record] &, forest, Method -> "CoarsestGrained", DistributedContexts -> Automatic]; +Options[ParallelDecisionForestClassify] = {"Weighted" -> False}; +ParallelDecisionForestClassify[forest_, record_] := + Block[{res, weightedQ = OptionValue[ParallelDecisionForestClassify, "Weighted"]}, + res = ParallelMap[TreeClassify[#, record] &, forest, Method -> "CoarsestGrained", DistributedContexts -> Automatic]; res = Flatten[res, 1]; - res = GatherBy[res, #[[2]] &]; - res = Map[{Total[#[[All, 1]]], #[[1, 2]]} &, res]; - SortBy[res, -#[[1]] &] + If[TrueQ[weightedQ], + res = GatherBy[res, #[[2]] &]; + res = Map[{Total[#[[All, 1]]], #[[1, 2]]} &, res]; + SortBy[res, -#[[1]] &], + (*ELSE*) + SortBy[Reverse /@ Tally[res[[All, 2]]], -#[[1]] &] + ] ]; (* Convert to rules *)