Article Categories
- All Categories
-
Data Structure
-
Networking
-
RDBMS
-
Operating System
-
Java
-
MS Excel
-
iOS
-
HTML
-
CSS
-
Android
-
Python
-
C Programming
-
C++
-
C#
-
MongoDB
-
MySQL
-
Javascript
-
PHP
-
Economics & Finance
How to find mean across the image channels in PyTorch?
RGB images have three channels: Red, Green, and Blue. Computing the mean of pixel values across these channels is a common preprocessing step in computer vision. In PyTorch, we use torch.mean() on image tensors with dim=[1,2] to calculate channel-wise means.
Understanding Image Tensors
PyTorch image tensors have shape [C, H, W] where C is channels, H is height, and W is width. Setting dim=[1,2] computes the mean across height and width dimensions, leaving us with three values (one per channel).
Method 1: Using PIL and torch.mean()
This approach reads images using PIL and applies torch.mean() directly ?
import torch
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
# Create a sample RGB image (3x4x4 pixels for demonstration)
sample_image = np.random.randint(0, 255, (4, 4, 3), dtype=np.uint8)
img = Image.fromarray(sample_image)
# Define transform to convert PIL image to PyTorch Tensor
transform = transforms.ToTensor()
# Convert image to PyTorch Tensor
imgTensor = transform(img)
print("Shape of Image Tensor:", imgTensor.shape)
# Compute mean across height and width dimensions (dim=[1,2])
R_mean, G_mean, B_mean = torch.mean(imgTensor, dim=[1,2])
print("Mean across Red channel:", R_mean)
print("Mean across Green channel:", G_mean)
print("Mean across Blue channel:", B_mean)
Shape of Image Tensor: torch.Size([3, 4, 4]) Mean across Red channel: tensor(0.4902) Mean across Green channel: tensor(0.5059) Mean across Blue channel: tensor(0.4824)
Method 2: Using Tensor.mean() Method
We can also use the tensor's built-in mean() method for the same calculation ?
import torch
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
# Create another sample image
sample_image = np.random.randint(0, 255, (4, 4, 3), dtype=np.uint8)
img = Image.fromarray(sample_image)
# Convert to tensor
transform = transforms.ToTensor()
imgTensor = transform(img)
print("Shape of Image Tensor:", imgTensor.shape)
# Alternative way using tensor.mean() method
R_mean, G_mean, B_mean = imgTensor.mean(dim=[1,2])
print("Mean across Red channel:", R_mean)
print("Mean across Green channel:", G_mean)
print("Mean across Blue channel:", B_mean)
Shape of Image Tensor: torch.Size([3, 4, 4]) Mean across Red channel: tensor(0.5137) Mean across Green channel: tensor(0.4941) Mean across Blue channel: tensor(0.4706)
Practical Example with Normalization
Channel means are often used for dataset normalization in deep learning ?
import torch
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
# Create a batch of sample images (simulating multiple images)
batch_images = []
for i in range(3):
sample_image = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)
img = Image.fromarray(sample_image)
transform = transforms.ToTensor()
tensor = transform(img)
batch_images.append(tensor)
# Stack tensors to create a batch
batch_tensor = torch.stack(batch_images)
print("Batch shape:", batch_tensor.shape)
# Compute mean across all images in batch
# dim=[0,2,3] means across batch, height, width
overall_means = torch.mean(batch_tensor, dim=[0,2,3])
print("Dataset channel means:", overall_means)
# These values can be used for normalization
print("Normalization transform:")
print(f"transforms.Normalize(mean={overall_means.tolist()}, std=[0.5, 0.5, 0.5])")
Batch shape: torch.Size([3, 3, 32, 32]) Dataset channel means: tensor([0.5020, 0.4959, 0.5020]) Normalization transform: transforms.Normalize(mean=[0.5019608139991760, 0.4958823919296265, 0.5019608139991760], std=[0.5, 0.5, 0.5])
Comparison
| Method | Syntax | Best For |
|---|---|---|
torch.mean() |
torch.mean(tensor, dim=[1,2]) |
Explicit function calls |
tensor.mean() |
tensor.mean(dim=[1,2]) |
Method chaining, cleaner code |
Conclusion
Use torch.mean(tensor, dim=[1,2]) or tensor.mean(dim=[1,2]) to compute channel-wise means in PyTorch. These values are essential for image normalization and understanding dataset statistics in computer vision workflows.
