-
Notifications
You must be signed in to change notification settings - Fork 619
[precompile] add ability to precompile torchtitan models #2092
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
for simplefsdp dsv3 we see the time taken to get through the first batch go down from 17.99 => 1.73 seconds. For posterity the command used for testing was ``` TORCH_LOGS="all" NGPU=2 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" cache_tlp ./run_train.sh --model.name simple_fsdp.deepseek_v3 --compile.enable --activation_checkpoint.mode "none" ``` For this to work you'll need to work on a pytorch checkout later than pytorch/pytorch#169242 This currently has only been tested with dsv3 and simplefsdp. Notably the current implementation does not yet support PP. This will be added at a later time. [ghstack-poisoned]
for simplefsdp dsv3 we see the time taken to get through the first batch go down from 17.99 => 1.73 seconds. For posterity the command used for testing was ``` TORCH_LOGS="all" NGPU=2 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" cache_tlp ./run_train.sh --model.name simple_fsdp.deepseek_v3 --compile.enable --activation_checkpoint.mode "none" ``` For this to work you'll need to work on a pytorch checkout later than pytorch/pytorch#169242 This currently has only been tested with dsv3 and simplefsdp. Notably the current implementation does not yet support PP. This will be added at a later time. ghstack-source-id: 757d8b7 Pull Request resolved: #2092
for simplefsdp dsv3 we see the time taken to get through the first batch go down from 17.99 => 1.73 seconds. For posterity the command used for testing was ``` TORCH_LOGS="all" NGPU=2 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" cache_tlp ./run_train.sh --model.name simple_fsdp.deepseek_v3 --compile.enable --activation_checkpoint.mode "none" ``` For this to work you'll need to work on a pytorch checkout later than pytorch/pytorch#169242 This currently has only been tested with dsv3 and simplefsdp. Notably the current implementation does not yet support PP. This will be added at a later time. [ghstack-poisoned]
for simplefsdp dsv3 we see the time taken to get through the first batch go down from 17.99 => 1.73 seconds. For posterity the command used for testing was ``` TORCH_LOGS="all" NGPU=2 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" cache_tlp ./run_train.sh --model.name simple_fsdp.deepseek_v3 --compile.enable --activation_checkpoint.mode "none" ``` For this to work you'll need to work on a pytorch checkout later than pytorch/pytorch#169242 This currently has only been tested with dsv3 and simplefsdp. Notably the current implementation does not yet support PP. This will be added at a later time. ghstack-source-id: 76c04e1 Pull Request resolved: #2092
|
|
||
| self.job_config = job_config | ||
|
|
||
| if job_config.compile.enable_precompilation: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq. Is this for simplefsdp-only or also works for fsdp2+block-level compile?
maybe you want to add this config to apply_compile here for fsdp2:
torchtitan/torchtitan/models/llama3/infra/parallelize.py
Lines 236 to 247 in cbdb311
| def apply_compile(model: nn.Module, compile_config: CompileConfig): | |
| """ | |
| Apply torch.compile to each TransformerBlock, which makes compilation efficient due to | |
| repeated structure. Alternatively one can compile the whole model (after applying DP). | |
| """ | |
| for layer_id, transformer_block in model.layers.named_children(): | |
| transformer_block = torch.compile( | |
| transformer_block, backend=compile_config.backend, fullgraph=True | |
| ) | |
| model.layers.register_module(layer_id, transformer_block) | |
| logger.info("Compiling each TransformerBlock with torch.compile") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently only simplefsdp but this should work with fsdp2+block-level compile with some additional work.
|
|
||
| # Create a unique filename based on model configuration and rank | ||
| filename = f"compiled_fn_{model_name}_{model_flavor}_rank_{rank}.pt" | ||
| return os.path.join("/tmp", filename) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't a realistic file path for training on FB infra, as the tmp is cleared if you restart training
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. For FB infra, we would either package the artifact into the conda or fbpkg build, or place it in oilfs and keep a reference to it. For Torchtitan, using /tmp seemed acceptable, though I can make the location configurable through an environment variable. Did you have a different approach in mind?
| } | ||
| module_cls = type( | ||
| f"SimpleFSDP{module.__class__.__name__}", | ||
| f"SimpleFSDP{module.__class__.__name__}_{_wrap_class_counter}", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
@bobrenjc93 Will this new precompile option also work with the compiler toolkit experiment? |
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry not sure if this is a draft or ready for review, so putting a hold so that it's not accidentally merged as is.
If it's ready for review: the change seems quite intrusive, please consider simplifying or putting it in compiler_toolkit experiment folder.
|
@tianyu-l @aditvenk this PR served as a proof of concept to demonstrate an end-to-end flow where precompilation works with simplefsdp. I'll abandon it for now and shift focus to a more narrowly scoped PR that integrates precompile with the compiler toolkit (which does need some work since |
Stack from ghstack (oldest at bottom):
For context for folks who don't know, precompile is a new technology which allows
us to serialize a torch.compile'd model as a file on disk that we can load in the future
to avoid recompilations. It doesn't help with cold starts but is quite useful for warm
starts and preemptions where the underlying model doesn't change.
for simplefsdp dsv3 we see the time taken to get through the first
batch go down from 17.99 => 1.73 seconds. For posterity the command
used for testing was
For this to work you'll need to work on a pytorch checkout later than
pytorch/pytorch#169242
This currently has only been tested with dsv3 and simplefsdp. Notably
the current implementation does not yet support PP. This will be added
at a later time.