Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 39 additions & 8 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4038,7 +4038,21 @@ def lookup_definer(typ: Instance, attr_name: str) -> str | None:
# We store the determined order inside the 'variants_raw' variable,
# which records tuples containing the method, base type, and the argument.

if op_name in operators.op_methods_that_shortcut and is_same_type(left_type, right_type):
if (
op_name in operators.op_methods_that_shortcut
and is_same_type(left_type, right_type)
and not (
# We consider typevars with equal IDs "same types" even if some narrowing
# has been applied. However, different bounds here might come from union
# expansion applied earlier, so we are not supposed to check them as
# being same types here. For plain union items `is_same_type` will
# return false, but not for typevars having these items as bounds.
# See testReversibleOpOnTypeVarProtocol.
isinstance(left_type, TypeVarType)
and isinstance(right_type, TypeVarType)
and not is_same_type(left_type.upper_bound, right_type.upper_bound)
)
):
# When we do "A() + A()", for example, Python will only call the __add__ method,
# never the __radd__ method.
#
Expand Down Expand Up @@ -4151,10 +4165,9 @@ def check_op(
"""

if allow_reverse:
left_variants = [base_type]
left_variants = self._union_items_from_typevar(base_type)
base_type = get_proper_type(base_type)
if isinstance(base_type, UnionType):
left_variants = list(flatten_nested_unions(base_type.relevant_items()))

right_type = self.accept(arg)

# Step 1: We first try leaving the right arguments alone and destructure
Expand Down Expand Up @@ -4192,13 +4205,17 @@ def check_op(
# We don't do the same for the base expression because it could lead to weird
# type inference errors -- e.g. see 'testOperatorDoubleUnionSum'.
# TODO: Can we use `type_overrides_set()` here?
right_variants = [(right_type, arg)]
right_type = get_proper_type(right_type)
if isinstance(right_type, UnionType):
right_variants: list[tuple[Type, Expression]]
p_right = get_proper_type(right_type)
if isinstance(p_right, (UnionType, TypeVarType)):
right_variants = [
(item, TempNode(item, context=context))
for item in flatten_nested_unions(right_type.relevant_items())
for item in self._union_items_from_typevar(right_type)
]
else:
# Preserve argument identity if we do not intend to modify it
right_variants = [(right_type, arg)]
right_type = p_right

all_results = []
all_inferred = []
Expand Down Expand Up @@ -4248,6 +4265,20 @@ def check_op(
context=context,
)

def _union_items_from_typevar(self, typ: Type) -> list[Type]:
variants = [typ]
typ = get_proper_type(typ)
base_type = typ
if unwrapped := (isinstance(typ, TypeVarType) and not typ.values):
typ = get_proper_type(typ.upper_bound)
if is_union := isinstance(typ, UnionType):
variants = list(flatten_nested_unions(typ.relevant_items()))
if is_union and unwrapped:
# If not a union, keep the original type
assert isinstance(base_type, TypeVarType)
variants = [base_type.copy_modified(upper_bound=item) for item in variants]
return variants

def check_boolean_op(self, e: OpExpr) -> Type:
"""Type check a boolean operation ('and' or 'or')."""

Expand Down
4 changes: 4 additions & 0 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,10 @@ def is_same_type(
):
return all(is_same_type(x, y) for x, y in zip(a.args, b.args))
elif isinstance(a, TypeVarType) and isinstance(b, TypeVarType) and a.id == b.id:
# This is not only a performance optimization. Deeper check will compare upper
# bounds, but we want to consider copies of the same type variable "same type".
# This makes sense semantically: even we have narrowed the upper bound somehow,
# it's still the same object it used to be before.
return True

# Note that using ignore_promotions=True (default) makes types like int and int64
Expand Down
80 changes: 80 additions & 0 deletions test-data/unit/check-expressions.test
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,86 @@ if int():
class C:
def __lt__(self, o: object, x: str = "") -> int: ...

[case testReversibleOpOnTypeVarBound]
from typing import TypeVar, Union

class A:
def __lt__(self, a: A) -> bool: ...
def __gt__(self, a: A) -> bool: ...

class B(A):
def __lt__(self, b: B) -> bool: ... # type: ignore[override]
def __gt__(self, b: B) -> bool: ... # type: ignore[override]

_T = TypeVar("_T", bound=Union[A, B])

def check(x: _T, y: _T) -> bool:
return x < y

[case testReversibleOpOnTypeVarBoundPromotion]
from typing import TypeVar, Union

_T = TypeVar("_T", bound=Union[int, float])

def check(x: _T, y: _T) -> bool:
return x < y
[builtins fixtures/ops.pyi]

[case testReversibleOpOnTypeVarProtocol]
# https://github.com/python/mypy/issues/18203
from typing import Protocol, TypeVar, Union
from typing_extensions import Self, runtime_checkable

class A(Protocol):
def __add__(self, other: Union[int, Self]) -> Self: ...
def __radd__(self, other: Union[int, Self]) -> Self: ...

AT = TypeVar("AT", bound=Union[int, A])

def f(a: AT, b: AT) -> None:
reveal_type(a + a) # N: Revealed type is "Union[builtins.int, AT`-1]"
reveal_type(a + b) # N: Revealed type is "Union[builtins.int, AT`-1]"
if isinstance(a, int):
reveal_type(a) # N: Revealed type is "AT`-1"
reveal_type(a + a) # N: Revealed type is "builtins.int"
reveal_type(a + b) # N: Revealed type is "Union[builtins.int, AT`-1]"
reveal_type(b + a) # N: Revealed type is "Union[builtins.int, AT`-1]"

@runtime_checkable
class B(Protocol):
def __radd__(self, other: Union[int, Self]) -> Self: ...

BT = TypeVar("BT", bound=Union[int, B])

def g(a: BT, b: BT) -> None:
reveal_type(a + a) # E: Unsupported left operand type for + ("BT") \
# N: Both left and right operands are unions \
# N: Revealed type is "Union[builtins.int, BT`-1, Any]"
reveal_type(a + b) # E: Unsupported left operand type for + ("BT") \
# N: Both left and right operands are unions \
# N: Revealed type is "Union[builtins.int, BT`-1, Any]"
if isinstance(a, int):
reveal_type(a) # N: Revealed type is "BT`-1"
reveal_type(0 + a) # N: Revealed type is "builtins.int"
reveal_type(a + 0) # N: Revealed type is "builtins.int"
reveal_type(a + a) # N: Revealed type is "builtins.int"
reveal_type(a + b) # N: Revealed type is "Union[builtins.int, BT`-1]"
reveal_type(b + a) # E: Unsupported left operand type for + ("BT") \
# N: Left operand is of type "BT" \
# N: Revealed type is "Union[builtins.int, Any]"
if isinstance(a, B):
reveal_type(a) # N: Revealed type is "BT`-1"
reveal_type(0 + a) # N: Revealed type is "BT`-1"
reveal_type(a + 0) # E: Unsupported left operand type for + ("BT") \
# N: Revealed type is "Any"
reveal_type(a + a) # E: Unsupported left operand type for + ("BT") \
# N: Revealed type is "Any"
reveal_type(a + b) # E: Unsupported left operand type for + ("BT") \
# N: Right operand is of type "BT" \
# N: Revealed type is "Any"
[builtins fixtures/isinstance.pyi]


[case testErrorContextAndBinaryOperators]
import typing
class A:
Expand Down
6 changes: 6 additions & 0 deletions test-data/unit/fixtures/ops.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ class float:
def __rdiv__(self, x: 'float') -> 'float': pass
def __truediv__(self, x: 'float') -> 'float': pass
def __rtruediv__(self, x: 'float') -> 'float': pass
def __eq__(self, x: object) -> bool: pass
def __ne__(self, x: object) -> bool: pass
def __lt__(self, x: 'float') -> bool: pass
def __le__(self, x: 'float') -> bool: pass
def __gt__(self, x: 'float') -> bool: pass
def __ge__(self, x: 'float') -> bool: pass

class complex:
def __add__(self, x: complex) -> complex: pass
Expand Down
Loading