How to read a JPEG or PNG image in PyTorch?


Reading the images is a very important part in image processing or computer vision related tasks. The torchvision.io package provides functions to perform different IO operations. To read an image, torchvision.io package provides the image_read() function. This function reads JPEG and PNG images. It returns a 3D RGB or Grayscale Tensor.

The three dimensions of the tensor correspond to [C,H,W]. C is the number of channels, W and H are the width and height of the image, respectively.

For RGB, the number of channels is 3. So, the output of the read image is a tensor of [3,H,W]. The values of the output tensor are in the range [0,255].

Syntax

torchvision.io.read_image(path)

Parameters

  • path - The input JPEG or PNG image path.

Output

It returns a torch tensor of size [image_channels, image_height, image_width]).

Steps

You can use the following steps to read and visualize a JPEG or PNG image in PyTorch.

  • Import the required libraries. In all the following examples, the required Python libraries are torch and torchvision. Make sure you have already installed them.

import torch
import torchvision
from torchvision.io import read_image
import torchvision.transforms as T
  • Read a JPEG or PNG image using image_read() function. Specify the full image path with image types (.jpg or .png). The output of this function is a torch tensor of size [image_channels, image_height, image_width].

img = read_image('butterfly.jpg')
  • Optionally, compute the different image properties, i.e., image type, image size, etc.

  • To display the image, first we convert the image tensor to a PIL image and then display the image.

img = T.ToPILImage()(img)
img.show()

Input Images

We will use these images as the input files in the following examples.

Example 1

Here's the full Python code to read a JPEG image using PyTorch.

# Import the required libraries
import torch
import torchvision
from torchvision.io import read_image
import torchvision.transforms as T

# read a JPEG image
img = read_image('butterfly.jpg')

# display the image properties
print("Image data:
", img) # check if input image is a PyTorch tensor print("Is image a PyTorch Tensor:", torch.is_tensor(img)) print("Type of Image:", type(img)) # size of the image print(img.size()) # convert the torch tensor to PIL image img = T.ToPILImage()(img) # display the image img.show()

Output

Image data:
   tensor([[[146, 169, 191, ..., 71, 61, 53],
      [140, 169, 192, ..., 75, 63, 53],
      [126, 161, 186, ..., 85, 68, 58],
      ...,
      [ 33, 31, 30, ..., 218, 221, 223],
      [ 30, 30, 31, ..., 216, 219, 224],
      [ 41, 45, 52, ..., 218, 219, 220]],

      [[130, 151, 170, ..., 47, 41, 35],
      [124, 151, 171, ..., 52, 42, 36],
      [110, 145, 168, ..., 61, 48, 39],
      ...,
      [ 29, 26, 25, ..., 197, 198, 200],
      [ 25, 25, 26, ..., 195, 198, 200],
      [ 20, 25, 33, ..., 200, 201, 202]],

      [[ 79, 101, 123, ..., 21, 17, 13],
      [ 73, 101, 126, ..., 21, 13, 10],
      [ 61, 96, 122, ..., 23, 11, 6],
      ...,
      [ 20, 20, 19, ..., 166, 167, 169],
      [ 19, 19, 20, ..., 164, 167, 172],
      [ 25, 27, 29, ..., 164, 165, 166]]],
dtype=torch.uint8)
Is image a PyTorch Tensor: True
Type of Image:
torch.Size([3, 465, 700])

Notice that the output of image_read() is torch Tensor and values are in range [0,255], and the data type of Tensor is torch.uint8.

Example 2

In this Python code, we will see how to read a png image using PyTorch.

import torch
import torchvision

# read a png image
img = torchvision.io.read_image('elephant.png')

# display the properties of image
print("Image data:
", img) print(img.size()) print(type(img)) # display the png image # convert the image tensor to PIL image img = torchvision.transforms.ToPILImage()(img) # display the PIL image img.show()

Output

Image data:
   tensor([[[ 14, 13, 11, ..., 22, 21, 13],
      [ 13, 12, 9, ..., 24, 27, 21],
      [ 12, 10, 7, ..., 26, 33, 32],
      ...,
      [ 54, 15, 25, ..., 39, 76, 111],
      [ 79, 29, 32, ..., 38, 61, 84],
      [112, 60, 60, ..., 23, 47, 72]],

      [[ 14, 13, 11, ..., 11, 11, 5],
      [ 13, 12, 9, ..., 14, 17, 13],
      [ 12, 10, 7, ..., 15, 23, 23],
      ...,
      [ 38, 0, 9, ..., 25, 62, 97],
      [ 58, 8, 9, ..., 28, 50, 70],
      [ 91, 39, 37, ..., 13, 36, 58]],

      [[ 12, 11, 9, ..., 15, 12, 2],
      11, 10, 7, ..., 15, 16, 10],
      [ 10, 8, 5, ..., 13, 21, 18],
      ...,
      [ 38, 0, 9, ..., 24, 61, 96],
      [ 65, 15, 15, ..., 27, 48, 67],
      [ 98, 46, 43, ..., 12, 34, 55]]],
dtype=torch.uint8)
Is image a PyTorch Tensor: True
torch.Size([3, 466, 700])
<class 'torch.Tensor'>

Notice that the output of image_read() is a torch Tensor and the values are in the range [0,255], and the data type of Tensor is torch.uint8.

Updated on: 20-Jan-2022

6K+ Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements