You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
import matplotlib.pylab as plt
import numpy as np
import tensorflow as tf
IMG_HEIGHT = 224
IMG_WIDTH = 224
IMG_CHANNELS = 3
CLASS_NAMES = 'daisy dandelion roses sunflowers tulips'.split()
class _Preprocessor:
def __init__(self):
self.preproc_layers = tf.keras.Sequential([
tf.keras.layers.Lambda(
lambda img: tf.image.resize_with_pad(
img, 2*IMG_HEIGHT, 2*IMG_WIDTH),
input_shape=(None, None, 3)),
tf.keras.layers.experimental.preprocessing.CenterCrop(
height=IMG_HEIGHT, width=IMG_WIDTH)
])
def read_from_tfr(self, proto):
feature_description = {
'image': tf.io.VarLenFeature(tf.float32),
'shape': tf.io.VarLenFeature(tf.int64),
'label': tf.io.FixedLenFeature([], tf.string, default_value=''),
'label_int': tf.io.FixedLenFeature([], tf.int64, default_value=0),
}
rec = tf.io.parse_single_example(
proto, feature_description
)
shape = tf.sparse.to_dense(rec['shape'])
img = tf.reshape(tf.sparse.to_dense(rec['image']), shape)
label_int = rec['label_int']
return img, label_int
def read_from_jpegfile(self, filename):
# same code as in 05_create_dataset/jpeg_to_tfrecord.py
img = tf.io.read_file(filename)
img = tf.image.decode_jpeg(img, channels=IMG_CHANNELS)
img = tf.image.convert_image_dtype(img, tf.float32)
return img
def preprocess(self, img):
# add to a batch, call preproc, remove from batch
x = tf.expand_dims(img, 0)
x = self.preproc_layers(x)
x = tf.squeeze(x, 0)
return x
def create_preproc_dataset(pattern):
preproc = _Preprocessor()
trainds = tf.data.TFRecordDataset(
[filename for filename in tf.io.gfile.glob(pattern)],
compression_type='GZIP'
).map(preproc.read_from_tfr).map(
lambda img, label: (preproc.preprocess(img), label))
return trainds
def create_preproc_image(filename):
preproc = _Preprocessor()
img = preproc.read_from_jpegfile(filename)
return preproc.preprocess(img)
# preprocessing records read from a TensorFlow dataset
trainds = create_preproc_dataset('gs://practical-ml-vision-book/flowers_tfr/train-*')
f, ax = plt.subplots(1, 5, figsize=(15,15))
for idx, (img, label) in enumerate(trainds.take(5)):
ax[idx].imshow((img.numpy()));
ERROR:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
[<ipython-input-72-a567f27bc3c2>](https://localhost:8080/#) in <cell line: 2>()
1 # preprocessing records read from a TensorFlow dataset
----> 2 trainds = create_preproc_dataset('gs://practical-ml-vision-book/flowers_tfr/train-*')
3 f, ax = plt.subplots(1, 5, figsize=(15,15))
4 for idx, (img, label) in enumerate(trainds.take(5)):
5 ax[idx].imshow((img.numpy()));
24 frames
[/usr/local/lib/python3.10/dist-packages/tensorflow/python/framework/ops.py](https://localhost:8080/#) in _create_c_op(graph, node_def, inputs, control_inputs, op_def, extract_traceback)
1971 except errors.InvalidArgumentError as e:
1972 # Convert to ValueError for backwards compatibility.
-> 1973 raise ValueError(e.message)
1974
1975 # Record the current Python stack trace as the creating stacktrace of this
ValueError: in user code:
File "<ipython-input-71-ac9c68a069e7>", line 56, in None *
lambda img, label: (preproc.preprocess(img), label)
File "<ipython-input-71-ac9c68a069e7>", line 47, in preprocess *
x = tf.squeeze(x, 0)
ValueError: Can not squeeze dim[0], expected a dimension of 1, got 224 for '{{node Squeeze}} = Squeeze[T=DT_FLOAT, squeeze_dims=[0]](sequential_24/center_crop_16/cond/Identity)' with input shapes: [224,224,?].
The text was updated successfully, but these errors were encountered:
code:
The text was updated successfully, but these errors were encountered: