Article Categories
- All Categories
-
Data Structure
-
Networking
-
RDBMS
-
Operating System
-
Java
-
MS Excel
-
iOS
-
HTML
-
CSS
-
Android
-
Python
-
C Programming
-
C++
-
C#
-
MongoDB
-
MySQL
-
Javascript
-
PHP
-
Economics & Finance
How to find the k-th and the top "k" elements of a tensor in PyTorch?
PyTorch provides powerful methods to find specific elements in tensors. torch.kthvalue() finds the k-th smallest element, while torch.topk() finds the k largest elements.
Finding the k-th Element with torch.kthvalue()
The torch.kthvalue() method returns the k-th smallest element after sorting the tensor in ascending order. It returns both the value and its index in the original tensor.
Syntax
torch.kthvalue(input, k, dim=None, keepdim=False)
Parameters
- input ? The input tensor
- k ? The k-th element to find (1-indexed)
- dim ? The dimension along which to find the k-th value
- keepdim ? Whether to keep the reduced dimension
Example
import torch
# Create a 1D tensor
T = torch.Tensor([2.334, 4.433, -4.33, -0.433, 5.0, 4.443])
print("Original Tensor:")
print(T)
# Find the 3rd smallest element (k=3)
value, index = torch.kthvalue(T, 3)
print("3rd smallest element value:", value)
print("3rd smallest element index:", index)
Original Tensor: tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430]) 3rd smallest element value: tensor(2.3340) 3rd smallest element index: tensor(0)
Finding Top k Elements with torch.topk()
The torch.topk() method returns the k largest elements and their indices in the original tensor.
Syntax
torch.topk(input, k, dim=None, largest=True, sorted=True)
Parameters
- input ? The input tensor
- k ? Number of top elements to return
- dim ? The dimension along which to find top k values
- largest ? If True, returns largest elements; if False, returns smallest
- sorted ? If True, returns elements in sorted order
Example
import torch
# Create a 1D tensor
T = torch.Tensor([2.334, 4.433, -4.33, -0.433, 5.0, 4.443])
print("Original Tensor:")
print(T)
# Find the top 2 largest elements
values, indices = torch.topk(T, 2)
print("Top 2 element values:", values)
print("Top 2 element indices:", indices)
Original Tensor: tensor([ 2.3340, 4.4330, -4.3300, -0.4330, 5.0000, 4.4430]) Top 2 element values: tensor([5.0000, 4.4430]) Top 2 element indices: tensor([4, 5])
Working with Multi-dimensional Tensors
Both methods can work along specific dimensions in multi-dimensional tensors.
import torch
# Create a 2D tensor
T = torch.tensor([[3.2, 1.5, 4.1],
[2.8, 5.0, 1.2]])
print("Original 2D Tensor:")
print(T)
# Find top 2 elements along dimension 1 (columns)
values, indices = torch.topk(T, 2, dim=1)
print("\nTop 2 values along dim=1:")
print("Values:", values)
print("Indices:", indices)
Original 2D Tensor:
tensor([[3.2000, 1.5000, 4.1000],
[2.8000, 5.0000, 1.2000]])
Top 2 values along dim=1:
Values: tensor([[4.1000, 3.2000],
[5.0000, 2.8000]])
Indices: tensor([[2, 0],
[1, 0]])
Comparison
| Method | Returns | Use Case |
|---|---|---|
torch.kthvalue() |
Single k-th smallest value | Finding specific ranked element |
torch.topk() |
Top k largest/smallest values | Finding multiple top elements |
Conclusion
Use torch.kthvalue() to find a specific ranked element and torch.topk() for multiple top elements. Both methods return values and their original indices, making them useful for ranking and selection operations.
