Skip to content

Commit

Permalink
Added 'Memorizer.transform_only' attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Sep 23, 2023
1 parent f51d424 commit 6ad0853
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 11 deletions.
31 changes: 20 additions & 11 deletions sklearn2pmml/cross_reference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def __getstate__(self):
def __setstate__(self, state):
self.__dict__.update(state)

def make_memorizer_union(memory, names):
def make_memorizer_union(memory, names, transform_only = True):
return FeatureUnion([
("memorizer", Memorizer(memory, names)),
("memorizer", Memorizer(memory, names, transform_only = transform_only)),
("identity", IdentityTransformer()),
])

Expand All @@ -52,15 +52,13 @@ def __init__(self, memory, names):

class Memorizer(_BaseMemoryManager):

def __init__(self, memory, names):
def __init__(self, memory, names, transform_only = True):
super(Memorizer, self).__init__(memory, names)
self.transform_only = transform_only

def fit(self, X, y = None):
def memorize(self, X):
if X.shape[1] != len(self.names):
raise ValueError()
return self

def transform(self, X):
for idx, name in enumerate(self.names):
if isinstance(X, DataFrame):
x = X.iloc[:, idx]
Expand All @@ -69,17 +67,28 @@ def transform(self, X):
self.memory[name] = x.copy()
return numpy.empty(shape = (X.shape[0], 0), dtype = int)

def fit(self, X, y = None):
if not self.transform_only:
self.memorize(X)
return self

def transform(self, X):
return self.memorize(X)

class Recaller(_BaseMemoryManager):

def __init__(self, memory, names):
super(Recaller, self).__init__(memory, names)

def fit(self, X, y = None):
return self

def transform(self, X):
def recall(self, X):
result = []
for idx, name in enumerate(self.names):
x = self.memory[name]
result.append(x.copy())
return numpy.asarray(result).T

def fit(self, X, y = None):
return self

def transform(self, X):
return self.recall(X)
6 changes: 6 additions & 0 deletions sklearn2pmml/cross_reference/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ def test_transform(self):
self.assertEqual(1, len(memory))
self.assertEqual([-1, 1], memory["int"].tolist())

memory = dict()
self.assertEquals(0, len(memory))
memorizer = Memorizer(memory, ["int"], transform_only = False)
memorizer.fit(X)
self.assertEquals(1, len(memory))

memory = DataFrame()
self.assertEqual((0, 0), memory.shape)
memorizer = Memorizer(memory, ["int", "float", "str"])
Expand Down

0 comments on commit 6ad0853

Please sign in to comment.