Skip to content

Commit 9eb52a2

Browse files
authored
fix: use combined rms norm for accuracy for dipu (#128)
Use combined rms norm for accuracy for dipu.
1 parent 22bffd4 commit 9eb52a2

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

deeplink_ext/interntrain_ops/rms_norm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55
platform_type = deeplink_ext_get_platform_type()
66
if platform_type == PlatformType.TORCH_NPU:
77
# from ._mixed_rms_norm_npu import MixedFusedRMSNorm
8+
# Due to the accuracy problem of the npu fused operator, a torch combination is used as an alternative.
89
from .rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm
910
elif platform_type == PlatformType.TORCH_DIPU:
10-
from ._mixed_rms_norm_dipu import MixedFusedRMSNorm
11+
# from ._mixed_rms_norm_dipu import MixedFusedRMSNorm
12+
# Due to the accuracy problem of the npu fused operator, a torch combination is used as an alternative.
13+
from .rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm
1114
else:
1215
raise ImportError
1316

0 commit comments

Comments
 (0)