Understanding tf.stack()
The tf.stack()
operation combines tensors along a new
dimension, increasing rank by 1. Key features:
- Requires identical shapes for input tensors
- Creates new dimension specified by
axis
parameter - Essential for batch operations and neural network inputs
Table of Contents
Creating Base Tensors
Generate random tensors and verify their properties:
import tensorflow as tf
# Create three 2x2 random tensors
tensor1 = tf.random.uniform(shape=[2,2], minval=3, maxval=5)
tensor2 = tf.random.uniform(shape=[2,2], minval=3, maxval=5)
tensor3 = tf.random.uniform(shape=[2,2], minval=3, maxval=5)
# Verify tensor properties
print(f"Tensor 1:\n{tensor1}\nRank: {tf.rank(tensor1)}\n")
print(f"Tensor 2:\n{tensor2}\nRank: {tf.rank(tensor2)}\n")
print(f"Tensor 3:\n{tensor3}\nRank: {tf.rank(tensor3)}")
Sample output shows 2D tensors (rank 2) with shape (2,2)
Stacking Tensors with tf.stack()
Combine tensors along new dimension (axis=0):
# Stack tensors along first dimension
stacked_tensor = tf.stack(values=[tensor1, tensor2, tensor3], axis=0)
print("Stacked Tensor:")
print(stacked_tensor)
print(f"\nStacked Tensor Rank: {tf.rank(stacked_tensor)}")
print(f"Stacked Tensor Shape: {stacked_tensor.shape}")
Resulting tensor has rank 3 with shape (3,2,2)
Axis Parameter Variations
Stacking Along Different Axes
# Stack along last dimension (axis=-1)
stacked_axis_minus1 = tf.stack([tensor1, tensor2], axis=-1)
Shape becomes (2,2,2)
Visual Representation
Original Tensors (2x2) β Stacked Along Axis:
- 0: (3,2,2)
- 1: (2,3,2)
- 2: (2,2,3)
Best Practices for tf.stack()
- Ensure identical shapes for all input tensors
- Use
axis=-1
for channel-wise stacking -
Combine with
tf.unstack()
for reversible operations -
Prefer
tf.stack()
overtf.concat()
when creating new dimensions
Pro Tip: Real-World Application
Use tf.stack()
to create batch dimensions:
# Convert list of single images to batch format
image_batch = tf.stack([img1, img2, img3], axis=0)