Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
Merge pull request #40 from keras-team/redesign
Browse files Browse the repository at this point in the history
Remove reliance on Keras submodule imports.
  • Loading branch information
fchollet authored Aug 21, 2018
2 parents b9d1424 + 3075487 commit dad7fcc
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 90 deletions.
8 changes: 1 addition & 7 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,12 @@ matrix:
include:
- python: 2.7
env: TEST_MODE=PEP8
- python: 2.7
env: TEST_MODE=INTEGRATION_TESTS
- python: 2.7
env: KERAS_HEAD=true
- python: 3.6
env: KERAS_HEAD=true
- python: 2.7
- python: 3.6
- python: 3.6
env: TEST_MODE=INTEGRATION_TESTS
install:
# code below is taken from http://conda.pydata.org/docs/travis.html
# We do this conditionally because it saves us some downloading if the
Expand Down Expand Up @@ -57,8 +53,6 @@ install:
script:
- if [[ "$TEST_MODE" == "PEP8" ]]; then
PYTHONPATH=$PWD:$PYTHONPATH py.test --pep8 -m pep8 -n0;
elif [[ "$TEST_MODE" == "INTEGRATION_TESTS" ]]; then
PYTHONPATH=$PWD:$PYTHONPATH py.test tests/integration_test.py;
else
PYTHONPATH=$PWD:$PYTHONPATH py.test tests/ --cov-config .coveragerc --cov=keras_preprocessing tests/ --ignore=tests/integration_test.py;
PYTHONPATH=$PWD:$PYTHONPATH py.test tests/ --cov-config .coveragerc --cov=keras_preprocessing tests/;
fi
4 changes: 4 additions & 0 deletions keras_preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@


def set_keras_submodules(backend, utils):
# Deprecated, will be removed in the future.
global _KERAS_BACKEND
global _KERAS_UTILS
_KERAS_BACKEND = backend
_KERAS_UTILS = utils


def get_keras_submodule(name):
# Deprecated, will be removed in the future.
if name not in {'backend', 'utils'}:
raise ImportError(
'Can only retrieve "backend" and "utils". '
Expand All @@ -36,3 +38,5 @@ def get_keras_submodule(name):
return _KERAS_BACKEND
elif name == 'utils':
return _KERAS_UTILS

__version__ = '1.0.3'
74 changes: 35 additions & 39 deletions keras_preprocessing/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,6 @@
import multiprocessing.pool
from functools import partial

from . import get_keras_submodule

backend = get_keras_submodule('backend')
keras_utils = get_keras_submodule('utils')

try:
from PIL import ImageEnhance
from PIL import Image as pil_image
Expand Down Expand Up @@ -349,7 +344,7 @@ def flip_axis(x, axis):
return x


def array_to_img(x, data_format=None, scale=True):
def array_to_img(x, data_format='channels_last', scale=True, dtype='float32'):
"""Converts a 3D Numpy array to a PIL Image instance.
# Arguments
Expand All @@ -358,6 +353,7 @@ def array_to_img(x, data_format=None, scale=True):
either "channels_first" or "channels_last".
scale: Whether to rescale image values
to be within `[0, 255]`.
dtype: Dtype to use.
# Returns
A PIL Image instance.
Expand All @@ -369,13 +365,11 @@ def array_to_img(x, data_format=None, scale=True):
if pil_image is None:
raise ImportError('Could not import PIL.Image. '
'The use of `array_to_img` requires PIL.')
x = np.asarray(x, dtype=backend.floatx())
x = np.asarray(x, dtype=dtype)
if x.ndim != 3:
raise ValueError('Expected image array to have rank 3 (single image). '
'Got array with shape: %s' % (x.shape,))

if data_format is None:
data_format = backend.image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Invalid data_format: %s' % data_format)

Expand Down Expand Up @@ -403,28 +397,27 @@ def array_to_img(x, data_format=None, scale=True):
raise ValueError('Unsupported channel number: %s' % (x.shape[2],))


def img_to_array(img, data_format=None):
def img_to_array(img, data_format='channels_last', dtype='float32'):
"""Converts a PIL Image instance to a Numpy array.
# Arguments
img: PIL Image instance.
data_format: Image data format,
either "channels_first" or "channels_last".
dtype: Dtype to use for the returned array.
# Returns
A 3D Numpy array.
# Raises
ValueError: if invalid `img` or `data_format` is passed.
"""
if data_format is None:
data_format = backend.image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format: %s' % data_format)
# Numpy array x has format (height, width, channel)
# or (channel, height, width)
# but original PIL image has format (width, height, channel)
x = np.asarray(img, dtype=backend.floatx())
x = np.asarray(img, dtype=dtype)
if len(x.shape) == 3:
if data_format == 'channels_first':
x = x.transpose(2, 0, 1)
Expand All @@ -440,9 +433,10 @@ def img_to_array(img, data_format=None):

def save_img(path,
x,
data_format=None,
data_format='channels_last',
file_format=None,
scale=True, **kwargs):
scale=True,
**kwargs):
"""Saves an image stored as a Numpy array to a path or file object.
# Arguments
Expand Down Expand Up @@ -602,6 +596,7 @@ class ImageDataGenerator(object):
If you never set it, then it will be "channels_last".
validation_split: Float. Fraction of images reserved for validation
(strictly between 0 and 1).
dtype: Dtype to use for the generated arrays.
# Examples
Example of using `.flow(x, y)`:
Expand Down Expand Up @@ -728,10 +723,9 @@ def __init__(self,
vertical_flip=False,
rescale=None,
preprocessing_function=None,
data_format=None,
validation_split=0.0):
if data_format is None:
data_format = backend.image_data_format()
data_format='channels_last',
validation_split=0.0,
dtype='float32'):
self.featurewise_center = featurewise_center
self.samplewise_center = samplewise_center
self.featurewise_std_normalization = featurewise_std_normalization
Expand All @@ -751,6 +745,7 @@ def __init__(self,
self.vertical_flip = vertical_flip
self.rescale = rescale
self.preprocessing_function = preprocessing_function
self.dtype = dtype

if data_format not in {'channels_last', 'channels_first'}:
raise ValueError(
Expand Down Expand Up @@ -983,7 +978,7 @@ def standardize(self, x):
if self.samplewise_center:
x -= np.mean(x, keepdims=True)
if self.samplewise_std_normalization:
x /= (np.std(x, keepdims=True) + backend.epsilon())
x /= (np.std(x, keepdims=True) + 1e-6)

if self.featurewise_center:
if self.mean is not None:
Expand All @@ -995,7 +990,7 @@ def standardize(self, x):
'first by calling `.fit(numpy_data)`.')
if self.featurewise_std_normalization:
if self.std is not None:
x /= (self.std + backend.epsilon())
x /= (self.std + 1e-6)
else:
warnings.warn('This ImageDataGenerator specifies '
'`featurewise_std_normalization`, '
Expand Down Expand Up @@ -1202,7 +1197,7 @@ def fit(self, x,
this is how many augmentation passes over the data to use.
seed: Int (default: None). Random seed.
"""
x = np.asarray(x, dtype=backend.floatx())
x = np.asarray(x, dtype=self.dtype)
if x.ndim != 4:
raise ValueError('Input to `.fit()` should have rank 4. '
'Got array with shape: ' + str(x.shape))
Expand All @@ -1225,7 +1220,7 @@ def fit(self, x,
if augment:
ax = np.zeros(
tuple([rounds * x.shape[0]] + list(x.shape)[1:]),
dtype=backend.floatx())
dtype=self.dtype)
for r in range(rounds):
for i in range(x.shape[0]):
ax[i + r * x.shape[0]] = self.random_transform(x[i])
Expand All @@ -1243,7 +1238,7 @@ def fit(self, x,
broadcast_shape = [1, 1, 1]
broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis]
self.std = np.reshape(self.std, broadcast_shape)
x /= (self.std + backend.epsilon())
x /= (self.std + 1e-6)

if self.zca_whitening:
if scipy is None:
Expand All @@ -1257,7 +1252,7 @@ def fit(self, x,
self.principal_components = (u * s_inv).dot(u.T)


class Iterator(keras_utils.Sequence):
class Iterator(object):
"""Base class for image data iterators.
Every `Iterator` must implement the `_get_batches_of_transformed_samples`
Expand Down Expand Up @@ -1375,13 +1370,15 @@ class NumpyArrayIterator(Iterator):
(if `save_to_dir` is set).
subset: Subset of data (`"training"` or `"validation"`) if
validation_split is set in ImageDataGenerator.
dtype: Dtype to use for the generated arrays.
"""

def __init__(self, x, y, image_data_generator,
batch_size=32, shuffle=False, sample_weight=None,
seed=None, data_format=None,
seed=None, data_format='channels_last',
save_to_dir=None, save_prefix='', save_format='png',
subset=None):
subset=None, dtype='float32'):
self.dtype = dtype
if (type(x) is tuple) or (type(x) is list):
if type(x[1]) is not list:
x_misc = [np.asarray(x[1])]
Expand Down Expand Up @@ -1423,9 +1420,7 @@ def __init__(self, x, y, image_data_generator,
x_misc = [np.asarray(xx[split_idx:]) for xx in x_misc]
if y is not None:
y = y[split_idx:]
if data_format is None:
data_format = backend.image_data_format()
self.x = np.asarray(x, dtype=backend.floatx())
self.x = np.asarray(x, dtype=self.dtype)
self.x_misc = x_misc
if self.x.ndim != 4:
raise ValueError('Input data in `NumpyArrayIterator` '
Expand Down Expand Up @@ -1461,12 +1456,12 @@ def __init__(self, x, y, image_data_generator,

def _get_batches_of_transformed_samples(self, index_array):
batch_x = np.zeros(tuple([len(index_array)] + list(self.x.shape)[1:]),
dtype=backend.floatx())
dtype=self.dtype)
for i, j in enumerate(index_array):
x = self.x[j]
params = self.image_data_generator.get_random_transform(x.shape)
x = self.image_data_generator.apply_transform(
x.astype(backend.floatx()), params)
x.astype(self.dtype), params)
x = self.image_data_generator.standardize(x)
batch_x[i] = x

Expand Down Expand Up @@ -1654,19 +1649,19 @@ class DirectoryIterator(Iterator):
If PIL version 1.1.3 or newer is installed, "lanczos" is also
supported. If PIL version 3.4.0 or newer is installed, "box" and
"hamming" are also supported. By default, "nearest" is used.
dtype: Dtype to use for generated arrays.
"""

def __init__(self, directory, image_data_generator,
target_size=(256, 256), color_mode='rgb',
classes=None, class_mode='categorical',
batch_size=32, shuffle=True, seed=None,
data_format=None,
data_format='channels_last',
save_to_dir=None, save_prefix='', save_format='png',
follow_links=False,
subset=None,
interpolation='nearest'):
if data_format is None:
data_format = backend.image_data_format()
interpolation='nearest',
dtype='float32'):
self.directory = directory
self.image_data_generator = image_data_generator
self.target_size = tuple(target_size)
Expand Down Expand Up @@ -1702,6 +1697,7 @@ def __init__(self, directory, image_data_generator,
self.save_prefix = save_prefix
self.save_format = save_format
self.interpolation = interpolation
self.dtype = dtype

if subset is not None:
validation_split = self.image_data_generator._validation_split
Expand Down Expand Up @@ -1769,7 +1765,7 @@ def __init__(self, directory, image_data_generator,
def _get_batches_of_transformed_samples(self, index_array):
batch_x = np.zeros(
(len(index_array),) + self.image_shape,
dtype=backend.floatx())
dtype=self.dtype)
# build batch of image data
for i, j in enumerate(index_array):
fname = self.filenames[j]
Expand Down Expand Up @@ -1802,11 +1798,11 @@ def _get_batches_of_transformed_samples(self, index_array):
elif self.class_mode == 'sparse':
batch_y = self.classes[index_array]
elif self.class_mode == 'binary':
batch_y = self.classes[index_array].astype(backend.floatx())
batch_y = self.classes[index_array].astype(self.dtype)
elif self.class_mode == 'categorical':
batch_y = np.zeros(
(len(batch_x), self.num_classes),
dtype=backend.floatx())
dtype=self.dtype)
for i, label in enumerate(self.classes[index_array]):
batch_y[i, label] = 1.
else:
Expand Down
9 changes: 3 additions & 6 deletions keras_preprocessing/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@
import json
from six.moves import range

from . import get_keras_submodule

keras_utils = get_keras_submodule('utils')


def pad_sequences(sequences, maxlen=None, dtype='int32',
padding='pre', truncating='pre', value=0.):
Expand Down Expand Up @@ -251,7 +247,7 @@ def _remove_long_seq(maxlen, seq, label):
return new_seq, new_label


class TimeseriesGenerator(keras_utils.Sequence):
class TimeseriesGenerator(object):
"""Utility class for generating batches of temporal data.
This class takes in a sequence of data-points gathered at
Expand Down Expand Up @@ -325,7 +321,8 @@ def __init__(self, data, targets, length,

if len(data) != len(targets):
raise ValueError('Data and targets have to be' +
' of same length. Data length is {}'.format(len(data)) +
' of same length. '
'Data length is {}'.format(len(data)) +
' while target length is {}'.format(len(targets)))

self.data = data
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@
'''

setup(name='Keras_Preprocessing',
version='1.0.2',
version='1.0.3',
description='Easy data preprocessing and data augmentation '
'for deep learning models',
long_description=long_description,
author='Keras Team',
url='https://github.com/keras-team/keras-preprocessing',
download_url='https://github.com/keras-team/'
'keras-preprocessing/tarball/1.0.2',
'keras-preprocessing/tarball/1.0.3',
license='MIT',
install_requires=['keras>=2.1.6',
'numpy>=1.9.1',
Expand Down
Loading

0 comments on commit dad7fcc

Please sign in to comment.