TensorFlow | How to use dataset api take() method in TensorFlow

take() method of tf.data.Dataset used for limiting number of items in dataset. This code snippet is using TensorFlow2.0, if you are using earlier versions of TensorFlow than enable eager execution to run the code.

Lets have a look to below snippet for understanding take() method.

Create dataset with tf.data.Dataset.from_tensor_slices

  
import tensorflow as tf
# Create Tensor
tensor1 = tf.range(5)

#Create dataset, this will return object of TensorSliceDataset
dataset = tf.data.Dataset.from_tensor_slices(tensor1)
  
  
  

Apply batch and repeat on dataset

  
print("dataset after applying batch and repeat")
dataset = dataset.repeat(6).batch(batch_size=2)
for i in dataset:
    print(i)
  
  

Example Output:

  
dataset after applying batch and repeat

tf.Tensor([0 1], shape=(2,), dtype=int32)
tf.Tensor([2 3], shape=(2,), dtype=int32)
tf.Tensor([4 0], shape=(2,), dtype=int32)
tf.Tensor([1 2], shape=(2,), dtype=int32)
tf.Tensor([3 4], shape=(2,), dtype=int32)
tf.Tensor([0 1], shape=(2,), dtype=int32)
tf.Tensor([2 3], shape=(2,), dtype=int32)
tf.Tensor([4 0], shape=(2,), dtype=int32)
tf.Tensor([1 2], shape=(2,), dtype=int32)
tf.Tensor([3 4], shape=(2,), dtype=int32)
tf.Tensor([0 1], shape=(2,), dtype=int32)
tf.Tensor([2 3], shape=(2,), dtype=int32)
tf.Tensor([4 0], shape=(2,), dtype=int32)
tf.Tensor([1 2], shape=(2,), dtype=int32)
tf.Tensor([3 4], shape=(2,), dtype=int32)
  
  
   

Apply take() on dataset to select few examples from dataset

  
  print("dataset after applying take() method")
  for i in dataset.take(3):
    print(i)
  
  
  

Example Output:

  
  dataset after applying take() method
  
  tf.Tensor([0 1], shape=(2,), dtype=int32)
  tf.Tensor([2 3], shape=(2,), dtype=int32)
  tf.Tensor([4 0], shape=(2,), dtype=int32)
  
  
  

Follow US on Twitter: