Skip to content

Commit

Permalink
Added utility functions for constructing memorizer and recaller unions
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Sep 19, 2023
1 parent 69fd0f7 commit ab7bdb4
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 3 deletions.
13 changes: 13 additions & 0 deletions sklearn2pmml/cross_reference/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import FeatureUnion

import numpy

def make_memorizer_union(memory, names):
return FeatureUnion([
("memorizer", Memorizer(memory, names)),
("passthrough", "passthrough"),
])

def make_recaller_union(memory, names):
return FeatureUnion([
("recaller", Recaller(memory, names)),
("passthrough", "passthrough")
])

class _BaseMemoryManager(BaseEstimator, TransformerMixin):

def __init__(self, memory, names):
Expand Down
29 changes: 26 additions & 3 deletions sklearn2pmml/cross_reference/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pandas import DataFrame
from sklearn2pmml.cross_reference import Memorizer, Recaller
from sklearn2pmml.cross_reference import make_memorizer_union, make_recaller_union, Memorizer, Recaller
from unittest import TestCase

import numpy
Expand Down Expand Up @@ -32,19 +32,19 @@ def test_fit_transform(self):
class RecallerTest(TestCase):

def test_fit_transform(self):
X = numpy.empty((100, 5), dtype = str)

memory = {
"int": [-1, 1]
}
recaller = Recaller(memory, ["int"])
X = numpy.empty((2, 1), dtype = str)
Xt = recaller.fit_transform(X)
self.assertEqual((2, 1), Xt.shape)
self.assertEqual([-1, 1], Xt[:, 0].tolist())

memory = DataFrame([[1, 1.0, "one"], [2, 2.0, "two"], [3, 3.0, "three"]], columns = ["int", "float", "str"])
self.assertEqual((3, 3), memory.shape)
recaller = Recaller(memory, ["int"])
X = numpy.empty((3, 5), dtype = str)
Xt = recaller.fit_transform(X)
self.assertEqual((3, 1), Xt.shape)
self.assertEqual([1, 2, 3], Xt[:, 0].tolist())
Expand All @@ -54,3 +54,26 @@ def test_fit_transform(self):
self.assertEqual([1, 2, 3], Xt[:, 0].tolist())
self.assertEqual([1.0, 2.0, 3.0], Xt[:, 1].tolist())
self.assertEqual(["one", "two", "three"], Xt[:, 2].tolist())

class FunctionTest(TestCase):

def test_make_memorizer_union(self):
memory = dict()
self.assertEqual(0, len(memory))
memorizer = make_memorizer_union(memory, ["int"])
X = numpy.asarray([[-1], [1]])
Xt = memorizer.fit_transform(X)
self.assertEqual((2, 1), Xt.shape)
self.assertEqual(1, len(memory))
self.assertEqual([-1, 1], memory["int"].tolist())

def test_make_recaller_union(self):
memory = {
"int": [-1, 1]
}
recaller = make_recaller_union(memory, ["int"])
X = numpy.full((2, 1), 0, dtype = int)
Xt = recaller.fit_transform(X)
self.assertEqual((2, 2), Xt.shape)
self.assertEqual([-1, 1], Xt[:, 0].tolist())
self.assertEqual([0, 0], Xt[:, 1].tolist())

0 comments on commit ab7bdb4

Please sign in to comment.