This code snippet is using TensorFlow2.0
, if you are using earlier versions of TensorFlow
than enable eager execution to
run the code.
batch()
method of tf.data.Dataset
class used for combining consecutive elements of dataset into batches.In below example we look
into the use of batch first without using repeat()
method and than with using repeat()
method.
batch()
method without repeat()
import tensorflow as tf
print(tf.__version__)
# Create Tensor
tensor1 = tf.range(5)
#print(dir(tf.data.Dataset))
#Create dataset, this will return object of TensorSliceDataset
dataset = tf.data.Dataset.from_tensor_slices(tensor1)
print(dataset)
print("Original dataset")
for i in dataset:
print(i)
======= Output ======
2.0.0
<TensorSliceDataset shapes: (), types: tf.int32>
Original dataset
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)
batch()
on dataset, notice change in shape of tensor after applying batch()
method
dataset = dataset.batch(batch_size=2)
print("dataset after applying batch method")
for i in dataset:
print(i)
dataset after applying batch method
tf.Tensor([0 1], shape=(2,), dtype=int32)
tf.Tensor([2 3], shape=(2,), dtype=int32)
tf.Tensor([4], shape=(1,), dtype=int32)
batch()
with repeat()
method on dataset.
import tensorflow as tf
print(tf.__version__)
# Create Tensor
tensor1 = tf.range(5)
#print(dir(tf.data.Dataset))
#Create dataset, this will return object of TensorSliceDataset
dataset = tf.data.Dataset.from_tensor_slices(tensor1)
print(dataset)
print("Original dataset")
for i in dataset:
print(i)
2.0.0
<TensorSliceDataset shapes: (), types: tf.int32>
Original dataset
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)
batch()
method with repeat()
#Using batch method with repeat
dataset = dataset.repeat(3).batch(batch_size=2)
print("dataset after applying batch method with repeat()")
for i in dataset:
print(i)
dataset after applying batch method with repeat()
tf.Tensor([0 1], shape=(2,), dtype=int32)
tf.Tensor([2 3], shape=(2,), dtype=int32)
tf.Tensor([4 0], shape=(2,), dtype=int32)
tf.Tensor([1 2], shape=(2,), dtype=int32)
tf.Tensor([3 4], shape=(2,), dtype=int32)
tf.Tensor([0 1], shape=(2,), dtype=int32)
tf.Tensor([2 3], shape=(2,), dtype=int32)
tf.Tensor([4], shape=(1,), dtype=int32)
Similar Articles