![Trending Articles on Technical and Non Technical topics](/images/trending_categories.jpeg)
Data Structure
Networking
RDBMS
Operating System
Java
MS Excel
iOS
HTML
CSS
Android
Python
C Programming
C++
C#
MongoDB
MySQL
Javascript
PHP
Physics
Chemistry
Biology
Mathematics
English
Economics
Psychology
Social Studies
Fashion Studies
Legal Studies
- Selected Reading
- UPSC IAS Exams Notes
- Developer's Best Practices
- Questions and Answers
- Effective Resume Writing
- HR Interview Questions
- Computer Glossary
- Who is Who
What does Tensor.detach() do in PyTorch?
Tensor.detach() is used to detach a tensor from the current computational graph. It returns a new tensor that doesn't require a gradient.
When we don't need a tensor to be traced for the gradient computation, we detach the tensor from the current computational graph.
We also need to detach a tensor when we need to move the tensor from GPU to CPU.
Syntax
Tensor.detach()
It returns a new tensor without requires_grad = True. The gradient with respect to this tensor will no longer be computed.
Steps
Import the torch library. Make sure you have it already installed.
import torch
Create a PyTorch tensor with requires_grad = True and print the tensor.
x = torch.tensor(2.0, requires_grad = True) print("x:", x)
Compute Tensor.detach() and optionally assign this value to a new variable.
x_detach = x.detach()
Print the tensor after .detach() operation is performed.
print("Tensor with detach:", x_detach)
Example 1
# import torch library import torch # create a tensor with requires_gradient=true x = torch.tensor(2.0, requires_grad = True) # print the tensor print("Tensor:", x) # tensor.detach operation x_detach = x.detach() print("Tensor with detach:", x_detach)
Output
Tensor: tensor(2., requires_grad=True) Tensor with detach: tensor(2.)
Notice that in the above output, the tensor after detach doesn't have requires_grad = True
Example 2
# import torch library import torch # define a tensor with requires_grad=true x = torch.rand(3, requires_grad = True) print("x:", x) # apply above tensor to use detach() y = 3 + x z = 3 * x.detach() print("y:", y) print("z:", z)
Output
x: tensor([0.5656, 0.8402, 0.6661], requires_grad=True) y: tensor([3.5656, 3.8402, 3.6661], grad_fn=<AddBackward0>) z: tensor([1.6968, 2.5207, 1.9984])