-
Notifications
You must be signed in to change notification settings - Fork 595
[SimpleFSDP] add manual bucketing pass #1881
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
c20775e to
a5c4027
Compare
8fa2426 to
71cb39b
Compare
71cb39b to
27bcc7d
Compare
27bcc7d to
3c46d64
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks nice. Had some comments.
| "1D+aot_eager_autobucketing", | ||
| "1d_aot_eager_autobucketing", | ||
| ), | ||
| # TODO(ruisizhang123): add back after autobucketing pass is mature |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we add a manual bucketing test?
we should also add one in the loss unit test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a few to do items for reordering. I think it'd be better to add the tests after the API is stable?
| """Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing""" | ||
| """Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing """ | ||
|
|
||
| manual_bucketed_modules: list[str] = field(default_factory=list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need to have instructions about this field. E.g. it's not super obvious what this means "tok_embeddings,layers.[0-5],norm+output", as it involves regex I have a guess, but users might not.
btw, are the list separated by ,?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The list is separated by ,; but I didn't do explicit spilting here. essentially, it's similar to filter_fqns here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add fsdp_ prefix? Or do we imagine this field will be use for other use cases, if so what are the use cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm at least for now, it's only for fsdp. I think we can add a fsdp prefix -- if there are new bucketing cases for other parallelisms, we can update the name.
3c46d64 to
d62eb25
Compare
| manual_overlap_bucketing, | ||
| ) | ||
|
|
||
| torch._inductor.config.allow_buffer_reuse = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens by default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In bucketing, we shouldn't allow buffer reuse; otherwise newly created comm copy-in/copy-out buffers will reuse prev buffer, which messed up the copied out data value and made the loss nan.
| class Compile: | ||
| model_backend_override: str | None = None | ||
| """Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing""" | ||
| """Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should make this subclass torchtitan.config.job_config.Compile
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's additional config extended from job_config.Comfile. not sure wdym here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
something like class Compile(torchtitan.config.job_config.Compile)
d62eb25 to
1453136
Compare
| """Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing""" | ||
| """Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing """ | ||
|
|
||
| manual_bucketed_modules: list[str] = field(default_factory=list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add fsdp_ prefix? Or do we imagine this field will be use for other use cases, if so what are the use cases?
| Manual bucket modules based on user specified FQNs | ||
| Abbreviations are supported to make specifying modules easier. | ||
| Currently, the following abbreviations are available: | ||
| (1) layers.[0-2] -> [layers.0], [layers.1], [layers.2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right now user has to know how many layer a particular flavor of model has, when applying manual bucketing. Do you think we can improve the UX by automatically resolving the number of layers?
I even think we shouldn't expose this option in toml. In toml user should just need to specify bucketing_mode = "none", "transformer_block", "auto"
And if it's transformer_block, we explicitly iterate over all the transformerblocks and pass the expanded fqns in manual_overlap_bucketing. That means manual_overlap_bucketing don't need to be smart about abbreviations.
Happy to hear people's thoughts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean we could have another "manual" mode supporting Manual bucket modules if people really want to override, but a good default of transformer block level bucketing should be enabled more easily.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
transformer_block is a good idea!
I think we need to have manual mode to expose override APIs to users tho; otherwise simplefsdp would be the same as fsdp2 lolll.
cc. @ezyang
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to check how 'transformer_block' would be implemented. Does it assume the transformer blocks are organized a certain way for easy discovery, e.g. modulelist/dict? how do we even detect which block is a transformer block (unless i missed that this option would have the user pass a Class name).
I think I agree that in principle there should be a way for users to fully control bucketing, but i'm not sure if it needs to be exposed from torchtitan's job config - it could be more of an example we provide on using simple-fsdp in an advanced way including your own graph-pass, or something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, this block bucketing pass should read in pre-defined block FQN names. However, this can be annotated in model.py or paralelize.py and users don't need to parse it as part of job config.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wconstab
I think config and how this config would be consumed are orthogonal.
A concrete way to do this is having model-specific code to consume this config and call into manual bucketing API, so this transformer block level bucketing is a torchtitan framework option rather than a compiler pass option.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have an updated prototype for it @tianyu-l @wconstab.
We can specify modules to bucket similar to apply_fsdp in FSDP2 parallelize.py. Then, convert these modules to FQNs here. These FQNs are parsed into pytorch manual bucketing & overlapping pass
I think this is a very clean way to get out of box perf.
1453136 to
df7b9cd
Compare
| backend = aot_autograd_backend( | ||
| fw_compiler=aten_manualbucketing_reordering_pass, | ||
| bw_compiler=aten_manualbucketing_reordering_pass, | ||
| keep_inference_input_mutations=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
side note - once @soulitzer finishes adding AC support to the default partitioner (pytorch/pytorch#166610), we'll probably want to use the default partitioner here instead of min cut? (min cut tries to automatically recompute ops that it thinks will be free due to fusions, but without inductor those ops won't end up being free).
a7bb57c to
ec41c3f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so you no longer want the pure manual mode? Fine with me.
| def get_compile_backend(backend_name: str) -> Union[str, callable]: | ||
|
|
||
| def get_compile_backend( | ||
| compile_config: CompileConfig, bucket_module_name: list[list[str] | str] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe rename to fsdp_buckets?
not sure what will happen if it's in DDP / HSDP mode
| compile_config: CompileConfig, bucket_module_name: list[list[str] | str] | |
| compile_config: CompileConfig, fsdp_buckets: list[list[str] | str] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it will only bucket FSDP-related AG/RS in HSDP, and will not touch all-reduce in DDP/HSDP.
| bw_compiler=aten_autobucketing_reordering_pass, | ||
| keep_inference_input_mutations=True, | ||
| ) | ||
| elif backend_name == "aot_eager_blockbucketing": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update the config helper message with this option?
ec41c3f to
26f62a8
Compare
| --compile.model_backend_override "aot_eager_autobucketing" | ||
| ``` | ||
| ```bash | ||
| --compile.backend "aot_eager" --compile.model_backend_override "aot_eager_autobucketing" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need --compile.backend "aot_eager"?
| bw_compiler=aten_autobucketing_reordering_pass, | ||
| keep_inference_input_mutations=True, | ||
| ) | ||
| elif backend_name == "aot_eager_blockbucketing": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"block" is ambiguous, maybe transformer_block_bucketing
| manual_overlap_bucketing, | ||
| ) | ||
|
|
||
| torch._inductor.config.allow_buffer_reuse = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aren't we doing passes in fx graph / aot_eager backend? why it has anything to do with inductor?
In fact, I have this confusion for all other torch._inductor fields.
|
|
||
| @dataclass | ||
| class Compile: | ||
| model_backend_override: str | None = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the way I think of configuring this would be:
- choose backend, say
aot_eager - choose custom passes, say
auto_bucketing/transformer_block_bucketing
It seems to me that you are merging them into backend altogether because that is the interface exposed by torch.compile. Do you think we can separate them in torchtitan? e.g.
get_compile_backend(job_config.compile)is still there- inside it, we use
CompileConfig.compiler_passesorCompileConfig.aot_autograd_passesto specify the custom passes, e.g. bucketing, reshard_after_forward, etc.
My point is we will be having more and more passes, hopefully composable with each other, and we can't afford having one custom backend for each combination, whose amount grows exponentially.
Maybe not urgent.
This PR adds support for aten-level manual bucketing in SimpleFSDP+
aot_eagerbackend. Dependent on PyTorch PRTODO List:
manual_bucketed_modules. It would be very easy to miss some of model modules. (cc. @xmfan @SherlockNoMad )I'll address the TODO items in follow up PRs. Let's start with this simple FSDP+TP+llama3 PR.
aot_eagerbackend)Llama 3-8B
Example SimpleFSDP 1D overlapping trace:
Example SimpleFSDP 2D overlapping trace:

FSDP-only:

FSDP+TP:
