From 78c19b7c847f76ba8edab4d3f54769c5c6e97f2f Mon Sep 17 00:00:00 2001 From: Anton Antonov Date: Wed, 14 Aug 2013 09:06:19 -0400 Subject: [PATCH] Implemented value checks and messages for DecisionTreeClassificationSuccess and DecisionForestClassificationSuccess. Extended their signatures to take selection function for the classification results. --- AVCDecisionTreeForest.m | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/AVCDecisionTreeForest.m b/AVCDecisionTreeForest.m index ac8e978b..8c3c5f57 100644 --- a/AVCDecisionTreeForest.m +++ b/AVCDecisionTreeForest.m @@ -645,28 +645,39 @@ Mathematica is (C) Copyright 1988-2012 Wolfram Research, Inc. Clear[DecisionTreeOrForestClassificationSuccess, DecisionTreeClassificationSuccess, DecisionForestClassificationSuccess] DecisionTreeOrForestClassificationSuccess[classFunc : (DecisionTreeClassify | DecisionForestClassify), dTreeOrForest_, dataArr_?MatrixQ] := DecisionTreeOrForestClassificationSuccess[classFunc, dTreeOrForest, dataArr, Union[dataArr[[All, -1]]]]; -DecisionTreeOrForestClassificationSuccess[classFunc : (DecisionTreeClassify | DecisionForestClassify), dTreeOrForest_, dataArr_?MatrixQ, labels_?VectorQ] := - Block[{guesses, guessStats, tdata, t}, +DecisionTreeOrForestClassificationSuccess[classFunc : (DecisionTreeClassify | DecisionForestClassify), dTreeOrForest_, dataArr_?MatrixQ, labels_?VectorQ, selectionFunc_: First] := + Block[{guesses, guessStats, tdata, t, dataLabels=Union[dataArr[[All, -1]]]}, t = Table[ - (tdata = Select[dataArr, #[[-1]] == lbl &]; - guesses = classFunc[dTreeOrForest, Most[#]][[1, 2]] & /@ tdata; + If[ !MemberQ[ dataLabels, lbl], + If[ TrueQ[classFunc === DecisionTreeClassify], + Message[DecisionTreeClassificationSuccess::nlbl, lbl, dataLabels], + Message[DecisionForestClassificationSuccess::nlbl, lbl, dataLabels], + ]; + {0,0}, + tdata = Select[dataArr, #[[-1]] == lbl &]; + guesses = selectionFunc[classFunc[dTreeOrForest, Most[#]]][[2]] & /@ tdata; guessStats = MapThread[Equal, {guesses, tdata[[All, -1]]}]; - {Count[guessStats, True], Count[guessStats, False]}/Length[tdata] // N) + {Count[guessStats, True], Count[guessStats, False]}/Length[tdata] // N + ] , {lbl, labels}]; t = MapThread[{{#1, True} -> #2[[1]], {#1, False} -> #2[[2]]} &, {labels, t}]; guesses = classFunc[dTreeOrForest, Most[#]][[1, 2]] & /@ dataArr; guessStats = MapThread[Equal, {guesses, dataArr[[All, -1]]}]; - Flatten[#, 1] &@ - Join[t, {{All, True} -> (Count[guessStats, True]/Length[dataArr] // N), - {All, False} -> (Count[guessStats, False]/Length[dataArr] // N)}] + Flatten[#, 1] &@ Join[t, {{All, True} -> (Count[guessStats, True]/Length[dataArr] // N), {All, False} -> (Count[guessStats, False]/Length[dataArr] // N)}] ]; +DecisionTreeClassificationSuccess::nlbl = "The specified label `1` is not one of the data array labels `2`." +DecisionTreeClassificationSuccess::wsig = "The first two arguments are expected to be a decision tree and a data array."; DecisionTreeClassificationSuccess[dTreeOrForest_, dataArr_?MatrixQ, x___] := DecisionTreeOrForestClassificationSuccess[DecisionTreeClassify, dTreeOrForest, dataArr, x]; +DecisionTreeClassificationSuccess[___]:=Message[DecisionTreeClassificationSuccess::wsig]; +DecisionForestClassificationSuccess::nlbl = "The specified label `1` is not one of the data array labels `2`." +DecisionForestClassificationSuccess::wsig = "The first two arguments are expected to be a decision forest and a data array."; DecisionForestClassificationSuccess[dTreeOrForest_, dataArr_?MatrixQ, x___] := DecisionTreeOrForestClassificationSuccess[DecisionForestClassify, dTreeOrForest, dataArr, x]; +DecisionForestClassificationSuccess[___]:=Message[DecisionForestClassificationSuccess::wsig]; End[]