|
| 1 | + |
| 2 | +""" |
| 3 | +Reducing AoT cold start compilation time with regional compilation |
| 4 | +============================================================================ |
| 5 | +
|
| 6 | +**Author:** `Sayak Paul <https://huggingface.co/sayakpaul>`_, `Charles Bensimon <https://huggingface.co/cbensimon>`_, `Angela Yi <https://github.com/angelayi>`_ |
| 7 | +
|
| 8 | +In the `regional compilation recipe <https://docs.pytorch.org/tutorials/recipes/regional_compilation.html>`__, we showed |
| 9 | +how to reduce cold start compilation times while retaining (almost) full compilation benefits. This was demonstrated for |
| 10 | +just-in-time (JIT) compilation. |
| 11 | +
|
| 12 | +This recipe shows how to apply similar principles when compiling a model ahead-of-time (AoT). If you |
| 13 | +are not familiar with AOTInductor and ``torch.export``, we recommend you to check out `this tutorial <https://docs.pytorch.org/tutorials/recipes/torch_export_aoti_python.html>`__. |
| 14 | +
|
| 15 | +Prerequisites |
| 16 | +---------------- |
| 17 | +
|
| 18 | +* Pytorch 2.6 or later |
| 19 | +* Familiarity with regional compilation |
| 20 | +* Familiarity with AOTInductor and ``torch.export`` |
| 21 | +
|
| 22 | +Setup |
| 23 | +----- |
| 24 | +Before we begin, we need to install ``torch`` if it is not already |
| 25 | +available. |
| 26 | +
|
| 27 | +.. code-block:: sh |
| 28 | +
|
| 29 | + pip install torch |
| 30 | +""" |
| 31 | + |
| 32 | +###################################################################### |
| 33 | +# Steps |
| 34 | +# ----- |
| 35 | +# |
| 36 | +# In this recipe, we will follow the same steps as the regional compilation recipe mentioned above: |
| 37 | +# |
| 38 | +# 1. Import all necessary libraries. |
| 39 | +# 2. Define and initialize a neural network with repeated regions. |
| 40 | +# 3. Measure the compilation time of the full model and the regional compilation with AoT. |
| 41 | +# |
| 42 | +# First, let's import the necessary libraries for loading our data: |
| 43 | +# |
| 44 | + |
| 45 | +import torch |
| 46 | +torch.set_grad_enabled(False) |
| 47 | + |
| 48 | +from time import perf_counter |
| 49 | + |
| 50 | +################################################################################### |
| 51 | +# Defining the Neural Network |
| 52 | +# --------------------------- |
| 53 | +# |
| 54 | +# We will use the same neural network structure as the regional compilation recipe. |
| 55 | +# |
| 56 | +# We will use a network, composed of repeated layers. This mimics a |
| 57 | +# large language model, that typically is composed of many Transformer blocks. In this recipe, |
| 58 | +# we will create a ``Layer`` using the ``nn.Module`` class as a proxy for a repeated region. |
| 59 | +# We will then create a ``Model`` which is composed of 64 instances of this |
| 60 | +# ``Layer`` class. |
| 61 | +# |
| 62 | +class Layer(torch.nn.Module): |
| 63 | + def __init__(self): |
| 64 | + super().__init__() |
| 65 | + self.linear1 = torch.nn.Linear(10, 10) |
| 66 | + self.relu1 = torch.nn.ReLU() |
| 67 | + self.linear2 = torch.nn.Linear(10, 10) |
| 68 | + self.relu2 = torch.nn.ReLU() |
| 69 | + |
| 70 | + def forward(self, x): |
| 71 | + a = self.linear1(x) |
| 72 | + a = self.relu1(a) |
| 73 | + a = torch.sigmoid(a) |
| 74 | + b = self.linear2(a) |
| 75 | + b = self.relu2(b) |
| 76 | + return b |
| 77 | + |
| 78 | + |
| 79 | +class Model(torch.nn.Module): |
| 80 | + def __init__(self): |
| 81 | + super().__init__() |
| 82 | + self.linear = torch.nn.Linear(10, 10) |
| 83 | + self.layers = torch.nn.ModuleList([Layer() for _ in range(64)]) |
| 84 | + |
| 85 | + def forward(self, x): |
| 86 | + # In regional compilation, the self.linear is outside of the scope of ``torch.compile``. |
| 87 | + x = self.linear(x) |
| 88 | + for layer in self.layers: |
| 89 | + x = layer(x) |
| 90 | + return x |
| 91 | + |
| 92 | + |
| 93 | +################################################################################## |
| 94 | +# Compiling the model ahead-of-time |
| 95 | +# --------------------------------- |
| 96 | +# |
| 97 | +# Since we're compiling the model ahead-of-time, we need to prepare representative |
| 98 | +# input examples, that we expect the model to see during actual deployments. |
| 99 | +# |
| 100 | +# Let's create an instance of ``Model`` and pass it some sample input data. |
| 101 | +# |
| 102 | + |
| 103 | +model = Model().cuda() |
| 104 | +input = torch.randn(10, 10, device="cuda") |
| 105 | +output = model(input) |
| 106 | +print(f"{output.shape=}") |
| 107 | + |
| 108 | +############################################################################################### |
| 109 | +# Now, let's compile our model ahead-of-time. We will use ``input`` created above to pass |
| 110 | +# to ``torch.export``. This will yield a ``torch.export.ExportedProgram`` which we can compile. |
| 111 | + |
| 112 | +path = torch._inductor.aoti_compile_and_package( |
| 113 | + torch.export.export(model, args=(input,)) |
| 114 | +) |
| 115 | + |
| 116 | +################################################################# |
| 117 | +# We can load from this ``path`` and use it to perform inference. |
| 118 | + |
| 119 | +compiled_binary = torch._inductor.aoti_load_package(path) |
| 120 | +output_compiled = compiled_binary(input) |
| 121 | +print(f"{output_compiled.shape=}") |
| 122 | + |
| 123 | +###################################################################################### |
| 124 | +# Compiling _regions_ of the model ahead-of-time |
| 125 | +# ---------------------------------------------- |
| 126 | +# |
| 127 | +# Compiling model regions ahead-of-time, on the other hand, requires a few key changes. |
| 128 | +# |
| 129 | +# Since the compute pattern is shared by all the blocks that |
| 130 | +# are repeated in a model (``Layer`` instances in this cases), we can just |
| 131 | +# compile a single block and let the inductor reuse it. |
| 132 | + |
| 133 | +model = Model().cuda() |
| 134 | +path = torch._inductor.aoti_compile_and_package( |
| 135 | + torch.export.export(model.layers[0], args=(input,)), |
| 136 | + inductor_configs={ |
| 137 | + # compile artifact w/o saving params in the artifact |
| 138 | + "aot_inductor.package_constants_in_so": False, |
| 139 | + } |
| 140 | +) |
| 141 | + |
| 142 | +################################################### |
| 143 | +# An exported program (``torch.export.ExportedProgram``) contains the Tensor computation, |
| 144 | +# a ``state_dict`` containing tensor values of all lifted parameters and buffer alongside |
| 145 | +# other metadata. We specify the ``aot_inductor.package_constants_in_so`` to be ``False`` to |
| 146 | +# not serialize the model parameters in the generated artifact. |
| 147 | +# |
| 148 | +# Now, when loading the compiled binary, we can reuse the existing parameters of |
| 149 | +# each block. This lets us take advantage of the compiled binary obtained above. |
| 150 | +# |
| 151 | + |
| 152 | +for layer in model.layers: |
| 153 | + compiled_layer = torch._inductor.aoti_load_package(path) |
| 154 | + compiled_layer.load_constants( |
| 155 | + layer.state_dict(), check_full_update=True, user_managed=True |
| 156 | + ) |
| 157 | + layer.forward = compiled_layer |
| 158 | + |
| 159 | +output_regional_compiled = model(input) |
| 160 | +print(f"{output_regional_compiled.shape=}") |
| 161 | + |
| 162 | +##################################################### |
| 163 | +# Just like JIT regional compilation, compiling regions within a model ahead-of-time |
| 164 | +# leads to significantly reduced cold start times. The actual number will vary from |
| 165 | +# model to model. |
| 166 | +# |
| 167 | +# Even though full model compilation offers the fullest scope of optimizations, |
| 168 | +# for practical purposes and depending on the type of model, we have seen regional |
| 169 | +# compilation (both JiT and AoT) providing similar speed benefits, while drastically |
| 170 | +# reducing the cold start times. |
| 171 | + |
| 172 | +################################################### |
| 173 | +# Measuring compilation time |
| 174 | +# -------------------------- |
| 175 | +# Next, let's measure the compilation time of the full model and the regional compilation. |
| 176 | +# |
| 177 | + |
| 178 | +def measure_compile_time(input, regional=False): |
| 179 | + start = perf_counter() |
| 180 | + model = aot_compile_load_model(regional=regional) |
| 181 | + torch.cuda.synchronize() |
| 182 | + end = perf_counter() |
| 183 | + # make sure the model works. |
| 184 | + _ = model(input) |
| 185 | + return end - start |
| 186 | + |
| 187 | +def aot_compile_load_model(regional=False) -> torch.nn.Module: |
| 188 | + input = torch.randn(10, 10, device="cuda") |
| 189 | + model = Model().cuda() |
| 190 | + |
| 191 | + inductor_configs = {} |
| 192 | + if regional: |
| 193 | + inductor_configs = {"aot_inductor.package_constants_in_so": False} |
| 194 | + |
| 195 | + # Reset the compiler caches to ensure no reuse between different runs |
| 196 | + torch.compiler.reset() |
| 197 | + with torch._inductor.utils.fresh_inductor_cache(): |
| 198 | + path = torch._inductor.aoti_compile_and_package( |
| 199 | + torch.export.export( |
| 200 | + model.layers[0] if regional else model, |
| 201 | + args=(input,) |
| 202 | + ), |
| 203 | + inductor_configs=inductor_configs, |
| 204 | + ) |
| 205 | + |
| 206 | + if regional: |
| 207 | + for layer in model.layers: |
| 208 | + compiled_layer = torch._inductor.aoti_load_package(path) |
| 209 | + compiled_layer.load_constants( |
| 210 | + layer.state_dict(), check_full_update=True, user_managed=True |
| 211 | + ) |
| 212 | + layer.forward = compiled_layer |
| 213 | + else: |
| 214 | + model = torch._inductor.aoti_load_package(path) |
| 215 | + return model |
| 216 | + |
| 217 | +input = torch.randn(10, 10, device="cuda") |
| 218 | +full_model_compilation_latency = measure_compile_time(input, regional=False) |
| 219 | +print(f"Full model compilation time = {full_model_compilation_latency:.2f} seconds") |
| 220 | + |
| 221 | +regional_compilation_latency = measure_compile_time(input, regional=True) |
| 222 | +print(f"Regional compilation time = {regional_compilation_latency:.2f} seconds") |
| 223 | + |
| 224 | +assert regional_compilation_latency < full_model_compilation_latency |
| 225 | + |
| 226 | +############################################################################ |
| 227 | +# There may also be layers in a model incompatible with compilation. So, |
| 228 | +# full compilation will result in a fragmented computation graph resulting |
| 229 | +# in potential latency degradation. In these case, regional compilation |
| 230 | +# can be beneficial. |
| 231 | +# |
| 232 | + |
| 233 | +############################################################################ |
| 234 | +# Conclusion |
| 235 | +# ----------- |
| 236 | +# |
| 237 | +# This recipe shows how to control the cold start time when compiling your |
| 238 | +# model ahead-of-time. This becomes effective when your model has repeated |
| 239 | +# blocks, which is typically seen in large generative models. We used this |
| 240 | +# recipe on various models to speed up real-time performance. Learn more |
| 241 | +# `here <https://huggingface.co/blog/zerogpu-aoti>`__. |
0 commit comments