Skip to content

[mypyc] Call generator helper method directly in await expression #19376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Jul 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mypyc/irbuild/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def c() -> None:
)

# Re-enter the FuncItem and visit the body of the function this time.
gen_generator_func_body(builder, fn_info, sig, func_reg)
gen_generator_func_body(builder, fn_info, func_reg)
else:
func_ir, func_reg = gen_func_body(builder, sig, cdef, is_singledispatch)

Expand Down
64 changes: 17 additions & 47 deletions mypyc/irbuild/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from typing import Callable

from mypy.nodes import ARG_OPT, FuncDef, Var
from mypyc.common import ENV_ATTR_NAME, NEXT_LABEL_ATTR_NAME, SELF_NAME
from mypyc.common import ENV_ATTR_NAME, NEXT_LABEL_ATTR_NAME
from mypyc.ir.class_ir import ClassIR
from mypyc.ir.func_ir import FuncDecl, FuncIR, FuncSignature, RuntimeArg
from mypyc.ir.func_ir import FuncDecl, FuncIR
from mypyc.ir.ops import (
NO_TRACEBACK_LINE_NO,
BasicBlock,
Expand Down Expand Up @@ -78,17 +78,15 @@ def gen_generator_func(
return func_ir, func_reg


def gen_generator_func_body(
builder: IRBuilder, fn_info: FuncInfo, sig: FuncSignature, func_reg: Value | None
) -> None:
def gen_generator_func_body(builder: IRBuilder, fn_info: FuncInfo, func_reg: Value | None) -> None:
"""Generate IR based on the body of a generator function.

Add "__next__", "__iter__" and other generator methods to the generator
class that implements the function (each function gets a separate class).

Return the symbol table for the body.
"""
builder.enter(fn_info, ret_type=sig.ret_type)
builder.enter(fn_info, ret_type=object_rprimitive)
setup_env_for_generator_class(builder)

load_outer_envs(builder, builder.fn_info.generator_class)
Expand Down Expand Up @@ -117,7 +115,7 @@ class that implements the function (each function gets a separate class).

args, _, blocks, ret_type, fn_info = builder.leave()

add_methods_to_generator_class(builder, fn_info, sig, args, blocks, fitem.is_coroutine)
add_methods_to_generator_class(builder, fn_info, args, blocks, fitem.is_coroutine)

# Evaluate argument defaults in the surrounding scope, since we
# calculate them *once* when the function definition is evaluated.
Expand Down Expand Up @@ -153,10 +151,9 @@ def instantiate_generator_class(builder: IRBuilder) -> Value:


def setup_generator_class(builder: IRBuilder) -> ClassIR:
name = f"{builder.fn_info.namespaced_name()}_gen"

generator_class_ir = ClassIR(name, builder.module_name, is_generated=True, is_final_class=True)
generator_class_ir.reuse_freed_instance = True
mapper = builder.mapper
assert isinstance(builder.fn_info.fitem, FuncDef)
generator_class_ir = mapper.fdef_to_generator[builder.fn_info.fitem]
if builder.fn_info.can_merge_generator_and_env_classes():
builder.fn_info.env_class = generator_class_ir
else:
Expand Down Expand Up @@ -216,46 +213,25 @@ def add_raise_exception_blocks_to_generator_class(builder: IRBuilder, line: int)
def add_methods_to_generator_class(
builder: IRBuilder,
fn_info: FuncInfo,
sig: FuncSignature,
arg_regs: list[Register],
blocks: list[BasicBlock],
is_coroutine: bool,
) -> None:
helper_fn_decl = add_helper_to_generator_class(builder, arg_regs, blocks, sig, fn_info)
add_next_to_generator_class(builder, fn_info, helper_fn_decl, sig)
add_send_to_generator_class(builder, fn_info, helper_fn_decl, sig)
helper_fn_decl = add_helper_to_generator_class(builder, arg_regs, blocks, fn_info)
add_next_to_generator_class(builder, fn_info, helper_fn_decl)
add_send_to_generator_class(builder, fn_info, helper_fn_decl)
add_iter_to_generator_class(builder, fn_info)
add_throw_to_generator_class(builder, fn_info, helper_fn_decl, sig)
add_throw_to_generator_class(builder, fn_info, helper_fn_decl)
add_close_to_generator_class(builder, fn_info)
if is_coroutine:
add_await_to_generator_class(builder, fn_info)


def add_helper_to_generator_class(
builder: IRBuilder,
arg_regs: list[Register],
blocks: list[BasicBlock],
sig: FuncSignature,
fn_info: FuncInfo,
builder: IRBuilder, arg_regs: list[Register], blocks: list[BasicBlock], fn_info: FuncInfo
) -> FuncDecl:
"""Generates a helper method for a generator class, called by '__next__' and 'throw'."""
sig = FuncSignature(
(
RuntimeArg(SELF_NAME, object_rprimitive),
RuntimeArg("type", object_rprimitive),
RuntimeArg("value", object_rprimitive),
RuntimeArg("traceback", object_rprimitive),
RuntimeArg("arg", object_rprimitive),
),
sig.ret_type,
)
helper_fn_decl = FuncDecl(
"__mypyc_generator_helper__",
fn_info.generator_class.ir.name,
builder.module_name,
sig,
internal=True,
)
helper_fn_decl = fn_info.generator_class.ir.method_decls["__mypyc_generator_helper__"]
helper_fn_ir = FuncIR(
helper_fn_decl, arg_regs, blocks, fn_info.fitem.line, traceback_name=fn_info.fitem.name
)
Expand All @@ -272,9 +248,7 @@ def add_iter_to_generator_class(builder: IRBuilder, fn_info: FuncInfo) -> None:
builder.add(Return(builder.self()))


def add_next_to_generator_class(
builder: IRBuilder, fn_info: FuncInfo, fn_decl: FuncDecl, sig: FuncSignature
) -> None:
def add_next_to_generator_class(builder: IRBuilder, fn_info: FuncInfo, fn_decl: FuncDecl) -> None:
"""Generates the '__next__' method for a generator class."""
with builder.enter_method(fn_info.generator_class.ir, "__next__", object_rprimitive, fn_info):
none_reg = builder.none_object()
Expand All @@ -289,9 +263,7 @@ def add_next_to_generator_class(
builder.add(Return(result))


def add_send_to_generator_class(
builder: IRBuilder, fn_info: FuncInfo, fn_decl: FuncDecl, sig: FuncSignature
) -> None:
def add_send_to_generator_class(builder: IRBuilder, fn_info: FuncInfo, fn_decl: FuncDecl) -> None:
"""Generates the 'send' method for a generator class."""
with builder.enter_method(fn_info.generator_class.ir, "send", object_rprimitive, fn_info):
arg = builder.add_argument("arg", object_rprimitive)
Expand All @@ -307,9 +279,7 @@ def add_send_to_generator_class(
builder.add(Return(result))


def add_throw_to_generator_class(
builder: IRBuilder, fn_info: FuncInfo, fn_decl: FuncDecl, sig: FuncSignature
) -> None:
def add_throw_to_generator_class(builder: IRBuilder, fn_info: FuncInfo, fn_decl: FuncDecl) -> None:
"""Generates the 'throw' method for a generator class."""
with builder.enter_method(fn_info.generator_class.ir, "throw", object_rprimitive, fn_info):
typ = builder.add_argument("type", object_rprimitive)
Expand Down
17 changes: 15 additions & 2 deletions mypyc/irbuild/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def f(x: int) -> int:
from typing import Any, Callable, TypeVar, cast

from mypy.build import Graph
from mypy.nodes import ClassDef, Expression, MypyFile
from mypy.nodes import ClassDef, Expression, FuncDef, MypyFile
from mypy.state import state
from mypy.types import Type
from mypyc.analysis.attrdefined import analyze_always_defined_attrs
Expand All @@ -37,7 +37,11 @@ def f(x: int) -> int:
from mypyc.irbuild.builder import IRBuilder
from mypyc.irbuild.mapper import Mapper
from mypyc.irbuild.prebuildvisitor import PreBuildVisitor
from mypyc.irbuild.prepare import build_type_map, find_singledispatch_register_impls
from mypyc.irbuild.prepare import (
build_type_map,
create_generator_class_if_needed,
find_singledispatch_register_impls,
)
from mypyc.irbuild.visitor import IRBuilderVisitor
from mypyc.irbuild.vtable import compute_vtable
from mypyc.options import CompilerOptions
Expand Down Expand Up @@ -76,6 +80,15 @@ def build_ir(
pbv = PreBuildVisitor(errors, module, singledispatch_info.decorators_to_remove, types)
module.accept(pbv)

# Declare generator classes for nested async functions and generators.
for fdef in pbv.nested_funcs:
if isinstance(fdef, FuncDef):
# Make generator class name sufficiently unique.
suffix = f"___{fdef.line}"
create_generator_class_if_needed(
module.fullname, None, fdef, mapper, name_suffix=suffix
)

# Construct and configure builder objects (cyclic runtime dependency).
visitor = IRBuilderVisitor()
builder = IRBuilder(
Expand Down
11 changes: 10 additions & 1 deletion mypyc/irbuild/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def __init__(self, group_map: dict[str, str | None]) -> None:
self.type_to_ir: dict[TypeInfo, ClassIR] = {}
self.func_to_decl: dict[SymbolNode, FuncDecl] = {}
self.symbol_fullnames: set[str] = set()
# The corresponding generator class that implements a generator/async function
self.fdef_to_generator: dict[FuncDef, ClassIR] = {}

def type_to_rtype(self, typ: Type | None) -> RType:
if typ is None:
Expand Down Expand Up @@ -171,7 +173,14 @@ def fdef_to_sig(self, fdef: FuncDef, strict_dunders_typing: bool) -> FuncSignatu
for typ, kind in zip(fdef.type.arg_types, fdef.type.arg_kinds)
]
arg_pos_onlys = [name is None for name in fdef.type.arg_names]
ret = self.type_to_rtype(fdef.type.ret_type)
# TODO: We could probably support decorators sometimes (static and class method?)
if (fdef.is_coroutine or fdef.is_generator) and not fdef.is_decorated:
# Give a more precise type for generators, so that we can optimize
# code that uses them. They return a generator object, which has a
# specific class. Without this, the type would have to be 'object'.
ret: RType = RInstance(self.fdef_to_generator[fdef])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the first part of the optimization.

else:
ret = self.type_to_rtype(fdef.type.ret_type)
else:
# Handle unannotated functions
arg_types = [object_rprimitive for _ in fdef.arguments]
Expand Down
47 changes: 44 additions & 3 deletions mypyc/irbuild/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from mypy.semanal import refers_to_fullname
from mypy.traverser import TraverserVisitor
from mypy.types import Instance, Type, get_proper_type
from mypyc.common import PROPSET_PREFIX, get_id_from_name
from mypyc.common import PROPSET_PREFIX, SELF_NAME, get_id_from_name
from mypyc.crash import catch_errors
from mypyc.errors import Errors
from mypyc.ir.class_ir import ClassIR
Expand All @@ -51,7 +51,14 @@
RuntimeArg,
)
from mypyc.ir.ops import DeserMaps
from mypyc.ir.rtypes import RInstance, RType, dict_rprimitive, none_rprimitive, tuple_rprimitive
from mypyc.ir.rtypes import (
RInstance,
RType,
dict_rprimitive,
none_rprimitive,
object_rprimitive,
tuple_rprimitive,
)
from mypyc.irbuild.mapper import Mapper
from mypyc.irbuild.util import (
get_func_def,
Expand Down Expand Up @@ -115,7 +122,7 @@ def build_type_map(

# Collect all the functions also. We collect from the symbol table
# so that we can easily pick out the right copy of a function that
# is conditionally defined.
# is conditionally defined. This doesn't include nested functions!
for module in modules:
for func in get_module_func_defs(module):
prepare_func_def(module.fullname, None, func, mapper, options)
Expand Down Expand Up @@ -179,6 +186,8 @@ def prepare_func_def(
mapper: Mapper,
options: CompilerOptions,
) -> FuncDecl:
create_generator_class_if_needed(module_name, class_name, fdef, mapper)

kind = (
FUNC_STATICMETHOD
if fdef.is_static
Expand All @@ -190,6 +199,38 @@ def prepare_func_def(
return decl


def create_generator_class_if_needed(
module_name: str, class_name: str | None, fdef: FuncDef, mapper: Mapper, name_suffix: str = ""
) -> None:
"""If function is a generator/async function, declare a generator class.

Each generator and async function gets a dedicated class that implements the
generator protocol with generated methods.
"""
if fdef.is_coroutine or fdef.is_generator:
name = "_".join(x for x in [fdef.name, class_name] if x) + "_gen" + name_suffix
cir = ClassIR(name, module_name, is_generated=True, is_final_class=True)
cir.reuse_freed_instance = True
mapper.fdef_to_generator[fdef] = cir

helper_sig = FuncSignature(
(
RuntimeArg(SELF_NAME, object_rprimitive),
RuntimeArg("type", object_rprimitive),
RuntimeArg("value", object_rprimitive),
RuntimeArg("traceback", object_rprimitive),
RuntimeArg("arg", object_rprimitive),
),
object_rprimitive,
)

# The implementation of most generator functionality is behind this magic method.
helper_fn_decl = FuncDecl(
"__mypyc_generator_helper__", name, module_name, helper_sig, internal=True
)
cir.method_decls[helper_fn_decl.name] = helper_fn_decl


def prepare_method_def(
ir: ClassIR,
module_name: str,
Expand Down
39 changes: 33 additions & 6 deletions mypyc/irbuild/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,13 @@
)
from mypyc.common import TEMP_ATTR_NAME
from mypyc.ir.ops import (
ERR_NEVER,
NAMESPACE_MODULE,
NO_TRACEBACK_LINE_NO,
Assign,
BasicBlock,
Branch,
Call,
InitStatic,
Integer,
LoadAddress,
Expand Down Expand Up @@ -930,16 +932,41 @@ def emit_yield_from_or_await(
to_yield_reg = Register(object_rprimitive)
received_reg = Register(object_rprimitive)

get_op = coro_op if is_await else iter_op
if isinstance(get_op, PrimitiveDescription):
iter_val = builder.primitive_op(get_op, [val], line)
helper_method = "__mypyc_generator_helper__"
if (
isinstance(val, (Call, MethodCall))
and isinstance(val.type, RInstance)
and val.type.class_ir.has_method(helper_method)
):
# This is a generated native generator class, and we can use a fast path.
# This allows two optimizations:
# 1) No need to call CPy_GetCoro() or iter() since for native generators
# it just returns the generator object (implemented here).
# 2) Instead of calling next(), call generator helper method directly,
# since next() just calls __next__ which calls the helper method.
iter_val: Value = val
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes in this function implemented the bulk of the optimization.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a bit more detailed comment (essentially summarizing the PR description) so we will not forget the motivation.

else:
iter_val = builder.call_c(get_op, [val], line)
get_op = coro_op if is_await else iter_op
if isinstance(get_op, PrimitiveDescription):
iter_val = builder.primitive_op(get_op, [val], line)
else:
iter_val = builder.call_c(get_op, [val], line)

iter_reg = builder.maybe_spill_assignable(iter_val)

stop_block, main_block, done_block = BasicBlock(), BasicBlock(), BasicBlock()
_y_init = builder.call_c(next_raw_op, [builder.read(iter_reg)], line)

if isinstance(iter_reg.type, RInstance) and iter_reg.type.class_ir.has_method(helper_method):
# Second fast path optimization: call helper directly (see also comment above).
obj = builder.read(iter_reg)
nn = builder.none_object()
m = MethodCall(obj, helper_method, [nn, nn, nn, nn], line)
# Generators have custom error handling, so disable normal error handling.
m.error_kind = ERR_NEVER
_y_init = builder.add(m)
else:
_y_init = builder.call_c(next_raw_op, [builder.read(iter_reg)], line)

builder.add(Branch(_y_init, stop_block, main_block, Branch.IS_ERROR))

# Try extracting a return value from a StopIteration and return it.
Expand All @@ -948,7 +975,7 @@ def emit_yield_from_or_await(
builder.assign(result, builder.call_c(check_stop_op, [], line), line)
# Clear the spilled iterator/coroutine so that it will be freed.
# Otherwise, the freeing of the spilled register would likely be delayed.
err = builder.add(LoadErrorValue(object_rprimitive))
err = builder.add(LoadErrorValue(iter_reg.type))
builder.assign(iter_reg, err, line)
builder.goto(done_block)

Expand Down
Loading