diff --git a/sklearn2pmml/cross_reference/__init__.py b/sklearn2pmml/cross_reference/__init__.py index 997a87c..2b13428 100644 --- a/sklearn2pmml/cross_reference/__init__.py +++ b/sklearn2pmml/cross_reference/__init__.py @@ -30,9 +30,9 @@ def __getstate__(self): def __setstate__(self, state): self.__dict__.update(state) -def make_memorizer_union(memory, names): +def make_memorizer_union(memory, names, transform_only = True): return FeatureUnion([ - ("memorizer", Memorizer(memory, names)), + ("memorizer", Memorizer(memory, names, transform_only = transform_only)), ("identity", IdentityTransformer()), ]) @@ -52,15 +52,13 @@ def __init__(self, memory, names): class Memorizer(_BaseMemoryManager): - def __init__(self, memory, names): + def __init__(self, memory, names, transform_only = True): super(Memorizer, self).__init__(memory, names) + self.transform_only = transform_only - def fit(self, X, y = None): + def memorize(self, X): if X.shape[1] != len(self.names): raise ValueError() - return self - - def transform(self, X): for idx, name in enumerate(self.names): if isinstance(X, DataFrame): x = X.iloc[:, idx] @@ -69,17 +67,28 @@ def transform(self, X): self.memory[name] = x.copy() return numpy.empty(shape = (X.shape[0], 0), dtype = int) + def fit(self, X, y = None): + if not self.transform_only: + self.memorize(X) + return self + + def transform(self, X): + return self.memorize(X) + class Recaller(_BaseMemoryManager): def __init__(self, memory, names): super(Recaller, self).__init__(memory, names) - def fit(self, X, y = None): - return self - - def transform(self, X): + def recall(self, X): result = [] for idx, name in enumerate(self.names): x = self.memory[name] result.append(x.copy()) return numpy.asarray(result).T + + def fit(self, X, y = None): + return self + + def transform(self, X): + return self.recall(X) diff --git a/sklearn2pmml/cross_reference/tests/__init__.py b/sklearn2pmml/cross_reference/tests/__init__.py index e63b93f..0dfcb5f 100644 --- a/sklearn2pmml/cross_reference/tests/__init__.py +++ b/sklearn2pmml/cross_reference/tests/__init__.py @@ -40,6 +40,12 @@ def test_transform(self): self.assertEqual(1, len(memory)) self.assertEqual([-1, 1], memory["int"].tolist()) + memory = dict() + self.assertEquals(0, len(memory)) + memorizer = Memorizer(memory, ["int"], transform_only = False) + memorizer.fit(X) + self.assertEquals(1, len(memory)) + memory = DataFrame() self.assertEqual((0, 0), memory.shape) memorizer = Memorizer(memory, ["int", "float", "str"])