5050###################################################################################
5151# Defining the Neural Network
5252# ---------------------------
53- #
53+ #
5454# We will use the same neural network structure as the regional compilation recipe.
5555#
5656# We will use a network, composed of repeated layers. This mimics a
@@ -93,12 +93,12 @@ def forward(self, x):
9393##################################################################################
9494# Compiling the model ahead-of-time
9595# ---------------------------------
96- #
96+ #
9797# Since we're compiling the model ahead-of-time, we need to prepare representative
9898# input examples, that we expect the model to see during actual deployments.
99- #
99+ #
100100# Let's create an instance of ``Model`` and pass it some sample input data.
101- #
101+ #
102102
103103model = Model ().cuda ()
104104input = torch .randn (10 , 10 , device = "cuda" )
@@ -123,7 +123,7 @@ def forward(self, x):
123123######################################################################################
124124# Compiling _regions_ of the model ahead-of-time
125125# ----------------------------------------------
126- #
126+ #
127127# Compiling model regions ahead-of-time, on the other hand, requires a few key changes.
128128#
129129# Since the compute pattern is shared by all the blocks that
@@ -141,13 +141,13 @@ def forward(self, x):
141141
142142###################################################
143143# An exported program (``torch.export.ExportedProgram``) contains the Tensor computation,
144- # a ``state_dict`` containing tensor values of all lifted parameters and buffer alongside
144+ # a ``state_dict`` containing tensor values of all lifted parameters and buffer alongside
145145# other metadata. We specify the ``aot_inductor.package_constants_in_so`` to be ``False`` to
146146# not serialize the model parameters in the generated artifact.
147147#
148148# Now, when loading the compiled binary, we can reuse the existing parameters of
149149# each block. This lets us take advantage of the compiled binary obtained above.
150- #
150+ #
151151
152152for layer in model .layers :
153153 compiled_layer = torch ._inductor .aoti_load_package (path )
@@ -187,17 +187,17 @@ def measure_compile_time(input, regional=False):
187187def aot_compile_load_model (regional = False ) -> torch .nn .Module :
188188 input = torch .randn (10 , 10 , device = "cuda" )
189189 model = Model ().cuda ()
190-
190+
191191 inductor_configs = {}
192192 if regional :
193193 inductor_configs = {"aot_inductor.package_constants_in_so" : False }
194-
194+
195195 # Reset the compiler caches to ensure no reuse between different runs
196196 torch .compiler .reset ()
197197 with torch ._inductor .utils .fresh_inductor_cache ():
198198 path = torch ._inductor .aoti_compile_and_package (
199199 torch .export .export (
200- model .layers [0 ] if regional else model ,
200+ model .layers [0 ] if regional else model ,
201201 args = (input ,)
202202 ),
203203 inductor_configs = inductor_configs ,
@@ -224,16 +224,16 @@ def aot_compile_load_model(regional=False) -> torch.nn.Module:
224224assert regional_compilation_latency < full_model_compilation_latency
225225
226226############################################################################
227- # There may also be layers in a model incompatible with compilation. So,
227+ # There may also be layers in a model incompatible with compilation. So,
228228# full compilation will result in a fragmented computation graph resulting
229229# in potential latency degradation. In these case, regional compilation
230230# can be beneficial.
231- #
231+ #
232232
233233############################################################################
234234# Conclusion
235235# -----------
236236#
237- # This recipe shows how to control the cold start time when compiling your
237+ # This recipe shows how to control the cold start time when compiling your
238238# model ahead-of-time. This becomes effective when your model has repeated
239239# blocks, which is typically seen in large generative models.
0 commit comments