Skip to content

Commit 8843d6b

Browse files
committed
Simplify the canonical enum clone branches to a copy statement
1 parent 5881ed4 commit 8843d6b

File tree

1 file changed

+126
-16
lines changed

1 file changed

+126
-16
lines changed

compiler/rustc_mir_transform/src/match_branches.rs

Lines changed: 126 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,12 @@ impl<'tcx> crate::MirPass<'tcx> for MatchBranchSimplification {
5858
} else if otherwise_is_unreachable
5959
&& let Some(new_stmts) = simplify_switch(
6060
tcx,
61+
body,
62+
bb,
6163
first_case,
6264
targets,
6365
typing_env,
64-
&body.basic_blocks,
66+
discr,
6567
discr_local,
6668
discr_ty,
6769
)
@@ -293,10 +295,12 @@ fn can_cast(
293295
/// ```
294296
fn simplify_switch<'tcx>(
295297
tcx: TyCtxt<'tcx>,
298+
body: &Body<'tcx>,
299+
switch_bb: BasicBlock,
296300
first_case: BasicBlock,
297301
targets: &SwitchTargets,
298302
typing_env: ty::TypingEnv<'tcx>,
299-
bbs: &IndexSlice<BasicBlock, BasicBlockData<'tcx>>,
303+
discr: &Operand<'tcx>,
300304
discr_local: Local,
301305
discr_ty: Ty<'tcx>,
302306
) -> Option<Vec<StatementKind<'tcx>>> {
@@ -307,22 +311,25 @@ fn simplify_switch<'tcx>(
307311
if !targets.is_distinct() {
308312
return None;
309313
}
310-
if !bbs[targets.otherwise()].is_empty_unreachable() {
314+
if !body.basic_blocks[targets.otherwise()].is_empty_unreachable() {
311315
return None;
312316
}
313-
let first_bb: &BasicBlockData<'tcx> = &bbs[first_case];
317+
let first_bb: &BasicBlockData<'tcx> = &body.basic_blocks[first_case];
314318

315319
let mut stmts_iter: Vec<_> =
316-
targets.iter().map(|(case, bb)| (case, bbs[bb].statements.iter())).collect();
320+
targets.iter().map(|(case, bb)| (case, body.basic_blocks[bb].statements.iter())).collect();
317321
let mut current_line_stmts: Vec<(u128, &StatementKind<'tcx>)> =
318322
stmts_iter.iter_mut().map(|(case, bb)| (*case, &bb.next().unwrap().kind)).collect();
319323
let discr_layout = tcx.layout_of(typing_env.as_query_input(discr_ty)).unwrap();
320324
let mut new_stmts: Vec<StatementKind<'tcx>> = Vec::with_capacity(first_bb.statements.len());
321325
'finish: loop {
322326
let new_stmt = simplify_stmt(
323327
tcx,
328+
body,
329+
switch_bb,
324330
current_line_stmts.as_slice(),
325331
typing_env,
332+
discr,
326333
discr_local,
327334
discr_ty,
328335
discr_layout,
@@ -342,8 +349,11 @@ fn simplify_switch<'tcx>(
342349

343350
fn simplify_stmt<'tcx>(
344351
tcx: TyCtxt<'tcx>,
352+
body: &Body<'tcx>,
353+
switch_bb: BasicBlock,
345354
stmts: &[(u128, &StatementKind<'tcx>)],
346355
typing_env: ty::TypingEnv<'tcx>,
356+
discr: &Operand<'tcx>,
347357
discr_local: Local,
348358
discr_ty: Ty<'tcx>,
349359
discr_layout: TyAndLayout<'_>,
@@ -353,21 +363,28 @@ fn simplify_stmt<'tcx>(
353363
if stmts.into_iter().skip(1).all(|&(_, stmt)| first_stmt == stmt) {
354364
return Some(first_stmt.clone());
355365
}
356-
if let StatementKind::Assign(box (first_lhs, Rvalue::Use(Operand::Constant(box first_const)))) =
357-
first_stmt
366+
let StatementKind::Assign(box (first_lhs, first_rval)) = first_stmt else {
367+
return None;
368+
};
369+
if !stmts.into_iter().skip(1).all(|&(_, stmt)| {
370+
let StatementKind::Assign(box (other_lhs, _)) = stmt else {
371+
return false;
372+
};
373+
first_lhs == other_lhs
374+
}) {
375+
return None;
376+
}
377+
if let Rvalue::Use(Operand::Constant(box first_const)) = first_rval
358378
&& first_const.ty().is_integral()
359379
&& let Some(first_scalar_int) = first_const.const_.try_eval_scalar_int(tcx, typing_env)
360380
{
361381
if stmts.into_iter().skip(1).all(|&(_, stmt)| {
362-
let StatementKind::Assign(box (
363-
other_lhs,
364-
Rvalue::Use(Operand::Constant(box other_const)),
365-
)) = stmt
382+
let StatementKind::Assign(box (_, Rvalue::Use(Operand::Constant(box other_const)))) =
383+
stmt
366384
else {
367385
return false;
368386
};
369-
first_lhs == other_lhs
370-
&& first_const.ty() == other_const.ty()
387+
first_const.ty() == other_const.ty()
371388
&& other_const.const_.try_eval_scalar_int(tcx, typing_env) == Some(first_scalar_int)
372389
}) {
373390
return Some(first_stmt.clone());
@@ -378,7 +395,7 @@ fn simplify_stmt<'tcx>(
378395
if can_cast(tcx, first_case, discr_layout, first_const.ty(), first_scalar_int) {
379396
if stmts.into_iter().skip(1).all(|&(other_case, stmt)| {
380397
let StatementKind::Assign(box (
381-
other_lhs,
398+
_,
382399
Rvalue::Use(Operand::Constant(box other_const)),
383400
)) = stmt
384401
else {
@@ -389,8 +406,7 @@ fn simplify_stmt<'tcx>(
389406
else {
390407
return false;
391408
};
392-
first_lhs == other_lhs
393-
&& first_const.ty() == other_const.ty()
409+
first_const.ty() == other_const.ty()
394410
&& can_cast(tcx, other_case, discr_layout, other_const.ty(), other_scalar_int)
395411
}) {
396412
let operand = Operand::Copy(Place::from(discr_local));
@@ -403,5 +419,99 @@ fn simplify_stmt<'tcx>(
403419
}
404420
}
405421
}
422+
423+
if let Some(new_stmt) =
424+
simplify_to_copy(tcx, body, switch_bb, discr, typing_env, first_lhs, stmts)
425+
{
426+
return Some(new_stmt);
427+
}
428+
406429
None
407430
}
431+
432+
/// This is primarily used to merge these copy statements that simplified the canonical enum clone method by GVN.
433+
/// The GVN simplified
434+
/// ```ignore (syntax-highlighting-only)
435+
/// match a {
436+
/// Foo::A(x) => Foo::A(*x),
437+
/// Foo::B => Foo::B
438+
/// }
439+
/// ```
440+
/// to
441+
/// ```ignore (syntax-highlighting-only)
442+
/// match a {
443+
/// Foo::A(_x) => a, // copy a
444+
/// Foo::B => Foo::B
445+
/// }
446+
/// ```
447+
/// This function will simplify into a copy statement.
448+
fn simplify_to_copy<'tcx>(
449+
tcx: TyCtxt<'tcx>,
450+
body: &Body<'tcx>,
451+
switch_bb: BasicBlock,
452+
discr: &Operand<'tcx>,
453+
typing_env: ty::TypingEnv<'tcx>,
454+
first_case_lhs: &Place<'tcx>,
455+
stmts: &[(u128, &StatementKind<'tcx>)],
456+
) -> Option<StatementKind<'tcx>> {
457+
let bbs = &body.basic_blocks;
458+
// Check if the copy source matches the following pattern.
459+
// _2 = discriminant(*_1); // "*_1" is the expected the copy source.
460+
// switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1];
461+
let &Statement {
462+
kind: StatementKind::Assign(box (discr_place, Rvalue::Discriminant(copy_src_place))),
463+
..
464+
} = bbs[switch_bb].statements.last()?
465+
else {
466+
return None;
467+
};
468+
if discr.place() != Some(discr_place) {
469+
return None;
470+
}
471+
let src_ty = copy_src_place.ty(body.local_decls(), tcx);
472+
if !src_ty.ty.is_enum() || src_ty.variant_index.is_some() {
473+
return None;
474+
}
475+
let dest_ty = first_case_lhs.ty(body.local_decls(), tcx);
476+
if dest_ty.ty != src_ty.ty || dest_ty.variant_index.is_some() {
477+
return None;
478+
}
479+
let ty::Adt(def, _) = dest_ty.ty.kind() else {
480+
return None;
481+
};
482+
483+
for &(case, stmt) in stmts.iter() {
484+
let StatementKind::Assign(box (_, rvalue)) = stmt else {
485+
return None;
486+
};
487+
match rvalue {
488+
// Check if `_3 = const Foo::B` can be transformed to `_3 = copy *_1`.
489+
Rvalue::Use(Operand::Constant(box constant))
490+
if let Const::Val(const_, ty) = constant.const_ =>
491+
{
492+
let (ecx, op) =
493+
mk_eval_cx_for_const_val(tcx.at(constant.span), typing_env, const_, ty)?;
494+
let variant = ecx.read_discriminant(&op).discard_err()?;
495+
if !def.variants()[variant].fields.is_empty() {
496+
return None;
497+
}
498+
let Discr { val, .. } = ty.discriminant_for_variant(tcx, variant)?;
499+
if val != case {
500+
return None;
501+
}
502+
}
503+
Rvalue::Use(Operand::Copy(src_place)) if *src_place == copy_src_place => {}
504+
// Check if `_3 = Foo::B` can be transformed to `_3 = copy *_1`.
505+
Rvalue::Aggregate(box AggregateKind::Adt(_, variant_index, _, _, None), fields)
506+
if fields.is_empty()
507+
&& let Some(Discr { val, .. }) =
508+
src_ty.ty.discriminant_for_variant(tcx, *variant_index)
509+
&& val == case => {}
510+
_ => return None,
511+
}
512+
}
513+
Some(StatementKind::Assign(Box::new((
514+
*first_case_lhs,
515+
Rvalue::Use(Operand::Copy(copy_src_place)),
516+
))))
517+
}

0 commit comments

Comments
 (0)