Skip to content

Commit b422fa8

Browse files
committed
refactor python code
refactor python code by ops
1 parent 9eb52a2 commit b422fa8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+200
-1181
lines changed

deeplink_ext/ascend_speed/_flash_attention_dipu.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010

1111
class FlashSelfAttention(torch.autograd.Function):
12-
1312
@staticmethod
1413
def forward(
1514
ctx, q, k, v, attention_mask, dropout_p, softmax_scale, head_num, input_layout

deeplink_ext/ascend_speed/_rms_norm_dipu.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010

1111
class RMSNorm(torch.autograd.Function):
12-
1312
@staticmethod
1413
def forward(ctx, hidden_states, weight, eps):
1514
output = torch.empty_like(hidden_states)

deeplink_ext/ascend_speed/_scaled_masked_softmax_dipu.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212

1313
class ScaledMaskedSoftmax(torch.autograd.Function):
14-
1514
@staticmethod
1615
def forward(ctx, input, mask, scale, fixed_triu_mask):
1716
out = torch.empty_like(input)

deeplink_ext/ascend_speed/_scaled_masked_softmax_npu.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88

99
class ScaledMaskedSoftmax(torch.autograd.Function):
10-
1110
@staticmethod
1211
def forward(ctx, input, mask, scale, fixed_triu_mask):
1312
out = torch_npu.npu_scaled_masked_softmax(input, mask, scale, fixed_triu_mask)

deeplink_ext/easyllm_ops/__init__.py

Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,40 +3,22 @@
33
_not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation."
44

55
try:
6-
from .adamw import AdamW
6+
from deeplink_ext.ops.adamw import AdamW
77
except Exception as e:
88
print(_not_impl.format(op_name="adamw"))
99
from torch.optim import AdamW
1010

11-
try:
12-
from .flash_attention import (
13-
flash_attn_qkvpacked_func,
14-
flash_attn_kvpacked_func,
15-
flash_attn_func,
16-
flash_attn_varlen_qkvpacked_func,
17-
flash_attn_varlen_kvpacked_func,
18-
flash_attn_varlen_func,
19-
)
20-
except Exception as e:
21-
print(_not_impl.format(op_name="flash attention"))
22-
from .flash_attention_fallback import (
23-
flash_attn_qkvpacked_func_torch as flash_attn_qkvpacked_func,
24-
flash_attn_kvpacked_func_torch as flash_attn_kvpacked_func,
25-
flash_attn_func_torch as flash_attn_func,
26-
flash_attn_varlen_qkvpacked_func_torch as flash_attn_varlen_qkvpacked_func,
27-
flash_attn_varlen_kvpacked_func_torch as flash_attn_varlen_kvpacked_func,
28-
flash_attn_varlen_func_torch as flash_attn_varlen_func,
29-
)
30-
31-
try:
32-
from .rms_norm import rms_norm
33-
except:
34-
print(
35-
_not_impl.format(op_name="RMSNorm"),
36-
)
37-
from .rms_norm_fallback import rms_norm_torch as rms_norm
11+
from deeplink_ext.ops.flash_attention import (
12+
flash_attn_qkvpacked_func,
13+
flash_attn_kvpacked_func,
14+
flash_attn_func,
15+
flash_attn_varlen_qkvpacked_func,
16+
flash_attn_varlen_kvpacked_func,
17+
flash_attn_varlen_func,
18+
)
3819

39-
from .bert_padding import pad_input, unpad_input, index_first_axis
20+
from deeplink_ext.ops.rms_norm import rms_norm
21+
from deeplink_ext.ops.bert_padding import pad_input, unpad_input, index_first_axis
4022

4123
__all__ = [
4224
"AdamW",

deeplink_ext/easyllm_ops/adamw.py

Lines changed: 0 additions & 5 deletions
This file was deleted.

deeplink_ext/easyllm_ops/flash_attention.py

Lines changed: 0 additions & 19 deletions
This file was deleted.

deeplink_ext/easyllm_ops/flash_attention_fallback.py

Lines changed: 0 additions & 20 deletions
This file was deleted.

deeplink_ext/internevo_ops/__init__.py

Lines changed: 11 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,23 @@
11
# Copyright (c) 2024, DeepLink.
22

3-
_not_impl = "[deeplink_ext] {op_name} is not implemented in diopi. Falling back to the slower torch implementation."
4-
53
try:
6-
from .adamw import AdamW
4+
from deeplink_ext.ops.adamw import AdamW
75
except Exception as e:
86
print(_not_impl.format(op_name="adamw"))
97
from torch.optim import AdamW
108

11-
try:
12-
from .flash_attention import (
13-
flash_attn_qkvpacked_func,
14-
flash_attn_kvpacked_func,
15-
flash_attn_func,
16-
flash_attn_varlen_qkvpacked_func,
17-
flash_attn_varlen_kvpacked_func,
18-
flash_attn_varlen_func,
19-
)
20-
except Exception as e:
21-
print(_not_impl.format(op_name="flash attention"))
22-
from .flash_attention_fallback import (
23-
flash_attn_qkvpacked_func_torch as flash_attn_qkvpacked_func,
24-
flash_attn_kvpacked_func_torch as flash_attn_kvpacked_func,
25-
flash_attn_func_torch as flash_attn_func,
26-
flash_attn_varlen_qkvpacked_func_torch as flash_attn_varlen_qkvpacked_func,
27-
flash_attn_varlen_kvpacked_func_torch as flash_attn_varlen_kvpacked_func,
28-
flash_attn_varlen_func_torch as flash_attn_varlen_func,
29-
)
9+
from deeplink_ext.ops.flash_attention import (
10+
flash_attn_qkvpacked_func,
11+
flash_attn_kvpacked_func,
12+
flash_attn_func,
13+
flash_attn_varlen_qkvpacked_func,
14+
flash_attn_varlen_kvpacked_func,
15+
flash_attn_varlen_func,
16+
)
3017

31-
try:
32-
from .rms_norm import MixedFusedRMSNorm
33-
except:
34-
print(
35-
_not_impl.format(op_name="RMSNorm"),
36-
)
37-
from .rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm
18+
from deeplink_ext.ops.rms_norm import MixedFusedRMSNorm
3819

39-
try:
40-
from .rotary_embedding import ApplyRotaryEmb
41-
except:
42-
print(_not_impl.format(op_name="rotary embedding"))
43-
from .rotary_embedding_fallback import ApplyRotaryEmbTorch as ApplyRotaryEmb
20+
from deeplink_ext.ops.rotary_embedding import ApplyRotaryEmb
4421

4522
__all__ = [
4623
"AdamW",

0 commit comments

Comments
 (0)