flat_map
method of
tf.data.Dataset
flattens the dataset and maps the function given in method argument across the dataset.
Function provided in argument must return a dataset object.
Lets understand working of flat_map
with an example.
tf.data.Dataset.from_tensor_slices
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices([[[1,2, 3], [3,4,5]]])
for i in dataset:
print(i)
print(i.shape)
tf.Tensor(
[[1 2 3]
[3 4 5]], shape=(2, 3), dtype=int32)
(2, 3)
flat_map
method on datasetIn the below code snippet map_func = lambda x : tf.data.Dataset.from_tensor_slices(x**2)
is being passed to dataset._flat_map, that will covert every dataset item to its square.
Note that map_func must needs to return dataset object ,
if we use map_func = lambda x : x**2
than it will throw an error as it is not returning dataset object.
dataset = dataset.flat_map(lambda x : tf.data.Dataset.from_tensor_slices(x**2))
for i in dataset:
print(i)
tf.Tensor([1 4 9], shape=(3,), dtype=int32)
tf.Tensor([ 9 16 25], shape=(3,), dtype=int32)
Similar Articles