diff --git a/vyper/venom/analysis/mem_ssa.py b/vyper/venom/analysis/mem_ssa.py index c1db8789ba..25cb063b18 100644 --- a/vyper/venom/analysis/mem_ssa.py +++ b/vyper/venom/analysis/mem_ssa.py @@ -60,15 +60,6 @@ def __hash__(self) -> int: def __repr__(self) -> str: return f"{self.__class__.__name__}({self.id_str})" - -class LiveOnEntry(MemoryAccess): - """ - For type checking purposes - """ - - pass - - class MemoryDef(MemoryAccess): """Represents a definition of memory state""" @@ -103,6 +94,15 @@ def __init__(self, id: int, block: IRBasicBlock): self.block = block self.operands: list[tuple[MemoryPhiOperand, IRBasicBlock]] = [] +class LiveOnEntry(MemoryDef): + """ + For type checking purposes + """ + + def __init__(self, id: int, addr_space: AddrSpace): + super().__init__(id, IRInstruction("nop", []), addr_space) + + # Type aliases for signatures in this module MemoryDefOrUse = MemoryDef | MemoryUse @@ -123,6 +123,7 @@ class MemSSAAbstract(IRAnalysis): addr_space: AddrSpace mem_alias_type: type[MemoryAliasAnalysisAbstract] + live_on_entry: LiveOnEntry = None def __init__(self, analyses_cache, function): super().__init__(analyses_cache, function) @@ -130,7 +131,7 @@ def __init__(self, analyses_cache, function): self.next_id = 1 # Start from 1 since 0 will be live_on_entry # live_on_entry node - self.live_on_entry = LiveOnEntry(0) + self.live_on_entry = LiveOnEntry(0, self.addr_space) self.memory_defs: dict[IRBasicBlock, list[MemoryDef]] = {} self.memory_uses: dict[IRBasicBlock, list[MemoryUse]] = {} diff --git a/vyper/venom/passes/dft.py b/vyper/venom/passes/dft.py index 02570bb403..6eb3f2dd88 100644 --- a/vyper/venom/passes/dft.py +++ b/vyper/venom/passes/dft.py @@ -1,10 +1,23 @@ from collections import defaultdict +from typing import Dict, List, Optional import vyper.venom.effects as effects from vyper.utils import OrderedSet from vyper.venom.analysis import DFGAnalysis, LivenessAnalysis -from vyper.venom.basicblock import IRBasicBlock, IRInstruction +from vyper.venom.analysis.mem_ssa import ( + LiveOnEntry, + MemoryAccess, + MemoryDef, + MemoryPhi, + MemoryUse, + MemSSA, + MemSSAAbstract, + StorageSSA, + TransientSSA, +) +from vyper.venom.basicblock import IRBasicBlock, IRInstruction, IRVariable from vyper.venom.function import IRFunction +from vyper.venom.memory_location import MemoryLocation from vyper.venom.passes.base_pass import IRPass @@ -17,11 +30,22 @@ class DFTPass(IRPass): # "effect dependency analysis" eda: dict[IRInstruction, OrderedSet[IRInstruction]] + effective_reaching_defs: dict[MemoryUse, MemoryDef] + defs_to_uses: Dict[MemoryDef, List[MemoryUse]] + def run_pass(self) -> None: self.data_offspring = {} self.visited_instructions: OrderedSet[IRInstruction] = OrderedSet() self.dfg = self.analyses_cache.force_analysis(DFGAnalysis) + self.mem_ssa = { + effects.MEMORY: self.analyses_cache.force_analysis(MemSSA), + effects.STORAGE: self.analyses_cache.force_analysis(StorageSSA), + effects.TRANSIENT: self.analyses_cache.force_analysis(TransientSSA), + } + self._calculate_effective_reaching_defs(effects.MEMORY) + self._calculate_effective_reaching_defs(effects.STORAGE) + self._calculate_effective_reaching_defs(effects.TRANSIENT) for bb in self.function.get_basic_blocks(): self._process_basic_block(bb) @@ -111,10 +135,40 @@ def _calculate_dependency_graphs(self, bb: IRBasicBlock) -> None: last_write_effects[write_effect] = inst for read_effect in read_effects: + if read_effect == effects.MEMORY: + self._handle_memory_effect(inst, effects.MEMORY, last_write_effects, last_read_effects) + continue + if read_effect == effects.STORAGE: + self._handle_memory_effect(inst, effects.STORAGE, last_write_effects, last_read_effects) + continue + if read_effect == effects.TRANSIENT: + self._handle_memory_effect(inst, effects.TRANSIENT, last_write_effects, last_read_effects) + continue if read_effect in last_write_effects and last_write_effects[read_effect] != inst: self.eda[inst].add(last_write_effects[read_effect]) last_read_effects[read_effect] = inst + def _handle_memory_effect( + self, + inst: IRInstruction, + effect: effects.Effects, + last_write_effects: dict[effects.Effects, IRInstruction], + last_read_effects: dict[effects.Effects, IRInstruction], + ) -> None: + mem_use = self.mem_ssa[effect].get_memory_use(inst) + mem_def = self.effective_reaching_defs.get(mem_use, None) + + if mem_def is not None and isinstance(mem_def, MemoryDef): + if mem_def.inst.parent == inst.parent: + self.eda[inst].add(mem_def.inst) + elif ( + effects.MEMORY in last_write_effects + and last_write_effects[effects.MEMORY] != inst + ): + self.eda[inst].add(last_write_effects[effects.MEMORY]) + + last_read_effects[effects.MEMORY] = inst + def _calculate_data_offspring(self, inst: IRInstruction): if inst in self.data_offspring: return self.data_offspring[inst] @@ -128,3 +182,47 @@ def _calculate_data_offspring(self, inst: IRInstruction): self.data_offspring[inst] |= res return self.data_offspring[inst] + + def _calculate_effective_reaching_defs(self, effect: effects.Effects): + self.effective_reaching_defs = {} + self.defs_to_uses: Dict[MemoryDef, List[MemoryUse]] = {} + for mem_use in self.mem_ssa[effect].get_memory_uses(): + if mem_use.inst.opcode != "mload": + continue + #if isinstance(mem_use.inst.operands[0], IRVariable): + # continue + mem_def = self._walk_for_effective_reaching_def( + mem_use.reaching_def, mem_use.loc, OrderedSet(), effect + ) + self.effective_reaching_defs[mem_use] = mem_def + if mem_def not in self.defs_to_uses: + self.defs_to_uses[mem_def] = [] + self.defs_to_uses[mem_def].append(mem_use) + + def _walk_for_effective_reaching_def( + self, mem_access: MemoryAccess, query_loc: MemoryLocation, visited: OrderedSet[MemoryAccess], effect: effects.Effects + ) -> Optional[MemoryDef | MemoryPhi | LiveOnEntry]: + current: Optional[MemoryAccess] = mem_access + while current is not None: + if current in visited: + break + visited.add(current) + + if isinstance(current, MemoryDef): + if self.mem_ssa[effect].memalias.may_alias(query_loc, current.loc): + return current + + if isinstance(current, MemoryPhi): + reaching_defs = [] + for access, _ in current.operands: + reaching_def = self._walk_for_effective_reaching_def(access, query_loc, visited, effect) + if reaching_def: + reaching_defs.append(reaching_def) + if len(reaching_defs) == 1: + return reaching_defs[0] + + return current + + current = current.reaching_def + + return MemSSAAbstract.live_on_entry