tf.keras.backend.batch_flatten
method in TensorFlow flattens the each data samples of a batch.
If batch_flatten
is applied on a Tensor having dimension like 3D,4D,5D or ND it always turn that tensor to 2D.
0th dimension would remain same in both input tensor and output tensor.
Lets see with below example.
Example 1
Create a 4D tensor with tf.ones
import tensorflow as tf
t1_batch = tf.ones((2,3,2,2))
print(t1_batch)
tf.Tensor(
[[[[1. 1.]
[1. 1.]]
[[1. 1.]
[1. 1.]]
[[1. 1.]
[1. 1.]]]
[[[1. 1.]
[1. 1.]]
[[1. 1.]
[1. 1.]]
[[1. 1.]
[1. 1.]]]], shape=(2, 3, 2, 2), dtype=float32)
Apply batch_flatten
to convert t1_batch
to 2D tensor
t1_batch_flatten = tf.keras.backend.batch_flatten(t1_batch)
print(t1_batch_flatten)
tf.Tensor: shape=(2, 12), dtype=float32, numpy=
array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)
Example 2
Create a 5D tensor with tf.ones
import tensorflow as tf
t2_batch = tf.ones((3,4,2,2,4))
Apply batch_flatten
to convert t2_batch
to 2D tensor
t2_batch_flatten = tf.keras.backend.batch_flatten(t2_batch)
print(t2_batch_flatten)
tf.Tensor: shape=(3, 64), dtype=float32, numpy=
array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]],
dtype=float32)
Similar Articles