From 69fd0f712f1fd9b3ac40ab1224afe7a8330860c7 Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Tue, 19 Sep 2023 16:34:41 +0300 Subject: [PATCH] Added 'Memorizer' and 'Recaller' transformation types --- setup.py | 1 + sklearn2pmml/cross_reference/__init__.py | 42 ++++++++++++++ .../cross_reference/tests/__init__.py | 56 +++++++++++++++++++ 3 files changed, 99 insertions(+) create mode 100644 sklearn2pmml/cross_reference/__init__.py create mode 100644 sklearn2pmml/cross_reference/tests/__init__.py diff --git a/setup.py b/setup.py index e54316d..2deab95 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,7 @@ ], packages = [ "sklearn2pmml", + "sklearn2pmml.cross_reference", "sklearn2pmml.decoration", "sklearn2pmml.ensemble", "sklearn2pmml.expression", diff --git a/sklearn2pmml/cross_reference/__init__.py b/sklearn2pmml/cross_reference/__init__.py new file mode 100644 index 0000000..41f20e3 --- /dev/null +++ b/sklearn2pmml/cross_reference/__init__.py @@ -0,0 +1,42 @@ +from sklearn.base import BaseEstimator, TransformerMixin + +import numpy + +class _BaseMemoryManager(BaseEstimator, TransformerMixin): + + def __init__(self, memory, names): + self.memory = memory + if not isinstance(names, list): + raise TypeError() + self.names = names + +class Memorizer(_BaseMemoryManager): + + def __init__(self, memory, names): + super(Memorizer, self).__init__(memory, names) + + def fit(self, X, y = None): + if X.shape[1] != len(self.names): + raise ValueError() + return self + + def transform(self, X): + for idx, name in enumerate(self.names): + x = X[:, idx] + self.memory[name] = x.copy() + return numpy.empty(shape = (X.shape[0], 0), dtype = int) + +class Recaller(_BaseMemoryManager): + + def __init__(self, memory, names): + super(Recaller, self).__init__(memory, names) + + def fit(self, X, y = None): + return self + + def transform(self, X): + result = [] + for idx, name in enumerate(self.names): + x = self.memory[name] + result.append(x.copy()) + return numpy.asarray(result).T diff --git a/sklearn2pmml/cross_reference/tests/__init__.py b/sklearn2pmml/cross_reference/tests/__init__.py new file mode 100644 index 0000000..2f65453 --- /dev/null +++ b/sklearn2pmml/cross_reference/tests/__init__.py @@ -0,0 +1,56 @@ +from pandas import DataFrame +from sklearn2pmml.cross_reference import Memorizer, Recaller +from unittest import TestCase + +import numpy + +class MemorizerTest(TestCase): + + def test_fit_transform(self): + memory = dict() + self.assertEqual(0, len(memory)) + memorizer = Memorizer(memory, ["int"]) + X = numpy.asarray([[-1], [1]]) + Xt = memorizer.fit_transform(X) + self.assertEqual((2, 0), Xt.shape) + self.assertEqual(1, len(memory)) + self.assertEqual([-1, 1], memory["int"].tolist()) + + memory = DataFrame() + self.assertEqual((0, 0), memory.shape) + memorizer = Memorizer(memory, ["int", "float", "str"]) + X = numpy.asarray([[1, 1.0, "one"], [2, 2.0, "two"], [3, 3.0, "three"]]) + Xt = memorizer.fit_transform(X) + self.assertEqual((3, 0), Xt.shape) + self.assertEqual((3, 3), memory.shape) + self.assertEqual(["1", "2", "3"], memory["int"].tolist()) + self.assertEqual([1, 2, 3], memory["int"].astype(int).tolist()) + self.assertEqual([str(1.0), str(2.0), str(3.0)], memory["float"].tolist()) + self.assertEqual([1.0, 2.0, 3.0], memory["float"].astype(float).tolist()) + self.assertEqual(["one", "two", "three"], memory["str"].tolist()) + +class RecallerTest(TestCase): + + def test_fit_transform(self): + X = numpy.empty((100, 5), dtype = str) + + memory = { + "int": [-1, 1] + } + recaller = Recaller(memory, ["int"]) + Xt = recaller.fit_transform(X) + self.assertEqual((2, 1), Xt.shape) + self.assertEqual([-1, 1], Xt[:, 0].tolist()) + + memory = DataFrame([[1, 1.0, "one"], [2, 2.0, "two"], [3, 3.0, "three"]], columns = ["int", "float", "str"]) + self.assertEqual((3, 3), memory.shape) + recaller = Recaller(memory, ["int"]) + Xt = recaller.fit_transform(X) + self.assertEqual((3, 1), Xt.shape) + self.assertEqual([1, 2, 3], Xt[:, 0].tolist()) + recaller = Recaller(memory, ["int", "float", "str"]) + Xt = recaller.fit_transform(X) + self.assertEqual((3, 3), Xt.shape) + self.assertEqual([1, 2, 3], Xt[:, 0].tolist()) + self.assertEqual([1.0, 2.0, 3.0], Xt[:, 1].tolist()) + self.assertEqual(["one", "two", "three"], Xt[:, 2].tolist())