Skip to content

Commit 523eede

Browse files
zhaoguochun1995yangbofun
authored andcommitted
[diopi]add attention define and impl on ascend (#1228)
1 parent 8ce8a54 commit 523eede

File tree

9 files changed

+1076
-63
lines changed

9 files changed

+1076
-63
lines changed

diopi_test/diopi_stub/codegen/gen.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def get_func_info(content):
5050
paras_can_be_none = []
5151
ins_vector, outs_vector = {}, {}
5252
out_ptr = []
53+
var_len_array_out = {}
5354
type_change = False
5455
row = content.replace('\n', '').replace('(', ',').replace(')', '')
5556
arg_define = row.split(',')
@@ -73,6 +74,16 @@ def get_func_info(content):
7374
out_ptr.append(arg_index)
7475
arg_type = 'PtrWrapper<diopiTensor>'
7576
break
77+
elif next_arg[0] == 'int64_t*':
78+
type_change = True
79+
next_arg_process = '(*static_cast<int64_t*>(' + next_arg[1] + '))'
80+
if arg_type == 'diopiTensorHandle_t*':
81+
outs_vector[arg_index] = next_arg_process
82+
else:
83+
ins_vector[arg_index] = next_arg_process
84+
arg_type = 'py::list&'
85+
var_len_array_out[arg_index] = ({"param": arg, "param_num": next_arg_process})
86+
break
7687
elif next_arg[0] == 'int64_t':
7788
type_change = True
7889
if arg_type == 'diopiTensorHandle_t*':
@@ -98,7 +109,7 @@ def get_func_info(content):
98109
if arg_type in can_be_none:
99110
paras_can_be_none.append(len(args) - 1)
100111
arg_index += 1
101-
return type_change, args, attr_types, paras_can_be_none, ins_vector, outs_vector, out_ptr
112+
return type_change, args, attr_types, paras_can_be_none, ins_vector, outs_vector, out_ptr, var_len_array_out
102113

103114

104115
def get_export(content, ft, exports):
@@ -116,7 +127,7 @@ def get_export(content, ft, exports):
116127
idx2 = row1.find(")")
117128
temp_content += row1.replace(';', '')
118129
idx += 1
119-
type_change, args, attr_types, paras_none, ins_vector, outs_vector, out_ptr = get_func_info(temp_content)
130+
type_change, args, attr_types, paras_none, ins_vector, outs_vector, out_ptr, var_len_array_out = get_func_info(temp_content)
120131
call_args = copy.deepcopy(args)
121132
type_change = True
122133
if type_change:
@@ -142,6 +153,8 @@ def get_export(content, ft, exports):
142153
out_copy += "if ({param}.get() != nullptr && {param}Handle != nullptr)\n \
143154
*{param} = *{param}Handle;\n".format(param=call_args[out])
144155
call_args[out] = '&' + call_args[out] + 'Handle'
156+
for out_array in var_len_array_out.values():
157+
out_copy += OT.var_len_array_out_template.substitute(param=out_array['param'], param_num=out_array['param_num'])
145158
call_func = func_name + '(' + ', '.join(call_args) + ')'
146159
exports.append(ft.substitute(env=dict(func_name=func_name, attrs=', '.join(attrs), convert=convert,
147160
out_copy=out_copy, call_func=call_func)))

diopi_test/diopi_stub/codegen/op_template.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,10 @@ class OpTemplate(object):
4848
for (int i = 0; i < ${param_num}; ++i)
4949
${param}V[i] = ${param}[i].cast<PtrWrapper<diopiTensor>>().get();
5050
auto ${param}DIOPI = ${param}V.data();
51+
""")
52+
53+
var_len_array_out_template = CodeTemplate("""\
54+
for (int i = 0; i < ${param_num}; ++i) {
55+
${param}[i] = ${param}DIOPI[i];
56+
}
5157
""")

diopi_test/python/configs/diopi_configs.py

Lines changed: 174 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8974,6 +8974,170 @@
89748974
),
89758975
),
89768976

8977+
'attention': dict(
8978+
name=['attention'],
8979+
interface=['CustomizedTest'],
8980+
dtype=[np.float16],
8981+
saved_args=dict(out=0),
8982+
atol_half=5e-2,
8983+
rtol_half=5e-2,
8984+
para=dict(
8985+
dropout_p=[0, 0, 0, 0,
8986+
0, 0, 0, 0,
8987+
0, 0, 0, 0,
8988+
0, 0, 0, 0,
8989+
0, 0, 0, 0,
8990+
0, 0, 0, 0],
8991+
is_causal=[True, False, True, False,
8992+
True, False, True, True,
8993+
True, True, True, False,
8994+
False, True, False, True,
8995+
True, False, True, False,
8996+
False, True, False, True,],
8997+
scale=[0.0883, None, 0.125, None,
8998+
0.0883, None, 0.125, 0.0625,
8999+
0.0883, 0.0221, None, 0.0625,
9000+
None, None, None, None,
9001+
None, None, None, None,
9002+
None, None, None, None],
9003+
),
9004+
tensor_para=dict(
9005+
gen_fn='Genfunc.rand',
9006+
args=[
9007+
{
9008+
"ins": ['query'],
9009+
"shape": ((1, 64, 64, 128), (1, 64, 32, 128), (1, 32, 64, 512), (8, 128, 32, 256),
9010+
(2, 64, 128, 128), (4, 16, 256, 128), (6, 32, 32, 128), (8, 8, 256, 64),
9011+
(2, 128, 64, 128), (4, 512, 128, 64), (6, 32, 128, 256), (8, 1024, 8, 64),
9012+
(2, 64, 128, 128), (4, 16, 256, 128), (6, 128, 32, 128), (8, 8, 256, 64),
9013+
(64, 8, 8, 16), (8, 32, 256, 512), (16, 8, 256, 128), (8, 16, 256, 64),
9014+
(1, 64, 64, 128), (1, 256, 16, 128), (1, 64, 32, 128), (1, 16, 8, 64),),
9015+
"requires_grad": [True],
9016+
},
9017+
{
9018+
"ins": ['key'],
9019+
"shape": ((1, 64, 64, 128), (1, 64, 32, 128), (1, 32, 64, 512), (8, 128, 32, 256),
9020+
(2, 64, 128, 128), (4, 16, 256, 128), (6, 32, 32, 128), (8, 8, 256, 64),
9021+
(2, 128, 64, 128), (4, 512, 128, 64), (6, 32, 128, 256), (8, 512, 8, 64),
9022+
(2, 64, 128, 128), (4, 16, 256, 128), (6, 32, 32, 128), (8, 8, 256, 64),
9023+
(64, 8, 8, 16), (8, 32, 256, 512), (16, 8, 256, 128), (8, 16, 256, 64),
9024+
(1, 64, 64, 128), (1, 256, 16, 128), (1, 64, 32, 128), (1, 16, 8, 64),),
9025+
"requires_grad": [True],
9026+
},
9027+
{
9028+
"ins": ['value'],
9029+
"shape": ((1, 64, 64, 128), (1, 64, 32, 128), (1, 32, 64, 512), (8, 128, 32, 256),
9030+
(2, 64, 128, 128), (4, 16, 256, 128), (6, 32, 32, 128), (8, 8, 256, 64),
9031+
(2, 128, 64, 128), (4, 512, 128, 64), (6, 32, 128, 256), (8, 512, 8, 64),
9032+
(2, 64, 128, 128), (4, 16, 256, 128), (6, 32, 32, 128), (8, 8, 256, 64),
9033+
(64, 8, 8, 16), (8, 32, 256, 512), (16, 8, 256, 128), (8, 16, 256, 64),
9034+
(1, 64, 64, 128), (1, 256, 16, 128), (1, 64, 32, 128), (1, 16, 8, 64),),
9035+
"requires_grad": [True],
9036+
},
9037+
{
9038+
"ins": ['attn_bias'],
9039+
"shape": (None, None, None, (8, 32, 128, 128),
9040+
None, None, None, None,
9041+
None, None, None, None,
9042+
None, None, None, None,
9043+
None, None, None, None,
9044+
None, None, None, None,),
9045+
"requires_grad": [False],
9046+
},
9047+
],
9048+
),
9049+
),
9050+
9051+
'attention_varlen': dict(
9052+
name=['attention_varlen'],
9053+
interface=['CustomizedTest'],
9054+
dtype=[np.float16],
9055+
saved_args=dict(out=0),
9056+
atol=1e-3,
9057+
rtol=1e-4,
9058+
para=dict(
9059+
dropout_p=[0, 0, 0, 0,
9060+
0, 0, 0, 0,
9061+
0, 0, 0, 0,
9062+
0, 0, 0, 0,
9063+
0, 0, 0, 0],
9064+
is_causal=[False, True, False, True,
9065+
True, False, True, False,
9066+
True, False, True, True,
9067+
True, True, False, True,
9068+
False, True, False, False],
9069+
scale=[None, 0.0883, None, 0.125,
9070+
None, None, None, None,
9071+
None, None, None, None,
9072+
None, None, None, None,
9073+
None, None, None, None],
9074+
max_seqlen_q=[32, 32, 128, 64,
9075+
32, 32, 128, 64,
9076+
384, 384, 64, 53,
9077+
400, 200, 64, 131,
9078+
1024, 1024, 256, 72],
9079+
max_seqlen_kv=[32, 32, 128, 64,
9080+
32, 32, 128, 64,
9081+
384, 384, 64, 53,
9082+
400, 200, 64, 131,
9083+
1024, 1024, 256, 72],
9084+
),
9085+
tensor_para=dict(
9086+
gen_fn='Genfunc.randn',
9087+
args=[
9088+
{
9089+
"ins": ['query'],
9090+
"shape": ((32, 32, 128), (64, 64, 128), (256, 16, 128), (128, 8, 64),
9091+
(32, 32, 128), (64, 64, 128), (256, 16, 128), (128, 8, 64),
9092+
(1098, 64, 256), (128, 64, 128), (128, 16, 128), (128, 8, 32),
9093+
(2048, 32, 128), (2048, 32, 8), (256, 256, 128), (512, 256, 128),
9094+
(4096, 128, 64), (4096, 128, 64), (512, 128, 8), (256, 128, 128),),
9095+
"requires_grad": [True],
9096+
},
9097+
{
9098+
"ins": ['key'],
9099+
"shape": ((32, 32, 128), (64, 64, 128), (256, 16, 128), (128, 8, 64),
9100+
(32, 32, 128), (64, 64, 128), (256, 16, 128), (128, 8, 64),
9101+
(1098, 64, 256), (128, 64, 128), (128, 16, 128), (128, 8, 32),
9102+
(2048, 32, 128), (2048, 32, 8), (256, 256, 128), (512, 256, 128),
9103+
(4096, 128, 64), (4096, 128, 64), (512, 128, 8), (256, 128, 128),),
9104+
"requires_grad": [True],
9105+
},
9106+
{
9107+
"ins": ['value'],
9108+
"shape": ((32, 32, 128), (64, 64, 128), (256, 16, 128), (128, 8, 64),
9109+
(32, 32, 128), (64, 64, 128), (256, 16, 128), (128, 8, 64),
9110+
(1098, 64, 256), (128, 64, 128), (128, 16, 128), (128, 8, 32),
9111+
(2048, 32, 128), (2048, 32, 8), (256, 256, 128), (512, 256, 128),
9112+
(4096, 128, 64), (4096, 128, 64), (512, 128, 8), (256, 128, 128),),
9113+
"requires_grad": [True],
9114+
},
9115+
{
9116+
"ins": ['cu_seqlens_q'],
9117+
"value": ([0, 32], [0, 16, 48, 64], [0, 32, 64, 128, 256], [0, 16, 48, 64, 128],
9118+
[0, 32], [0, 16, 48, 64], [0, 32, 64, 128, 256], [0, 16, 48, 64, 128],
9119+
[0, 200, 352, 600, 616, 1000, 1098], [0, 16, 48, 64, 128], [0, 16, 48, 64, 128], [0, 16, 48, 64, 75, 128],
9120+
[0, 100, 300, 600, 1000, 1250, 1490, 1800, 1900, 2048], [0, 100, 150, 300, 500, 600, 800, 1000, 1150, 1250, 1300, 1490, 1600, 1800, 1900, 2048], [0, 32, 64, 96, 128, 160, 192, 256], [0, 2, 7, 19, 32, 64, 96, 128, 256, 387, 512],
9121+
[0, 1024, 2048, 3072, 4000, 4096], [0, 1024, 2048, 3072, 4096], [0, 26, 52, 79, 112, 128, 256, 512], [0, 11, 32, 90, 128, 200, 256],),
9122+
"gen_policy": "gen_tensor_by_value",
9123+
"dtype": [np.int64],
9124+
"requires_grad": [False],
9125+
},
9126+
{
9127+
"ins": ['cu_seqlens_kv'],
9128+
"value": ([0, 32], [0, 16, 48, 64], [0, 32, 64, 128, 256], [0, 16, 48, 64, 128],
9129+
[0, 32], [0, 16, 48, 64], [0, 32, 64, 128, 256], [0, 16, 48, 64, 128],
9130+
[0, 200, 352, 600, 616, 1000, 1098], [0, 16, 48, 64, 128], [0, 16, 48, 64, 128], [0, 16, 48, 64, 75, 128],
9131+
[0, 100, 300, 600, 1000, 1250, 1490, 1800, 1900, 2048], [0, 100, 150, 300, 500, 600, 800, 1000, 1150, 1250, 1300, 1490, 1600, 1800, 1900, 2048], [0, 32, 64, 96, 128, 160, 192, 256], [0, 2, 7, 19, 32, 64, 96, 128, 256, 387, 512],
9132+
[0, 1024, 2048, 3072, 4000, 4096], [0, 1024, 2048, 3072, 4096], [0, 26, 52, 79, 112, 128, 256, 512], [0, 11, 32, 90, 128, 200, 256],),
9133+
"dtype": [np.int64],
9134+
"gen_policy": "gen_tensor_by_value",
9135+
"requires_grad": [False],
9136+
},
9137+
],
9138+
),
9139+
),
9140+
89779141
'flash_attention_varlen': dict(
89789142
name=['flash_attention_varlen'],
89799143
interface=['CustomizedTest'],
@@ -8982,28 +9146,28 @@
89829146
atol=1e-3,
89839147
rtol=1e-4,
89849148
para=dict(
8985-
p_dropout=[0, 0, 0, 0],
8986-
is_causal=[True, True, False, True],
8987-
softmax_scale=[None, 0.0883, None, 0.125],
8988-
max_seqlen_q=[32, 32, 128, 64],
8989-
max_seqlen_kv=[32, 32, 128, 64],
8990-
cu_seqlens_q=[[0, 32], [0, 16, 48, 64], [0, 32, 64, 128, 256], [0, 16, 48, 64, 128]],
8991-
cu_seqlens_kv=[[0, 32], [0, 16, 48, 64], [0, 32, 64, 128, 256], [0, 16, 48, 64, 128]],
9149+
p_dropout=[0, 0, 0, 0, 0],
9150+
is_causal=[True, True, False, True, False],
9151+
softmax_scale=[None, 0.0883, None, 0.125, None],
9152+
max_seqlen_q=[32, 32, 128, 64, 256],
9153+
max_seqlen_kv=[32, 32, 128, 64, 256],
9154+
cu_seqlens_q=[[0, 32], [0, 16, 48, 64], [0, 32, 64, 128, 256], [0, 16, 48, 64, 128], [0, 2, 7, 19, 32, 64, 96, 128, 256, 512]],
9155+
cu_seqlens_kv=[[0, 32], [0, 16, 48, 64], [0, 32, 64, 128, 256], [0, 16, 48, 64, 128], [0, 2, 7, 19, 32, 64, 96, 128, 256, 512]],
89929156
),
89939157
tensor_para=dict(
89949158
gen_fn='Genfunc.randn',
89959159
args=[
89969160
{
89979161
"ins": ['q'],
8998-
"shape": ((32, 32, 128), (64, 64, 128), (256, 16, 128), (128, 8, 64)),
9162+
"shape": ((32, 32, 128), (64, 64, 128), (256, 16, 128), (128, 8, 64), (512, 8, 64)),
89999163
},
90009164
{
90019165
"ins": ['k'],
9002-
"shape": ((32, 32, 128), (64, 64, 128), (256, 16, 128), (128, 8, 64)),
9166+
"shape": ((32, 32, 128), (64, 64, 128), (256, 16, 128), (128, 8, 64), (512, 8, 64)),
90039167
},
90049168
{
90059169
"ins": ['v'],
9006-
"shape": ((32, 32, 128), (64, 64, 128), (256, 16, 128), (128, 8, 64)),
9170+
"shape": ((32, 32, 128), (64, 64, 128), (256, 16, 128), (128, 8, 64), (512, 8, 64)),
90079171
},
90089172
],
90099173
),

0 commit comments

Comments
 (0)