@@ -131,14 +131,6 @@ def setUp(self):
131131 self .model_tester = MllamaText2TextModelTester (self )
132132 self .config_tester = ConfigTester (self , config_class = MllamaTextConfig , has_text_modality = True )
133133
134- @unittest .skip (reason = "The outputs don't match, no idea why" )
135- def test_beam_search_low_memory (self ):
136- pass
137-
138- @unittest .skip (reason = "Quanto test is borken" )
139- def test_generate_with_quant_cache (self ):
140- pass
141-
142134
143135class MllamaVisionText2TextModelTester :
144136 def __init__ (
@@ -201,6 +193,7 @@ def __init__(
201193 self .image_size = 224
202194 self .max_num_images = 1
203195 self .max_image_tiles = 4
196+ self .image_length = 904
204197
205198 def get_config (self ):
206199 return MllamaConfig (
@@ -319,86 +312,50 @@ def test_inputs_embeds_matches_input_ids(self):
319312 out_embeds = model (inputs_embeds = inputs_embeds , ** inputs )[0 ]
320313 self .assertTrue (ops .allclose (out_embeds , out_ids ))
321314
322- @unittest .skip (reason = "Static cache not supported" )
323- def test_static_cache_matches_dynamic (self ):
324- # TypeError: list indices must be integers or slices, not tuple
325- # TODO: @raushan, please look into this for new cache format
326- pass
327-
328- @unittest .skip (reason = "Mllama has dynamic control flow which is not yet supported by compile" )
329- def test_generate_compile_fullgraph (self ):
330- pass
315+ def _check_attentions_for_generate (
316+ self , batch_size , attentions , min_length , max_length , config , use_cache = False , num_beam_groups = 1
317+ ):
318+ # Mllama has cross attention layers and those have a different shape than normal attention layers
319+ self .assertIsInstance (attentions , tuple )
320+ self .assertListEqual (
321+ [isinstance (iter_attentions , tuple ) for iter_attentions in attentions ], [True ] * len (attentions )
322+ )
323+ self .assertEqual (len (attentions ), (max_length - min_length ) * num_beam_groups )
324+ cross_attention_layers = self .model_tester .text_config ["cross_attention_layers" ]
325+ for idx , iter_attentions in enumerate (attentions ):
326+ tgt_len = min_length + idx if not use_cache else 1
327+ src_len = min_length + idx
328+ expected_shape = (
329+ batch_size * num_beam_groups ,
330+ config .num_attention_heads ,
331+ tgt_len ,
332+ src_len ,
333+ )
334+ expected_shape_cross = (
335+ batch_size * num_beam_groups ,
336+ config .num_attention_heads ,
337+ tgt_len ,
338+ self .model_tester .image_length ,
339+ )
340+ expected_shapes = [
341+ expected_shape if layer_idx not in cross_attention_layers else expected_shape_cross
342+ for layer_idx in range (len (iter_attentions ))
343+ ]
344+ self .assertListEqual ([layer_attention .shape for layer_attention in iter_attentions ], expected_shapes )
331345
332- @unittest .skip (reason = "The outputs don't match, no idea why" )
333- def test_beam_search_low_memory (self ):
334- pass
335346
336- @unittest .skip (reason = "Mllama is not yet supported by compile" )
337- def test_sdpa_can_compile_dynamic (self ):
338- # TODO: look into this, AttributeError("'tensor' object has no attribute '__pow__'")
339- # relevant issue: https://github.com/pytorch/pytorch/issues/133166
347+ @unittest .skip (reason = "The test itself is broken" ) # TODO @zucchini-nlp
348+ def test_generate_with_quant_cache (self ):
340349 pass
341350
342351 @unittest .skip (reason = "The test itself is broken" ) # TODO @zucchini-nlp
343- def test_generate_with_quant_cache (self ):
352+ def test_beam_search_low_memory (self ):
344353 pass
345354
346355 @unittest .skip (reason = "AssertionError: Items in the second set but not the first: might be a setting issue" )
347356 def test_model_parallelism (self ):
348357 pass
349358
350- @unittest .skip (reason = "Failing test, need to fix" )
351- def test_compile_cuda_graph_time (self ):
352- pass
353-
354- @unittest .skip (reason = "Failing test, need to fix" )
355- def test_torch_compile_fullgraph (self ):
356- pass
357-
358- @unittest .skip (reason = "Device side assert triggered" )
359- def test_assisted_decoding_with_num_logits_to_keep (self ):
360- pass
361-
362- @unittest .skip (reason = "Failing test, need to fix" )
363- def test_beam_sample_generate_dict_output ():
364- pass
365-
366- @unittest .skip (reason = "Failing test, need to fix" )
367- def test_beam_search_generate_dict_output ():
368- pass
369-
370- @unittest .skip (reason = "Failing test, need to fix" )
371- def test_constrained_beam_search_generate_dict_output ():
372- pass
373-
374- @unittest .skip (reason = "Failing test, need to fix" )
375- def test_dola_decoding_sample ():
376- pass
377-
378- @unittest .skip (reason = "Failing test, need to fix" )
379- def test_generate_methods_with_num_logits_to_keep ():
380- pass
381-
382- @unittest .skip (reason = "Failing test, need to fix" )
383- def test_greedy_generate_dict_outputs ():
384- pass
385-
386- @unittest .skip (reason = "Failing test, need to fix" )
387- def test_group_beam_search_generate_dict_output ():
388- pass
389-
390- @unittest .skip (reason = "Failing test, need to fix" )
391- def test_model_parallel_beam_search ():
392- pass
393-
394- @unittest .skip (reason = "Failing test, need to fix" )
395- def test_new_cache_format_2 ():
396- pass
397-
398- @unittest .skip (reason = "Failing test, need to fix" )
399- def test_sample_generate_dict_output ():
400- pass
401-
402359
403360@require_mindspore
404361class MllamaForConditionalGenerationIntegrationTest (unittest .TestCase ):
0 commit comments