@@ -58,10 +58,12 @@ impl<'tcx> crate::MirPass<'tcx> for MatchBranchSimplification {
5858 } else if otherwise_is_unreachable
5959 && let Some ( new_stmts) = simplify_switch (
6060 tcx,
61+ body,
62+ bb,
6163 first_case,
6264 targets,
6365 typing_env,
64- & body . basic_blocks ,
66+ discr ,
6567 discr_local,
6668 discr_ty,
6769 )
@@ -293,10 +295,12 @@ fn can_cast(
293295/// ```
294296fn simplify_switch < ' tcx > (
295297 tcx : TyCtxt < ' tcx > ,
298+ body : & Body < ' tcx > ,
299+ switch_bb : BasicBlock ,
296300 first_case : BasicBlock ,
297301 targets : & SwitchTargets ,
298302 typing_env : ty:: TypingEnv < ' tcx > ,
299- bbs : & IndexSlice < BasicBlock , BasicBlockData < ' tcx > > ,
303+ discr : & Operand < ' tcx > ,
300304 discr_local : Local ,
301305 discr_ty : Ty < ' tcx > ,
302306) -> Option < Vec < StatementKind < ' tcx > > > {
@@ -307,22 +311,25 @@ fn simplify_switch<'tcx>(
307311 if !targets. is_distinct ( ) {
308312 return None ;
309313 }
310- if !bbs [ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
314+ if !body . basic_blocks [ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
311315 return None ;
312316 }
313- let first_bb: & BasicBlockData < ' tcx > = & bbs [ first_case] ;
317+ let first_bb: & BasicBlockData < ' tcx > = & body . basic_blocks [ first_case] ;
314318
315319 let mut stmts_iter: Vec < _ > =
316- targets. iter ( ) . map ( |( case, bb) | ( case, bbs [ bb] . statements . iter ( ) ) ) . collect ( ) ;
320+ targets. iter ( ) . map ( |( case, bb) | ( case, body . basic_blocks [ bb] . statements . iter ( ) ) ) . collect ( ) ;
317321 let mut current_line_stmts: Vec < ( u128 , & StatementKind < ' tcx > ) > =
318322 stmts_iter. iter_mut ( ) . map ( |( case, bb) | ( * case, & bb. next ( ) . unwrap ( ) . kind ) ) . collect ( ) ;
319323 let discr_layout = tcx. layout_of ( typing_env. as_query_input ( discr_ty) ) . unwrap ( ) ;
320324 let mut new_stmts: Vec < StatementKind < ' tcx > > = Vec :: with_capacity ( first_bb. statements . len ( ) ) ;
321325 ' finish: loop {
322326 let new_stmt = simplify_stmt (
323327 tcx,
328+ body,
329+ switch_bb,
324330 current_line_stmts. as_slice ( ) ,
325331 typing_env,
332+ discr,
326333 discr_local,
327334 discr_ty,
328335 discr_layout,
@@ -342,8 +349,11 @@ fn simplify_switch<'tcx>(
342349
343350fn simplify_stmt < ' tcx > (
344351 tcx : TyCtxt < ' tcx > ,
352+ body : & Body < ' tcx > ,
353+ switch_bb : BasicBlock ,
345354 stmts : & [ ( u128 , & StatementKind < ' tcx > ) ] ,
346355 typing_env : ty:: TypingEnv < ' tcx > ,
356+ discr : & Operand < ' tcx > ,
347357 discr_local : Local ,
348358 discr_ty : Ty < ' tcx > ,
349359 discr_layout : TyAndLayout < ' _ > ,
@@ -353,21 +363,28 @@ fn simplify_stmt<'tcx>(
353363 if stmts. into_iter ( ) . skip ( 1 ) . all ( |& ( _, stmt) | first_stmt == stmt) {
354364 return Some ( first_stmt. clone ( ) ) ;
355365 }
356- if let StatementKind :: Assign ( box ( first_lhs, Rvalue :: Use ( Operand :: Constant ( box first_const) ) ) ) =
357- first_stmt
366+ let StatementKind :: Assign ( box ( first_lhs, first_rval) ) = first_stmt else {
367+ return None ;
368+ } ;
369+ if !stmts. into_iter ( ) . skip ( 1 ) . all ( |& ( _, stmt) | {
370+ let StatementKind :: Assign ( box ( other_lhs, _) ) = stmt else {
371+ return false ;
372+ } ;
373+ first_lhs == other_lhs
374+ } ) {
375+ return None ;
376+ }
377+ if let Rvalue :: Use ( Operand :: Constant ( box first_const) ) = first_rval
358378 && first_const. ty ( ) . is_integral ( )
359379 && let Some ( first_scalar_int) = first_const. const_ . try_eval_scalar_int ( tcx, typing_env)
360380 {
361381 if stmts. into_iter ( ) . skip ( 1 ) . all ( |& ( _, stmt) | {
362- let StatementKind :: Assign ( box (
363- other_lhs,
364- Rvalue :: Use ( Operand :: Constant ( box other_const) ) ,
365- ) ) = stmt
382+ let StatementKind :: Assign ( box ( _, Rvalue :: Use ( Operand :: Constant ( box other_const) ) ) ) =
383+ stmt
366384 else {
367385 return false ;
368386 } ;
369- first_lhs == other_lhs
370- && first_const. ty ( ) == other_const. ty ( )
387+ first_const. ty ( ) == other_const. ty ( )
371388 && other_const. const_ . try_eval_scalar_int ( tcx, typing_env) == Some ( first_scalar_int)
372389 } ) {
373390 return Some ( first_stmt. clone ( ) ) ;
@@ -378,7 +395,7 @@ fn simplify_stmt<'tcx>(
378395 if can_cast ( tcx, first_case, discr_layout, first_const. ty ( ) , first_scalar_int) {
379396 if stmts. into_iter ( ) . skip ( 1 ) . all ( |& ( other_case, stmt) | {
380397 let StatementKind :: Assign ( box (
381- other_lhs ,
398+ _ ,
382399 Rvalue :: Use ( Operand :: Constant ( box other_const) ) ,
383400 ) ) = stmt
384401 else {
@@ -389,8 +406,7 @@ fn simplify_stmt<'tcx>(
389406 else {
390407 return false ;
391408 } ;
392- first_lhs == other_lhs
393- && first_const. ty ( ) == other_const. ty ( )
409+ first_const. ty ( ) == other_const. ty ( )
394410 && can_cast ( tcx, other_case, discr_layout, other_const. ty ( ) , other_scalar_int)
395411 } ) {
396412 let operand = Operand :: Copy ( Place :: from ( discr_local) ) ;
@@ -403,5 +419,99 @@ fn simplify_stmt<'tcx>(
403419 }
404420 }
405421 }
422+
423+ if let Some ( new_stmt) =
424+ simplify_to_copy ( tcx, body, switch_bb, discr, typing_env, first_lhs, stmts)
425+ {
426+ return Some ( new_stmt) ;
427+ }
428+
406429 None
407430}
431+
432+ /// This is primarily used to merge these copy statements that simplified the canonical enum clone method by GVN.
433+ /// The GVN simplified
434+ /// ```ignore (syntax-highlighting-only)
435+ /// match a {
436+ /// Foo::A(x) => Foo::A(*x),
437+ /// Foo::B => Foo::B
438+ /// }
439+ /// ```
440+ /// to
441+ /// ```ignore (syntax-highlighting-only)
442+ /// match a {
443+ /// Foo::A(_x) => a, // copy a
444+ /// Foo::B => Foo::B
445+ /// }
446+ /// ```
447+ /// This function will simplify into a copy statement.
448+ fn simplify_to_copy < ' tcx > (
449+ tcx : TyCtxt < ' tcx > ,
450+ body : & Body < ' tcx > ,
451+ switch_bb : BasicBlock ,
452+ discr : & Operand < ' tcx > ,
453+ typing_env : ty:: TypingEnv < ' tcx > ,
454+ first_case_lhs : & Place < ' tcx > ,
455+ stmts : & [ ( u128 , & StatementKind < ' tcx > ) ] ,
456+ ) -> Option < StatementKind < ' tcx > > {
457+ let bbs = & body. basic_blocks ;
458+ // Check if the copy source matches the following pattern.
459+ // _2 = discriminant(*_1); // "*_1" is the expected the copy source.
460+ // switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1];
461+ let & Statement {
462+ kind : StatementKind :: Assign ( box ( discr_place, Rvalue :: Discriminant ( copy_src_place) ) ) ,
463+ ..
464+ } = bbs[ switch_bb] . statements . last ( ) ?
465+ else {
466+ return None ;
467+ } ;
468+ if discr. place ( ) != Some ( discr_place) {
469+ return None ;
470+ }
471+ let src_ty = copy_src_place. ty ( body. local_decls ( ) , tcx) ;
472+ if !src_ty. ty . is_enum ( ) || src_ty. variant_index . is_some ( ) {
473+ return None ;
474+ }
475+ let dest_ty = first_case_lhs. ty ( body. local_decls ( ) , tcx) ;
476+ if dest_ty. ty != src_ty. ty || dest_ty. variant_index . is_some ( ) {
477+ return None ;
478+ }
479+ let ty:: Adt ( def, _) = dest_ty. ty . kind ( ) else {
480+ return None ;
481+ } ;
482+
483+ for & ( case, stmt) in stmts. iter ( ) {
484+ let StatementKind :: Assign ( box ( _, rvalue) ) = stmt else {
485+ return None ;
486+ } ;
487+ match rvalue {
488+ // Check if `_3 = const Foo::B` can be transformed to `_3 = copy *_1`.
489+ Rvalue :: Use ( Operand :: Constant ( box constant) )
490+ if let Const :: Val ( const_, ty) = constant. const_ =>
491+ {
492+ let ( ecx, op) =
493+ mk_eval_cx_for_const_val ( tcx. at ( constant. span ) , typing_env, const_, ty) ?;
494+ let variant = ecx. read_discriminant ( & op) . discard_err ( ) ?;
495+ if !def. variants ( ) [ variant] . fields . is_empty ( ) {
496+ return None ;
497+ }
498+ let Discr { val, .. } = ty. discriminant_for_variant ( tcx, variant) ?;
499+ if val != case {
500+ return None ;
501+ }
502+ }
503+ Rvalue :: Use ( Operand :: Copy ( src_place) ) if * src_place == copy_src_place => { }
504+ // Check if `_3 = Foo::B` can be transformed to `_3 = copy *_1`.
505+ Rvalue :: Aggregate ( box AggregateKind :: Adt ( _, variant_index, _, _, None ) , fields)
506+ if fields. is_empty ( )
507+ && let Some ( Discr { val, .. } ) =
508+ src_ty. ty . discriminant_for_variant ( tcx, * variant_index)
509+ && val == case => { }
510+ _ => return None ,
511+ }
512+ }
513+ Some ( StatementKind :: Assign ( Box :: new ( (
514+ * first_case_lhs,
515+ Rvalue :: Use ( Operand :: Copy ( copy_src_place) ) ,
516+ ) ) ) )
517+ }
0 commit comments