Skip to content

Conversation

@ruisizhang123
Copy link
Contributor

@ruisizhang123 ruisizhang123 commented Oct 15, 2025

This PR adds support for aten-level manual bucketing in SimpleFSDP+aot_eager backend. Dependent on PyTorch PR

TODO List:

  • We should have better way of handling region info other than a list of str FQNs in current manual_bucketed_modules. It would be very easy to miss some of model modules. (cc. @xmfan @SherlockNoMad )
  • Currently, the reordering happens under the hood and overlap with last/next compute. We should allow users to specify which module they want to reorder.
  • Loss difference on multi-node training
  • DSV3 manual bucketing

I'll address the TODO items in follow up PRs. Let's start with this simple FSDP+TP+llama3 PR.

  1. Performance (FSDP2 under eager mode, SimpleFSDP uses aot_eager backend)

Llama 3-8B

  • Performance (All Batch_size = 1). (The slower TPS on Single Node is sort of as expected, since FSDP2 handles copy-in/out in two different streams, whereas SimpleFSDP handles copy-in/out in the same stream)
Node Method Parallelism Memory TPS Trace
1-Node (8H100) SimpleFSDP FSDP=8 40.96GiB(43.12%) 7,227 LINK
1-Node (8H100) FSDP2-eager FSDP=8 47.82GiB(50.35%) 7,380 LINK
8-Node (64H100) SimpleFSDP FSDP=64 29.37GiB 4,984
8-Node (64H100) FSDP2 FSDP=64 31.41GiB 5,097
1-Node (8H100) SimpleFSDP FSDP=4 TP=2 28.28GiB(29.77%) 5,881 LINK
1-Node (8H100) FSDP2 FSDP=4 TP=2 35.33GiB(37.20%) 5,898 LINK
8-Node (64H100) SimpleFSDP FSDP=8 TP=8
8-Node (64H100) FSDP2 FSDP=8 TP=8

Example SimpleFSDP 1D overlapping trace:

Screenshot 2025-10-16 at 10 49 55 AM

Example SimpleFSDP 2D overlapping trace:
Screenshot 2025-10-26 at 6 00 51 PM

  • Bitwise Loss:

FSDP-only:
Screenshot 2025-10-17 at 10 41 56 AM

FSDP+TP:
Screenshot 2025-10-26 at 9 03 58 PM

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 15, 2025
@ruisizhang123 ruisizhang123 marked this pull request as draft October 15, 2025 17:41
@ruisizhang123 ruisizhang123 force-pushed the ruisi/manual_bucket_pass branch 3 times, most recently from c20775e to a5c4027 Compare October 23, 2025 21:27
@ruisizhang123 ruisizhang123 force-pushed the ruisi/manual_bucket_pass branch 4 times, most recently from 8fa2426 to 71cb39b Compare October 27, 2025 04:51
@ruisizhang123 ruisizhang123 force-pushed the ruisi/manual_bucket_pass branch from 71cb39b to 27bcc7d Compare October 28, 2025 07:06
@ruisizhang123 ruisizhang123 marked this pull request as ready for review October 28, 2025 07:06
@ruisizhang123 ruisizhang123 changed the title [WIP][SimpleFSDP] add manual bucketing pass [SimpleFSDP] add manual bucketing pass Oct 28, 2025
@ruisizhang123 ruisizhang123 force-pushed the ruisi/manual_bucket_pass branch from 27bcc7d to 3c46d64 Compare October 28, 2025 17:57
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks nice. Had some comments.

"1D+aot_eager_autobucketing",
"1d_aot_eager_autobucketing",
),
# TODO(ruisizhang123): add back after autobucketing pass is mature
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we add a manual bucketing test?

we should also add one in the loss unit test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a few to do items for reordering. I think it'd be better to add the tests after the API is stable?

"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing"""
"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing """

manual_bucketed_modules: list[str] = field(default_factory=list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to have instructions about this field. E.g. it's not super obvious what this means "tok_embeddings,layers.[0-5],norm+output", as it involves regex I have a guess, but users might not.

btw, are the list separated by ,?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The list is separated by ,; but I didn't do explicit spilting here. essentially, it's similar to filter_fqns here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add fsdp_ prefix? Or do we imagine this field will be use for other use cases, if so what are the use cases?

Copy link
Contributor Author

@ruisizhang123 ruisizhang123 Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm at least for now, it's only for fsdp. I think we can add a fsdp prefix -- if there are new bucketing cases for other parallelisms, we can update the name.

@ruisizhang123 ruisizhang123 force-pushed the ruisi/manual_bucket_pass branch from 3c46d64 to d62eb25 Compare October 30, 2025 04:44
manual_overlap_bucketing,
)

torch._inductor.config.allow_buffer_reuse = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens by default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In bucketing, we shouldn't allow buffer reuse; otherwise newly created comm copy-in/copy-out buffers will reuse prev buffer, which messed up the copied out data value and made the loss nan.

class Compile:
model_backend_override: str | None = None
"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing"""
"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing """
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should make this subclass torchtitan.config.job_config.Compile

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's additional config extended from job_config.Comfile. not sure wdym here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something like class Compile(torchtitan.config.job_config.Compile)

@ruisizhang123 ruisizhang123 force-pushed the ruisi/manual_bucket_pass branch from d62eb25 to 1453136 Compare October 30, 2025 05:21
"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing"""
"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing """

manual_bucketed_modules: list[str] = field(default_factory=list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add fsdp_ prefix? Or do we imagine this field will be use for other use cases, if so what are the use cases?

Manual bucket modules based on user specified FQNs
Abbreviations are supported to make specifying modules easier.
Currently, the following abbreviations are available:
(1) layers.[0-2] -> [layers.0], [layers.1], [layers.2]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now user has to know how many layer a particular flavor of model has, when applying manual bucketing. Do you think we can improve the UX by automatically resolving the number of layers?

I even think we shouldn't expose this option in toml. In toml user should just need to specify bucketing_mode = "none", "transformer_block", "auto"
And if it's transformer_block, we explicitly iterate over all the transformerblocks and pass the expanded fqns in manual_overlap_bucketing. That means manual_overlap_bucketing don't need to be smart about abbreviations.

Happy to hear people's thoughts.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean we could have another "manual" mode supporting Manual bucket modules if people really want to override, but a good default of transformer block level bucketing should be enabled more easily.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

transformer_block is a good idea!

I think we need to have manual mode to expose override APIs to users tho; otherwise simplefsdp would be the same as fsdp2 lolll.

cc. @ezyang

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to check how 'transformer_block' would be implemented. Does it assume the transformer blocks are organized a certain way for easy discovery, e.g. modulelist/dict? how do we even detect which block is a transformer block (unless i missed that this option would have the user pass a Class name).

I think I agree that in principle there should be a way for users to fully control bucketing, but i'm not sure if it needs to be exposed from torchtitan's job config - it could be more of an example we provide on using simple-fsdp in an advanced way including your own graph-pass, or something.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, this block bucketing pass should read in pre-defined block FQN names. However, this can be annotated in model.py or paralelize.py and users don't need to parse it as part of job config.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wconstab
I think config and how this config would be consumed are orthogonal.

A concrete way to do this is having model-specific code to consume this config and call into manual bucketing API, so this transformer block level bucketing is a torchtitan framework option rather than a compiler pass option.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have an updated prototype for it @tianyu-l @wconstab.

We can specify modules to bucket similar to apply_fsdp in FSDP2 parallelize.py. Then, convert these modules to FQNs here. These FQNs are parsed into pytorch manual bucketing & overlapping pass

I think this is a very clean way to get out of box perf.

backend = aot_autograd_backend(
fw_compiler=aten_manualbucketing_reordering_pass,
bw_compiler=aten_manualbucketing_reordering_pass,
keep_inference_input_mutations=True,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

side note - once @soulitzer finishes adding AC support to the default partitioner (pytorch/pytorch#166610), we'll probably want to use the default partitioner here instead of min cut? (min cut tries to automatically recompute ops that it thinks will be free due to fusions, but without inductor those ops won't end up being free).

@ruisizhang123 ruisizhang123 force-pushed the ruisi/manual_bucket_pass branch 2 times, most recently from a7bb57c to ec41c3f Compare November 5, 2025 04:03
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so you no longer want the pure manual mode? Fine with me.

def get_compile_backend(backend_name: str) -> Union[str, callable]:

def get_compile_backend(
compile_config: CompileConfig, bucket_module_name: list[list[str] | str]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe rename to fsdp_buckets?
not sure what will happen if it's in DDP / HSDP mode

Suggested change
compile_config: CompileConfig, bucket_module_name: list[list[str] | str]
compile_config: CompileConfig, fsdp_buckets: list[list[str] | str]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it will only bucket FSDP-related AG/RS in HSDP, and will not touch all-reduce in DDP/HSDP.

bw_compiler=aten_autobucketing_reordering_pass,
keep_inference_input_mutations=True,
)
elif backend_name == "aot_eager_blockbucketing":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update the config helper message with this option?

@ruisizhang123 ruisizhang123 force-pushed the ruisi/manual_bucket_pass branch from ec41c3f to 26f62a8 Compare November 6, 2025 05:49
--compile.model_backend_override "aot_eager_autobucketing"
```
```bash
--compile.backend "aot_eager" --compile.model_backend_override "aot_eager_autobucketing"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need --compile.backend "aot_eager"?

bw_compiler=aten_autobucketing_reordering_pass,
keep_inference_input_mutations=True,
)
elif backend_name == "aot_eager_blockbucketing":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"block" is ambiguous, maybe transformer_block_bucketing

manual_overlap_bucketing,
)

torch._inductor.config.allow_buffer_reuse = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aren't we doing passes in fx graph / aot_eager backend? why it has anything to do with inductor?

In fact, I have this confusion for all other torch._inductor fields.


@dataclass
class Compile:
model_backend_override: str | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the way I think of configuring this would be:

  1. choose backend, say aot_eager
  2. choose custom passes, say auto_bucketing / transformer_block_bucketing

It seems to me that you are merging them into backend altogether because that is the interface exposed by torch.compile. Do you think we can separate them in torchtitan? e.g.

  • get_compile_backend(job_config.compile) is still there
  • inside it, we use CompileConfig.compiler_passes or CompileConfig.aot_autograd_passes to specify the custom passes, e.g. bucketing, reshard_after_forward, etc.

My point is we will be having more and more passes, hopefully composable with each other, and we can't afford having one custom backend for each combination, whose amount grows exponentially.

Maybe not urgent.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants