Skip to content

[mypyc] Fix error value check for GetAttr that allows nullable values #19378

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions mypyc/codegen/emitfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,16 @@ def visit_goto(self, op: Goto) -> None:
if op.label is not self.next_block:
self.emit_line("goto %s;" % self.label(op.label))

def error_value_check(self, value: Value, compare: str) -> str:
typ = value.type
if isinstance(typ, RTuple):
# TODO: What about empty tuple?
return self.emitter.tuple_undefined_check_cond(
typ, self.reg(value), self.c_error_value, compare
)
else:
return f"{self.reg(value)} {compare} {self.c_error_value(typ)}"

def visit_branch(self, op: Branch) -> None:
true, false = op.true, op.false
negated = op.negated
Expand All @@ -225,15 +235,8 @@ def visit_branch(self, op: Branch) -> None:
expr_result = self.reg(op.value)
cond = f"{neg}{expr_result}"
elif op.op == Branch.IS_ERROR:
typ = op.value.type
compare = "!=" if negated else "=="
if isinstance(typ, RTuple):
# TODO: What about empty tuple?
cond = self.emitter.tuple_undefined_check_cond(
typ, self.reg(op.value), self.c_error_value, compare
)
else:
cond = f"{self.reg(op.value)} {compare} {self.c_error_value(typ)}"
cond = self.error_value_check(op.value, compare)
else:
assert False, "Invalid branch"

Expand Down Expand Up @@ -358,8 +361,8 @@ def get_attr_expr(self, obj: str, op: GetAttr | SetAttr, decl_cl: ClassIR) -> st
return f"({cast}{obj})->{self.emitter.attr(op.attr)}"

def visit_get_attr(self, op: GetAttr) -> None:
if op.allow_null:
self.get_attr_with_allow_null(op)
if op.allow_error_value:
self.get_attr_with_allow_error_value(op)
return
dest = self.reg(op)
obj = self.reg(op.obj)
Expand Down Expand Up @@ -429,8 +432,11 @@ def visit_get_attr(self, op: GetAttr) -> None:
elif not always_defined:
self.emitter.emit_line("}")

def get_attr_with_allow_null(self, op: GetAttr) -> None:
"""Handle GetAttr with allow_null=True which allows NULL without raising AttributeError."""
def get_attr_with_allow_error_value(self, op: GetAttr) -> None:
"""Handle GetAttr with allow_error_value=True.

This allows NULL or other error value without raising AttributeError.
"""
dest = self.reg(op)
obj = self.reg(op.obj)
rtype = op.class_type
Expand All @@ -443,7 +449,8 @@ def get_attr_with_allow_null(self, op: GetAttr) -> None:

# Only emit inc_ref if not NULL
if attr_rtype.is_refcounted and not op.is_borrowed:
self.emitter.emit_line(f"if ({dest} != NULL) {{")
check = self.error_value_check(op, "!=")
self.emitter.emit_line(f"if ({check}) {{")
self.emitter.emit_inc_ref(dest, attr_rtype)
self.emitter.emit_line("}")

Expand Down
12 changes: 9 additions & 3 deletions mypyc/ir/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,17 +811,23 @@ class GetAttr(RegisterOp):
error_kind = ERR_MAGIC

def __init__(
self, obj: Value, attr: str, line: int, *, borrow: bool = False, allow_null: bool = False
self,
obj: Value,
attr: str,
line: int,
*,
borrow: bool = False,
allow_error_value: bool = False,
) -> None:
super().__init__(line)
self.obj = obj
self.attr = attr
self.allow_null = allow_null
self.allow_error_value = allow_error_value
assert isinstance(obj.type, RInstance), "Attribute access not supported: %s" % obj.type
self.class_type = obj.type
attr_type = obj.type.attr_type(attr)
self.type = attr_type
if allow_null:
if allow_error_value:
self.error_kind = ERR_NEVER
elif attr_type.error_overlap:
self.error_kind = ERR_MAGIC_OVERLAPPING
Expand Down
8 changes: 2 additions & 6 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,13 +709,9 @@ def read(
assert False, "Unsupported lvalue: %r" % target

def read_nullable_attr(self, obj: Value, attr: str, line: int = -1) -> Value:
"""Read an attribute that might be NULL without raising AttributeError.

This is used for reading spill targets in try/finally blocks where NULL
indicates the non-return path was taken.
"""
"""Read an attribute that might have an error value without raising AttributeError."""
assert isinstance(obj.type, RInstance) and obj.type.class_ir.is_ext_class
return self.add(GetAttr(obj, attr, line, allow_null=True))
return self.add(GetAttr(obj, attr, line, allow_error_value=True))

def assign(self, target: Register | AssignmentTarget, rvalue_reg: Value, line: int) -> None:
if isinstance(target, Register):
Expand Down
12 changes: 12 additions & 0 deletions mypyc/test/test_emitfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def add_local(name: str, rtype: RType) -> Register:
"y": int_rprimitive,
"i1": int64_rprimitive,
"i2": int32_rprimitive,
"t": RTuple([object_rprimitive, object_rprimitive]),
}
ir.bitmap_attrs = ["i1", "i2"]
compute_vtable(ir)
Expand Down Expand Up @@ -418,6 +419,17 @@ def test_get_attr_with_bitmap(self) -> None:
""",
)

def test_get_attr_nullable_with_tuple(self) -> None:
self.assert_emit(
GetAttr(self.r, "t", 1, allow_error_value=True),
"""cpy_r_r0 = ((mod___AObject *)cpy_r_r)->_t;
if (cpy_r_r0.f0 != NULL) {
CPy_INCREF(cpy_r_r0.f0);
CPy_INCREF(cpy_r_r0.f1);
}
""",
)

def test_set_attr(self) -> None:
self.assert_emit(
SetAttr(self.r, "y", self.m, 1),
Expand Down