@@ -28,23 +28,32 @@ def test_init_attention_mask_builder(self):
28
28
self .assertEqual (attention_mask_builder ._seq_len_cached , 1024 )
29
29
self .assertEqual (attention_mask_builder .attn_mask_cache .dtype ,
30
30
torch .float16 )
31
- self .assertEqual (attention_mask_builder .splitfuse_mask_value , - 10000 )
32
31
self .assertEqual (attention_mask_builder .attn_mask_cache .shape ,
33
32
(1024 , 1024 ))
34
33
self .assertEqual (attention_mask_builder .attn_mask_cache [0 ][- 1 ],
35
34
torch .tensor (float ("-inf" ), dtype = torch .float16 ))
36
35
37
- # generate attention_mask_builder with int8
38
- attention_mask_builder = AttentionMaskBuilder (max_seq_len = 512 ,
39
- dtype = torch .int8 )
40
- self .assertEqual (attention_mask_builder ._seq_len_cached , 512 )
36
+ # generate attention_mask_builder with bfloat16
37
+ attention_mask_builder = AttentionMaskBuilder (max_seq_len = 2048 ,
38
+ dtype = torch .bfloat16 )
39
+ self .assertEqual (attention_mask_builder ._seq_len_cached , 2048 )
41
40
self .assertEqual (attention_mask_builder .attn_mask_cache .dtype ,
42
- torch .int8 )
43
- self .assertEqual (attention_mask_builder .splitfuse_mask_value , - 10000 )
41
+ torch .bfloat16 )
44
42
self .assertEqual (attention_mask_builder .attn_mask_cache .shape ,
45
- (512 , 512 ))
43
+ (2048 , 2048 ))
46
44
self .assertEqual (attention_mask_builder .attn_mask_cache [0 ][- 1 ],
47
- torch .tensor (1 , dtype = torch .int8 ))
45
+ torch .tensor (1 , dtype = torch .bfloat16 ))
46
+
47
+ def test_get_mask_scale_factor (self ):
48
+ # supported data types
49
+ self .assertEqual (
50
+ AttentionMaskBuilder .get_mask_scale_factor (torch .float16 ), 1 )
51
+ self .assertEqual (
52
+ AttentionMaskBuilder .get_mask_scale_factor (torch .bfloat16 ), - 10000 )
53
+ # mask_scale_factor now only supports data types: torch.float16 and torch.bfloat16
54
+ # Otherwise raise ValueError
55
+ with self .assertRaises (ValueError ):
56
+ AttentionMaskBuilder .get_mask_scale_factor (torch .int8 )
48
57
49
58
def test_get_attn_mask (self ):
50
59
# if the len is less than max_seq_len, the attn_mask_cache will not be updated
@@ -77,80 +86,48 @@ def test_get_splitfuse_attn_mask(self):
77
86
attention_mask_builder = AttentionMaskBuilder (max_seq_len = 1024 ,
78
87
dtype = torch .float16 )
79
88
attn_mask = attention_mask_builder .get_splitfuse_attn_mask (
80
- seq_lens = [512 ],
81
- query_lens = [512 ],
82
- position = torch .tensor ([0 ]),
89
+ seq_lens = torch .tensor ([10 , 20 , 100 ]),
90
+ position = torch .tensor ([7 , 8 , 9 , 18 , 19 , 99 ]),
83
91
dtype = torch .float16 ,
84
92
device = torch .device ("cpu" ),
85
93
)
86
- self .assertEqual (attn_mask .shape , (1 , 512 ))
94
+ self .assertEqual (attn_mask .shape , (6 , 100 ))
87
95
self .assertEqual (attention_mask_builder ._seq_len_cached , 1024 )
88
96
89
97
attn_mask = attention_mask_builder .get_splitfuse_attn_mask (
90
- seq_lens = [2048 ],
91
- query_lens = [1024 ],
92
- position = torch .tensor ([0 ]),
98
+ seq_lens = torch .tensor ([10 , 3000 , 2000 ]),
99
+ position = torch .tensor ([7 , 8 , 9 , 2999 , 1999 ]),
93
100
dtype = torch .float16 ,
94
101
device = torch .device ("cpu" ),
95
102
)
96
- self .assertEqual (attn_mask .shape , (1024 , 2048 ))
103
+ self .assertEqual (attn_mask .shape , (5 , 3000 ))
104
+ self .assertEqual (attention_mask_builder ._seq_len_cached , 3000 )
105
+
106
+ # splitfuse_attn_mask now only supports data types: torch.float16 and torch.bfloat16
107
+ # otherwise raise ValueError
108
+ with self .assertRaises (ValueError ):
109
+ attn_mask = attention_mask_builder .get_splitfuse_attn_mask (
110
+ seq_lens = torch .tensor ([10 , 20 , 100 ]),
111
+ position = torch .tensor ([7 , 8 , 9 , 18 , 19 , 99 ]),
112
+ dtype = torch .int8 ,
113
+ device = torch .device ("cpu" ),
114
+ )
115
+
116
+ def test_mask_value_cleanliness (self ):
117
+ attention_mask_builder = AttentionMaskBuilder (max_seq_len = 6 ,
118
+ dtype = torch .bfloat16 )
119
+ self .assertEqual (attention_mask_builder .attn_mask_cache [- 2 ][- 1 ],
120
+ torch .tensor (1 , dtype = torch .bfloat16 ))
97
121
98
- attention_mask_builder = AttentionMaskBuilder (max_seq_len = 1024 ,
99
- dtype = torch .int8 )
100
122
attn_mask = attention_mask_builder .get_splitfuse_attn_mask (
101
- seq_lens = [512 ],
102
- query_lens = [512 ],
103
- position = torch .tensor ([0 ]),
104
- dtype = torch .int8 ,
105
- device = torch .device ("cpu" ),
106
- )
107
- self .assertEqual (attn_mask .shape , (1 , 512 ))
108
-
109
- def test_use_multiple_masks (self ):
110
- max_seq_lens = [128 , 512 , 1024 ]
111
- dtypes = [torch .float16 , torch .bfloat16 , torch .int8 ]
112
- for max_seq_len , dtype in zip (max_seq_lens , dtypes ):
113
- with self .subTest (max_seq_len = max_seq_len , dtype = dtype ):
114
- self ._test_use_multiple_masks (max_seq_len , dtype )
115
-
116
- def _test_use_multiple_masks (self , max_seq_len , dtype ):
117
- expected_mask_value = torch .finfo (
118
- torch .float32 ).min if dtype == torch .float16 else 1
119
- if dtype == torch .float16 :
120
- expected_splitfuse_mask_value = expected_mask_value
121
- elif dtype == torch .bfloat16 :
122
- expected_splitfuse_mask_value = - 10000
123
- else :
124
- assert dtype == torch .int8 , "Unsupported dtype for attention mask"
125
- expected_splitfuse_mask_value = - 16
126
-
127
- attention_mask_builder = AttentionMaskBuilder (max_seq_len = max_seq_len ,
128
- dtype = dtype )
129
-
130
- splitfuse_attn_mask = attention_mask_builder .get_splitfuse_attn_mask (
131
- seq_lens = [max_seq_len ],
132
- query_lens = [max_seq_len ],
133
- position = torch .tensor ([0 ]),
134
- dtype = dtype ,
123
+ seq_lens = torch .tensor ([6 ]),
124
+ position = torch .tensor ([3 , 4 , 5 ]),
125
+ dtype = torch .float16 ,
135
126
device = torch .device ("cpu" ),
136
127
)
137
- self .assertEqual (splitfuse_attn_mask .shape , (1 , max_seq_len ))
138
128
self .assertEqual (
139
- splitfuse_attn_mask [0 ][- 1 ],
140
- torch .tensor (expected_splitfuse_mask_value , dtype = dtype ))
141
- self .assertEqual (attention_mask_builder ._seq_len_cached , max_seq_len )
142
- self .assertEqual (attention_mask_builder .attn_mask_cache .shape ,
143
- (max_seq_len , max_seq_len ))
144
- self .assertEqual (attention_mask_builder .attn_mask_cache [0 ][- 1 ],
145
- torch .tensor (expected_mask_value , dtype = dtype ))
146
-
147
- attn_mask = attention_mask_builder .get_attn_mask (
148
- max_seq_len = max_seq_len , dtype = dtype , device = torch .device ("cpu" ))
149
- self .assertEqual (attn_mask .shape , (max_seq_len , max_seq_len ))
150
- self .assertEqual (attn_mask [0 ][- 1 ],
151
- torch .tensor (expected_mask_value , dtype = dtype ))
152
- self .assertEqual (attention_mask_builder ._seq_len_cached , max_seq_len )
153
- self .assertEqual (attention_mask_builder .attn_mask_cache .shape ,
154
- (max_seq_len , max_seq_len ))
155
- self .assertEqual (attention_mask_builder .attn_mask_cache [0 ][- 1 ],
156
- torch .tensor (expected_mask_value , dtype = dtype ))
129
+ attn_mask [- 2 ][- 1 ],
130
+ torch .tensor (- 10000 , dtype = torch .bfloat16 ,
131
+ device = attn_mask .device ))
132
+ self .assertEqual (attention_mask_builder .attn_mask_cache [- 2 ][- 1 ],
133
+ torch .tensor (1 , dtype = torch .bfloat16 ))
0 commit comments