diff --git a/sklearn2pmml/cross_reference/__init__.py b/sklearn2pmml/cross_reference/__init__.py index 41f20e3..99d27bb 100644 --- a/sklearn2pmml/cross_reference/__init__.py +++ b/sklearn2pmml/cross_reference/__init__.py @@ -1,7 +1,20 @@ from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.pipeline import FeatureUnion import numpy +def make_memorizer_union(memory, names): + return FeatureUnion([ + ("memorizer", Memorizer(memory, names)), + ("passthrough", "passthrough"), + ]) + +def make_recaller_union(memory, names): + return FeatureUnion([ + ("recaller", Recaller(memory, names)), + ("passthrough", "passthrough") + ]) + class _BaseMemoryManager(BaseEstimator, TransformerMixin): def __init__(self, memory, names): diff --git a/sklearn2pmml/cross_reference/tests/__init__.py b/sklearn2pmml/cross_reference/tests/__init__.py index 2f65453..b344b8a 100644 --- a/sklearn2pmml/cross_reference/tests/__init__.py +++ b/sklearn2pmml/cross_reference/tests/__init__.py @@ -1,5 +1,5 @@ from pandas import DataFrame -from sklearn2pmml.cross_reference import Memorizer, Recaller +from sklearn2pmml.cross_reference import make_memorizer_union, make_recaller_union, Memorizer, Recaller from unittest import TestCase import numpy @@ -32,12 +32,11 @@ def test_fit_transform(self): class RecallerTest(TestCase): def test_fit_transform(self): - X = numpy.empty((100, 5), dtype = str) - memory = { "int": [-1, 1] } recaller = Recaller(memory, ["int"]) + X = numpy.empty((2, 1), dtype = str) Xt = recaller.fit_transform(X) self.assertEqual((2, 1), Xt.shape) self.assertEqual([-1, 1], Xt[:, 0].tolist()) @@ -45,6 +44,7 @@ def test_fit_transform(self): memory = DataFrame([[1, 1.0, "one"], [2, 2.0, "two"], [3, 3.0, "three"]], columns = ["int", "float", "str"]) self.assertEqual((3, 3), memory.shape) recaller = Recaller(memory, ["int"]) + X = numpy.empty((3, 5), dtype = str) Xt = recaller.fit_transform(X) self.assertEqual((3, 1), Xt.shape) self.assertEqual([1, 2, 3], Xt[:, 0].tolist()) @@ -54,3 +54,26 @@ def test_fit_transform(self): self.assertEqual([1, 2, 3], Xt[:, 0].tolist()) self.assertEqual([1.0, 2.0, 3.0], Xt[:, 1].tolist()) self.assertEqual(["one", "two", "three"], Xt[:, 2].tolist()) + +class FunctionTest(TestCase): + + def test_make_memorizer_union(self): + memory = dict() + self.assertEqual(0, len(memory)) + memorizer = make_memorizer_union(memory, ["int"]) + X = numpy.asarray([[-1], [1]]) + Xt = memorizer.fit_transform(X) + self.assertEqual((2, 1), Xt.shape) + self.assertEqual(1, len(memory)) + self.assertEqual([-1, 1], memory["int"].tolist()) + + def test_make_recaller_union(self): + memory = { + "int": [-1, 1] + } + recaller = make_recaller_union(memory, ["int"]) + X = numpy.full((2, 1), 0, dtype = int) + Xt = recaller.fit_transform(X) + self.assertEqual((2, 2), Xt.shape) + self.assertEqual([-1, 1], Xt[:, 0].tolist()) + self.assertEqual([0, 0], Xt[:, 1].tolist())