Skip to content

Commit 2c6aac9

Browse files
authored
Update fx_conv_bn_fuser.py ๋ฒˆ์—ญ (#370)
* Update fx_conv_bn_fuser.py
1 parent b0696f4 commit 2c6aac9

File tree

1 file changed

+88
-100
lines changed

1 file changed

+88
-100
lines changed
Lines changed: 88 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,25 @@
11
# -*- coding: utf-8 -*-
22
"""
3-
(beta) Building a Convolution/Batch Norm fuser in FX
4-
*******************************************************
5-
**Author**: `Horace He <https://github.com/chillee>`_
3+
(๋ฒ ํƒ€) FX์—์„œ ํ•ฉ์„ฑ๊ณฑ/๋ฐฐ์น˜ ์ •๊ทœํ™”(Convolution/Batch Norm) ๊ฒฐํ•ฉ๊ธฐ(Fuser) ๋งŒ๋“ค๊ธฐ
4+
****************************************************************************
5+
**์ €์ž**: `Horace He <https://github.com/chillee>`_
66
7-
In this tutorial, we are going to use FX, a toolkit for composable function
8-
transformations of PyTorch, to do the following:
7+
**๋ฒˆ์—ญ:** `์˜ค์ฐฌํฌ <https://github.com/kozeldark>`_
98
10-
1) Find patterns of conv/batch norm in the data dependencies.
11-
2) For the patterns found in 1), fold the batch norm statistics into the convolution weights.
9+
์ด ํŠœํ† ๋ฆฌ์–ผ์—์„œ๋Š” PyTorch์˜ ๊ตฌ์„ฑ ๊ฐ€๋Šฅํ•œ ํ•จ์ˆ˜์˜ ๋ณ€ํ™˜์„ ์œ„ํ•œ ํˆดํ‚ท์ธ FX๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋‹ค์Œ์„ ์ˆ˜ํ–‰ํ•˜๊ณ ์ž ํ•ฉ๋‹ˆ๋‹ค.
1210
13-
Note that this optimization only works for models in inference mode (i.e. `mode.eval()`)
11+
1) ๋ฐ์ดํ„ฐ ์˜์กด์„ฑ์—์„œ ํ•ฉ์„ฑ๊ณฑ/๋ฐฐ์น˜ ์ •๊ทœํ™” ํŒจํ„ด์„ ์ฐพ์Šต๋‹ˆ๋‹ค.
12+
2) 1๋ฒˆ์—์„œ ๋ฐœ๊ฒฌ๋œ ํŒจํ„ด์˜ ๊ฒฝ์šฐ ๋ฐฐ์น˜ ์ •๊ทœํ™” ํ†ต๊ณ„๋ฅผ ํ•ฉ์„ฑ๊ณฑ ๊ฐ€์ค‘์น˜๋กœ ๊ฒฐํ•ฉํ•ฉ๋‹ˆ๋‹ค(folding).
1413
15-
We will be building the fuser that exists here:
14+
์ด ์ตœ์ ํ™”๋Š” ์ถ”๋ก  ๋ชจ๋“œ(์ฆ‰, `mode.eval()`)์˜ ๋ชจ๋ธ์—๋งŒ ์ ์šฉ๋œ๋‹ค๋Š” ์ ์— ์œ ์˜ํ•˜์„ธ์š”.
15+
16+
๋‹ค์Œ ๋งํฌ์— ์žˆ๋Š” ๊ฒฐํ•ฉ๊ธฐ๋ฅผ ๋งŒ๋“ค ๊ฒƒ์ž…๋‹ˆ๋‹ค.
1617
https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/fx/experimental/fuser.py
1718
1819
"""
1920

20-
2121
######################################################################
22-
# First, let's get some imports out of the way (we will be using all
23-
# of these later in the code).
22+
# ๋ช‡ ๊ฐ€์ง€์˜ import ๊ณผ์ •์„ ๋จผ์ € ์ฒ˜๋ฆฌํ•ด์ค์‹œ๋‹ค(๋‚˜์ค‘์— ์ฝ”๋“œ์—์„œ ๋ชจ๋‘ ์‚ฌ์šฉํ•  ๊ฒƒ์ž…๋‹ˆ๋‹ค).
2423

2524
from typing import Type, Dict, Any, Tuple, Iterable
2625
import copy
@@ -29,10 +28,10 @@
2928
import torch.nn as nn
3029

3130
######################################################################
32-
# For this tutorial, we are going to create a model consisting of convolutions
33-
# and batch norms. Note that this model has some tricky components - some of
34-
# the conv/batch norm patterns are hidden within Sequentials and one of the
35-
# BatchNorms is wrapped in another Module.
31+
# ์ด ํŠœํ† ๋ฆฌ์–ผ์—์„œ๋Š” ํ•ฉ์„ฑ๊ณฑ๊ณผ ๋ฐฐ์น˜ ์ •๊ทœํ™”๋กœ ๊ตฌ์„ฑ๋œ ๋ชจ๋ธ์„ ๋งŒ๋“ค ๊ฒƒ์ž…๋‹ˆ๋‹ค.
32+
# ์ด ๋ชจ๋ธ์—๋Š” ์•„๋ž˜์™€ ๊ฐ™์€ ๊นŒ๋‹ค๋กœ์šด ์š”์†Œ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.
33+
# ํ•ฉ์„ฑ๊ณฑ/๋ฐฐ์น˜ ์ •๊ทœํ™” ํŒจํ„ด ์ค‘์˜ ์ผ๋ถ€๋Š” ์‹œํ€€์Šค์— ์ˆจ๊ฒจ์ ธ ์žˆ๊ณ 
34+
# ๋ฐฐ์น˜ ์ •๊ทœํ™” ์ค‘ ํ•˜๋‚˜๋Š” ๋‹ค๋ฅธ ๋ชจ๋“ˆ๋กœ ๊ฐ์‹ธ์ ธ ์žˆ์Šต๋‹ˆ๋‹ค.
3635

3736
class WrappedBatchNorm(nn.Module):
3837
def __init__(self):
@@ -66,42 +65,40 @@ def forward(self, x):
6665
model.eval()
6766

6867
######################################################################
69-
# Fusing Convolution with Batch Norm
70-
# -----------------------------------------
71-
# One of the primary challenges with trying to automatically fuse convolution
72-
# and batch norm in PyTorch is that PyTorch does not provide an easy way of
73-
# accessing the computational graph. FX resolves this problem by symbolically
74-
# tracing the actual operations called, so that we can track the computations
75-
# through the `forward` call, nested within Sequential modules, or wrapped in
76-
# an user-defined module.
68+
# ํ•ฉ์„ฑ๊ณฑ๊ณผ ๋ฐฐ์น˜ ์ •๊ทœํ™” ๊ฒฐํ•ฉํ•˜๊ธฐ
69+
# -----------------------------
70+
# PyTorch์—์„œ ํ•ฉ์„ฑ๊ณฑ๊ณผ ๋ฐฐ์น˜ ์ •๊ทœํ™”๋ฅผ ์ž๋™์œผ๋กœ ๊ฒฐํ•ฉํ•˜๋ ค๊ณ  ํ•  ๋•Œ ๊ฐ€์žฅ ํฐ ์–ด๋ ค์›€ ์ค‘ ํ•˜๋‚˜๋Š”
71+
# PyTorch๊ฐ€ ๊ณ„์‚ฐ ๊ทธ๋ž˜ํ”„์— ์‰ฝ๊ฒŒ ์ ‘๊ทผํ•  ์ˆ˜ ์žˆ๋Š” ๋ฐฉ๋ฒ•์„ ์ œ๊ณตํ•˜์ง€ ์•Š๋Š”๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.
72+
# FX๋Š” ํ˜ธ์ถœ๋œ ์‹ค์ œ ์—ฐ์‚ฐ์„ ๊ธฐํ˜ธ์ (symbolically)์œผ๋กœ ์ถ”์ ํ•˜์—ฌ ์ด ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๋ฏ€๋กœ ์ˆœ์ฐจ์  ๋ชจ๋“ˆ ๋‚ด์— ์ค‘์ฒฉ๋˜๊ฑฐ๋‚˜
73+
# ์‚ฌ์šฉ์ž ์ •์˜ ๋ชจ๋“ˆ๋กœ ๊ฐ์‹ธ์ง„ `forward` ํ˜ธ์ถœ์„ ํ†ตํ•ด ๊ณ„์‚ฐ์„ ์ถ”์ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
7774

7875
traced_model = torch.fx.symbolic_trace(model)
7976
print(traced_model.graph)
8077

8178
######################################################################
82-
# This gives us a graph representation of our model. Note that both the modules
83-
# hidden within the sequential as well as the wrapped Module have been inlined
84-
# into the graph. This is the default level of abstraction, but it can be
85-
# configured by the pass writer. More information can be found at the FX
86-
# overview https://pytorch.org/docs/master/fx.html#module-torch.fx
79+
# ์ด๋ ‡๊ฒŒ ํ•˜๋ฉด ๋ชจ๋ธ์„ ๊ทธ๋ž˜ํ”„๋กœ ๋‚˜ํƒ€๋‚ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
80+
# ์ˆœ์ฐจ์  ๋ชจ๋“ˆ ๋ฐ ๊ฐ์‹ธ์ง„ ๋ชจ๋“ˆ ๋‚ด์— ์ˆจ๊ฒจ์ง„ ๋‘ ๋ชจ๋“ˆ์ด ๋ชจ๋‘ ๊ทธ๋ž˜ํ”„์— ์‚ฝ์ž…๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.
81+
# ์ด๋Š” ๊ธฐ๋ณธ ์ถ”์ƒํ™” ์ˆ˜์ค€์ด์ง€๋งŒ ์ „๋‹ฌ ๊ธฐ๋ก๊ธฐ(pass writer)์—์„œ ๊ตฌ์„ฑํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
82+
# ์ž์„ธํ•œ ๋‚ด์šฉ์€ ๋‹ค์Œ ๋งํฌ์˜ FX ๊ฐœ์š”์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
83+
# https://pytorch.org/docs/master/fx.html#module-torch.fx
8784

8885

8986
####################################
90-
# Fusing Convolution with Batch Norm
91-
# ----------------------------------
92-
# Unlike some other fusions, fusion of convolution with batch norm does not
93-
# require any new operators. Instead, as batch norm during inference
94-
# consists of a pointwise add and multiply, these operations can be "baked"
95-
# into the preceding convolution's weights. This allows us to remove the batch
96-
# norm entirely from our model! Read
97-
# https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ for further details. The
98-
# code here is copied from
87+
# ํ•ฉ์„ฑ๊ณฑ๊ณผ ๋ฐฐ์น˜ ์ •๊ทœํ™” ๊ฒฐํ•ฉํ•˜๊ธฐ
88+
# ---------------------------
89+
# ์ผ๋ถ€ ๋‹ค๋ฅธ ๊ฒฐํ•ฉ๊ณผ ๋‹ฌ๋ฆฌ, ํ•ฉ์„ฑ๊ณฑ๊ณผ ๋ฐฐ์น˜ ์ •๊ทœํ™”์˜ ๊ฒฐํ•ฉ์€ ์ƒˆ๋กœ์šด ์—ฐ์‚ฐ์ž๋ฅผ ํ•„์š”๋กœ ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค.
90+
# ๋Œ€์‹ , ์ถ”๋ก  ์ค‘ ๋ฐฐ์น˜ ์ •๊ทœํ™”๋Š” ์ ๋ณ„ ๋ง์…ˆ๊ณผ ๊ณฑ์…ˆ์œผ๋กœ ๊ตฌ์„ฑ๋˜๋ฏ€๋กœ,
91+
# ์ด๋Ÿฌํ•œ ์—ฐ์‚ฐ์€ ์ด์ „ ํ•ฉ์„ฑ๊ณฑ์˜ ๊ฐ€์ค‘์น˜๋กœ "๋ฏธ๋ฆฌ ๊ณ„์‚ฐ๋˜์–ด ์ €์žฅ(baked)" ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
92+
# ์ด๋ฅผ ํ†ตํ•ด ๋ฐฐ์น˜ ์ •๊ทœํ™”๋ฅผ ๋ชจ๋ธ์—์„œ ์™„์ „ํžˆ ์ œ๊ฑฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค!
93+
# ์ž์„ธํ•œ ๋‚ด์šฉ์€ ๋‹ค์Œ ๋งํฌ์—์„œ ํ™•์ธ ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
94+
# https://nenadmarkus.com/p/fusing-batchnorm-and-conv/
95+
# ์ด ์ฝ”๋“œ๋Š” ๋ช…ํ™•์„ฑ์„ ์œ„ํ•ด ๋‹ค์Œ ๋งํฌ์—์„œ ๋ณต์‚ฌํ•œ ๊ฒƒ์ž…๋‹ˆ๋‹ค.
9996
# https://github.com/pytorch/pytorch/blob/orig/release/1.8/torch/nn/utils/fusion.py
100-
# clarity purposes.
97+
10198
def fuse_conv_bn_eval(conv, bn):
10299
"""
103-
Given a conv Module `A` and an batch_norm module `B`, returns a conv
104-
module `C` such that C(x) == B(A(x)) in inference mode.
100+
ํ•ฉ์„ฑ๊ณฑ ๋ชจ๋“ˆ 'A'์™€ ๋ฐฐ์น˜ ์ •๊ทœํ™” ๋ชจ๋“ˆ 'B'๊ฐ€ ์ฃผ์–ด์ง€๋ฉด
101+
C(x) == B(A(x))๋ฅผ ๋งŒ์กฑํ•˜๋Š” ํ•ฉ์„ฑ๊ณฑ ๋ชจ๋“ˆ 'C'๋ฅผ ์ถ”๋ก  ๋ชจ๋“œ๋กœ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
105102
"""
106103
assert(not (conv.training or bn.training)), "Fusion only for eval!"
107104
fused_conv = copy.deepcopy(conv)
@@ -128,17 +125,15 @@ def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):
128125

129126

130127
####################################
131-
# FX Fusion Pass
132-
# ----------------------------------
133-
# Now that we have our computational graph as well as a method for fusing
134-
# convolution and batch norm, all that remains is to iterate over the FX graph
135-
# and apply the desired fusions.
136-
128+
# FX ๊ฒฐํ•ฉ ์ „๋‹ฌ(pass)
129+
# --------------
130+
# ์ด์ œ ํ•ฉ์„ฑ๊ณฑ๊ณผ ๋ฐฐ์น˜ ์ •๊ทœํ™”๋ฅผ ๊ฒฐํ•ฉํ•˜๋Š” ๋ฐฉ๋ฒ•๋ฟ๋งŒ ์•„๋‹ˆ๋ผ ๊ณ„์‚ฐ ๊ทธ๋ž˜ํ”„๋„ ์–ป์—ˆ์œผ๋ฏ€๋กœ
131+
# ๋‚จ์€ ๊ฒƒ์€ FX ๊ทธ๋ž˜ํ”„์— ์ ˆ์ฐจ๋ฅผ ๋ฐ˜๋ณตํ•˜๊ณ  ์›ํ•˜๋Š” ๊ฒฐํ•ฉ์„ ์ ์šฉํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.
137132

138133
def _parent_name(target : str) -> Tuple[str, str]:
139134
"""
140-
Splits a qualname into parent path and last atom.
141-
For example, `foo.bar.baz` -> (`foo.bar`, `baz`)
135+
์ •๊ทœํ™” ๋œ ์ด๋ฆ„(qualname)์„ ๋ถ€๋ชจ๊ฒฝ๋กœ(parent path)์™€ ๋งˆ์ง€๋ง‰ ์š”์†Œ(last atom)๋กœ ๋‚˜๋ˆ ์ค๋‹ˆ๋‹ค.
136+
์˜ˆ๋ฅผ ๋“ค์–ด, `foo.bar.baz` -> (`foo.bar`, `baz`)
142137
"""
143138
*parent, name = target.rsplit('.', 1)
144139
return parent[0] if parent else '', name
@@ -151,62 +146,57 @@ def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torc
151146

152147
def fuse(model: torch.nn.Module) -> torch.nn.Module:
153148
model = copy.deepcopy(model)
154-
# The first step of most FX passes is to symbolically trace our model to
155-
# obtain a `GraphModule`. This is a representation of our original model
156-
# that is functionally identical to our original model, except that we now
157-
# also have a graph representation of our forward pass.
149+
# ๋Œ€๋ถ€๋ถ„์˜ FX ์ „๋‹ฌ์˜ ์ฒซ ๋ฒˆ์งธ ๋‹จ๊ณ„๋Š” `GraphModule` ์„ ์–ป๊ธฐ ์œ„ํ•ด
150+
# ๋ชจ๋ธ์„ ๊ธฐํ˜ธ์ ์œผ๋กœ ์ถ”์ ํ•˜๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค.
151+
# ์ด๊ฒƒ์€ ์›๋ž˜ ๋ชจ๋ธ๊ณผ ๊ธฐ๋Šฅ์ ์œผ๋กœ ๋™์ผํ•œ ์›๋ž˜ ๋ชจ๋ธ์˜ ํ‘œํ˜„์ž…๋‹ˆ๋‹ค.
152+
# ๋‹จ, ์ด์ œ๋Š” ์ˆœ์ „ํŒŒ ๋‹จ๊ณ„(forward pass)์— ๋Œ€ํ•œ ๊ทธ๋ž˜ํ”„ ํ‘œํ˜„๋„ ๊ฐ€์ง€๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค.
158153
fx_model: fx.GraphModule = fx.symbolic_trace(model)
159154
modules = dict(fx_model.named_modules())
160155

161-
# The primary representation for working with FX are the `Graph` and the
162-
# `Node`. Each `GraphModule` has a `Graph` associated with it - this
163-
# `Graph` is also what generates `GraphModule.code`.
164-
# The `Graph` itself is represented as a list of `Node` objects. Thus, to
165-
# iterate through all of the operations in our graph, we iterate over each
166-
# `Node` in our `Graph`.
156+
# FX ์ž‘์—…์„ ์œ„ํ•œ ๊ธฐ๋ณธ ํ‘œํ˜„์€ `๊ทธ๋ž˜ํ”„(Graph)` ์™€ `๋…ธ๋“œ(Node)` ์ž…๋‹ˆ๋‹ค.
157+
# ๊ฐ `GraphModule` ์—๋Š” ์—ฐ๊ด€๋œ `๊ทธ๋ž˜ํ”„` ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.
158+
# ์ด `๊ทธ๋ž˜ํ”„` ๋Š” `GraphModule.code` ๋ฅผ ์ƒ์„ฑํ•˜๋Š” ๊ฒƒ์ด๊ธฐ๋„ ํ•ฉ๋‹ˆ๋‹ค.
159+
# `๊ทธ๋ž˜ํ”„` ์ž์ฒด๋Š” `๋…ธ๋“œ` ๊ฐ์ฒด์˜ ๋ชฉ๋ก์œผ๋กœ ํ‘œ์‹œ๋ฉ๋‹ˆ๋‹ค.
160+
# ๋”ฐ๋ผ์„œ ๊ทธ๋ž˜ํ”„์˜ ๋ชจ๋“  ์ž‘์—…์„ ๋ฐ˜๋ณตํ•˜๊ธฐ ์œ„ํ•ด `๊ทธ๋ž˜ํ”„` ์—์„œ ๊ฐ `๋…ธ๋“œ` ์— ๋Œ€ํ•ด ๋ฐ˜๋ณตํ•ฉ๋‹ˆ๋‹ค.
167161
for node in fx_model.graph.nodes:
168-
# The FX IR contains several types of nodes, which generally represent
169-
# call sites to modules, functions, or methods. The type of node is
170-
# determined by `Node.op`.
171-
if node.op != 'call_module': # If our current node isn't calling a Module then we can ignore it.
162+
# FX IR ์—๋Š” ์ผ๋ฐ˜์ ์œผ๋กœ ๋ชจ๋“ˆ, ํ•จ์ˆ˜ ๋˜๋Š” ๋ฉ”์†Œ๋“œ์— ๋Œ€ํ•œ
163+
# ํ˜ธ์ถœ ์‚ฌ์ดํŠธ๋ฅผ ๋‚˜ํƒ€๋‚ด๋Š” ์—ฌ๋Ÿฌ ์œ ํ˜•์˜ ๋…ธ๋“œ๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค.
164+
# ๋…ธ๋“œ์˜ ์œ ํ˜•์€ `Node.op` ์— ์˜ํ•ด ๊ฒฐ์ •๋ฉ๋‹ˆ๋‹ค.
165+
if node.op != 'call_module': # ํ˜„์žฌ ๋…ธ๋“œ๊ฐ€ ๋ชจ๋“ˆ์„ ํ˜ธ์ถœํ•˜์ง€ ์•Š์œผ๋ฉด ๋ฌด์‹œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
172166
continue
173-
# For call sites, `Node.target` represents the module/function/method
174-
# that's being called. Here, we check `Node.target` to see if it's a
175-
# batch norm module, and then check `Node.args[0].target` to see if the
176-
# input `Node` is a convolution.
167+
# ํ˜ธ์ถœ ์‚ฌ์ดํŠธ์˜ ๊ฒฝ์šฐ, `Node.target` ์€ ํ˜ธ์ถœ๋˜๋Š” ๋ชจ๋“ˆ/ํ•จ์ˆ˜/๋ฐฉ๋ฒ•์„ ๋‚˜ํƒ€๋ƒ…๋‹ˆ๋‹ค.
168+
# ์—ฌ๊ธฐ์„œ๋Š” 'Node.target' ์„ ํ™•์ธํ•˜์—ฌ ๋ฐฐ์น˜ ์ •๊ทœํ™” ๋ชจ๋“ˆ์ธ์ง€ ํ™•์ธํ•œ ๋‹ค์Œ
169+
# `Node.args[0].target` ์„ ํ™•์ธํ•˜์—ฌ ์ž…๋ ฅ `๋…ธ๋“œ` ๊ฐ€ ํ•ฉ์„ฑ๊ณฑ์ธ์ง€ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค.
177170
if type(modules[node.target]) is nn.BatchNorm2d and type(modules[node.args[0].target]) is nn.Conv2d:
178-
if len(node.args[0].users) > 1: # Output of conv is used by other nodes
171+
if len(node.args[0].users) > 1: # ํ•ฉ์„ฑ๊ณฑ ์ถœ๋ ฅ์€ ๋‹ค๋ฅธ ๋…ธ๋“œ์—์„œ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.
179172
continue
180173
conv = modules[node.args[0].target]
181174
bn = modules[node.target]
182175
fused_conv = fuse_conv_bn_eval(conv, bn)
183176
replace_node_module(node.args[0], modules, fused_conv)
184-
# As we've folded the batch nor into the conv, we need to replace all uses
185-
# of the batch norm with the conv.
177+
# ๋ฐฐ์น˜ ์ •๊ทœํ™”๋ฅผ ํ•ฉ์„ฑ๊ณฑ์œผ๋กœ ๊ฒฐํ•ฉํ–ˆ๊ธฐ ๋•Œ๋ฌธ์—
178+
# ๋ฐฐ์น˜ ์ •๊ทœํ™”์˜ ์‚ฌ์šฉ์„ ๋ชจ๋‘ ํ•ฉ์„ฑ๊ณฑ์œผ๋กœ ๊ต์ฒดํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
186179
node.replace_all_uses_with(node.args[0])
187-
# Now that all uses of the batch norm have been replaced, we can
188-
# safely remove the batch norm.
180+
# ๋ฐฐ์น˜ ์ •๊ทœํ™” ์‚ฌ์šฉ์„ ๋ชจ๋‘ ๊ต์ฒดํ–ˆ์œผ๋ฏ€๋กœ
181+
# ์•ˆ์ „ํ•˜๊ฒŒ ๋ฐฐ์น˜ ์ •๊ทœํ™”๋ฅผ ์ œ๊ฑฐํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
189182
fx_model.graph.erase_node(node)
190183
fx_model.graph.lint()
191-
# After we've modified our graph, we need to recompile our graph in order
192-
# to keep the generated code in sync.
184+
# ๊ทธ๋ž˜ํ”„๋ฅผ ์ˆ˜์ •ํ•œ ํ›„์—๋Š” ์ƒ์„ฑ๋œ ์ฝ”๋“œ๋ฅผ ๋™๊ธฐํ™”ํ•˜๊ธฐ ์œ„ํ•ด ๊ทธ๋ž˜ํ”„๋ฅผ ๋‹ค์‹œ ์ปดํŒŒ์ผํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
193185
fx_model.recompile()
194186
return fx_model
195187

196188

197189
######################################################################
198190
# .. note::
199-
# We make some simplifications here for demonstration purposes, such as only
200-
# matching 2D convolutions. View
191+
# ์—ฌ๊ธฐ์„œ๋Š” 2D ํ•ฉ์„ฑ๊ณฑ๋งŒ ์ผ์น˜์‹œํ‚ค๋Š” ๋“ฑ ์‹œ์—ฐ ๋ชฉ์ ์œผ๋กœ ์•ฝ๊ฐ„์˜ ๋‹จ์ˆœํ™”๋ฅผ ํ•˜์˜€์Šต๋‹ˆ๋‹ค.
192+
# ๋” ์œ ์šฉํ•œ ์ „๋‹ฌ์€ ๋‹ค์Œ ๋งํฌ๋ฅผ ์ฐธ์กฐํ•˜์‹ญ์‹œ์˜ค.
201193
# https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/fuser.py
202-
# for a more usable pass.
203194

204195
######################################################################
205-
# Testing out our Fusion Pass
206-
# -----------------------------------------
207-
# We can now run this fusion pass on our initial toy model and verify that our
208-
# results are identical. In addition, we can print out the code for our fused
209-
# model and verify that there are no more batch norms.
196+
# ๊ฒฐํ•ฉ ์ „๋‹ฌ(Fusion pass) ์‹คํ—˜ํ•˜๊ธฐ
197+
# --------------------------------
198+
# ์ด์ œ ์•„์ฃผ ์ž‘์€ ์ดˆ๊ธฐ ๋ชจ๋ธ์— ๋Œ€ํ•ด ์ด ๊ฒฐํ•ฉ ์ „๋‹ฌ์„ ์‹คํ–‰ํ•ด ๊ฒฐ๊ณผ๊ฐ€ ๋™์ผํ•œ์ง€ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
199+
# ๋˜ํ•œ ๊ฒฐํ•ฉ ๋ชจ๋ธ์˜ ์ฝ”๋“œ๋ฅผ ์ถœ๋ ฅํ•˜์—ฌ ๋” ์ด์ƒ ๋ฐฐ์น˜ ์ •๊ทœํ™”๊ฐ€ ์—†๋Š”์ง€ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
210200

211201

212202
fused_model = fuse(model)
@@ -216,10 +206,10 @@ def fuse(model: torch.nn.Module) -> torch.nn.Module:
216206

217207

218208
######################################################################
219-
# Benchmarking our Fusion on ResNet18
220-
# ----------
221-
# We can test our fusion pass on a larger model like ResNet18 and see how much
222-
# this pass improves inference performance.
209+
# ResNet18์—์„œ ๊ฒฐํ•ฉ ๋ฒค์น˜๋งˆํ‚นํ•˜๊ธฐ
210+
# ------------------------------
211+
# ์ด์ œ ResNet18๊ณผ ๊ฐ™์€ ๋Œ€ํ˜• ๋ชจ๋ธ์—์„œ ๊ฒฐํ•ฉ ์ „๋‹ฌ์„ ์‹คํ—˜ํ•˜๊ณ 
212+
# ์ด ์ „๋‹ฌ์ด ์ถ”๋ก  ์„ฑ๋Šฅ์„ ์–ผ๋งˆ๋‚˜ ํ–ฅ์ƒ์‹œํ‚ค๋Š”์ง€ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
223213
import torchvision.models as models
224214
import time
225215

@@ -241,22 +231,20 @@ def benchmark(model, iters=20):
241231
print("Unfused time: ", benchmark(rn18))
242232
print("Fused time: ", benchmark(fused_rn18))
243233
######################################################################
244-
# As we previously saw, the output of our FX transformation is
245-
# (Torchscriptable) PyTorch code, we can easily `jit.script` the output to try
246-
# and increase our performance even more. In this way, our FX model
247-
# transformation composes with Torchscript with no issues.
234+
# ์•ž์„œ ์‚ดํŽด๋ณธ ๋ฐ”์™€ ๊ฐ™์ด, FX ๋ณ€ํ™˜์˜ ์ถœ๋ ฅ์€ (Torchscriptable) PyTorch ์ฝ”๋“œ์ž…๋‹ˆ๋‹ค.
235+
# ๋”ฐ๋ผ์„œ `jit.script` ๋ฅผ ํ†ตํ•ด ์‰ฝ๊ฒŒ ์ถœ๋ ฅํ•˜์—ฌ ์„ฑ๋Šฅ์„ ๋” ๋†’์ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
236+
# ์ด๋Ÿฌํ•œ ๋ฐฉ์‹์œผ๋กœ FX ๋ชจ๋ธ ๋ณ€ํ™˜์€ Torchscript์™€ ์•„๋ฌด๋Ÿฐ ๋ฌธ์ œ ์—†์ด ๊ตฌ์„ฑ๋ฉ๋‹ˆ๋‹ค.
237+
248238
jit_rn18 = torch.jit.script(fused_rn18)
249239
print("jit time: ", benchmark(jit_rn18))
250240

251241

252-
############
253-
# Conclusion
254-
# ----------
255-
# As we can see, using FX we can easily write static graph transformations on
256-
# PyTorch code.
242+
######
243+
# ๊ฒฐ๋ก 
244+
# ---
245+
# FX๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด PyTorch ์ฝ”๋“œ์— ์ •์  ๊ทธ๋ž˜ํ”„ ๋ณ€ํ™˜์„ ์‰ฝ๊ฒŒ ์ž‘์„ฑํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
257246
#
258-
# Since FX is still in beta, we would be happy to hear any
259-
# feedback you have about using it. Please feel free to use the
260-
# PyTorch Forums (https://discuss.pytorch.org/) and the issue tracker
261-
# (https://github.com/pytorch/pytorch/issues) to provide any feedback
262-
# you might have.
247+
# FX๋Š” ์•„์ง ๋ฒ ํƒ€ ๋ฒ„์ „์ด๊ธฐ ๋•Œ๋ฌธ์— FX ์‚ฌ์šฉ์— ๋Œ€ํ•œ ํ”ผ๋“œ๋ฐฑ์„ ๋ณด๋‚ด์ฃผ์‹œ๋ฉด ๊ฐ์‚ฌํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.
248+
# PyTorch ํฌ๋Ÿผ (https://discuss.pytorch.org/)
249+
# ์ด์Šˆ ์ถ”์ ๊ธฐ (https://github.com/pytorch/pytorch/issues)
250+
# ์œ„ ๋‘ ๋งํฌ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ”ผ๋“œ๋ฐฑ์„ ์ œ๊ณตํ•ด์ฃผ์‹œ๋ฉด ๋ฉ๋‹ˆ๋‹ค.

0 commit comments

Comments
ย (0)