From f51d424f0932953fbbb228ca24e23941c082ec4f Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Fri, 22 Sep 2023 23:02:38 +0300 Subject: [PATCH] Added 'Memory' class --- sklearn2pmml/cross_reference/__init__.py | 25 +++++++++++ .../cross_reference/tests/__init__.py | 41 ++++++++++++++++--- 2 files changed, 60 insertions(+), 6 deletions(-) diff --git a/sklearn2pmml/cross_reference/__init__.py b/sklearn2pmml/cross_reference/__init__.py index afe6309..997a87c 100644 --- a/sklearn2pmml/cross_reference/__init__.py +++ b/sklearn2pmml/cross_reference/__init__.py @@ -5,6 +5,31 @@ import numpy +class Memory(object): + + def __init__(self): + self.data = dict() + + def __getitem__(self, key): + return self.data[key] + + def __setitem__(self, key, value): + self.data[key] = value + + def clear(self): + self.data.clear() + + def __len__(self): + return len(self.data) + + def __getstate__(self): + state = self.__dict__.copy() + state["data"] = dict() + return state + + def __setstate__(self, state): + self.__dict__.update(state) + def make_memorizer_union(memory, names): return FeatureUnion([ ("memorizer", Memorizer(memory, names)), diff --git a/sklearn2pmml/cross_reference/tests/__init__.py b/sklearn2pmml/cross_reference/tests/__init__.py index e4839b7..e63b93f 100644 --- a/sklearn2pmml/cross_reference/tests/__init__.py +++ b/sklearn2pmml/cross_reference/tests/__init__.py @@ -1,8 +1,30 @@ from pandas import DataFrame -from sklearn2pmml.cross_reference import make_memorizer_union, make_recaller_union, Memorizer, Recaller +from sklearn.pipeline import make_pipeline +from sklearn2pmml.cross_reference import make_memorizer_union, make_recaller_union, Memory, Memorizer, Recaller from unittest import TestCase import numpy +import pickle + +class MemoryTest(TestCase): + + def test_item_assignment(self): + memory = Memory() + self.assertEqual(0, len(memory)) + memory["int"] = [1, 2, 3] + self.assertEqual(1, len(memory)) + self.assertEqual([1, 2, 3], memory["int"]) + with self.assertRaises(KeyError): + memory["float"] + memory.clear() + self.assertEqual(0, len(memory)) + + def test_pickle(self): + memory = Memory() + memory["int"] = [1, 2, 3] + self.assertEqual(1, len(memory)) + memory_clone = pickle.loads(pickle.dumps(memory)) + self.assertEqual(0, len(memory_clone)) class MemorizerTest(TestCase): @@ -11,7 +33,9 @@ def test_transform(self): self.assertEqual(0, len(memory)) memorizer = Memorizer(memory, ["int"]) X = numpy.asarray([[-1], [1]]) - Xt = memorizer.fit_transform(X) + memorizer.fit(X) + self.assertEquals(0, len(memory)) + Xt = memorizer.transform(X) self.assertEqual((2, 0), Xt.shape) self.assertEqual(1, len(memory)) self.assertEqual([-1, 1], memory["int"].tolist()) @@ -20,7 +44,8 @@ def test_transform(self): self.assertEqual((0, 0), memory.shape) memorizer = Memorizer(memory, ["int", "float", "str"]) X = numpy.asarray([[1, 1.0, "one"], [2, 2.0, "two"], [3, 3.0, "three"]]) - Xt = memorizer.fit_transform(X) + pipeline = make_pipeline(memorizer) + Xt = pipeline.fit_transform(X) self.assertEqual((3, 0), Xt.shape) self.assertEqual((3, 3), memory.shape) self.assertEqual(["1", "2", "3"], memory["int"].tolist()) @@ -37,19 +62,23 @@ def test_transform(self): } recaller = Recaller(memory, ["int"]) X = numpy.empty((2, 1), dtype = str) - Xt = recaller.fit_transform(X) + recaller.fit(X) + self.assertEqual(1, len(memory)) + Xt = recaller.transform(X) self.assertEqual((2, 1), Xt.shape) self.assertEqual([-1, 1], Xt[:, 0].tolist()) memory = DataFrame([[1, 1.0, "one"], [2, 2.0, "two"], [3, 3.0, "three"]], columns = ["int", "float", "str"]) self.assertEqual((3, 3), memory.shape) recaller = Recaller(memory, ["int"]) + pipeline = make_pipeline(recaller) X = numpy.empty((3, 5), dtype = str) - Xt = recaller.fit_transform(X) + Xt = pipeline.fit_transform(X) self.assertEqual((3, 1), Xt.shape) self.assertEqual([1, 2, 3], Xt[:, 0].tolist()) recaller = Recaller(memory, ["int", "float", "str"]) - Xt = recaller.fit_transform(X) + pipeline = make_pipeline(recaller) + Xt = pipeline.fit_transform(X) self.assertEqual((3, 3), Xt.shape) self.assertEqual([1, 2, 3], Xt[:, 0].tolist()) self.assertEqual([1.0, 2.0, 3.0], Xt[:, 1].tolist())