Skip to content

Commit 4148f29

Browse files
liji-nvRansiki
authored andcommitted
feat: Add non UB AR + Residual + Norm + Quant fusion (NVIDIA#6320)
Signed-off-by: Jin Li <[email protected]> Signed-off-by: Ransiki Zhang <[email protected]>
1 parent 5cb4098 commit 4148f29

File tree

4 files changed

+644
-537
lines changed

4 files changed

+644
-537
lines changed

tensorrt_llm/_torch/compilation/backend.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@
1313
from tensorrt_llm import logger
1414

1515
from .multi_stream.auto_multi_stream import multi_stream_schedule
16-
from .patterns.ar_residual_norm import register_ar_residual_norm
16+
from .patterns.ar_residual_norm import register_ar_fusions
1717
from .patterns.residual_add_norm import register_add_norm
18-
from .patterns.ub_allreduce import register_ub_patterns
1918
from .piecewise_optimizer import piecewise_optimizer
2019
from .recover_pass import recover_pass
2120
from .remove_copy_pass import remove_copy_for_mutates_args
@@ -76,10 +75,9 @@ def get_custom_pass(cls, enable_userbuffers):
7675
# Currently torch compile cannot work properly with lamport fusion kernel
7776
# TO-DO: Fix this issue
7877
os.environ["DISABLE_LAMPORT_REDUCE_NORM_FUSION"] = "1"
79-
register_ar_residual_norm(cls._custom_pass_instances[0])
80-
if enable_userbuffers and tensorrt_llm.bindings.internal.userbuffers.ub_supported(
81-
):
82-
register_ub_patterns(cls._custom_pass_instances)
78+
ub_enabled = enable_userbuffers and tensorrt_llm.bindings.internal.userbuffers.ub_supported(
79+
)
80+
register_ar_fusions(cls._custom_pass_instances, ub_enabled)
8381
else:
8482
register_add_norm(cls._custom_pass_instances[0])
8583
return cls._custom_pass_instances

0 commit comments

Comments
 (0)