From 9887ab8cadb98d85a1985ac5744275f9e406c0d8 Mon Sep 17 00:00:00 2001 From: Justin Silver Date: Mon, 9 Jun 2025 14:34:54 -0700 Subject: [PATCH] Fix #3186 - visualizing gradients tutorial The original PRs had a combined visualizing gradients tutorial and a section on understanding leaf vs non-leaf and requires_grad vs. retain_grad. I've broken up these two and the latter is in a second PR. I moved the visualizing gradients tutorial into the intermediate section. Another change I made from the last PR is renaming the forward/backward hook functions to be more clear. --- .../visualizing_gradients_tutorial.png | Bin 0 -> 41537 bytes index.rst | 9 +- .../visualizing_gradients_tutorial.py | 298 ++++++++++++++++++ 3 files changed, 306 insertions(+), 1 deletion(-) create mode 100644 _static/img/thumbnails/cropped/visualizing_gradients_tutorial.png create mode 100644 intermediate_source/visualizing_gradients_tutorial.py diff --git a/_static/img/thumbnails/cropped/visualizing_gradients_tutorial.png b/_static/img/thumbnails/cropped/visualizing_gradients_tutorial.png new file mode 100644 index 0000000000000000000000000000000000000000..6ff6d97f2e200d86705257ac22ca1a1c83993727 GIT binary patch literal 41537 zcmb@u2RxSj`#w%P4Gkouls&T}Qi>vEXGO@~GrOrl$g1ozvy$x0l98E^9VJ`JmQlas zqUZDZetzHY>;M1#N3U1i&vU!o*LA(ec^>C+9OwJKqQco7WHe+XBqTdz&Ye*vAz8L{h*sA6O6cvauth(uoB(bm$&(bCM|kh77!gPDyrKR54DZhp=~ zrjCxb4x&6fR{#10w~f6Ck1c5)yKK;xFlQi6k=;l4t2MXHKcQM2`2k zYJ{jQl}-Mrdzry~sL!YHtF(D+Vr+szr=)Kf7x%BMSDe8u*>&j^e1cq_{2bbP%^re! z{6}0VRf4;@gmsj3G|q&0?l+F^J=p)i-*P=)9m)LA>mustp*sF~x?R={4lxd@KC!G@ z7;f*9Y~P=gzJ8Y^{+rc@k&SqLc|Itf_|aQ#nV`e?#Fj!v86V;k<3ayNA1LHAZ(&`x zeto@vvMRfAeYjDui;s$;Vp#~AifPO9-R3#rKY!YmoiL<~O$lO@+Cq9#M~AMzDj-QM zRXx9;;LG5kw4&lJp3-H$Kg-hr>E&6kLpd(KGHp8S+m)e_c{AX|;=QVa*CK67S03Hp z%c8||jKQ3Bv-ajgw_g_f(F*ffwjZUVqjOmtNg8QM-iwR1w`+V#*RC>&@s!4eR}O)Fs&T( z4$eZ`-#j%WSOsph3ckV1&98DE4Z>e^)Y?IY7mDUTy9-#h z@7lL--$db$6ACdWtjyc;Z1lWVW-m6xNCe0xNG2##rug!TUY($05A_)yHua|$H#aY} z>dHObwVqDQ<>Atg(t|uaJYr&E`9DgQ7REA4<%Tl!O8uqdz64)5!F%AmjLfNYb(K8p zOV&MUUfub23rkD;mb4^QS@?nupV^UT*DtM=ZF=h5xpQIB7ZnvxYG_awyUiN>UjBud zCw*yLPxC4(Yrn8HozvJSqsC})4@{#@kxQVU#iu9HrLNP)^zP%lH*em2a{May#SC4e znc*fp(aV-L1zFix_xXwEs_c~k2Lm!PcuH61XbmeLFo=s6f8dHBKe0fPqg9VNJa_)Q z;pZY(v6Y$T@z2Gy*YiHR`DSMeaO)PCzV__@@1~)i_8A_Ts%27Wl$5sc1%=s>eoNjUHoW4Wh))RF5&K>seojDwi%X1p#<>iztEGz?cVf=}Pk9qXe#l*!!Pk1gq z@h@JUs`ZJ9i5Z!`cW+C=g_skvn7@$56H9VR6jB9lvnjF4$szc1ZMvjgYGK_IyJI(< zzi?r*553@UR~}DENlD_h^N+c?T8yzay!{JazYeE2F)|7a3X;TZD^Ymg+DyLg;lqbz z&raW^7O`ji{P}a@-1<%153#a7fA@q^N*{M9w$OKr)khMKH=>92V`gSXcg(hITqN#n zKqYp_Oo1Ka*YS=lrjk!lA4eQ5KNp1|+=L5S7>m;H+Eu1bSh#t(sciLE~8^PHDa2KPTbGqd3kKY!MTf<}oyk%%p-h0*2B|tzqp&gWmfZIwf;)K`V$$K3WZ48UR>gs%$aDxN2?nexTnN( zth%U*bsN36?YkavzuEu5@#>nIuYGTB%}n$}5c3ffa^Sdg-@t%Wl0sA(P2R3KNpt1f z&)*ak7h{F=RRZ8|NT2w!~^zZ1*HoH|=c>FQ< zB?H8v;n~IG{v9rUCef~UumGt#5)G^TjbkhX!Vh<`-Qx96h>A zN@=($ZZjrT?#a;`(?g9}hE+S==EgNgS308k>-{|P*VMo!$`5)KLPJWh6 zP*!Wtd4Msqb8z6860fMNj8#vqH{&#a$!>?nz>pBtVGm7d+6PCd@=lypqM+i@wfw|==+O1P zzRM*mbK1>ZhmJDxW7l%b3KAjb(~C>m-?n_fym0!@t4~Z!7-y<(@RKetyRrV)k86J{ zgbs&dDIuARw5IOU)zwYxFLs|F#I8-#E$(z&I-9QEoMm!>id*NNdTOW}g;E@r*DgVw z+DBZP1@=tkcgd3w@DeFaD5V&sT!^$cD4GiNqj?e!F7@&l~GB<_s^d5mttdMhw&Ud#WQ=5syu%G z8u~Ucz>6G~m0P*t8+J_5pWjP~{ho_6U-Ejr$~!v33Xu)`4fn*>yll-r%y8RN#B(v5 z&$|0qf|M_nkFT%q*U{0GY_n!**U5@~p{MWe49htnVjur?VnX+u6IONN%#U8Z$)7(J zJv=-@G7YQPul-rcTgR$*Z6!yuE61XZ9T`Y1GAb(c>C;0-J+GZ=u(Hob-ArTk2?_~` zFG5U`jg$7HID7VNMQ^WO0d6K_p!RV@Cw0=H49%>{@d7(LJMXXW>*`da-A36G)YG-3 z6mFG7C8#{4lv3JBckJnHZ*MjXf=@(5#K#Oh2_tMoDcwBXVmGpoKhOB%=KUMR5kg`H0 z*94DF{^i{H%BTaw`1gL*TdBphmnyH0%Ob+W*-ZR4+);6 z3V8hZqBk>}wY9a2tLuAVJEoxj^7;ALP<3tw?=0eG{|5*t$fh38O15$1M(j)9PKz8Z zpsecp`kV!kuMIM#j{~IR-n@BpDzR@HLz;9D1;_iwMivH!8%SopF>$P9zkmPM$+tb_ z_4DH4k@+bH2Ztkv4v`?mXx}}2`}rM4nap*?x)%V?X#DOh&W`ne`*x*?M*?pd8#7~G zn7FNvus@n{?_l$h^zu4!@r4etO&)P+%l2+%V2@;EQ-2r}6QG-sq!cHOEv~sNhF`@8 zxw*Ng)|c)fZZ9>J!{6VZNT4m(B)2F9eq(t>`JP8isptKC{;s#--&h{}?+CY@H%gab zuuko_`|uuKlNYhg9JxD_A7UjwQ%*)8cf0=`7*{jfR4L(|>h6+XgS6(YDMn*QZ|{ng z4rZ)SkfN7)+ICla|L+PtISr?KU-XuxGRX7xxH7(B&&adv;6qAfV`V*hpti2=T~kwF ziRV&yjvzBL(B?Cf#%M-H#=h_0cc_l@%;TAGMY%ttpVDiep1QR#6mzY{F-kMbcsCuL z4g+tG*QHR|3SXWQQ|ILm&rSmY5#jzz2HWjjg7+EOoLQ&N9q#u!c2ZL7MaSg3Y@Ny# z+cNh1vZY=t)pd1OdP_VL4UzQ9dV6DlOoAia`5yuI(Fj}1O%K-F+1tPE?d>%m^W9D( zaPRi*4SV+NX@6z34N*btXW4pn*Yqh{JG(Eb86{s^l2s_(dH_DjilcW)s!9dLtFDt$ zs&3$UIjm=FO`0W4ANQR0c%@;M5Z#d@7cH7*DJouRpGnuwV@A})poaP5AHTZ?y`Ch> zM@+TpT?;$#qb(z5By-Za5fGLd$r719u_*u#9v1yLDyklV?=IjnFu(Dysey#bf${MN zrKJ*BX&W|g=64!Z<<`kBLnxKAA~GESxpzbaZGE_4|M$FJ>D1PBcLJk2vP`Dud#^

#ZO$}2D`eemI=5K*=l5~kTm?{3$kCq9Hb{c~>3N^oCq0)Z zs}4GjecAztzHZ&RC!$U%9naM;r28ej#86VuoH%g;ktJ+#4i_1EQ8@9owKdfJNAY7U zZ?C2A!kxKRzKa8d^9N|94~CgU>wk=wSzhY>Bjh`Fi~K|%3K3%r zm?Kh`VmARpe)he2KqKfwgwd^UBB%!;NbE%w>EFxu|q1g^%b@!iGhJZb!{!bUwn7X`}bRIY;087 zDY3S^8_4O!o;=od+j{hYmj@QWnX_jfKZ&>O&Tk4CzG7sw8yK_FkGB5P(32CMuSe`I zUp^GfEN{@7qDD;C#fukFJK5UW9`!ZD{sK(Bi^{h-R%+YF=H`+1%(xCGzaNFJ(}N@r z4oXCeI1EzwWI7Brh)urRTQpfoi~2L+!>?;AV|s||_n6Kgi9)Ogagqo7`R|xDC)%OR1Qgy>)*m3D;OLl*0{5Oz+^3T> z?suy4nAxw@@zeddI5li1708&dCtuY3m=H@^i?dsCw1mIzw*kSb^Hxjs!-XTCvi+qp ztdbG4fnQowXaU( z2Sm@5-`Qc-nnLMTTv9EBG?utwO60=8f~)ehE2Bwx&k?kFIJHSR2V*6S$u ziGsUW6-0$MTFZUuHHGVU?WOw0#zsf&u}_BXbK`82D`X63!l?HK56>U^a&UN-CbGkg z{}h|~?FuqR>__&{npkaa48nxGMgNZhNo6bE4eDUr?@%(4|OTE~-;|`&+YL+)If18+4@TKCCi@Ek|;4v`k zbj3cspV$t2_U#qAiWtsQ|Uv3*4mm`O6{*^aC4pVF^j7L=9K*OrrW6Z1ZnS= zi1~F{(VK-;8o|sE6J^jPhbm|RJFOjHH89WZL1bj)hY?5h`#ZLmK5R7?1ZzS< zJP)&c_|H$;HcdGe7qxON-~XB(I`5Z^`yndzV3rH&8kR)Z`Sw0IO)IBn+y!5M|Ned6 z^R~q1McOlKf2r?XT!Ji-f9GniN19zq%Z82Phh#KgJF?xlaU(xJf7hNp%A+r5u~LcE zjV$gs{NaXtgiu(NUY_-V^_zFbTlc(1rrUWyRP~z^c4|JT4ZEvXt7>a~iL%IRvZM3q z+;5LLIg{n$dST1ydFL-#CZhb$J+HUw7Gl?;ctPD$`XAV0n! zFLP6<0K?qfdEhV;lQ-&YeD_Fm5_y359JP{?62M1vuh)`ab#=9U)hL1_(07c-)cuZ0 zgp>Q*_HqBj_wS;mt-bbVNzb?;Qs&H=+vA=`Zojm-*1HgQK4H_Pl4r~bdpc%VzF~*l zzjyEQU|kq$-|GrdqPwZ6)W<#lfMekGSg=JLSdC`n(NwebG37hP8r{@0>)!^Z70 zvI+`2xwNv&h@B%QAQ1B4!S-g=fv+>;9l`tbJOcI0?|c~>vrylnMIRU#xM9PFUAuOj zJa=vjSQ#T^-9)A6yYz1NP;K*hExVvv>&UU#haxbyW3|BAw1eghAj!|!*#VGy8!jUs zbbHSC5HyYBkIzxbDoLLwSwEl-sxuZK6)^p99jhS|@j}#T^c}EI^%GvhIx;WsU2i9r zdC9rQg~aHW1HY(Wm~jt{c+M)+lCCX0Fz{GU$n!<>NC*HX2ZAcstChVM#GY;6&-ZRz zQ*XENP?*#1+m}RAgH*6BYTMSWN{JU@yv}AOB_yz24|?=SeOO09!QXzoofpYA zNv^GB#IXkD2}P8`)w7xRUj=&clumhNfIav&G{iC1mT_@vdOCeq7_0wGp_OWZSMS`j z*t!|z5b||94mO^NIM_S)<-nzq_}*uEuQb0{00Mo26C0bG zn@P(B)cW7*OKYW5Nhy68n>AuLRkd#6q>kZWzdPgpYh>1qLh4j%Yj`@Z;nw~?5nX@& zY#BpQ;Ft30$LLLuVeyT#G24_#wKABWwn{eMf9@p8T7CA~<{uRwyut# zk~>R_nVFeVCs=rSGppyY+o?(?0Io>O*J*+mU^#Z|36ao|i~_J2KfE%kJscx^bwbCu zG5XT*0%A7mdj8|aI>AT4v*^@Dm|VFM@@uWGxP)Or+_CH}>9O3{^Xi_DPfFkO+f4Dj zR&k>K_Ely4$}>dJzxw!EpPhE|3~f8rt9KgwQTy!R3Ck>qNR z(?0E7OZ^HT%DSlIdDTN$L-m#vEaKu(pm4Am$|No9zYPxBe|fu+$kKLp1uM!G$`syM z=?Gitx%>QCq%V<$|bGe=}_BJ?`rysJ-^~FE_$W@|Ty5zKY&31cB z3DMc{sTS`vk2oeKzcOO06l7Mc9P+7QSkd(|&*{u|KJR+%=jZnh zQom$MyQD)$=d&G1o{tJ;{CvVy99{2q__&Bj6H5Ioqj$4z;&}FFX9MV^P9))9??(p$ ze0KF@Eg3_LL*>$`SnJA{kIzs?lfOsCOv&;yFT#FcvoJdrEm|JCOlDr$1lF? zZ2Ng9;!~BApNP5g@tZUdWDTJDG!o7pBI^B~j**=pMZx=4;gY!Ff<}_-`|oMhi83`W z^~>|^2gs0A4j~YOqB(rzhi=Nq6l1zqVN@pv(7vl8}lsv z@=Yp1>wf}EtUKy3DA(+tMWFJ6f&#h7V}ZCaM2|D7sv#XGxQG(m&sa{`eDoX{wezC` z5?&sGL#R8{hucw9?H6@4&1|C;G^YT~z&`5)wyg#Q57!PA)XL9AvI`zMzxmmIN5)5W zD;aVB-8sijlA7??&Sz>;2gTnz{>rLL2=K8jLysr(5)dKseoDxb@t)^Oaa+jA^PH%W zvc*)IEFE5c`M5nOzMY{pm)xh9_Dg{4u1hmo+*b`cUKzQ#xrK98hlYjSkBf`zu&%b{ z=LJ68j>l(YU?2rPCHY%}s(qQ%toe9_6$D+VtWo$0q+jjdKmPm>)Er);anF+HzcP_WY;#kS- zrSc(9=Fxk^9#d78-@cK0lNG8jNbHi>$L;QoFuj`3vm1Ih4<3w@sG$hD#c7tGAX zRWBf`AE~tQj?u#y2qOgY=+4jVXtBEv!p-4F@jO^5S+Gg2Ah^H{2}JRjw;UD|>sDp$ z0jCc|q!E>qe1_hh3l}adudJATcuERZ&-mTF-2~z6_!~)h`@s{DNL_uRT*3Pr8N9|qR?Q#V|g=T>l`l>8EA7&U>}ClwTwDv5HZlA3oPI3SPm3<*bnexgTgzT3e&x|d)w zxwPJ)q>vkH{=X?K<<-@x^!4?j)7-+?1w}^I{aNY#!wU#NtXAKz1fBzT{uM%s304V} zc-8RKKrO1Nz=scaHD_XVy>=XVOKVef1LIfvdyxZD1804Na1us_Q@7{|BGyNUVOVbK zuY)#51sRLDJoCNlF{p*F0|ON?*M6Hhb8v}@hC{fEjf(>_DS#s4$gyMXoYs)CKD&)^ z;kV1dDB4}+S$0~+orE3LUxzIkd+pC};!c|4&Njc)Kdo|j>z(416AOLo@$|QE-CEC~ zp4JD2WO=U3=8Uv7k%sB7E!#Cji=R(bPltvm(GG~yC_YbzR4zQWkKxRYhvst1R~V(* z@1y()^9nP*a>e#%->Ef}7+`jr39t;PF-z~dw1NV`W5Epg-cvJ zfx!7S;jwNkikUNK&iovX3*fu<+at>v_3N$8 ze1PA5IeU9PlB$4%Nk~z~aT~Gj$x;99Wl^B$TzP$1E?O*%VC9-ej@ftXh$$*6k}TBioj@MJ?k{{-%+{G>aaYf8?u6?k{?}>%jDmt+J(gha79_V>SJPzc_H1 zHpOn-K`X?FU4hBlO+!O=l2CRa&SPyH#=D3wc{**SE01$!6ZotqOt$7|`Ppmx)kjg> zSX@~60KQ_zP29`NYhhu*7Iz47QU1^Gm0y@ujUpFbV(*~3$}cRO?zG4VjET|v>U8PS zrSqQOU(HnTH;R=b_PNgYT&)$)N_ws>B~@XYq<4hjjNoOKS=K{&?&HmK0o%9U=RHMQ zUgc#&DL$684!Ab}0Y4LYG%LC4ukc_9j!WwED6hS}y<}2z=r1Q^y4v1re-LvXjkIIq z-`7eP>S3P5L^CkFanwy}w)<8cNapHUCh#gSIM~o_g`gV=u?rieDNU1KAzDmxbPPmA zAXGKgc1{{d#?Wwnjb}B=c^63I1_mGj3LKw9sYlg0puL5G(|EJ{9JNf4!Ch6iDA$_? zSL9S@29}sCi46|oMgaxt6&e}REjxA$OgBofNJx~vV`Msd^db0&T+0q zwy;6bss>y5M8HhBK=1q(rShKF40Aho05>nJXU7gys7#CpL{#s*U^)N+Aqm_=?BitID#CNIe)&qwzhU$ z2b36eC;_WZ3OF!)Gm_iJWG-9)Wl@{vqNMQ{)mZ4f`+@DVXCj2G2paTiWZBl*RaifA zRq6(L#7S$K2VEPzzNby`F|F(@4qZM&MMyQL=G0E#T&I{6W@UJ4O?sl#aCzASrH~Q3C1fO?;{@5?z`SOC z3HyE;aXNO!?VPxU3#W|ita0~^JkE$L$TuJnUc7ig#UvPOaIW;R9!Nvk9xJUaN>W!C zuM5$t(mYq2lfROVC*e>^Rt1Zon4v4~J8#ZEL($LPDogmU<&SjyG`i^Up^f z_~y5oZF0_jp0ILe)hX5WvOHGv@3%SUngj%{e%l#&6FQm9xG?cUYd$JKfK{pRY{WLR zecJ<=)&#Bjw8w__OJ#S7#VM)a0qGl0y6_$ zb`ut)sN--Mq?fve4DCFBNQ%lZ@t|e_tYx3m@`<$tK?^vd0Xbk}5iM~mK1cmy^OtT~ za7Elvrifr+OTC!EeE9IqFn$yJzA{pxw1PVW$Wbb{7dQxc&ZO=MFJU5r9RrpZ*sks+ z6YvlQlCS0D5}dQM`+%!z8WMH3;_E}-Y6KYEjbpvMS*GT#_qP4VojD0m&&qnwsP+*& zVwFCW96pw%NR7bX@4z*hb>(vXD4O1h{cXG!D}WSHGCl>0V=9@0p+w{5@aHPDGZ9tv`D+jsBYZCv&C?b{lUROaKyBa))s zd=kz-CZdM<5?~srLAjWuq@?vvTg|aNjoiL}rOr@MYTNrL@ZYMc`}eEzHEb={UpTAG zgi=Q&tXyt}pv(l-<}O?`9d;5D5(b8bz8(s=YT)<1d-vAPbOL6?Ljn{7RY#Q*1E&*+ zac)Ns;Cg7Ap(8ox_fk@-J^SCTGd6$1WL|}~{n~`+);8hL&XU+GCMMb~4T&;7>2OHh z035al*7S|xx9Swg+8v45M7YY}3V@v{LZmU8Y1$5pZ$QlFiuOFOrB@*fmMT0hAHiWGfmNaehPB23bqdJ?YTlbz zQ%Av1kub|WISDd<8x`l}>W2)Vvwa}6WqSOwZx&QhQMn%wFobG_B=VTux}y)?b}#;L zh8md&o#m5PEfJf)_q<-pjC$o7M7R=P=oGv~%|y`2;KQn`&#Bm1LMAhxttX8EjrYoa zKo)FcLxfF<4B-nY^%94*ML4J9rT#6p_oOaU?^#Hnh&V838P}&o3NvFBw`Zv?MSE>` zei*;17ESh7?Q?zceDNWkpAUq&L547Vm|(Wl5~(Sq*i_t!VUSH-cZ5wyuTGu_Kth*? zwiVo1i!}9dZKp>y%MFiT+H-I%GMYy}*zm#lmhAeRh<4F%gEcQ`i$JAYNau~GnZZul zK5OAjsriZVT>Fo7`oCexHRa?#xu!+zqS0Eco2%9cHmKrvKe-mhjGHLSa*mZ7!X3F9 zwF0Tg7uVU^l4fkm^-nyrWR-@df4{$=W!mwXcC)By3p>YV*WkZ_ z){bkZVp+$wZqM_UG_H&{I!ft|FmSawa+Y;@hmmjMg>n$OUhrX;I6)N2+l|>Tq}B^3=tfz)P>y59rTvb%>KR2l!V-yT)+ zoQm@N+9)I*O%GJ|Xc~{sZ;?F0nx}mIcDByky99MxTkpT4O_6M}hne~JybZ|TSR3xum7$l*Dw8wvMes74r$nvT0>!18NlO$QS zQhaG)Jggx!SHXvKkXU>z8;o>zmg;j|B=P8ua#?zFNYhT8SX1gjZncuuRL34Y+DGBrl0gZ8NQ)@tnzm0cT&*D4ggRQ83flkaF{WI+WH zo9E@Dan)*hez6xl*)4-wQPIKo!$`7&16SQO8h)sDQW=gFgmy01h00Aixrs#`FxBX4 zt$9Ov;qO;+STv3MxrOAC>OPlk=j!NO8={m#KB{{93okwYS);1gj5xyu6(n+v{{7Cy z^SvwW=aeY2y2NwWyw6HeBeTL5%A6W|BzVOiTXbk)+_+Z6(&%*F|NY*qf9;l#&Tfr> zk8j)7?>p<46UXrv12`GeNhKfHLQ|z2(N!*+CB3%DTQxNwmtQ#gDb7WUi=aj=Q|z~9**LoNOJ5`IzeOG z^;Z&R)%b)2wHa+7SO&{lm%^%>Vd5-%nw2Pi&f7nGaW^Y;Jx5f1a7b(F#l5?CzX9nX zV%tXo-sUin^tW%{4tM>4EC>P%`IdSoz1TSymt$YPe4!Juf7imy!Qrc0>^Ax3?m;S7 zKbV>veoYw>o&ZGO1=ud`J$P^hz(Cy-$iqZ0tjk zBUNII@?5jfIei3aG=i}9PS)`^Dk>^ozjdp=kbq;P&&ZXzyh0P2zDz?uk^RS!5Yp_ZQsDe!xI2} z&^$CUa)&2gJr$0E&rKGz`~Mp_(4H-1AK9?;x^f}gM$rSqj!d}{ZXJqqA6`cnsu||Y zbPCBU;|~Z%J05ML|)|CCcWJiEqL z#gY$WTWkNOalA+y&AVrPt9K;XGu6}KuDvq&-I0v0G@ewIG3r=ZqRUuCXjGJyQX$PX zy(3UX;0)eHLGcb%mXDS2e}@ezPxN`y)z7P{Qox~G)|5H3@#>k1HS5``0}YQSIz^tp z!RwUxevco0d%B_(8qO@Wad$@wJaD-}VxCiDZnyu3bVwyIN z({3~O(0;}?;B-Nk_knfx(Xp`c9*dg&qNrHtnRl2joF3WLZ{GkuYv^l88zs$&YxxIg2r@e;E_2`SMnmIAG!f zI*}mIN*^IJf}uMD{=qD_@!swurl5?WG;X~7_V!lPtFJD3l+LHa23POlhQs7+$5zG^6Cz3X_Bf3^0@D2uor9*5;^iX zq)x1b-L&6EBG28jY))k!-T0$a3wG^vFM%Ek)=9x zWG%-C^{3MdAGA$ZGicr`vwxWK=cH|Hbdo2RCHq>~?NPn;PeJ$Gt43A8J&wMj!FqNr zMV*;dxf83|94-F0_)5+C=?oX5ZFjawz~3CkER}goMkFR&-nnF_$!0-$&u$Ke{8xYL zM}=1^3~h=glHU`VgaoPTbFoZM8D3+%u_DDtO!O4Lf2@n(29~@%5%hV@p~ilHGfi-d zwzCAcx1ETGd!!fSqScY`O{TB7$vE(8g{uDM8l;&u&~LIdS3hxoq<|!Q?TWk9ug+FV zS|^Nm7^a=wnl^UlZq}`wIQD(3R@Hl>k{P85%n9EDDcp5;_MZEjMyt<}SP4BBP%MI@d!?8#0Evf<*NzDOk^h508%a^7?7gPI@m)Fpqr0Sj2N3C&@ z3JvimGZbs3xKs6rnax>;(5P0Q`?QwWa86&}I^7*UR*9@+RjFk(Y=3d7c(+zSEcfrWr3wp? zE8i&l;+2}a+v0+5avc2oynQJ~g$EB7^ZI=J*tni$^OC~)@#CUcR+1W(8UhwyqQ&QQ z1-AWjo&6?r?e5}aYXD4@Jo!=2L^2tc>EE2~)Txv2_2y^(?ZA+1ORmQ+hDg>H#l0Ha ze~_iGm9tpWQJ;@5b2&c0V6r)BO8$2QUTY^EeEhO*C6oQt37@}kLG5DwqjKRjQ}Odr zWv;H3e0IuuNs%B5Kk3t8JfHQW|3xaUN%fo}kvCb_Z=uFb`|I^flB z!^9KC-BZrtLvhSZqN1{~CIzbE@_PQAjDJ zd^;x_77Y7F71YTJw1!B53n+k>sQ!0o>;LGn`L_M=b&!7lCnlS#U*Fw_vd4q;c}L>+ z!YzdRX%OU(M7;3CZxiK?u+{rm1wDLJJ8Q|c6KHbOm7=TRvxY}ul_}DN2WVDUQ*J}o z>9w1z|2GA)=`zG%f(u4_*G>3hsLcNh2|`K$0(c}02nHnl2}DGy>qtP9zrtajC9Gnw zNPY~_@C09(ZPoP%ILa52L~oP(9?TZIA5hl=sVv+{-&YN%N88vOf?k=ovU@F#u%I{1 zqlzawCZ_TC%2Jq+Rrtsgt@K1tD{!Y!{EB??g!=H|!*%}N-rku|-+lG2VjxX(w9vp| zBl-zl41}Q=e*$G>|7}j$p=X<}?yH`fvygM%ph!XL(gh9a#U3G8J5J*z?b8pRo8_&| z&q>lR?qwZZJj;D)^HaZ=yO%t^bNM~1lkve~z^cH=91GV!m7d8&cR^jWc;Sah(8c|OgT|8<+K1WMQzSE3IFI}r zL{RqpP}sP?J0@aN*$?*%u_`e0TEvb#J?SYSR2<`}zQiF6%e<51iZ?FgGUMIOxPNK3iMsA{;AVq@e6P(Jcevk7#Zq znmt-N0OGS2wBR!~6$qEiGKVAQ=lnd+lnB&$BXhlOK^TwP}3cvPu4*dY!o z=WF0udIkXj!p8z)8=J)Y|A?4@gnN@PEjTGe9y^5=kFXa+)!!bnQ7VgN4?oG4pU%(x z-CH|L66=nzq36TNdlL*#Zy;2E*I4Wa)*ybCj4`Y+r}+0jB)@beqDQ5{7m3?{b% zl)grZN5M#Va&q#!sw$$+D0_JQy8Z?BZfu>*%o^ z)v8)QUX>7f5%4jWEIq(Lb&zr6)2B}fUpnEKC7MQ{x)B~0c&m)v7c~rEVbm$G_ukKM z{P4+>bMW)R^OpFv9Z0TaG-CaU{|BZNG{|nGe)Q{NYOB~>yCE8FVCAbUg(r5(@h$?d zK5Zpq#Z5%UkI1}O0fa39679W!02Y)P#Bv%Lc@Lm?4tG{C_k+-J0LFNxxxwexqv z*kqfWkC;@>NDw!5;YIpbo;)iG597QiZ>)*$t z`>we;OY$ZWht9BZYVQ0~nI0{K(!Ev9#wrWD45rU3SO$rWYV4FS1w$^)l6+LNkH~!+ zIm0?synT(I(dD^QacU4vrKe7vB2qw$$^}BD!jfBnjFi4R4!6W$Z4%z0vOH0xR$xk1 z6Btj(L~!^MF55|{dxXCeWpE*M9}#D?mj@AuM=D)e|jat&#pdv=#T z<>A`sJjizWU~#H{yY3EOOZyo-v2h7~zNi^H|7$WsrJz0Z^q_?g&yKbp5fdwFcV-Vw zsScAMp7{C6i;XlVK3VM?jU6hZyYReID(kX-P;A?nuLmbF?6e>hLR1B`E5?+b(y2UL zeF{S(jUx#u4dbNS&MKey^=+>&&%$VG#-roZcmzp}ry*bS?!I0-?}CB;u^j3wiEeFs&QN-@X3wX>lsP>CD9F`)EX79hMm#I6FJfg z(ef;|1%zxCR5%pATfZY=LNkX8mvNU;|uL+T%Z8*C7A zXadm?X9(7M5Jn2)#hDRGn#k4JbNEj8^m^95{3WW6_!OSIHrT6>cW=NkAX()kdH~fU z<(?4YlRkFPPSlzrHmTo@)nq*zWDk${Bs`w*`H?nf!jr*|egY&Yy=QL-a`CHBLfv@N z2&Dp}F^QlpqY->HL>CD7;mD{cgw8}bPEdzW!bYlfM4A{h>QKI{q2*T2SJO`U$5PBj z`*rW%P%Gm5(bGRRRtv)5nClcHT#+OYSneP`skzt^bC>UDH{WkyJV$Lan>{x#8OOco ztf;2+V#g*@w5TxMzfX)zhisUa&Ltn?XOS{BXr=`PydGvZm`Wp}wDWECy*CgH+`X=E zMkfdTWZ0d26qNlh2CqIOtd&&sZo>?4b4XiSa)Ln~0{ww@Sq%N(_1`!6x8%99x4ABw zq*0cv#;r!@7{HzYjNR|R;y@uH8g0?Abp!1ocXl2y_oF7(5NQHced3kDz>w&>vPREk zl1FPI;|Wa}eVinHeSIgP?V(usD7jzpXImaP) z7Zw&K?_`4gf$%~#B*e$7>RcfP=jcV*rnF@G1R4rJU32BvSAR61n~9Ki&=@>8+om^! z_E7>l0<~zb%|aR15w?o)4u_6>lMAgGXK&auvcXCX7x&)laG(ZbI|Im>p>d*aoan<# zlTzv%Wy7K~o>P5H1iX-VG_;{P;2f+Uq~+-`ke4)>juP|RefT(~ee?E!F#!%o(ba;6 zljP>QA`X7OR5lfPAL%%KBzK)c=X$M4A`Zl|R)&e|?s)S){lkcgOxW-YnNxSv&uRM? z-}=(x>b$dvxd4d$bZ=GstKg-qr=pi4J@)7_wUHTZ-~E&)J%Op=h;(rCa1$*fXbVOc z*u&^(>$wZW_p2+`OjNa+iFm~E?P~zZMtvQI>S|6E6^f0avfc_wsrkdW=`H$-opsUX z>H^ya&MCw1l)jtDA`G$ya>_nWI^xKD=QkegF_fD1XdyH!iK+k9Kq12k{S_Z+P(k9cv6U3afVwB{J0U9ST~U9 zoLVO%iyoKzYUOCTBie-!SLyccdwiTsfq0N^&hYe}4Fdk&3Ky%eXsgDfLF|MP#!Uz- za7?@-w7L=x5o~^AI6Dd0YT5HT6i#~WBf7*u53=Q$e$og~E)EJo4k3;;Ago|7Umk7F z#IJvVSL?Iak`C;wcoAUziLzjJWM2u}W>4F+?o+2uMpf!2qx1t00{b z9k(c!Nc%n}oJWvQ(+R<-o{?97TSwo&$m%zm|8+n)UpIlBWBPkEs+uX-_?mE=&C<7U zb?Kz$OdO0r0_RvQOjD|vfrN=vFz^M5a{u-KUuqcPMQq-v?tsV{!S2VaI-Vi7W%ebA*Wx%%k=Dkn142fKa6phc};78HMKy+dc<0uXgrRe?H*Wm>x z3gMf?IR%9K4N~w{a`GXdy+k2!v;AnL&_{(suuoU7UWNRMCR4#2*S;H@Sy@>x1Md?4 z78rIW%hr)P7>x6w5<_+_P_(EO1HxbqMoLq zbY%%amEZKkA?6Fwa$u~8oHX2Bz<>S`C+YS$R@j#ekZ=h7)e&AG!d-{jEemxG3Fz2k zPNNjC$0VTNCEKzi2q#NiN0VFDm9tpdFGUbh-xeMm+BYl`<pAEI}x4Kf-JWBg}D}&hY^5 zaO@coK!Z47IjIusf2u0K<0sIXo2_M4@cGa82W`pE|AYRmgs8h%BO?GgntvuwTtbLp zeY(P%+KL-)6Fb!g$IQ(l|3x`_8?XEI>(_2tTJ3KRXi4(5v?-D|T-H#gMYAc<_J?Uh zKj0Z0_5p(i{MgB;iUTu_|0fIWkk!|t#r@yQtAUYu`lPEWWLoP1ciWa4O;Ff%@K zY^~S%9?=UupvHUJQY-xtj%wJno%(v5;cU3HapT{z5-Vrn`K+ z`;C02EpxPcr=b4NJ|Rk~j2wcNPSla9f$!{BH#tH$6SI-Z!Ja zzyFqW90MEM`*9a#p~w&IJM zBMusXX00{sgrgaJyN{tcEc5N3IPL!MSjT@Chf0H9ZG)1X8%gMvJIkyZ{p4#lJoUV# zr8YWdW|n5e9k~KCnkic1h$Cqb=Z4_ZCP9X5NXvb<2QE^_k&oVx%b)n8)q}8{;S3YF z8}M&k=Pmt(y7m&N)xbFsIq)Kxz~h@JRH@HAYo+$lJcp)OS1)n?5V5=Ja~!U(Q;6=H zuyB@DqG+aJYqeYCMFt_7kWf_+k;C$FKe}eJ!1MBLCdxPT1iiuO0JV^%i343UH8mg5 zi_=^u&U|2!*vl+;h>tG_#T(J@1crp*!%??v-agArSoDb;NC0PKd!ppSGwr3P-$}hE z3X~DQ{kLl0r>RC1E`K@2v&i8Hj$yN9g}AsyKe58Va1~LufCkuSMZu zIn2lRF+}y(!UEA4t8G(UfsQ|r8wuDK%k!mw5)d=~xyO)fy5D|M-C8ZvWgVwf!j#mA zv=3W^6jc$9#@M5$*Lyfx=`{u}LR!)~4@IbC6+l=G7%$)AXxRjVhDZ^BtfY1!eD?tC z;XDp$M7l{RaAF(~s*DOqMb42nD(RUBVg4QvLR3dfD}|`8qT*#OyIeR$(Kp@g>PURMd_1e|Y{-{MD}MKe#4CY% zPpf)lPRDgx3vGEZ6(S--_@Rjo_08n;0tb$}{5%AV*bgmpBe79nEX@8IOJHvEpufs_ zbKQ7Lx>S~V&gh(pItSQ;lEqODh^D7eq2rJ!ut8yvt(21#4iG;e!Eg~CDeui3#Om0x z3TMmSdt`sp*rBnr_j#Ql8zS5pDJf)yTEGNS)1KFY0BJ%wxG%j5sns3pm0F3b|SX-?xRdf~ua~N7?WbYG`IQRQy-q!v{|j91~;XZ2je}H|3ru_?VL6 z{90Qm42`~J<4_G zd>~8|WNha4Z@p$T53H;ZA>;>>ryWi^0wmIak9r@s&LMC|z?o~MTJ; z1qc=$>sgb7pu}R0gTus`kzlNF${6fuIntb~1Dwclz2(?y(QYogT15~}?l1wfJ+#8s z=tQ~!-82g)Y@nt{`WL$)f7i*)d++`IC+lfs*c`r$CvZ%GW zvUi)GmeW#JN;-b}?Af!wKUEg6NW_vQj#BHx!2`sp8fYfjdL8!sXe=0YQ`^;1z=Eln zyd8dVwP3%EDE*gn)7U2RuxewpNwsR?}`7Q{cLnEceQG|BH~mLF~tT@6ps zsN(JK?(UARpIM)KGDKi79sn5Zdn z>ek=2^;vat;j%sgpCdfBedtF(uPxD!khV&iDXoOQl15(v6&gp1 zW$0LSID`QbIYMlBSAsO1xHsv1VYgIP`ihv{hY$<7XEhwtg`Z3Xh;ssw z`d*@q$Zcj=U-T;a5%@O0belDYzC+j#H}9l-3MhkPO#&k$b*?P^Phwd2c+OCLL>20m zx8S4Ni`|8BIsx$;5fK*VS1yNN16*96Tt ztb@uf93rymI@)g^@#rbwbVihsDq-t2P-_yWA7E%VY}}ZM!#dD1ppj`H75E)lr4H;k zfz^p_ji{@ipCLnVYQEZV^g%4AR`xxR2-8H@Aj)T=mV)h~4%1Fbl$M0(Hq&&r2>o|N zdlQbDigx*WpRbAs%cwgL8MWBjMQ3P`-lfW7p^iTP*4qv+I@7He^)m59qiq?DaRDc4 zFjuvXe*a&qJ&@hDIL;A;BGI!9DT`>6Z903f2@=O)J>mc2FS4OumC(4Fm$>~WMJb)c z_4ohl{HHF~(q>>Kd9)sAx|s5iVu6V*3!zo-+a7CngiFFLZ9e z5T;phE1R-dm%NO?-9TWQ%CPL^0WX03b^tI*J{uw&-j*Bw-Rpg_sEFN|RUx%?4|4~JK#^?W6Yi}7=<@&CB4+KR~LP=>wML@b6 zL_k16S{gw@x;qpE1SF-AknR)^=>{p01_5d5l9GC_N9TX7wf9mD@Kp-ncG_Za$?;0%J84t^dE z47AOS1At4ifdvu@WHyo;$VPSd-Ay7J_MX)j8r~2@Uk^vZ83bU0pc^kUSJ#8WDNCnL z2&CGO=d!V}eIXQq&Kt5-!ntCErhpObM9|V>b@^qCl6XtINh}UP}bDx8S6G5TD&cOmzeH5fi>VFGE5wR&ECRagn8as+U zA2Trv=2_t3v?2e1+zM%d0Qe3U&`tCpV$+3N#^CehzbAz(-6j5S-*OPvwzjpUs!VmI zBZeiw{2`IQW$ob|$gH2k?rm6Vlkq6_KA0Jx(kTGJa3+W|z!N$%GoxpMJh;=)F)Hm9 zl7R|2NEwi703@=}WirSPwcYWL!e*Q$JV`Zq*3RY*wbU1iwF$&s31Y=IDBuGkMkhdP zJdFc#WbUBX-z}hOCYQ6J z+Um;})ZY4uLbXd*2q%h!E@EWf)%sdgDfWpLlYtC#(Q`EV(v6pgVIP4NrlBagcX@o-hcH^ z=0CBhIeRt_TGtX~hHW4AGSlA2-<_0IqJHU zbWAuk7d2R9g$w$xSma?kJhIr0`72^7@)=*2e;L1fSApa5Up3B%gnZ{T?RQNY@qhlm z)trXz1+O|&+5*a^s_n0msgcM+qn&ZD^{%wHdclB=ZHcmygZN*KQwiZ*!|r9@Htz6< z*$fRqwgA0<$_VrFZewn&K zZr18w1JVNn_TaPu*6-`A>!UO)sWJaZq8>S44~;c;(yX{nEneF4kBsb5Hn+v@r^A<~ z<^M^azR#Y@>rGmZZ$vLYWuy4ddPCumPUNXZLR_}8%}-wk!K zy;1Qg;+{>J3n)!FvR5p|M^+>LUZmSrecEE_-4`M+m~gUAEWcRW@R!Ihv9{a$OKnWK z_G$6-{A+nbn?tkT&=xygTgMHk)cZIX7k&jSg9S^1b%@Wi8 zf34&*+SSl0W*|_x+!|rXs2V?!*D7RQrwwQO%J5@$`gI{B0qCWcT+DRZ{~qO`pmzTj zgdvfH3@R)L6Zt>Jc__s61=b}9@?JwO4Ap@DP`^jHL2?S^S|p@NblaBdFGIL{UIzCY1aO{wgs2B(C8&xu8E|K);motL*f!&4&tRYUthmi6rdu@k1v z+zh97UmnfZK_4%#+H&k~68BC{OJ9VC|6i)4P_R9K8z>E0FBv7reshmc%t|(KN#4B|M#fw@IA=5ZjOVdJ+6l=gm73wo&bD92vqVRG;4UF zsGnU7Rg~^~ywY-V2-OTK1$fMb0hO6qgC>Z%xc(&?60fv#iZA_%>uR=_eqnhgR5lL9 zoO~Zoy*v?r^rh87g_C3=SmooyDcYAHnj2+>({?6|mp_K!AJjjh{7InOf6=@J(K&I| zP3T`nc(%Yph01>VAKg;rW{ub%-4dGVzv-4-=idl~*RSdXhlNI6Ha=J%xN)vpE9Oua z&wFgROD7_7X7WX-E#crK^Al92;MolR3BFs-W!fD1HPWLR%cY}oFqZu{y7Oz}WWrGc zhER@y69d`sTfjXqs*~)={{%%67?7)X5C@jSt?h^B!1ZHABU(wv${1}eOisQ1&@1R> zE+7g}PtXF!{Cy6md3dux8(=a1DF+1|jsKw}KRuaK?B1X87?OQ%(|3u9JrGzuHb?Xt zNXh%3=IEn+42UnfL1Ll$VFN0&W>T4|F(;j@{cFA&*F(C&Zp7y<&)7`CUxkfxhVfE6 zT{OLSk+7R@?7(#tF#YN}qc>gCDW|5#L~QqwNMQ;rs7>*gEJMj@ywgcU?aYI|&{# z95hkg&m7^Ct}@)YbzT=_?!acOZMTIGjKR(f^6x7UhK;PitOJB20+HA~mlNo|5{Ucq z`QQ!83CeYz@|1o|DEd0cZnZAgG>a*WM%;Ao>qzsJ@+C@xrxg+$11n}Ex`VpJe?o>8 zp7?OJ8E3*D8*BF~GVJMTljB6>b^D(`vHO}w_0bbBQSAfw^*%q7f3i_jQJ-$uvr2b4 zd+aJeUQD`cFF(@pwCxN{QF$Q$5^QT7E_8%TBQ89p=Q(;K4h z?{>=yQPcXNX7t9af6RM>UCf?33%g0D67oNmLR-7sSS~_5TJLb${9e{@CwSu5rz&>E zAF!lEt(W;{-9OjKTizVj?3 zB`ET4$pWUyPZKmWbWwj|rX2rnNqpz|HgC{lgxFiOABEV7Xx-26C>4zR+BKYzt9E#A z!sYWdntb11d?P$p32Ht(DL}kir_~+Xd3Cm@&>{YqX5&yhoy4aTTtifOmUjiP^ zz0h(nIC4?R=nXFsFi2=THc!NHeu0N~#W94HpAa05+Ef)48%% zJ86fcT3aMWc|iDaLzBMOy);@gKe0C7UG2A>(R=bQXEY~er}>)>_P4lLy$4;Hm2}ev zpNdi*|NcC^HF1WyyTO(aK9R`)NLJ+BkiV&G%_`gvPbf@vl%a|e>_2aW`FKh0o+#KH zwO6LDy*wXpAYCGUs(D3K&fAabJO9M0lA1#jQOn=!9qP%XZ8WFq+TO*GELr;bo@Y4& zQ?O%HGKpaS$wI{3ZMNH$&sBTt_YQsAo^5ysR*b$Fb^omQCvi}24wHEyurqJLwvF|Z z=&4_!gni3kw%#=YmZb3GvPV)(i7E9NQpufFm9Ku7Ki^*pxPd&eN@&IYO9DrO7T;_y z|NNH5#UL>i6tCE_VfL-?vEY3JX`7DV(aetJLOvH_gsyS@lTcT zDfD|c%Ij%KEaaNO?mdS2ly@XY&OsEqRQ!a|;Vp`QsFpUQG)J#Kc=qVxWWmtmCvV|| z2kiPW;m5}`oZ!R9MJ0h29QX}ioQd{LIf zhC(E_#M%ezTTDl3z?gWZp?hev)f~t| zIRAj~ri{6&rPNa86`fOJVg*}ch2fw`{VK)Y*1Dg+%cAHy%se`aUg@4xC9l|ub9H=S z(4ASX>ton#*TN?GCNIIkfcSZoZj=l->_)y1>)d2Mo%ud!2C1aM{K)!&>v)^ajHR${ zOgOpyJ%3@-v52|S@gJL{O&7dy-OS2g{m7p@*(KS^9v#?scs3-qM$g@9@(9x296U9X zDi`UFuGuAK2irYqqFr$KTAMkN4%n30KfMEa__mp2(`uq47iFqQW`*^G$h6nk7?!(r z`clWA+KT9Ak7@ao$FAhXAs(Nc(1AR+z_AX6apktMJh#L{8}7_-_N{%J>9)nH{;c5m z{3iP6jqZ$=>*3XfzN_Jyjc~f?Vx1pp?c<`(hI8ozKYoLAbR>qs}FzLRF?!L~Y11FxXS@?uQml3|aYl0Ismmoiu>i%NXF{D_I)8=w-Ldfr32EuZ(vQJ3Aki z?ixoWYt4j{N4q>a!K)@Kq4ZJb$u2Jc!YVA}eV|esT!n7BzA3ho(-IJey51cwWn{qkMU6@j_{g{Icnrb+b~d0+II%qXx5{4tQZ>L|Y1X zPTCGzIeshJ^w+JdF%h?^lDxI0`R zvWp8Q_K!5(FXqcW(+#B(t7>OdHbwR_vd>YGy$6w)u|>;xOO@yCOn!D37Va4i2tJmN zw!K)Hn0Yxc>54?Um$EA}a&YuKCi0-L=hObZkLK`Rx?xbhX*jj2FI~!9!_U6Wq;leY zcr{B7=_9Ghij)CtjKwn#uXj?haF?>~_u-T{>)uBA4P>LF@y+$(RHo(lK$miZ*`nQ^NTw`0)1$rG;C7l+uc7A z|1~91axf|UeHezuB&$HJsx|c2O*MG@uMQ#iYoswSAoiIj zdqP~9k`=B;^-Uf~GxKre9h{R6Qg5Mc+DIv=!f(U>=4_~QNU zZBX2NuS2#ryRJpwerv?@eA($5o!lQx9Yki-Jf=c#5OGs?=5w3>Q%k4wU)vsw(@<<)Uh|r-RDQHg_6i z;C&&yi#-(`JR1@TbfHwXmA0WHyLcyKetW1_66I8>6KEp8sX2&5N09^-c_-Cv{L@Pc zl^T_jSBzUGUX^vgwL1>MY-b4l9TAW*%uro+%5pa9*=Loz)=@<=FgZ2!OgV1&cSTy1 z24UQv-5o@yhr*ukqEgev&@JI2r(=5&AWztNYdWT9wGVz9bLZ)<9nowjX z%4^B;I~|)H#2Y_!kL`g) z@$}#UErhNu3yMVO)mo2t>84uWFIF^N;tIxKL}yVx`v(LbTUJ)+8h!@& zay}efkEKrS0x9IAm<-Sxpk7~H*Adz-S5+5e#{5i-OHSdRnLVJIp93AD-SU7xsGh^$ zy$dv}GaY8~#ii_?hj9X1&|@xwMybQbWDsbgL5d6Zm~)`L z3xvKukm(k||BehZKniU@(JfV&EX_&`P#-$oXJaz8G(S-vp7nXsWM~X|qS*1E8JVQg zF62}H7TmpSbd3+**-ctR9c=am3#fRj-l%bxF3~FJRfK>J*d1?@aKds2`>WD%XKHv%3Gbqx3j;?Gfon9MfZycVP^>bO!&VLk0)Lo~+LXZ}@r^ z$1^*U5FNN|gzX7fJJOIz8oRn{45#Xt>R|l}c)`(F5W|IMragDTE6GBz8#ZxBZIB2u z(skZQp4E&mnjWx^h`|Ib75%Jb>IfSVnz(e^ynv1X`eh8o<;$1ZZ&QQM0$2{;fH~Fg z4E{8jAos@KADK0HZ~HU&HKB!iEFB6YZ|KR!3IT)`QzUmVWl*K7Ms?Pne%Ymw&GFmD zaLG<=v82b&1rVce(T{$lSeo%hitqZ52)nQR0F@ncBtgMuCpE~B`18s zVoZre`YD|XNBdI`#5kxfk0M6<3Ho(Y^}*Vy{BGcZL*ere54@?jw>su)Z*R!R$jHs4 zcOcPze$?Fjpx=46cxLSTI!IWMseWBa!gVA6OJx%p#DrCs4V9FXe6;1BX5^iC{E7=u zHX2AXRnAhz@+Mtgnxc-xcRl;@I#Mt1p&WyonFR$uC8tCx>eV$zfdYo{2*9LfzU;!KZoI6yq7_cd7* z%09@fDxz_V|ExfsSH4jU_E$iWkg({Yf`cRZ#mnw6nARCB5N?XjT~cWyakY_FJWC7tgxAM|a4I9aSL*5{+seM^j=jp!|L>Psi&FK?{LY7l80BxB3oVgz@sisOV^s;)stCbXx&lAQt}RtaZ4Arq+0bn zM~Iz;(sbPcV^kQ%_$0cB{nR6J(|BTT$+kN?vy*|Sdrmh=s0srpDC(5qeG0Ec?Y(+@ z*(l*ZwYw3kC-{`%&zLgIrp?&@RAlWga0bN%z_CG3uZ0Zil+QKS7Q43^lOI)9!9|P7 zDdRik6CH@|$E3NKIu%&kt=`Z3^k4_F`M^k;Z1qH(7mqq!C^b8Ww*471Z$(2w2R6WW zXFIs4-iiJQS9PUc!KfVi_%p-cH*bDGVf_b`OaBNBVFpiL-aSymVAlgNqzlWG7MNqt zpWg!J7&RAH6et*qEvJNFswM7A>GYYhVeMHMv10&@YH}gB7*JCqy%Z4Rzly&k4Qf!N z`HPIl0uKL39HD&asEeAbsi{(1P>8R?xgUWP)@HnkN?+Jbq99y|+3D=NHlFnolIYK) zhxv(L7Lw)JPhL@@YRd*A=GgmMEU2@PeCk9DssfBF&MXE`R=$MDMYJz)(!uA?LLVH! zJ>Kwa!#Ib3Ym>(Z*&rh!VDJ(fEYg#1JCD_2MhJ*(UQE^2WOu_5SY(DFD25TG`mg1| zKp>%MYF_c22Vtr8?_V!K`vl63g(i;xz4#f?XP#JgE<#ftQQ~L({wkT!B5M9F$v$$_2;x2m#?w7 zq+Fx;3->h+zZI4-_t6$0Wi~h7?e#CN*5(%eWETqrt^B7#?o4ONZ*3L*pckZ9*UBb1 z%-3WBX$)wc-@*Kc4?d;;&^c{-OTv*YF!SALwLbShl@EYIK^XJk*zS$@)-W2$Ku^*> zP|P5xIk*2YtFta~{eD+0 z0sO^VV1X%C{jF~_C+{iiUw6~^b+k11=mO)I>?`f*Bi4zZ6W=Nnw4=?b8q4>w$u0&@ znU)Lp`V5$@>r{J#$n1_^*}}XBV;*Thjc1)ImyD<1m>Wd+8vX*|R-Me`*i#8(esvX> z-8Bx83JoMrn|9W^JxyyDaV>PYb6eE|KM19>iwTW3obC)iEjqlVxKVh#t7!H(Q`O06 zmbG@&9_QsVEwZh3nlf{RdLtzNYmZ=1Sw$O-*{!ZOBKgtvfjC^s6${yBe3Os<%~bk zj6~UGnMr$jJf-zk){shPxBnY7ZU?DJH?NOo3ff`X@E~aTJy(V|z_0x! z2lb3hNL7S`*Lu(?R>t}AOct{(Dj>NtNvG?8vR_Q@AW|Xt$cL26olJmZQqocO@B9Pz7ksRa{q{qIvaEFU85}= z;YkVD*XrVJb*qZbbE6Q9A2yiyI*!DHkR=&phuZJD)Vxbs&&KL`~FnWt^J%; z^`R#ZjNPjpu_k&6m+*}oXJWQ$0;C`&5%5&Sl!+28lni($Ebk~i)F(K9r$Fd|g3B(6 z*0&|Xg=3qy*UvZ~v3jJKiq&49z_#R$B~`yZN&1Zlvn+tIxPvhIc-g+l`sQ61<1-$P z4I{qD#9V`Kadq+pPTT7egEsl8%h2DKr!ls75enw(nC;Z#apa6w=ax{kpjwUesl~72n|B*gSH`Uk|GLsZgIf(zyuNMw-KgGpn%mG@(Or8Ln57I=6j>*p&TU;~ zf48y+WiWnn>dZcIpPPbKHMb2cVUAQw^{hHGl%ho z3ToyWD3o?ARaPJFJVUkan#&H0RQhNPBExp0!oS&GKit(eJl?-ssV-Qf@w>=bFN;o; z_1nKSe(=0o=I*Cn<=pV%THTyy>EXFV=F zB?2Eh)3a&y@+gPhniwmcT&%a6N<>l$56Vink~*&(JMoXpX*Z0DcCY%4ZLR&Z@Ud8X zQxbou%ddXq93CFq);#E!^T7SfzGs3_7q{#@$rQw;Cs z!Xr0@GHab(;^Wv{7*=&p|_x*fgm`ccg7ZV?-P z?4g!<>#UMQp_c@@0sD;o)1|6xkNe{@aZJQygnI+7s43ki9dyrg8M5pPsabE33+o%RP1QKw~~FN@8}LeBZFsHw$-6l27NT$(@oM-&r6zL%QS3c1|=^JPI;*n9ggRG0u`U1Q27EpBuoK$NZu`84+ zBS$l=QuG9e(r<(~zIdbkO4U=~p|oLQz8IU7Y_Gw{qUTlnM~0h9gcY(<9!9C>6~is3 zM2aR-GZbY~dpGxeR-8WbZ1&A2JSQj%c$YMj@`cZaE(lY2;)VLg40IUH;=acB1f^4# zuyr|qqJXd@^!>9exxR5x_b*!*O_TzD%<1xN6aG|fzMHzVqu!bHaOq`Cs(5YaN4@Ae zB!>y1c{4>^5Wk1v!d>ki^qu}r0=DGQ`wm2d~s#d&OvvtY0YV_;KFuOd#kU5 z`NOW!++(-RYlWvyp^x!h6x$}U6|^;=cvZFZDO+T(xCcA-$QY9@v_|VEK~vk${UjaL#J0hi_euL_qDDF%zC4*C5$dfp3~+~2S{ulVTYvl^ zA!}ZRcgnNzH|$yt8`2V#io>T(Q6tY$pSg3J;AP{d8bhCpFhRitck|AE%gXzjaPek2 zzo)F*7P-Xt3>S$v$8yBXgrYF&(_ZDLiIm@(XubdC1;@DXj~hx^Pm)Dd+ofMU-U~}Z z{X(n#^?0VSjJ(IaE{@3D@}Bina_%!v%2cvTU!a?jalA-{lGweH<|s4u1y}uP5KN9N zIPu^Jd$EG^E-o_UokVwpXuk7$VunJorP22Uii6?)1ZSs?Br|?%pLzbCZ9QY#+!}2g z7pAlQ38B-<7~gl^aUql&?M)5Xd3!F(!6|^keKUnke9GolZDWgljgULwJ@Xp-aVh&r z!{#_0`iSBK^dOqovZC??lj@!^b8iLWjy~_SCUZrSW)=xS@lWeZ`yd-UiFu=*S6>3} z)J-4{-~717j6C+PESV>>q2NF;9P3Ws2JI*QqH<*{Rek(^=k=8uP1(-$@^~7Zw?FA! zO9XF`rcESLdxl4Y=g_OE`PYCuG_XmfTQ(EmV%qs*U|L# z??xS~Q@JPnUb1k%>(>k`n(@7V zvq*%Ce=zh7iKG-0)l0AOJp~1R(kOpR@^e9-n_2`n7xnDiGfG%#-e3xZd>B1N<5)4} z=A^4w9ue!z-AS#CD!6^un$|JlY{&4VE68g~<_jJ5`%C*}Y5Gm-OCjNA- z3rZy%0$pX%WoS89moMWL`(H;1^p)MP6vlg<8Ms=iFzCd?Uo@9pLbuhK{ymst* z+u5##%KFOEE*8~;))KcPH5BFAcZjU=TV=iz&o#`8al;v1^<9lWMIDS60Mz|LBlry>+uyKD-kXyR&9Wde+Qi zB}4w=r{S_=&fkw2PTwa4C4c`SntEq*%hCBLSFmVBU}~Mu(m{yzK-mO)NjLT90|V%7 zjC4VZxcanzYPGf)_4KTKH|EG#(oHbIoIEz`V()8?g+g{!r0p@QO8LA$x2E$Wy0wi2 zlYZSq+Z1uX)wireV`I*khSa;EHp@}+@R*of<(H~D6qzOKcsRR`_w^FF$--G+UbUBA z@5v|HM4@~)=*A^qqq!Cdy+QdU0`X^SuVlK$tRf&a16^vD6ZM;Cj@&ugo;QZAU=gbf zxxTD8qgWdIE2&QH6vDvTR$ya6Q=~@obE23zw`5ev&o2|@i4FLI!$qH8$kX6@c2!x+ z%$HxiY+mpoi}RGZ$+*L3!dJZAVzK}C!dx_3^p|oce&uC4bKdLz&t7<`K zHDR8`*k{^ZIQ!!q_UG^9w_iAMf0D9xl80mmB1Vl>=kpIAb{arSMlddVdM`SQ4aQQS zhe=)oHCr*+*>z4gM-wWvNhVMgWvOH9OSIi#KLkWmc&r~RGtFrEPOKlD%O15hKc*mc z*%=BgM|uL4brhM9O4^wy<+hhIIJWKVd4?qSGuWA``}HO*DSS|j*=W!2EJgp7h!JL4 zpE&B=~IBX}bYT8I_U6OVl^UskCCm;(1nc*iY9IK-{`~EK~ z-yTh){>0+6^lHxMt1Fgx7#zvUbu86-y+i!Y_t(Cv2%^D{GNFFOSG*B)ktpRvH3ZZT)V~lXmu(;b)$6a3Z?j~ zfl@2$jUe&p6X2LakkOKb7;h( zCMCgi=XZJ7M%sLFX;nlEmsX6;b~bf7X)9*hS_{_*a(ntYX5}3hQgkX@+VAhlA>n+R zC}5+C4*MpKrK(y?{+<3M!rCBSRlWFj8vjRBdhv3@f=rD4Z_F3 z5KFc2^~A1fXxOkM>XBz8IW)wDM3d zR>N%2j~N>CUmVUcMd zjXgQi*L{wGLHOO(cf%;Z#&BU`V{)gf>cSrThR1%DxrV8rp9rQ__J=xc&z+k4{U@e9 z6z2*ex$-l&KZm_NFhLb!GA`bE;YKJYcHd@4SN#q>h2DcuV#<|`(X_TG!Mo)lQUPzq zTx#vdeoR&z>l0w0DG#q$&Bx(H7Y+^y$0cgqVjol2EExKruIi`meU+HLu)^gl>XH8G zXZvFA-Nu0=?d2hHe{)F>ORn_`^pgAnJvzKf8;M%k@`p!3&*wBh(1}(avguYPEYXwc zVE-U{CjIEUn}uL~+19{zkd&K=v&M(ePN!IF+U8s057z8!mG8W3=$5*~O?%sAYWVk& z_Apu3rB1?poakIr>&nxu02QOK5zWD&#=9{K+?6@`=dJR)lM3^s+YT?c=ZO^zWaV;d zsr7$~PFt?X*nKa!>?(Ehk-xv=-DEqHqa&0Gf0;lSr+)ZNHrAdBK6=8Rql1?OtIAL1 z#&ijuXe*k^s*}1H<0MwT{XP?zZ_#^rIZL!kGh1FuCF$(|GjTAQJ-2(?b1V_T$YkG@)Ht#+YJ8r!@ir&`QlX;lyHr@NWWc2SH zzgDj$hjtcaJ5o9v7hYk&o{jrO-lJsJ^sZ*`NxNSx7Pd?O%OP7oi(7XHjpElwIwowa zOnXIN)?gTB-A<`$y(NB=z~(x~>G~&ydm{Sh(HrwTPR}t2enmGPhj$)qD`f}|RBdYM z!wD>ti|btE+o-o|qJ24m@gw&zqj~Q5s)*+fwN-3(bTe&=5NqG25cdPFw^CRWS96o< zlDA$7x$FhmWK;Y+-!PE+#E~MaKbF$F6@8c~ZYJ-5@S;EW;|>oyb`~qPAH>)ZtHk`< z3y0$9=(PI3`?^W`>xg{j3l2)H@R$f`AFtG%jg-`;jt5!a61tk^+4!*0K>VcqdO6UTJ=s3KMadq`r148n(!=z%6bXijW_E; zZ=!Qrhf|#AR2?Q-^eLuDTrqd8ostC$>L>1DZN;nOE;3cCi;Y3e$+DOvc0t<1L&B3L zT4IO08hNG})$5<~0Xj2RiHZHdq1({Zlqz)sCbbL5ym-JF%BT5q%^EAzTE-{92yKvh4FfbM-B#*1cR5*_4i@^ z0`_-nqc$1S7;+0E`kA z7dMu63bQ*u44NQA*lb6|a0Ooj$u3W;ipYz)etet^VG3zG&l93V3=9yPEyD5w6h-Q% zPbAQ{Z38d_xZ-e7$N)!)U~+#jEn-hj=F17GC@Y@_Ke&>LO5@-zQ&rLm{xz|)%qPG6 za2wxyc+WIv*1z1Ue_8BUu@UUx#&6}uz@kfz$Afh+t@vZkNF?cLJ#)r9Mv5%v&$CbV z=RTaF!`%tjv~^wd~Oi!epQzyKnc zgw<`K2$K%x=AH+Z0LvW)X;GjyJDKZ}KYk7+JkfMz*n^6OE|PwRwhLC6U!A2b7y;Mq!@`=u!w{uT?%}s*;Wxp zWNPIW|Ki1FziLK0#U-2D$#&ICy6YGX>#aMj9bR7#P?jEcd4-lPO#HjN1Z~2ve_S?ap>#Z@+*7$$#b|usX+L~c@w2>Wg0ZZO;=yAD1gr1efGSfrmoHGswceiV zFsz_sWHc}Lc=00J*(`jl&}Q|A z;9?J|(20c+478WE9xo)OzO7xXF2Zinp0ygY&0DhXJK}M#CC?OWn=}ZtBMurAa3sNK&9TS5}rWsFO<2CVY}}*T(qUKldDj>Vcx#Ye&g=y4fcUXnmN{v zHm0e4E&|Dl29{fd9=Nxs?JSs%6VCw zUf%y8qNWxr?yRN;^;5z4iPrx9ov2+h%FZuee$3As$kTFg3}`OShhPM~li$+l{bdYMvb@0mEh%|~Z^$6;qG zbU?p?b#1#Y(1L7yA}!yXdb!MfZ*TA7<;!B)+KH82#vC_+o({IMZVO->8OT$y*e7ug z?C!2_0Y(uqdJHejc3!7k-g(-%Xg97sD(H z-=r1DcW5Xu<&&`LVVv$&)FDJqSjE9`_5gUe@dc|dRD(A1#+<0bIKQh!I{MtcnSyi6 zC7jKA<^h69K~PcBGBPk5Gyo6}v6#pxt-?_5xGp+AK2jLS?gYe|EpShRl^q3^eD2<` zb0|7Gy00c>$dE;PdU_V8P2JV8vQ7a>EEHJJXjoaVLerRCZlPSSCBSvRf$Zr{f^t5u z!_AxF0)6iYP{@J75VX=2AYAmPnE{O$O4U4^a5VB|*G z^~HD?H+NG?FZ@18YQmdu`Mi1a=9a${7Q85R?HvKU7Jlmv9lUl+>X8Dx_Aopq4Zvh! zw^*fa{ExpV_5J)P=_M}|Dk~?)q{4l9*JWT}AUP)oOFF$|Vo#=`)qxmuT#MYy(Hhno zPI7*-;B>}k;WmCbxw&Yl+S=Nil5kYd6X1VgR^8ByOYB#5Qn79}(jXqsQvoovKW9^` zS}le5gP%o)S88f$jTHjJQC~BlEwUjeyrYJUqIS)y}}<#N&^U_4$v#aUJLV z?LTKi?*Dkw9<~EPbt_Z&PVG#&gw;i4) zNCTPZIoow$ncQ}`%e4W3vNqtf_0_mK^6>KB;Nc+w<|PsTy?8?wlF+EA3!a{y@b}XQ zBN=9&jV_;<{Je07?%LKrfejE@nvihi+BGlOyELnva>-p*JeCh1GijR+;uHesHdvFfOJN2-z_ra`|VzJT} zU=4c;kO&ANe88*^NP=g;aerEAlRol64O`pm-c=xwz#WAkW>Eb#uaH$1CdsE|WCW0F zEh3N(aJlJjOxBRFJ%2MYJ@OIauxB6+asuCJF+7hj^+gW`pzHY=lqx7HS_66*#;eM$ zKH-amXiPL;Wfmr#az_VI2~&bGaHQI~^i7x=?q1qM5p#3KfPer{=$SIT>zoIsy|}38 zt2bfxOp%e1C=`$}!)S`#Vcb$ZzMR0bt%S_rQh;!SI0G=L zy|c3k&?jgpU?5Wfh%%&&%2RW$JrC28a(Z%G)zZq+W$l zbN68mBCtT3aoRG>M@!C229r0yBGinN-hTiRj$*@ZQdqra0fN#3f4%ulhPh#Kk`f4K zdVWU87nWC6_5r||=!Ylux-2luh5{f&*Z4Vd8*cer%+T2rKHfx!qhW^IWDO_>Fr(o< z3`qq&Q43CHKLm6aadF#Vk~$g+fGEwt6n+UzayF?Y5i0jzR{`_}Lwvmv0y->pF`TAE zfTSN?H-}{I9B?hreD1$0DParThCiHKTmlG`7XhvL2Eb1+lbH&D`Fei({cwh@wr0ey zkaDzZB6nj_QV3wcVeUB`@Eej3<}l*%K4|wCzrgsLV*3@9u$#-|5c_)r&KjacbQI`b ztoK&cfZxU<^@1i?DhvfUCji~tU}o;sG~^%+i;ngK0LE3CQgbhAYFGggS1V{RvEiOC z7AtXv1+fVpj}QpSW20d#CXC{~1ei;GuWvBRS1(Re+YX z0^=4X01X0D<$eClFz8D~25}|2uJgjHl;@;cA*1pEbLA!2d+tAfo`#@Z;VZsC#c6Sw zjZixnCm{SnADHk|+*IV8Ct1u@=8g7DX@F};!PmBd4 zm9CFN3X$NVJ8U~ z4=)h)1yJ~=0k{T}E%plJe|moRpX0;k5!RiRpHIk-annQ+c9bkJP|IEcx+jqId|~#y zw%e9C!W~4xJkaUok1e9C#x3s{n2Kpwft6&6f0t9LNfHRDC$0H%R2t#jwtqkJ=9<#t%OG@ez zEVBUh^!oYJ7*5yCyLT@FJdU%_6EH?-C?IWlKf2Qdf~&4^vLFeV^IShMK>&S9ms6>khax6fme*^qMaQQoaLR8h}8tU_BnGv^6R%Mrca#Swck2+HYa}D>4N6 zCK?(-%SDFKKa`V`w6>OQZ1VR)|LH zj1gtv=D1#1!n?c5EKfcK~~**l&xF~z+t8r7yV#N8w!Qc ztPpq?{4~@eUcr!u?-LW&C%^6CoVFed!EIVpQBh}bucf60ws0Z?2yGV@)$@srdfqnx zmAGqTerCoCl+|k0!`PCu# zs*(~$YTp}V5-&_rxDR)s_3829on2r#zlA}$5L^=SSkGVp7Orar#<;$w@;tr@@yXBr z%#aqCUhSZ<1r~smqN1-U2!?QB5uP-7@{rvuzB(F&uYu$6tbWMKx(Kl|N_V}FmE=kV zA=OpR5J6ZU+UOY=tXD@^154m`N(F!)MAQ&0+y?n(5bRO95Mw_s-x$g#YYigv15CH> z!RGYl))s4FJOC0WT~-(Y#c=^R*{^{+0J!f!fRx?1eY+*a!W6XAEWnh79h~OQoeOXQ zcs!1IT=v(^3rb;XT=lPSGO%_xfkHI7f5{3+mJ!j>G`zf{nwF3=fk;4Qfefh@)SN_%c-coSAPU?7duc?vtMj=TazE(!&j zVt-%(Q*&|>!sP&|d<7xxRkiR-Ues?WFs|BbgXXdk-W|EnDwy>F*=cDQKqBRVZC)^5 zHYuxgc8T)lHSx!fv6(bKy@aRk_+-N~+Rp3#gP9-5D=n} oe!)Mf0o+snV>0kxW)UaC_a?NfF-r8|_)#b+aoGn2_jSGgKW}2Dr2qf` literal 0 HcmV?d00001 diff --git a/index.rst b/index.rst index f6d6f327fb..3e2d1ee07c 100644 --- a/index.rst +++ b/index.rst @@ -99,6 +99,13 @@ Welcome to PyTorch Tutorials :link: intermediate/pinmem_nonblock.html :tags: Getting-Started +.. customcarditem:: + :header: Visualizing Gradients in PyTorch + :card_description: Visualize the gradient flow of a network. + :image: _static/img/thumbnails/cropped/visualizing_gradients_tutorial.png + :link: intermediate/visualizing_gradients_tutorial.html + :tags: Getting-Started + .. Image/Video .. customcarditem:: @@ -849,4 +856,4 @@ Additional Resources :maxdepth: 1 :hidden: - prototype/prototype_index + prototype/prototype_index \ No newline at end of file diff --git a/intermediate_source/visualizing_gradients_tutorial.py b/intermediate_source/visualizing_gradients_tutorial.py new file mode 100644 index 0000000000..c33f4aa3e7 --- /dev/null +++ b/intermediate_source/visualizing_gradients_tutorial.py @@ -0,0 +1,298 @@ +""" +Visualizing Gradients +===================== + +**Author:** `Justin Silver `__ + +This tutorial explains how to extract and visualize gradients at any +layer in a neural network. By inspecting how information flows from the +end of the network to the parameters we want to optimize, we can debug +issues such as `vanishing or exploding +gradients `__ that occur during +training. + +Before starting, make sure you understand `tensors and how to manipulate +them `__. +A basic knowledge of `how autograd +works `__ +would also be useful. + +""" + + +###################################################################### +# Setup +# ----- +# +# First, make sure `PyTorch is +# installed `__ and then import +# the necessary libraries. +# + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +import matplotlib.pyplot as plt + + +###################################################################### +# Next, we’ll be creating a network intended for the MNIST dataset, +# similar to the architecture described by the `batch normalization +# paper `__. +# +# To illustrate the importance of gradient visualization, we will +# instantiate one version of the network with batch normalization +# (BatchNorm), and one without it. Batch normalization is an extremely +# effective technique to resolve `vanishing/exploding +# gradients `__, and we will be verifying +# that experimentally. +# +# The model we use has a configurable number of repeating fully-connected +# layers which alternate between ``nn.Linear``, ``norm_layer``, and +# ``nn.Sigmoid``. If batch normalization is enabled, then ``norm_layer`` +# will use +# `BatchNorm1d `__, +# otherwise it will use the +# `Identity `__ +# transformation. +# + +def fc_layer(in_size, out_size, norm_layer): + """Return a stack of linear->norm->sigmoid layers""" + return nn.Sequential(nn.Linear(in_size, out_size), norm_layer(out_size), nn.Sigmoid()) + +class Net(nn.Module): + """Define a network that has num_layers of linear->norm->sigmoid transformations""" + def __init__(self, in_size=28*28, hidden_size=128, + out_size=10, num_layers=3, batchnorm=False): + super().__init__() + if batchnorm is False: + norm_layer = nn.Identity + else: + norm_layer = nn.BatchNorm1d + + layers = [] + layers.append(fc_layer(in_size, hidden_size, norm_layer)) + + for i in range(num_layers-1): + layers.append(fc_layer(hidden_size, hidden_size, norm_layer)) + + layers.append(nn.Linear(hidden_size, out_size)) + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + x = torch.flatten(x, 1) + return self.layers(x) + + +###################################################################### +# Next we set up some dummy data, instantiate two versions of the model, +# and initialize the optimizers. +# + +# set up dummy data +x = torch.randn(10, 28, 28) +y = torch.randint(10, (10, )) + +# init model +model_bn = Net(batchnorm=True, num_layers=3) +model_nobn = Net(batchnorm=False, num_layers=3) + +model_bn.train() +model_nobn.train() + +optimizer_bn = optim.SGD(model_bn.parameters(), lr=0.01, momentum=0.9) +optimizer_nobn = optim.SGD(model_nobn.parameters(), lr=0.01, momentum=0.9) + + + +###################################################################### +# We can verify that batch normalization is only being applied to one of +# the models by probing one of the internal layers: +# + +print(model_bn.layers[0]) +print(model_nobn.layers[0]) + + +###################################################################### +# Registering hooks +# ----------------- +# + + +###################################################################### +# Because we wrapped up the logic and state of our model in a +# ``nn.Module``, we need another method to access the intermediate +# gradients if we want to avoid modifying the module code directly. This +# is done by `registering a +# hook `__. +# +# .. warning:: +# +# Using backward pass hooks attached to output tensors is preferred over using ``retain_grad()`` on the tensors themselves. An alternative method is to directly attach module hooks (e.g. ``register_full_backward_hook()``) so long as the ``nn.Module`` instance does not do perform any in-place operations. For more information, please refer to `this issue `__. +# +# The following code defines our hooks and gathers descriptive names for +# the network’s layers. +# + +# note that wrapper functions are used for Python closure +# so that we can pass arguments. + +def hook_forward(module_name, grads, hook_backward): + def hook(module, args, output): + """Forward pass hook which attaches backward pass hooks to intermediate tensors""" + output.register_hook(hook_backward(module_name, grads)) + return hook + +def hook_backward(module_name, grads): + def hook(grad): + """Backward pass hook which appends gradients""" + grads.append((module_name, grad)) + return hook + +def get_all_layers(model, hook_forward, hook_backward): + """Register forward pass hook (which registers a backward hook) to model outputs + + Returns: + - layers: a dict with keys as layer/module and values as layer/module names + e.g. layers[nn.Conv2d] = layer1.0.conv1 + - grads: a list of tuples with module name and tensor output gradient + e.g. grads[0] == (layer1.0.conv1, tensor.Torch(...)) + """ + layers = dict() + grads = [] + for name, layer in model.named_modules(): + # skip Sequential and/or wrapper modules + if any(layer.children()) is False: + layers[layer] = name + layer.register_forward_hook(hook_forward(name, grads, hook_backward)) + return layers, grads + +# register hooks +layers_bn, grads_bn = get_all_layers(model_bn, hook_forward, hook_backward) +layers_nobn, grads_nobn = get_all_layers(model_nobn, hook_forward, hook_backward) + + +###################################################################### +# Training and visualization +# -------------------------- +# +# Let’s now train the models for a few epochs: +# + +epochs = 10 + +for epoch in range(epochs): + + # important to clear, because we append to + # outputs everytime we do a forward pass + grads_bn.clear() + grads_nobn.clear() + + optimizer_bn.zero_grad() + optimizer_nobn.zero_grad() + + y_pred_bn = model_bn(x) + y_pred_nobn = model_nobn(x) + + loss_bn = F.cross_entropy(y_pred_bn, y) + loss_nobn = F.cross_entropy(y_pred_nobn, y) + + loss_bn.backward() + loss_nobn.backward() + + optimizer_bn.step() + optimizer_nobn.step() + + +###################################################################### +# After running the forward and backward pass, the gradients for all the +# intermediate tensors should be present in ``grads_bn`` and +# ``grads_nobn``. We compute the mean absolute value of each gradient +# matrix so that we can compare the two models. +# + +def get_grads(grads): + layer_idx = [] + avg_grads = [] + for idx, (name, grad) in enumerate(grads): + if grad is not None: + avg_grad = grad.abs().mean() + avg_grads.append(avg_grad) + # idx is backwards since we appended in backward pass + layer_idx.append(len(grads) - 1 - idx) + return layer_idx, avg_grads + +layer_idx_bn, avg_grads_bn = get_grads(grads_bn) +layer_idx_nobn, avg_grads_nobn = get_grads(grads_nobn) + + +###################################################################### +# With the average gradients computed, we can now plot them and see how +# the values change as a function of the network depth. Notice that when +# we don’t apply batch normalization, the gradient values in the +# intermediate layers fall to zero very quickly. The batch normalization +# model, however, maintains non-zero gradients in its intermediate layers. +# + +fig, ax = plt.subplots() +ax.plot(layer_idx_bn, avg_grads_bn, label="With BatchNorm", marker="o") +ax.plot(layer_idx_nobn, avg_grads_nobn, label="Without BatchNorm", marker="x") +ax.set_xlabel("Layer depth") +ax.set_ylabel("Average gradient") +ax.set_title("Gradient flow") +ax.grid(True) +ax.legend() +plt.show() + + +###################################################################### +# Conclusion +# ---------- +# +# In this tutorial, we demonstrated how to visualize the gradient flow +# through a neural network wrapped in a ``nn.Module`` class. We +# qualitatively showed how batch normalization helps to alleviate the +# vanishing gradient issue which occurs with deep neural networks. +# +# If you would like to learn more about how PyTorch’s autograd system +# works, please visit the `references <#references>`__ below. If you have +# any feedback for this tutorial (improvements, typo fixes, etc.) then +# please use the `PyTorch Forums `__ and/or +# the `issue tracker `__ to +# reach out. +# + + +###################################################################### +# (Optional) Additional exercises +# ------------------------------- +# +# - Try increasing the number of layers (``num_layers``) in our model and +# see what effect this has on the gradient flow graph +# - How would you adapt the code to visualize average activations instead +# of average gradients? (*Hint: in the hook_forward() function we have +# access to the raw tensor output*) +# - What are some other methods to deal with vanishing and exploding +# gradients? +# + + +###################################################################### +# References +# ---------- +# +# - `A Gentle Introduction to +# torch.autograd `__ +# - `Automatic Differentiation with +# torch.autograd `__ +# - `Autograd +# mechanics `__ +# - `Batch Normalization: Accelerating Deep Network Training by Reducing +# Internal Covariate Shift `__ +# - `On the difficulty of training Recurrent Neural +# Networks `__ +# \ No newline at end of file