How to squeeze and unsqueeze a tensor in PyTorch?

In PyTorch, you can modify tensor dimensions using torch.squeeze() and torch.unsqueeze() methods. The squeeze operation removes dimensions of size 1, while unsqueeze adds new dimensions of size 1 at specified positions.

Understanding Squeeze Operation

The torch.squeeze() method removes all dimensions of size 1 from a tensor. For example, if a tensor has shape (2 × 1 × 3 × 1), squeezing will result in shape (2 × 3).

Example

import torch

# Create a tensor with dimensions of size 1
tensor = torch.ones(2, 1, 2, 1)
print("Original tensor shape:", tensor.shape)
print("Original tensor:\n", tensor)

# Squeeze the tensor (removes all size-1 dimensions)
squeezed = torch.squeeze(tensor)
print("\nSqueezed tensor shape:", squeezed.shape)
print("Squeezed tensor:\n", squeezed)
Original tensor shape: torch.Size([2, 1, 2, 1])
Original tensor:
 tensor([[[[1.],
          [1.]]],


        [[[1.],
          [1.]]]])

Squeezed tensor shape: torch.Size([2, 2])
Squeezed tensor:
 tensor([[1., 1.],
        [1., 1.]])

Understanding Unsqueeze Operation

The torch.unsqueeze() method adds a new dimension of size 1 at the specified position. The dim parameter determines where the new dimension is inserted.

Example

import torch

# Create a 1D tensor
tensor = torch.tensor([1, 2, 3, 4])
print("Original tensor shape:", tensor.shape)
print("Original tensor:", tensor)

# Unsqueeze at dimension 0 (adds dimension at the beginning)
unsqueezed_0 = torch.unsqueeze(tensor, dim=0)
print("\nUnsqueezed at dim=0 shape:", unsqueezed_0.shape)
print("Unsqueezed at dim=0:\n", unsqueezed_0)

# Unsqueeze at dimension 1 (adds dimension at the end)
unsqueezed_1 = torch.unsqueeze(tensor, dim=1)
print("\nUnsqueezed at dim=1 shape:", unsqueezed_1.shape)
print("Unsqueezed at dim=1:\n", unsqueezed_1)
Original tensor shape: torch.Size([4])
Original tensor: tensor([1, 2, 3, 4])

Unsqueezed at dim=0 shape: torch.Size([1, 4])
Unsqueezed at dim=0:
 tensor([[1, 2, 3, 4]])

Unsqueezed at dim=1 shape: torch.Size([4, 1])
Unsqueezed at dim=1:
 tensor([[1],
        [2],
        [3],
        [4]])

Practical Use Cases

These operations are commonly used in deep learning for reshaping tensors to match expected input dimensions for neural network layers or mathematical operations ?

import torch

# Example: Preparing data for batch processing
data = torch.tensor([1.0, 2.0, 3.0])
print("Original data:", data.shape)

# Add batch dimension (common in neural networks)
batched = torch.unsqueeze(data, dim=0)
print("With batch dimension:", batched.shape)

# Add feature dimension
features = torch.unsqueeze(batched, dim=2)
print("With feature dimension:", features.shape)

# Remove unnecessary dimensions
cleaned = torch.squeeze(features)
print("After squeezing:", cleaned.shape)
Original data: torch.Size([3])
With batch dimension: torch.Size([1, 3])
With feature dimension: torch.Size([1, 3, 1])
After squeezing: torch.Size([3])

Key Parameters

  • torch.squeeze(input, dim=None): If dim is specified, only that dimension is squeezed (if it has size 1)
  • torch.unsqueeze(input, dim): The dim parameter is required and specifies where to insert the new dimension

Conclusion

Use torch.squeeze() to remove size-1 dimensions and torch.unsqueeze() to add new dimensions. These operations are essential for tensor manipulation in PyTorch, especially when preparing data for neural networks or mathematical operations.

Updated on: 2026-03-26T18:44:17+05:30

5K+ Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements