take()
method of tf.data.Dataset
used
for limiting number of items in dataset. This code snippet is using TensorFlow2.0,
if you are using earlier versions of TensorFlow than enable eager execution to
run the code.
Lets have a look to below snippet for understanding take()
method.
Create dataset with tf.data.Dataset.from_tensor_slices
import tensorflow as tf
# Create Tensor
tensor1 = tf.range(5)
#Create dataset, this will return object of TensorSliceDataset
dataset = tf.data.Dataset.from_tensor_slices(tensor1)
Apply batch
and repeat
on dataset
print("dataset after applying batch and repeat")
dataset = dataset.repeat(6).batch(batch_size=2)
for i in dataset:
print(i)
Example Output:
dataset after applying batch and repeat
tf.Tensor([0 1], shape=(2,), dtype=int32)
tf.Tensor([2 3], shape=(2,), dtype=int32)
tf.Tensor([4 0], shape=(2,), dtype=int32)
tf.Tensor([1 2], shape=(2,), dtype=int32)
tf.Tensor([3 4], shape=(2,), dtype=int32)
tf.Tensor([0 1], shape=(2,), dtype=int32)
tf.Tensor([2 3], shape=(2,), dtype=int32)
tf.Tensor([4 0], shape=(2,), dtype=int32)
tf.Tensor([1 2], shape=(2,), dtype=int32)
tf.Tensor([3 4], shape=(2,), dtype=int32)
tf.Tensor([0 1], shape=(2,), dtype=int32)
tf.Tensor([2 3], shape=(2,), dtype=int32)
tf.Tensor([4 0], shape=(2,), dtype=int32)
tf.Tensor([1 2], shape=(2,), dtype=int32)
tf.Tensor([3 4], shape=(2,), dtype=int32)
Apply take()
on dataset to select few examples from dataset
print("dataset after applying take() method")
for i in dataset.take(3):
print(i)
Example Output:
dataset after applying take() method
tf.Tensor([0 1], shape=(2,), dtype=int32)
tf.Tensor([2 3], shape=(2,), dtype=int32)
tf.Tensor([4 0], shape=(2,), dtype=int32)
Similar Articles