Skip to content

Commit

Permalink
Fixed the cloneability of memory manager classes
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Sep 23, 2023
1 parent 6ad0853 commit ec4c80c
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
14 changes: 11 additions & 3 deletions sklearn2pmml/cross_reference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@ def __getitem__(self, key):
def __setitem__(self, key, value):
self.data[key] = value

def clear(self):
self.data.clear()

def __len__(self):
return len(self.data)

def __copy__(self):
return self

def __deepcopy__(self, memo):
result = self
memo[id(self)] = result
return result

def __getstate__(self):
state = self.__dict__.copy()
state["data"] = dict()
Expand All @@ -30,6 +35,9 @@ def __getstate__(self):
def __setstate__(self, state):
self.__dict__.update(state)

def clear(self):
self.data.clear()

def make_memorizer_union(memory, names, transform_only = True):
return FeatureUnion([
("memorizer", Memorizer(memory, names, transform_only = transform_only)),
Expand Down
23 changes: 23 additions & 0 deletions sklearn2pmml/cross_reference/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from pandas import DataFrame
from sklearn.base import clone
from sklearn.pipeline import make_pipeline
from sklearn2pmml.cross_reference import make_memorizer_union, make_recaller_union, Memory, Memorizer, Recaller
from unittest import TestCase

import copy
import numpy
import pickle

Expand All @@ -19,6 +21,13 @@ def test_item_assignment(self):
memory.clear()
self.assertEqual(0, len(memory))

def test_copy(self):
memory = Memory()
memory_copy = copy.copy(memory)
self.assertIs(memory, memory_copy)
memory_deepcopy = copy.deepcopy(memory)
self.assertIs(memory, memory_deepcopy)

def test_pickle(self):
memory = Memory()
memory["int"] = [1, 2, 3]
Expand Down Expand Up @@ -60,6 +69,13 @@ def test_transform(self):
self.assertEqual([1.0, 2.0, 3.0], memory["float"].astype(float).tolist())
self.assertEqual(["one", "two", "three"], memory["str"].tolist())

def test_clone(self):
memory = Memory()
memorizer = Memorizer(memory, ["flag"])
memorizer_clone = clone(memorizer)
self.assertIsNot(memorizer, memorizer_clone)
self.assertIs(memorizer.memory, memorizer_clone.memory)

class RecallerTest(TestCase):

def test_transform(self):
Expand Down Expand Up @@ -90,6 +106,13 @@ def test_transform(self):
self.assertEqual([1.0, 2.0, 3.0], Xt[:, 1].tolist())
self.assertEqual(["one", "two", "three"], Xt[:, 2].tolist())

def test_clone(self):
memory = Memory()
recaller = Recaller(memory, ["flag"])
recaller_clone = clone(recaller)
self.assertIsNot(recaller, recaller_clone)
self.assertIs(recaller.memory, recaller_clone.memory)

class FunctionTest(TestCase):

def test_make_memorizer_union(self):
Expand Down

0 comments on commit ec4c80c

Please sign in to comment.