torch.argmax() Method in Python PyTorch


To find the indices of the maximum value of the elements in an input tensor, we can apply the torch.argmax() function. It returns the indices only, not the element value. If the input tensor has multiple maximal values, then the function will return the index of the first maximal element. We can apply the torch.argmax() function to compute the indices of the maximum values of a tensor across a dimension..

Syntax

torch.argmax(input)

Steps

We could use the following steps to find the indices of the maximum values of all elements in the input tensor −

  • Import the required library. In all the following examples, the required Python library is torch. Make sure you have already installed it.

import torch
  • Define an input tensor input.

input = torch.randn(3,4)
  • Compute the indices of the maximum values of all the elements in the tensor input.

indices = torch.argmax(input)
  • Print the above computed tensor with indices.

print("Indices:
", indices)

Example 1

# Import the required library
import torch

# define an input tensor
input = torch.tensor([0., -1., 2., 8.])

# print above defined tensor
print("Input Tensor:
", input) # Compute indices of the maximum value indices = torch.argmax(input) # print the indices print("Indices:
", indices)

Output

Input Tensor:
   tensor([ 0., -1., 2., 8.])
Indices:
   tensor(3)

In the above Python example, we find the index of the maximum value of the element of an input 1D tensor. The maximum value in the input tensor is 8 and the index of this element is 3.

Example 2

In this program, we compute the condition number with respect to the different matrix norms.

# Import the required library
import torch

# define an input tensor
input = torch.randn(4,4)

# print above defined tensor
print("Input Tensor:
", input) # Compute indices of the maximum value indices = torch.argmax(input) # print the indices print("Indices:
", indices) # Compute indices of the maximum value in dim 0 indices = torch.argmax(input, dim=0) # print the indices print("Indices in dim 0:
", indices) # Compute indices of the maximum value in dim 1 indices = torch.argmax(input, dim=1) # print the indices print("Indices in dim 1:
", indices)

Output

Input Tensor:
   tensor([[-1.6729, 1.2613, -1.2882, -0.8133],
   [ 0.9192, 0.9301, -0.2372, 0.0162],
   [-0.4669, 0.6604, -0.7982, 0.2621],
   [ 0.6436, 1.0328, 2.4573, 0.0606]])
Indices:
   tensor(14)
Indices in dim 0:
   tensor([1, 0, 3, 2])
Indices in dim 1:
   tensor([1, 1, 1, 2])

In the above Python example, we find the indices of the maximum value of the element of an input 2D tensor in different dimensions. We generated the elements of the input tensor using the torch.randn() method, so you may notice getting different input tensor and indices.

Updated on: 27-Jan-2022

8K+ Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements