1+ """
2+ Understanding requires_grad, retain_grad, Leaf, and Non-leaf Tensors
3+ ====================================================================
4+
5+ **Author:** `Justin Silver <https://github.com/j-silv>`__
6+
7+ This tutorial explains the subtleties of ``requires_grad``,
8+ ``retain_grad``, leaf, and non-leaf tensors using a simple example.
9+
10+ Before starting, make sure you understand `tensors and how to manipulate
11+ them <https://docs.pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html>`__.
12+ A basic knowledge of `how autograd
13+ works <https://docs.pytorch.org/tutorials/beginner/basics/autogradqs_tutorial.html>`__
14+ would also be useful.
15+
16+ """
17+
18+
19+ ######################################################################
20+ # Setup
21+ # -----
22+ #
23+ # First, make sure `PyTorch is
24+ # installed <https://pytorch.org/get-started/locally/>`__ and then import
25+ # the necessary libraries.
26+ #
27+
28+ import torch
29+ import torch .nn .functional as F
30+
31+
32+ ######################################################################
33+ # Next, we instantiate a simple network to focus on the gradients. This
34+ # will be an affine layer, followed by a ReLU activation, and ending with
35+ # a MSE loss between prediction and label tensors.
36+ #
37+ # .. math::
38+ #
39+ # \mathbf{y}_{\text{pred}} = \text{ReLU}(\mathbf{x} \mathbf{W} + \mathbf{b})
40+ #
41+ # .. math::
42+ #
43+ # L = \text{MSE}(\mathbf{y}_{\text{pred}}, \mathbf{y})
44+ #
45+ # Note that the ``requires_grad=True`` is necessary for the parameters
46+ # (``W`` and ``b``) so that PyTorch tracks operations involving those
47+ # tensors. We’ll discuss more about this in a future
48+ # `section <#requires-grad>`__.
49+ #
50+
51+ # tensor setup
52+ x = torch .ones (1 , 3 ) # input with shape: (1, 3)
53+ W = torch .ones (3 , 2 , requires_grad = True ) # weights with shape: (3, 2)
54+ b = torch .ones (1 , 2 , requires_grad = True ) # bias with shape: (1, 2)
55+ y = torch .ones (1 , 2 ) # output with shape: (1, 2)
56+
57+ # forward pass
58+ z = (x @ W ) + b # pre-activation with shape: (1, 2)
59+ y_pred = F .relu (z ) # activation with shape: (1, 2)
60+ loss = F .mse_loss (y_pred , y ) # scalar loss
61+
62+
63+ ######################################################################
64+ # Leaf vs. non-leaf tensors
65+ # -------------------------
66+ #
67+ # After running the forward pass, PyTorch autograd has built up a `dynamic
68+ # computational
69+ # graph <https://docs.pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#computational-graph>`__
70+ # which is shown below. This is a `Directed Acyclic Graph
71+ # (DAG) <https://en.wikipedia.org/wiki/Directed_acyclic_graph>`__ which
72+ # keeps a record of input tensors (leaf nodes), all subsequent operations
73+ # on those tensors, and the intermediate/output tensors (non-leaf nodes).
74+ # The graph is used to compute gradients for each tensor starting from the
75+ # graph roots (outputs) to the leaves (inputs) using the `chain
76+ # rule <https://en.wikipedia.org/wiki/Chain_rule>`__ from calculus:
77+ #
78+ # .. math::
79+ #
80+ # \mathbf{y} = \mathbf{f}_k\bigl(\mathbf{f}_{k-1}(\dots \mathbf{f}_1(\mathbf{x}) \dots)\bigr)
81+ #
82+ # .. math::
83+ #
84+ # \frac{\partial \mathbf{y}}{\partial \mathbf{x}} =
85+ # \frac{\partial \mathbf{f}_k}{\partial \mathbf{f}_{k-1}} \cdot
86+ # \frac{\partial \mathbf{f}_{k-1}}{\partial \mathbf{f}_{k-2}} \cdot
87+ # \cdots \cdot
88+ # \frac{\partial \mathbf{f}_1}{\partial \mathbf{x}}
89+ #
90+ # .. figure:: /_static/img/understanding_leaf_vs_nonleaf/comp-graph-1.png
91+ # :alt: Computational graph after forward pass
92+ #
93+ # Computational graph after forward pass
94+ #
95+ # PyTorch considers a node to be a *leaf* if it is not the result of a
96+ # tensor operation with at least one input having ``requires_grad=True``
97+ # (e.g. ``x``, ``W``, ``b``, and ``y``), and everything else to be
98+ # *non-leaf* (e.g. ``z``, ``y_pred``, and ``loss``). You can verify this
99+ # programmatically by probing the ``is_leaf`` attribute of the tensors:
100+ #
101+
102+ # prints True because new tensors are leafs by convention
103+ print (f"{ x .is_leaf = } " )
104+
105+ # prints False because tensor is the result of an operation with at
106+ # least one input having requires_grad=True
107+ print (f"{ z .is_leaf = } " )
108+
109+
110+ ######################################################################
111+ # The distinction between leaf and non-leaf determines whether the
112+ # tensor’s gradient will be stored in the ``grad`` property after the
113+ # backward pass, and thus be usable for `gradient
114+ # descent <https://en.wikipedia.org/wiki/Gradient_descent>`__. We’ll cover
115+ # this some more in the `following section <#retain-grad>`__.
116+ #
117+ # Let’s now investigate how PyTorch calculates and stores gradients for
118+ # the tensors in its computational graph.
119+ #
120+
121+
122+ ######################################################################
123+ # ``requires_grad``
124+ # -----------------
125+ #
126+ # To build the computational graph which can be used for gradient
127+ # calculation, we need to pass in the ``requires_grad=True`` parameter to
128+ # a tensor constructor. By default, the value is ``False``, and thus
129+ # PyTorch does not track gradients on any created tensors. To verify this,
130+ # try not setting ``requires_grad``, re-run the forward pass, and then run
131+ # backpropagation. You will see:
132+ #
133+ # ::
134+ #
135+ # >>> loss.backward()
136+ # RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
137+ #
138+ # This error means that autograd can’t backpropagate to any leaf tensors
139+ # because ``loss`` is not tracking gradients. If you need to change the
140+ # property, you can call ``requires_grad_()`` on the tensor (notice the \_
141+ # suffix).
142+ #
143+ # We can sanity check which nodes require gradient calculation, just like
144+ # we did above with the ``is_leaf`` attribute:
145+ #
146+
147+ print (f"{ x .requires_grad = } " ) # prints False because requires_grad=False by default
148+ print (f"{ W .requires_grad = } " ) # prints True because we set requires_grad=True in constructor
149+ print (f"{ z .requires_grad = } " ) # prints True because tensor is a non-leaf node
150+
151+
152+ ######################################################################
153+ # It’s useful to remember that a non-leaf tensor has
154+ # ``requires_grad=True`` by definition, since backpropagation would fail
155+ # otherwise. If the tensor is a leaf, then it will only have
156+ # ``requires_grad=True`` if it was specifically set by the user. Another
157+ # way to phrase this is that if at least one of the inputs to a tensor
158+ # requires the gradient, then it will require the gradient as well.
159+ #
160+ # There are two exceptions to this rule:
161+ #
162+ # 1. Any ``nn.Module`` that has ``nn.Parameter`` will have
163+ # ``requires_grad=True`` for its parameters (see
164+ # `here <https://docs.pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html#creating-models>`__)
165+ # 2. Locally disabling gradient computation with context managers (see
166+ # `here <https://docs.pytorch.org/docs/stable/notes/autograd.html#locally-disabling-gradient-computation>`__)
167+ #
168+ # In summary, ``requires_grad`` tells autograd which tensors need to have
169+ # their gradients calculated for backpropagation to work. This is
170+ # different from which tensors have their ``grad`` field populated, which
171+ # is the topic of the next section.
172+ #
173+
174+
175+ ######################################################################
176+ # ``retain_grad``
177+ # ---------------
178+ #
179+ # To actually perform optimization (e.g. SGD, Adam, etc.), we need to run
180+ # the backward pass so that we can extract the gradients.
181+ #
182+
183+ loss .backward ()
184+
185+
186+ ######################################################################
187+ # Calling ``backward()`` populates the ``grad`` field of all leaf tensors
188+ # which had ``requires_grad=True``. The ``grad`` is the gradient of the
189+ # loss with respect to the tensor we are probing. Before running
190+ # ``backward()``, this attribute is set to ``None``.
191+ #
192+
193+ print (f"{ W .grad = } " )
194+ print (f"{ b .grad = } " )
195+
196+
197+ ######################################################################
198+ # You might be wondering about the other tensors in our network. Let’s
199+ # check the remaining leaf nodes:
200+ #
201+
202+ # prints all None because requires_grad=False
203+ print (f"{ x .grad = } " )
204+ print (f"{ y .grad = } " )
205+
206+
207+ ######################################################################
208+ # The gradients for these tensors haven’t been populated because we did
209+ # not explicitly tell PyTorch to calculate their gradient
210+ # (``requires_grad=False``).
211+ #
212+ # Let’s now look at an intermediate non-leaf node:
213+ #
214+
215+ print (f"{ z .grad = } " )
216+
217+
218+ ######################################################################
219+ # PyTorch returns ``None`` for the gradient and also warns us that a
220+ # non-leaf node’s ``grad`` attribute is being accessed. Although autograd
221+ # has to calculate intermediate gradients for backpropagation to work, it
222+ # assumes you don’t need to access the values afterwards. To change this
223+ # behavior, we can use the ``retain_grad()`` function on a tensor. This
224+ # tells the autograd engine to populate that tensor’s ``grad`` after
225+ # calling ``backward()``.
226+ #
227+
228+ # we have to re-run the forward pass
229+ z = (x @ W ) + b
230+ y_pred = F .relu (z )
231+ loss = F .mse_loss (y_pred , y )
232+
233+ # tell PyTorch to store the gradients after backward()
234+ z .retain_grad ()
235+ y_pred .retain_grad ()
236+ loss .retain_grad ()
237+
238+ # have to zero out gradients otherwise they would accumulate
239+ W .grad = None
240+ b .grad = None
241+
242+ # backpropagation
243+ loss .backward ()
244+
245+ # print gradients for all tensors that have requires_grad=True
246+ print (f"{ W .grad = } " )
247+ print (f"{ b .grad = } " )
248+ print (f"{ z .grad = } " )
249+ print (f"{ y_pred .grad = } " )
250+ print (f"{ loss .grad = } " )
251+
252+
253+ ######################################################################
254+ # We get the same result for ``W.grad`` as before. Also note that because
255+ # the loss is scalar, the gradient of the loss with respect to itself is
256+ # simply ``1.0``.
257+ #
258+ # If we look at the state of the computational graph now, we see that the
259+ # ``retains_grad`` attribute has changed for the intermediate tensors. By
260+ # convention, this attribute will print ``False`` for any leaf node, even
261+ # if it requires its gradient.
262+ #
263+ # .. figure:: /_static/img/understanding_leaf_vs_nonleaf/comp-graph-2.png
264+ # :alt: Computational graph after backward pass
265+ #
266+ # Computational graph after backward pass
267+ #
268+ # If you call ``retain_grad()`` on a non-leaf node, it results in a no-op.
269+ # If we call ``retain_grad()`` on a node that has ``requires_grad=False``,
270+ # PyTorch actually throws an error, since it can’t store the gradient if
271+ # it is never calculated.
272+ #
273+ # ::
274+ #
275+ # >>> x.retain_grad()
276+ # RuntimeError: can't retain_grad on Tensor that has requires_grad=False
277+ #
278+
279+
280+ ######################################################################
281+ # Summary table
282+ # -------------
283+ #
284+ # Using ``retain_grad()`` and ``retains_grad`` only make sense for
285+ # non-leaf nodes, since the ``grad`` attribute will already be populated
286+ # for leaf tensors that have ``requires_grad=True``. By default, these
287+ # non-leaf nodes do not retain (store) their gradient after
288+ # backpropagation. We can change that by rerunning the forward pass,
289+ # telling PyTorch to store the gradients, and then performing
290+ # backpropagation.
291+ #
292+ # The following table can be used as a reference which summarizes the
293+ # above discussions. The following scenarios are the only ones that are
294+ # valid for PyTorch tensors.
295+ #
296+ #
297+ #
298+ # +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
299+ # | ``is_leaf`` | ``requires_grad`` | ``retains_grad`` | ``require_grad()`` | ``retain_grad()`` |
300+ # +================+========================+========================+===================================================+=====================================+
301+ # | ``True`` | ``False`` | ``False`` | sets ``requires_grad`` to ``True`` or ``False`` | no-op |
302+ # +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
303+ # | ``True`` | ``True`` | ``False`` | sets ``requires_grad`` to ``True`` or ``False`` | no-op |
304+ # +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
305+ # | ``False`` | ``True`` | ``False`` | no-op | sets ``retains_grad`` to ``True`` |
306+ # +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
307+ # | ``False`` | ``True`` | ``True`` | no-op | no-op |
308+ # +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
309+ #
310+
311+
312+ ######################################################################
313+ # Conclusion
314+ # ----------
315+ #
316+ # In this tutorial, we covered when and how PyTorch computes gradients for
317+ # leaf and non-leaf tensors. By using ``retain_grad``, we can access the
318+ # gradients of intermediate tensors within autograd’s computational graph.
319+ #
320+ # If you would like to learn more about how PyTorch’s autograd system
321+ # works, please visit the `references <#references>`__ below. If you have
322+ # any feedback for this tutorial (improvements, typo fixes, etc.) then
323+ # please use the `PyTorch Forums <https://discuss.pytorch.org/>`__ and/or
324+ # the `issue tracker <https://github.com/pytorch/tutorials/issues>`__ to
325+ # reach out.
326+ #
327+
328+
329+ ######################################################################
330+ # References
331+ # ----------
332+ #
333+ # - `A Gentle Introduction to
334+ # torch.autograd <https://docs.pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html>`__
335+ # - `Automatic Differentiation with
336+ # torch.autograd <https://docs.pytorch.org/tutorials/beginner/basics/autogradqs_tutorial>`__
337+ # - `Autograd
338+ # mechanics <https://docs.pytorch.org/docs/stable/notes/autograd.html>`__
339+ #
0 commit comments