TensorFlow | How to use tf.data.Dataset.map() function in TensorFlow

TensorFlow map() method of tf.data.Dataset used for transforming items in a dataset, refer below snippet for map() use.

This code snippet is using TensorFlow2.0, if you are using earlier versions of TensorFlow than enable execution to run the code.

Create dataset with tf.data.Dataset.from_tensor_slices

  
  import tensorflow as tf
  
  print(tf.__version__)
  
  # Create Tensor
  tensor1 = tf.range(5)
  
  #print(dir(tf.data.Dataset))
  #Create dataset, this will return object of TensorSliceDataset
  dataset = tf.data.Dataset.from_tensor_slices(tensor1)
  print(dataset)
  print("Original dataset")
  for i in dataset:
      print(i)
  
  

Example Output:

  
  2.0.0
  <TensorSliceDataset shapes: (), types: tf.int32>
  Original dataset
  
  tf.Tensor(0, shape=(), dtype=int32)
  tf.Tensor(1, shape=(), dtype=int32)
  tf.Tensor(2, shape=(), dtype=int32)
  tf.Tensor(3, shape=(), dtype=int32)
  tf.Tensor(4, shape=(), dtype=int32)
  
  

Transform dataset items using TensorFlow map() function

  
  # Transforming dataset items using map()
  print("dataset after applying map function")
  dataset = dataset.map(lambda x : x*x*x)
  for i in dataset:
      print(i)
  
  
  

Example Output: after applying map() function

  
  dataset after applying map function
  
  tf.Tensor(0, shape=(), dtype=int32)
  tf.Tensor(1, shape=(), dtype=int32)
  tf.Tensor(8, shape=(), dtype=int32)
  tf.Tensor(27, shape=(), dtype=int32)
  tf.Tensor(64, shape=(), dtype=int32)
  
  

Normalizing images in the dataset with TensorFlow map() function

Download cifar10 dataset with TensorFlow datasets with below code snippet

  
  import tensorflow as tf
  import tensorflow_datasets as tfds
  import matplotlib.pyplot as plt
  ds, dsinfo = tfds.load('cifar10', split='train', as_supervised=True, with_info=True)
  
  

Let's analyze the pixel values in a sample image from the dataset

  
  for i in ds:
    print(i)
    break
  
  
  
  (tf.Tensor: shape=(32, 32, 3), dtype=uint8, numpy=
  array([[[143,  96,  70],
          [141,  96,  72],
          [135,  93,  72],
          ...,
          [ 96,  37,  19],
          [105,  42,  18],
          [104,  38,  20]],
  
         [[128,  98,  92],
          [146, 118, 112],
          [170, 145, 138],
      .....
      ....
  
  

From the above output we can see image in de-normalized from and pixel values are in range of 0 to 255. Lets normalize the images in dataset using map() method, below are the two steps for this process.

  • Create a function to normalize the image
  •   
      def normalize_image(image, label):
        return tf.cast(image, tf.float32) / 255., label
      
      

  • Apply the normalize_image function to the dataset using map() method
  •   
      ds = ds.map(normalize_image)
      
      

    Let's analyze the pixel values in a sample image from the dataset after applying map() method

      
      for i in ds:
        print(i)
        break
      
      
      
      (tf.Tensor: shape=(32, 32, 3), dtype=float32, numpy=
      array([[[0.56078434, 0.3764706 , 0.27450982],
              [0.5529412 , 0.3764706 , 0.28235295],
              [0.5294118 , 0.3647059 , 0.28235295],
              ...,
              [0.3764706 , 0.14509805, 0.07450981],
              [0.4117647 , 0.16470589, 0.07058824],
              [0.40784314, 0.14901961, 0.07843138]],
      
             [[0.5019608 , 0.38431373, 0.36078432],
              [0.57254905, 0.4627451 , 0.4392157 ],
              [0.6666667 , 0.5686275 , 0.5411765 ],
              ...,
      
      

    From the above output we can see image in normalized from and pixel values are in range of 0 to 1 after applying normalize_image function to the dataset using map() method.


    Follow US on Twitter:

    Similar Articles