From 362196a9cabfb4ab0b90ce92aad8dcdcddc05495 Mon Sep 17 00:00:00 2001 From: Justin Harris Date: Tue, 14 Aug 2018 14:27:33 -0400 Subject: [PATCH 1/5] pad_sequences: Add support for string value. --- keras_preprocessing/sequence.py | 5 +++-- tests/sequence_test.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/keras_preprocessing/sequence.py b/keras_preprocessing/sequence.py index 5244651b..f918e32e 100644 --- a/keras_preprocessing/sequence.py +++ b/keras_preprocessing/sequence.py @@ -39,12 +39,13 @@ def pad_sequences(sequences, maxlen=None, dtype='int32', sequences: List of lists, where each element is a sequence. maxlen: Int, maximum length of all sequences. dtype: Type of the output sequences. + To pad sequences with variable length strings, you can use `object`. padding: String, 'pre' or 'post': pad either before or after each sequence. truncating: String, 'pre' or 'post': remove values from sequences larger than `maxlen`, either at the beginning or at the end of the sequences. - value: Float, padding value. + value: Float or String, padding value. # Returns x: Numpy array with shape `(len(sequences), maxlen)` @@ -74,7 +75,7 @@ def pad_sequences(sequences, maxlen=None, dtype='int32', sample_shape = np.asarray(s).shape[1:] break - x = (np.ones((num_samples, maxlen) + sample_shape) * value).astype(dtype) + x = np.full((num_samples, maxlen) + sample_shape, value, dtype=dtype) for idx, s in enumerate(sequences): if not len(s): continue # empty list/array was found diff --git a/tests/sequence_test.py b/tests/sequence_test.py index a8638b55..1c260d7c 100644 --- a/tests/sequence_test.py +++ b/tests/sequence_test.py @@ -2,6 +2,7 @@ import pytest import numpy as np from numpy.testing import assert_allclose +from numpy.testing import assert_equal from numpy.testing import assert_raises import keras @@ -35,6 +36,24 @@ def test_pad_sequences(): assert_allclose(b, [[1, 1, 1], [1, 1, 2], [1, 2, 3]]) +def test_pad_sequences_str(): + a = [['1'], ['1', '2'], ['1', '2', '3']] + + # test padding + b = sequence.pad_sequences(a, maxlen=3, padding='pre', value='pad', dtype=object) + assert_equal(b, [['pad', 'pad', '1'], ['pad', '1', '2'], ['1', '2', '3']]) + b = sequence.pad_sequences(a, maxlen=3, padding='post', value='pad', dtype=' Date: Fri, 17 Aug 2018 10:28:46 -0400 Subject: [PATCH 2/5] pad_sequences: Add warning if `dtype` is incompatible for string `value`. --- keras_preprocessing/sequence.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/keras_preprocessing/sequence.py b/keras_preprocessing/sequence.py index f918e32e..e0658d27 100644 --- a/keras_preprocessing/sequence.py +++ b/keras_preprocessing/sequence.py @@ -9,6 +9,8 @@ import random import json from six.moves import range +import six +import warnings from . import get_keras_submodule @@ -75,6 +77,13 @@ def pad_sequences(sequences, maxlen=None, dtype='int32', sample_shape = np.asarray(s).shape[1:] break + if isinstance(value, six.string_types) and dtype != object \ + and not np.issubdtype(dtype, np.str_) \ + and not np.issubdtype(dtype, np.unicode_): + warnings.warn("`dtype` {} is not compatible with `value`'s type: {}\n" + "You should set `dtype=object` for variable length strings." + .format(dtype, type(value))) + x = np.full((num_samples, maxlen) + sample_shape, value, dtype=dtype) for idx, s in enumerate(sequences): if not len(s): From 3c3cc362ee8c801b5f8e7585608ef610138bd0b2 Mon Sep 17 00:00:00 2001 From: Justin Harris Date: Tue, 21 Aug 2018 13:05:26 -0400 Subject: [PATCH 3/5] pad_sequences: Throw error with invalid dtype when pad is a str. --- keras_preprocessing/sequence.py | 12 +++++------- tests/sequence_test.py | 9 +++++++++ 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/keras_preprocessing/sequence.py b/keras_preprocessing/sequence.py index e0658d27..78af2a0d 100644 --- a/keras_preprocessing/sequence.py +++ b/keras_preprocessing/sequence.py @@ -10,7 +10,6 @@ import json from six.moves import range import six -import warnings from . import get_keras_submodule @@ -77,12 +76,11 @@ def pad_sequences(sequences, maxlen=None, dtype='int32', sample_shape = np.asarray(s).shape[1:] break - if isinstance(value, six.string_types) and dtype != object \ - and not np.issubdtype(dtype, np.str_) \ - and not np.issubdtype(dtype, np.unicode_): - warnings.warn("`dtype` {} is not compatible with `value`'s type: {}\n" - "You should set `dtype=object` for variable length strings." - .format(dtype, type(value))) + is_dtype_str = np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.unicode_) + if isinstance(value, six.string_types) and dtype != object and not is_dtype_str: + raise ValueError("`dtype` {} is not compatible with `value`'s type: {}\n" + "You should set `dtype=object` for variable length strings." + .format(dtype, type(value))) x = np.full((num_samples, maxlen) + sample_shape, value, dtype=dtype) for idx, s in enumerate(sequences): diff --git a/tests/sequence_test.py b/tests/sequence_test.py index 1c260d7c..dc51a483 100644 --- a/tests/sequence_test.py +++ b/tests/sequence_test.py @@ -53,6 +53,15 @@ def test_pad_sequences_str(): dtype='" + "\nYou should set `dtype=object`" + " for variable length strings.",) + def test_pad_sequences_vector(): a = [[[1, 1]], From 56a506a52ed8ae7f288edecbd91142a56fc4f976 Mon Sep 17 00:00:00 2001 From: Justin Harris Date: Thu, 23 Aug 2018 20:02:51 -0400 Subject: [PATCH 4/5] test/pad_sequences: Use pytest.raises to check for exceptions. --- tests/sequence_test.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/tests/sequence_test.py b/tests/sequence_test.py index dc51a483..9871064f 100644 --- a/tests/sequence_test.py +++ b/tests/sequence_test.py @@ -53,14 +53,8 @@ def test_pad_sequences_str(): dtype='" - "\nYou should set `dtype=object`" - " for variable length strings.",) def test_pad_sequences_vector(): From bfb72976cc923623b4ffbb8a65bb02b5352dd059 Mon Sep 17 00:00:00 2001 From: Justin Harris Date: Thu, 23 Aug 2018 20:08:33 -0400 Subject: [PATCH 5/5] test/pad_sequences: Add missing : --- tests/sequence_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/sequence_test.py b/tests/sequence_test.py index 9871064f..810cd5bb 100644 --- a/tests/sequence_test.py +++ b/tests/sequence_test.py @@ -53,7 +53,7 @@ def test_pad_sequences_str(): dtype='