This post explains how to use tf.data.Dataset.repeat()
method in TensorFlow.
repeat()
method of tf.data.Dataset
class is used for repeating the tensors for a given count
times in dataset. If repeat(count=None)
or repeat(count=-1)
is specified than dataset is repeated indefinitely.
tf.data.Dataset.repeat()
in TensorFlow.TensorSliceDataset
object
import tensorflow as tf
print(tf.__version__)
# Create Tensor
tensor1 = tf.range(5)
#Create dataset, this will return object of TensorSliceDataset
dataset = tf.data.Dataset.from_tensor_slices(tensor1)
print(type(dataset))
print(dataset)
for i in dataset:
print(i)
<class 'tensorflow.python.data.ops.dataset_ops.TensorSliceDataset'>
<TensorSliceDataset shapes: (), types: tf.int32>
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)
tf.data.Dataset.repeat()
with count=3
There are 5 tensors in dataset object,
by using repeat with count=3
on dataset, dataset would be repeated 3 times so each original value
would be appearing 3 times in output.
dataset = dataset.repeat(count=3)
for i in dataset:
print(i)
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)
tf.Tensor(0, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)
tf.data.Dataset.repeat()
with count=None
or count=-1
dataset = dataset.repeat(count=None)
print(dataset)
for i in dataset:
print(i)
dataset = dataset.repeat(count=-1)
print(dataset)
for i in dataset:
print(i)
Similar Articles