How to find the k-th and the top "k" elements of a tensor in PyTorch?

PyTorch provides powerful methods to find specific elements in tensors. torch.kthvalue() finds the k-th smallest element, while torch.topk() finds the k largest elements.

Finding the k-th Element with torch.kthvalue()

The torch.kthvalue() method returns the k-th smallest element after sorting the tensor in ascending order. It returns both the value and its index in the original tensor.

Syntax

torch.kthvalue(input, k, dim=None, keepdim=False)

Parameters

  • input ? The input tensor
  • k ? The k-th element to find (1-indexed)
  • dim ? The dimension along which to find the k-th value
  • keepdim ? Whether to keep the reduced dimension

Example

import torch

# Create a 1D tensor
T = torch.Tensor([2.334, 4.433, -4.33, -0.433, 5.0, 4.443])
print("Original Tensor:")
print(T)

# Find the 3rd smallest element (k=3)
value, index = torch.kthvalue(T, 3)

print("3rd smallest element value:", value)
print("3rd smallest element index:", index)
Original Tensor:
tensor([ 2.3340,  4.4330, -4.3300, -0.4330,  5.0000,  4.4430])
3rd smallest element value: tensor(2.3340)
3rd smallest element index: tensor(0)

Finding Top k Elements with torch.topk()

The torch.topk() method returns the k largest elements and their indices in the original tensor.

Syntax

torch.topk(input, k, dim=None, largest=True, sorted=True)

Parameters

  • input ? The input tensor
  • k ? Number of top elements to return
  • dim ? The dimension along which to find top k values
  • largest ? If True, returns largest elements; if False, returns smallest
  • sorted ? If True, returns elements in sorted order

Example

import torch

# Create a 1D tensor
T = torch.Tensor([2.334, 4.433, -4.33, -0.433, 5.0, 4.443])
print("Original Tensor:")
print(T)

# Find the top 2 largest elements
values, indices = torch.topk(T, 2)

print("Top 2 element values:", values)
print("Top 2 element indices:", indices)
Original Tensor:
tensor([ 2.3340,  4.4330, -4.3300, -0.4330,  5.0000,  4.4430])
Top 2 element values: tensor([5.0000, 4.4430])
Top 2 element indices: tensor([4, 5])

Working with Multi-dimensional Tensors

Both methods can work along specific dimensions in multi-dimensional tensors.

import torch

# Create a 2D tensor
T = torch.tensor([[3.2, 1.5, 4.1], 
                  [2.8, 5.0, 1.2]])
print("Original 2D Tensor:")
print(T)

# Find top 2 elements along dimension 1 (columns)
values, indices = torch.topk(T, 2, dim=1)
print("\nTop 2 values along dim=1:")
print("Values:", values)
print("Indices:", indices)
Original 2D Tensor:
tensor([[3.2000, 1.5000, 4.1000],
        [2.8000, 5.0000, 1.2000]])

Top 2 values along dim=1:
Values: tensor([[4.1000, 3.2000],
        [5.0000, 2.8000]])
Indices: tensor([[2, 0],
        [1, 0]])

Comparison

Method Returns Use Case
torch.kthvalue() Single k-th smallest value Finding specific ranked element
torch.topk() Top k largest/smallest values Finding multiple top elements

Conclusion

Use torch.kthvalue() to find a specific ranked element and torch.topk() for multiple top elements. Both methods return values and their original indices, making them useful for ranking and selection operations.

Updated on: 2026-03-26T18:42:39+05:30

1K+ Views

Kickstart Your Career

Get certified by completing the course

Get Started
Advertisements