TensorFlow | using tf.data.Dataset.batch() method

This code snippet is using TensorFlow2.0, if you are using earlier versions of TensorFlow than enable eager execution to run the code.

batch() method of tf.data.Dataset class used for combining consecutive elements of dataset into batches.In below example we look into the use of batch first without using repeat() method and than with using repeat() method.

Using batch() method without repeat()
1. Create dataset

import tensorflow as tf

print(tf.__version__)

# Create Tensor
tensor1 = tf.range(5)

#print(dir(tf.data.Dataset))
#Create dataset, this will return object of TensorSliceDataset
dataset = tf.data.Dataset.from_tensor_slices(tensor1)
print(dataset)
print("Original dataset")
for i in dataset:    
    print(i)

   
======= Output ======
2.0.0
<TensorSliceDataset shapes: (), types: tf.int32>
Original dataset
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)
 

2. Apply batch() on dataset, notice change in shape of tensor after applying batch() method

dataset = dataset.batch(batch_size=2)
print("dataset after applying batch method")
for i in dataset:
    print(i)   
 

Output
   
dataset after applying batch method
tf.Tensor([0 1], shape=(2,), dtype=int32)
tf.Tensor([2 3], shape=(2,), dtype=int32)
tf.Tensor([4], shape=(1,), dtype=int32)
 

Using batch() with repeat() method on dataset.

1. Create dataset

import tensorflow as tf
print(tf.__version__)

# Create Tensor
tensor1 = tf.range(5)

#print(dir(tf.data.Dataset))
#Create dataset, this will return object of TensorSliceDataset
dataset = tf.data.Dataset.from_tensor_slices(tensor1)
print(dataset)
print("Original dataset")
for i in dataset:    
    print(i)
  

Output
  
   
2.0.0
<TensorSliceDataset shapes: (), types: tf.int32>
Original dataset
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)
 

2. batch() method with repeat()
   
#Using batch method with repeat
dataset = dataset.repeat(3).batch(batch_size=2)
print("dataset after applying batch method with repeat()")
for i in dataset:
    print(i)
 

Output
   
dataset after applying batch method with repeat()
tf.Tensor([0 1], shape=(2,), dtype=int32)
tf.Tensor([2 3], shape=(2,), dtype=int32)
tf.Tensor([4 0], shape=(2,), dtype=int32)
tf.Tensor([1 2], shape=(2,), dtype=int32)
tf.Tensor([3 4], shape=(2,), dtype=int32)
tf.Tensor([0 1], shape=(2,), dtype=int32)
tf.Tensor([2 3], shape=(2,), dtype=int32)
tf.Tensor([4], shape=(1,), dtype=int32)
 


Follow US on Twitter: