Skip to content

Commit

Permalink
Added 'position' parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Nov 17, 2023
1 parent cd5bf09 commit f87f1d3
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
24 changes: 18 additions & 6 deletions sklearn2pmml/cross_reference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,29 @@ def clear(self):
else:
self.data.clear()

def make_memorizer_union(memory, names, transform_only = True):
return FeatureUnion([
def _set_position(transformers, position):
if position == "first":
return transformers
elif position == "last":
return transformers[1:] + transformers[0:1]
else:
raise ValueError()

def make_memorizer_union(memory, names, transform_only = True, position = "first"):
transformers = [
("memorizer", Memorizer(memory, names, transform_only = transform_only)),
("identity", IdentityTransformer()),
])
]
transformers = _set_position(transformers, position = position)
return FeatureUnion(transformers)

def make_recaller_union(memory, names):
return FeatureUnion([
def make_recaller_union(memory, names, position = "first"):
transformers = [
("recaller", Recaller(memory, names)),
("identity", IdentityTransformer())
])
]
transformers = _set_position(transformers, position = position)
return FeatureUnion(transformers)

class _BaseMemoryManager(BaseEstimator, TransformerMixin):

Expand Down
14 changes: 10 additions & 4 deletions sklearn2pmml/cross_reference/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from sklearn.base import clone
from sklearn.pipeline import make_pipeline
from sklearn2pmml.cross_reference import make_memorizer_union, make_recaller_union, Memory, Memorizer, Recaller
from sklearn2pmml.preprocessing import IdentityTransformer
from unittest import TestCase

import copy
Expand Down Expand Up @@ -137,9 +138,14 @@ class FunctionTest(TestCase):
def test_make_memorizer_union(self):
memory = dict()
self.assertEqual(0, len(memory))
memorizer = make_memorizer_union(memory, ["int"])
memorizer_union = make_memorizer_union(memory, ["int"], position = "first")
self.assertIsInstance(memorizer_union.transformer_list[0][1], Memorizer)
self.assertIsInstance(memorizer_union.transformer_list[1][1], IdentityTransformer)
memorizer_union = make_memorizer_union(memory, ["int"], position = "last")
self.assertIsInstance(memorizer_union.transformer_list[0][1], IdentityTransformer)
self.assertIsInstance(memorizer_union.transformer_list[1][1], Memorizer)
X = numpy.asarray([[-1], [1]])
Xt = memorizer.fit_transform(X)
Xt = memorizer_union.fit_transform(X)
self.assertEqual((2, 1), Xt.shape)
self.assertEqual(1, len(memory))
self.assertEqual([-1, 1], memory["int"].tolist())
Expand All @@ -148,9 +154,9 @@ def test_make_recaller_union(self):
memory = {
"int": [-1, 1]
}
recaller = make_recaller_union(memory, ["int"])
recaller_union = make_recaller_union(memory, ["int"])
X = numpy.full((2, 1), 0, dtype = int)
Xt = recaller.fit_transform(X)
Xt = recaller_union.fit_transform(X)
self.assertEqual((2, 2), Xt.shape)
self.assertEqual([-1, 1], Xt[:, 0].tolist())
self.assertEqual([0, 0], Xt[:, 1].tolist())

0 comments on commit f87f1d3

Please sign in to comment.