diff --git a/JavaTriesWithFrequencies.m b/JavaTriesWithFrequencies.m index b9f605e1..037bd492 100644 --- a/JavaTriesWithFrequencies.m +++ b/JavaTriesWithFrequencies.m @@ -526,6 +526,31 @@ Mathematica is (C) Copyright 1988-2017 Wolfram Research, Inc. ] ]; +JavaTrieClassify[tr_, records:(_Dataset|{_List..}), "Decision", opts : OptionsPattern[]] := + First @* Keys @* TakeLargest[1] /@ JavaTrieClassify[tr, records, "Probabilities", opts]; + +JavaTrieClassify[tr_, records:(_Dataset|{_List..}), "Probability" -> class_, opts : OptionsPattern[]] := + Map[Lookup[#, class, 0]&, JavaTrieClassify[tr, records, "Probabilities"] ]; + +JavaTrieClassify[tr_, records:(_Dataset|{_List..}), "TopProbabilities", opts : OptionsPattern[]] := + Map[ Select[#, # > 0 &]&, JavaTrieClassify[tr, records, "Probabilities", opts] ]; + +JavaTrieClassify[tr_, records:(_Dataset|{_List..}), "TopProbabilities" -> n_Integer, opts : OptionsPattern[]] := + Map[TakeLargest[#, UpTo[n]]&, JavaTrieClassify[tr, records, "Probabilities", opts] ]; + +JavaTrieClassify[tr_, records:(_Dataset|{_List..}), "Probabilities", opts:OptionsPattern[] ] := + Block[{clRes, classLabels, stencil}, + + clRes = Map[ JavaTrieClassify[tr, #, "Probabilities", opts] &, Normal@records ]; + + classLabels = Union[Flatten[Normal[Keys /@ clRes]]]; + + stencil = AssociationThread[classLabels -> 0]; + + KeySort[Join[stencil, #]] & /@ clRes + ]; + + End[] (* `Private` *) EndPackage[] \ No newline at end of file