diff --git a/sklearn2pmml/cross_reference/__init__.py b/sklearn2pmml/cross_reference/__init__.py index 2b13428..7b61049 100644 --- a/sklearn2pmml/cross_reference/__init__.py +++ b/sklearn2pmml/cross_reference/__init__.py @@ -16,12 +16,17 @@ def __getitem__(self, key): def __setitem__(self, key, value): self.data[key] = value - def clear(self): - self.data.clear() - def __len__(self): return len(self.data) + def __copy__(self): + return self + + def __deepcopy__(self, memo): + result = self + memo[id(self)] = result + return result + def __getstate__(self): state = self.__dict__.copy() state["data"] = dict() @@ -30,6 +35,9 @@ def __getstate__(self): def __setstate__(self, state): self.__dict__.update(state) + def clear(self): + self.data.clear() + def make_memorizer_union(memory, names, transform_only = True): return FeatureUnion([ ("memorizer", Memorizer(memory, names, transform_only = transform_only)), diff --git a/sklearn2pmml/cross_reference/tests/__init__.py b/sklearn2pmml/cross_reference/tests/__init__.py index 0dfcb5f..a2178ea 100644 --- a/sklearn2pmml/cross_reference/tests/__init__.py +++ b/sklearn2pmml/cross_reference/tests/__init__.py @@ -1,8 +1,10 @@ from pandas import DataFrame +from sklearn.base import clone from sklearn.pipeline import make_pipeline from sklearn2pmml.cross_reference import make_memorizer_union, make_recaller_union, Memory, Memorizer, Recaller from unittest import TestCase +import copy import numpy import pickle @@ -19,6 +21,13 @@ def test_item_assignment(self): memory.clear() self.assertEqual(0, len(memory)) + def test_copy(self): + memory = Memory() + memory_copy = copy.copy(memory) + self.assertIs(memory, memory_copy) + memory_deepcopy = copy.deepcopy(memory) + self.assertIs(memory, memory_deepcopy) + def test_pickle(self): memory = Memory() memory["int"] = [1, 2, 3] @@ -60,6 +69,13 @@ def test_transform(self): self.assertEqual([1.0, 2.0, 3.0], memory["float"].astype(float).tolist()) self.assertEqual(["one", "two", "three"], memory["str"].tolist()) + def test_clone(self): + memory = Memory() + memorizer = Memorizer(memory, ["flag"]) + memorizer_clone = clone(memorizer) + self.assertIsNot(memorizer, memorizer_clone) + self.assertIs(memorizer.memory, memorizer_clone.memory) + class RecallerTest(TestCase): def test_transform(self): @@ -90,6 +106,13 @@ def test_transform(self): self.assertEqual([1.0, 2.0, 3.0], Xt[:, 1].tolist()) self.assertEqual(["one", "two", "three"], Xt[:, 2].tolist()) + def test_clone(self): + memory = Memory() + recaller = Recaller(memory, ["flag"]) + recaller_clone = clone(recaller) + self.assertIsNot(recaller, recaller_clone) + self.assertIs(recaller.memory, recaller_clone.memory) + class FunctionTest(TestCase): def test_make_memorizer_union(self):