@@ -7128,22 +7128,19 @@ def get_device(self) -> Optional[torch.device]:
71287128class MultiOutput (ExternKernel ):
71297129 def codegen (self , wrapper ) -> None : # type: ignore[no-untyped-def]
71307130 wrapper .codegen_multi_output (self )
7131- if not self .skip_size_stride_alignment_checks :
7132- self .codegen_size_asserts (wrapper )
7133- self .codegen_alignment_asserts (wrapper )
7131+ self .codegen_size_asserts (wrapper )
7132+ self .codegen_alignment_asserts (wrapper )
71347133
71357134 def __init__ ( # type: ignore[no-untyped-def]
71367135 self ,
71377136 layout : OutputSpec ,
71387137 input ,
71397138 indices : list [tuple [Any , ...]],
7140- skip_size_stride_alignment_checks = False ,
71417139 ) -> None :
71427140 super ().__init__ (None , layout , [input ], ())
71437141 self .name = V .graph .register_buffer (self )
71447142 V .graph .register_operation (self )
71457143 self .indices = indices
7146- self .skip_size_stride_alignment_checks = skip_size_stride_alignment_checks
71477144
71487145 def get_free_symbol_uses (
71497146 self , unbacked_only : bool = False
@@ -7510,7 +7507,6 @@ def create_output(output: IRNode, ind: int):
75107507 ),
75117508 invoke_subgraph ,
75127509 [(list , ind )],
7513- skip_size_stride_alignment_checks = True ,
75147510 )
75157511
75167512 outputs = [create_output (output , i ) for i , output in enumerate (outputs )]
0 commit comments