Skip to content

Commit

Permalink
Implemented decision tree prunning. Added functions for querying deci…
Browse files Browse the repository at this point in the history
…sion trees and leaves combination.
  • Loading branch information
antononcube committed Mar 9, 2014
1 parent 268bd34 commit 8099a60
Showing 1 changed file with 123 additions and 0 deletions.
123 changes: 123 additions & 0 deletions AVCDecisionTreeForest.m
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ Mathematica is (C) Copyright 1988-2012 Wolfram Research, Inc.

DecisionForestClassificationSuccess::usage = "DecisionForestClassificationSuccess[dForest, testDataArray, lbls] finds the classification success using dForest over the test data testDataArray for each classification label in lbls. If the last argument, lbls, is omitted then Union[testDataArray[[All,-1]]] is taken as the set of labels. The returned result is a set of rules {{_,True|False}->_?NumberQ..}. The rules {_,True}->_ are for the fractions of correct guesses; the rules {_,False}->_ are for the fractions of incorrect guesses. The rules {_,All}->_ are for the classification success fractions using all records of testDataArray."

DecisionTreeNumberOfNodesAndLeaves::usage = "DecisionTreeNumberOfNodesAndLeaves[dTree] gives a list of two numbers: the number of internal nodes and the number of leaves of dTree."

DecisionTreeLeafQ::usage = "DecisionTreeLeafQ[rec] tests is rec a decision tree leaf."

DecisionTreeLabels::usage = "DecisionTreeLabels gives the labels used in a decision tree."

DecisionTreeCombinedLeaves::usage = "DecisionTreeCombinedLeaves[dTree] returns all the leaves of dTree into one. The frequencies of the corresponding labels are summed."

PruneDecisionTree::usage = "PruneDecisionTree[dTree] prunes branches of the decision tree dTree using the minimal description length principle."


Begin["`Private`"]

Expand Down Expand Up @@ -700,6 +710,119 @@ Mathematica is (C) Copyright 1988-2012 Wolfram Research, Inc.
DecisionTreeOrForestClassificationSuccess[DecisionForestClassify, dTreeOrForest, dataArr, x];
DecisionForestClassificationSuccess[___]:=Message[DecisionForestClassificationSuccess::wsig];


(* Pruning *)

Clear[DecisionTreeNumberOfNodesAndLeaves, DecisionTreeNumberOfNodesAndLeavesRec]
DecisionTreeNumberOfNodesAndLeaves[dTree_] :=
Block[{NNODES = 0, NLEAVES = 0},
DecisionTreeNumberOfNodesAndLeavesRec[dTree];
{NNODES, NLEAVES}
];
DecisionTreeNumberOfNodesAndLeavesRec[dTree_] :=
Block[{},
Which[
MatrixQ[dTree] && Dimensions[dTree][[2]] == 2, NLEAVES++,
MatrixQ[dTree[[1]]] && Dimensions[dTree[[1]]][[2]] == 2, NLEAVES++,
True,
NNODES++;
DecisionTreeNumberOfNodesAndLeavesRec[dTree[[2]]];
DecisionTreeNumberOfNodesAndLeavesRec[dTree[[3]]];
]
];

Clear[DecisionTreeLeafQ]
DecisionTreeLeafQ[leaf_] :=
MatrixQ[leaf] && VectorQ[leaf[[All, 1]], NumberQ] && VectorQ[leaf[[All, 2]], Not[NumberQ[#]] &];

Clear[DecisionTreeLabels]
DecisionTreeLabels[dtree_] :=
Union[Flatten[
Cases[dtree, s_ /; DecisionTreeLeafQ[s], Infinity][[All, All, 2]]]];

Clear[DecisionTreeCombinedLeaves]
DecisionTreeCombinedLeaves[dtree_] :=
Block[{leaves},
leaves = Flatten[Cases[dtree, s_ /; DecisionTreeLeafQ[s], Infinity], 1];
Map[{Total[#[[All, 1]]], #[[1, 2]]} &, GatherBy[leaves, #[[2]] &]]
];

Clear[MDLLeafCost, MDLLeavesCost]
MDLLeafCost[leaf_?DecisionTreeLeafQ, numberOfLabels_Integer] :=
Block[{n = Total[leaf[[All, 1]]], k = numberOfLabels},
Total[(#1*Log[n/#1] & ) /@ leaf[[All, 1]]] + ((k - 1)/2.)*Log[n/2.] + Log[Pi^(k/2)/Gamma[k/2]]
];
MDLLeafCost[dtree_, numberOfLabels_Integer] := MDLLeafCost[DecisionTreeCombinedLeaves[dtree], numberOfLabels];
MDLLeavesCost[dtree_, numberOfLabels_Integer] :=
Block[{leaves},
leaves = Cases[dtree, s_ /; DecisionTreeLeafQ[s], \[Infinity]];
Total[MDLLeafCost[#, numberOfLabels] & /@ leaves]
];

Clear[DecisionTreeNValuesPerIndex];
DecisionTreeNValuesPerIndex[dtree_] :=
Block[{t, tn, ts, td, res = 0},
t = Cases[dtree, s : {___, Symbol | Number | Dot, ___} /; Length[s] == 5, Infinity];
tn = GatherBy[Select[t, #[[4]] === Number &], #[[3]] &];
ts = GatherBy[Select[t, #[[4]] === Symbol &], #[[3]] &];
td = GatherBy[Select[t, #[[4]] === Dot &], #[[3, 1]] &];
Dispatch[Join[
Map[{#[[1, 3]], Number} -> Length[Union[#[[All, 2]]]] &, tn],
Map[{#[[1, 3]], Symbol} -> Length[Union[#[[All, 2]]]] &, ts],
Map[{#[[1, 3, 1]], Dot} -> {Length[Union[Flatten[#[[All, 3, 2]]]]],
Length[Union[#[[All, 2]]]]} &, td]
]]
];

Clear[MDLSplitCost]
MDLSplitCost[dtree_, varIndexToNValuesRules_Dispatch] :=
Block[{res},
res =
Which[
MatchQ[dtree[[1, 4]], Number | Symbol],
Log[((dtree[[1, 3 ;; 4]] /. varIndexToNValuesRules) - 1) /. {0 -> 1}] + Log[Length[varIndexToNValuesRules[[1]]] - 1],
MatchQ[dtree[[1, 4]], Dot],
Log[Total[{dtree[[1, 3, 1]], Dot} /. varIndexToNValuesRules] - 1] +
Log[2^(Length[Select[varIndexToNValuesRules[[1]], #[[1, 2]] === Number &]] - 1)],
True,
(* Fail-safe but should not happen*)
Print["MDLSplitCost::internal error!"];
0
];
N[res + 1]
];

Clear[PruneDecisionTree, MDLPruneDecisionTreeRec]
Options[PruneDecisionTree] = {Method -> "MDL"};
PruneDecisionTree[dtree_, opts : OptionsPattern[]] :=
Block[{resTree, cost},
{resTree, cost} =
MDLPruneDecisionTreeRec[dtree, Length[DecisionTreeLabels[dtree]], DecisionTreeNValuesPerIndex[dtree]];
resTree
];
MDLPruneDecisionTreeRec[dtree_, numberOfLabels_Integer, varIndexToNValuesRules_] :=
Block[{leftSubTree, leftCost, rightSubTree, rightCost, tAsLeafCost, tSplitCost},
If[
DecisionTreeLeafQ[dtree[[1]]],
(* we are at a leaf *)
{dtree, MDLLeafCost[dtree[[1]], numberOfLabels]},
(*ELSE*)
tSplitCost = MDLSplitCost[dtree, varIndexToNValuesRules];
tAsLeafCost = MDLLeafCost[dtree, numberOfLabels];
{leftSubTree, leftCost} =
MDLPruneDecisionTreeRec[dtree[[2]], numberOfLabels, varIndexToNValuesRules];
{rightSubTree, rightCost} =
MDLPruneDecisionTreeRec[dtree[[3]], numberOfLabels, varIndexToNValuesRules];
If[tAsLeafCost + 1 <= tSplitCost + 1 + leftCost + rightCost,
(*Prune child nodes*)
leftSubTree = DecisionTreeCombinedLeaves[dtree];
{{leftSubTree}, tAsLeafCost + 1},
(*ELSE*)
{{dtree[[1]], leftSubTree, rightSubTree}, tSplitCost + 1 + leftCost + rightCost}
]
]
];

End[]

EndPackage[]

0 comments on commit 8099a60

Please sign in to comment.