From 412f63bfe1a8bcc0682f5986b0cf462e59bcae0d Mon Sep 17 00:00:00 2001 From: STerliakov Date: Wed, 16 Jul 2025 03:31:28 +0200 Subject: [PATCH 1/6] Apply union expansion when checking ops to typevars --- mypy/checkexpr.py | 31 +++++++++++++++++++++------ test-data/unit/check-expressions.test | 26 ++++++++++++++++++++++ test-data/unit/fixtures/ops.pyi | 6 ++++++ 3 files changed, 56 insertions(+), 7 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 8223ccfe4ca0..9a9df97be358 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -4155,10 +4155,9 @@ def check_op( """ if allow_reverse: - left_variants = [base_type] + left_variants = self._union_items_from_typevar(base_type) base_type = get_proper_type(base_type) - if isinstance(base_type, UnionType): - left_variants = list(flatten_nested_unions(base_type.relevant_items())) + right_type = self.accept(arg) # Step 1: We first try leaving the right arguments alone and destructure @@ -4196,13 +4195,18 @@ def check_op( # We don't do the same for the base expression because it could lead to weird # type inference errors -- e.g. see 'testOperatorDoubleUnionSum'. # TODO: Can we use `type_overrides_set()` here? - right_variants = [(right_type, arg)] - right_type = get_proper_type(right_type) - if isinstance(right_type, UnionType): + right_variants: list[tuple[Type, Expression]] + if isinstance(right_type, ProperType) and isinstance( + right_type, (UnionType, TypeVarType) + ): right_variants = [ (item, TempNode(item, context=context)) - for item in flatten_nested_unions(right_type.relevant_items()) + for item in self._union_items_from_typevar(right_type) ] + else: + # Preserve argument identity if we do not intend to modify it + right_variants = [(right_type, arg)] + right_type = get_proper_type(right_type) all_results = [] all_inferred = [] @@ -4252,6 +4256,19 @@ def check_op( context=context, ) + def _union_items_from_typevar(self, typ: Type) -> list[Type]: + variants = [typ] + typ = get_proper_type(typ) + base_type = typ + if unwrapped := (isinstance(typ, TypeVarType) and not typ.values): + typ = get_proper_type(typ.upper_bound) + if isinstance(typ, UnionType): + variants = list(flatten_nested_unions(typ.relevant_items())) + if unwrapped: + assert isinstance(base_type, TypeVarType) + variants = [base_type.copy_modified(upper_bound=item) for item in variants] + return variants + def check_boolean_op(self, e: OpExpr, context: Context) -> Type: """Type check a boolean operation ('and' or 'or').""" diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index 33271a3cc04c..5c2cad914bce 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -706,6 +706,32 @@ if int(): class C: def __lt__(self, o: object, x: str = "") -> int: ... +[case testReversibleOpOnTypeVarBound] +from typing import TypeVar, Union + +class A: + def __lt__(self, a: A) -> bool: ... + def __gt__(self, a: A) -> bool: ... + +class B(A): + def __lt__(self, b: B) -> bool: ... # type: ignore[override] + def __gt__(self, b: B) -> bool: ... # type: ignore[override] + +_T = TypeVar("_T", bound=Union[A, B]) + +def check(x: _T, y: _T) -> bool: + return x < y + +[case testReversibleOpOnTypeVarBoundPromotion] +from typing import TypeVar, Union + +_T = TypeVar("_T", bound=Union[int, float]) + +def check(x: _T, y: _T) -> bool: + return x < y +[builtins fixtures/ops.pyi] + + [case testErrorContextAndBinaryOperators] import typing class A: diff --git a/test-data/unit/fixtures/ops.pyi b/test-data/unit/fixtures/ops.pyi index 67bc74b35c51..34e512b34984 100644 --- a/test-data/unit/fixtures/ops.pyi +++ b/test-data/unit/fixtures/ops.pyi @@ -61,6 +61,12 @@ class float: def __rdiv__(self, x: 'float') -> 'float': pass def __truediv__(self, x: 'float') -> 'float': pass def __rtruediv__(self, x: 'float') -> 'float': pass + def __eq__(self, x: object) -> bool: pass + def __ne__(self, x: object) -> bool: pass + def __lt__(self, x: 'float') -> bool: pass + def __le__(self, x: 'float') -> bool: pass + def __gt__(self, x: 'float') -> bool: pass + def __ge__(self, x: 'float') -> bool: pass class complex: def __add__(self, x: complex) -> complex: pass From fdab83007fae33a93020f282887bd14c911b70e4 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Wed, 16 Jul 2025 03:58:44 +0200 Subject: [PATCH 2/6] Preserve type identity --- mypy/checkexpr.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 9a9df97be358..3c870f90ee2b 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -4195,17 +4195,10 @@ def check_op( # We don't do the same for the base expression because it could lead to weird # type inference errors -- e.g. see 'testOperatorDoubleUnionSum'. # TODO: Can we use `type_overrides_set()` here? - right_variants: list[tuple[Type, Expression]] - if isinstance(right_type, ProperType) and isinstance( - right_type, (UnionType, TypeVarType) - ): - right_variants = [ - (item, TempNode(item, context=context)) - for item in self._union_items_from_typevar(right_type) - ] - else: - # Preserve argument identity if we do not intend to modify it - right_variants = [(right_type, arg)] + right_variants = [ + (item, TempNode(item, context=context)) + for item in self._union_items_from_typevar(right_type) + ] right_type = get_proper_type(right_type) all_results = [] @@ -4262,9 +4255,10 @@ def _union_items_from_typevar(self, typ: Type) -> list[Type]: base_type = typ if unwrapped := (isinstance(typ, TypeVarType) and not typ.values): typ = get_proper_type(typ.upper_bound) - if isinstance(typ, UnionType): + if is_union := isinstance(typ, UnionType): variants = list(flatten_nested_unions(typ.relevant_items())) - if unwrapped: + if is_union and unwrapped: + # If not a union, keep the original type assert isinstance(base_type, TypeVarType) variants = [base_type.copy_modified(upper_bound=item) for item in variants] return variants From 30c0e32e81c0dd7415b09d86dbee381a5efd9921 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Wed, 16 Jul 2025 04:39:21 +0200 Subject: [PATCH 3/6] Retain original arg if possible --- mypy/checkexpr.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 3c870f90ee2b..d9e4a704155c 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -4195,11 +4195,17 @@ def check_op( # We don't do the same for the base expression because it could lead to weird # type inference errors -- e.g. see 'testOperatorDoubleUnionSum'. # TODO: Can we use `type_overrides_set()` here? - right_variants = [ - (item, TempNode(item, context=context)) - for item in self._union_items_from_typevar(right_type) - ] - right_type = get_proper_type(right_type) + right_variants: list[tuple[Type, Expression]] + p_right = get_proper_type(right_type) + if isinstance(p_right, (UnionType, TypeVarType)): + right_variants = [ + (item, TempNode(item, context=context)) + for item in self._union_items_from_typevar(right_type) + ] + else: + # Preserve argument identity if we do not intend to modify it + right_variants = [(right_type, arg)] + right_type = p_right all_results = [] all_inferred = [] From bd109f71405aba6777e37a37d1144df4b7334f51 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Fri, 8 Aug 2025 18:10:47 +0200 Subject: [PATCH 4/6] Only consider typevars same type after deeper comparison - we already have other places that create same-id copies of typevars --- mypy/subtypes.py | 6 ++++-- test-data/unit/check-expressions.test | 15 +++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 7da258a827f3..a8f48fb1321f 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -264,7 +264,9 @@ def is_same_type( a non-simplified union) but are semantically exchangeable in all contexts. """ # First, use fast path for some common types. This is performance-critical. - if ( + if a is b: + return True + elif ( type(a) is Instance and type(b) is Instance and a.type == b.type @@ -272,7 +274,7 @@ def is_same_type( and a.last_known_value is b.last_known_value ): return all(is_same_type(x, y) for x, y in zip(a.args, b.args)) - elif isinstance(a, TypeVarType) and isinstance(b, TypeVarType) and a.id == b.id: + elif a == b: return True # Note that using ignore_promotions=True (default) makes types like int and int64 diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index 5c2cad914bce..487c3bbe14a0 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -731,6 +731,21 @@ def check(x: _T, y: _T) -> bool: return x < y [builtins fixtures/ops.pyi] +[case testReversibleOpOnTypeVarProtocol] +# https://github.com/python/mypy/issues/18203 +from typing import Protocol, TypeVar, Union +from typing_extensions import Self + +class A(Protocol): + def __add__(self, other: Union[int, Self]) -> Self: ... + def __radd__(self, other: Union[int, Self]) -> Self: ... + +AT = TypeVar("AT", bound=Union[int, A]) + +def f(a: AT, _b: AT) -> None: + a + a +[builtins fixtures/ops.pyi] + [case testErrorContextAndBinaryOperators] import typing From 80b9fcd441c0863a01a840b0f440dec9a1fcf49f Mon Sep 17 00:00:00 2001 From: STerliakov Date: Fri, 8 Aug 2025 19:01:07 +0200 Subject: [PATCH 5/6] Revert last change --- mypy/subtypes.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index a8f48fb1321f..7da258a827f3 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -264,9 +264,7 @@ def is_same_type( a non-simplified union) but are semantically exchangeable in all contexts. """ # First, use fast path for some common types. This is performance-critical. - if a is b: - return True - elif ( + if ( type(a) is Instance and type(b) is Instance and a.type == b.type @@ -274,7 +272,7 @@ def is_same_type( and a.last_known_value is b.last_known_value ): return all(is_same_type(x, y) for x, y in zip(a.args, b.args)) - elif a == b: + elif isinstance(a, TypeVarType) and isinstance(b, TypeVarType) and a.id == b.id: return True # Note that using ignore_promotions=True (default) makes types like int and int64 From 6b4e7a16c9f6f747f4b84585b1a4c21d1e7a69ba Mon Sep 17 00:00:00 2001 From: STerliakov Date: Fri, 8 Aug 2025 19:46:53 +0200 Subject: [PATCH 6/6] Move this check to special case in check_op --- mypy/checkexpr.py | 16 ++++++++- mypy/subtypes.py | 4 +++ test-data/unit/check-expressions.test | 47 ++++++++++++++++++++++++--- 3 files changed, 62 insertions(+), 5 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 61b69753710d..1e6d2dbcc824 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -4038,7 +4038,21 @@ def lookup_definer(typ: Instance, attr_name: str) -> str | None: # We store the determined order inside the 'variants_raw' variable, # which records tuples containing the method, base type, and the argument. - if op_name in operators.op_methods_that_shortcut and is_same_type(left_type, right_type): + if ( + op_name in operators.op_methods_that_shortcut + and is_same_type(left_type, right_type) + and not ( + # We consider typevars with equal IDs "same types" even if some narrowing + # has been applied. However, different bounds here might come from union + # expansion applied earlier, so we are not supposed to check them as + # being same types here. For plain union items `is_same_type` will + # return false, but not for typevars having these items as bounds. + # See testReversibleOpOnTypeVarProtocol. + isinstance(left_type, TypeVarType) + and isinstance(right_type, TypeVarType) + and not is_same_type(left_type.upper_bound, right_type.upper_bound) + ) + ): # When we do "A() + A()", for example, Python will only call the __add__ method, # never the __radd__ method. # diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 7da258a827f3..955a4d273e19 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -273,6 +273,10 @@ def is_same_type( ): return all(is_same_type(x, y) for x, y in zip(a.args, b.args)) elif isinstance(a, TypeVarType) and isinstance(b, TypeVarType) and a.id == b.id: + # This is not only a performance optimization. Deeper check will compare upper + # bounds, but we want to consider copies of the same type variable "same type". + # This makes sense semantically: even we have narrowed the upper bound somehow, + # it's still the same object it used to be before. return True # Note that using ignore_promotions=True (default) makes types like int and int64 diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index 487c3bbe14a0..fb92305a9818 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -734,7 +734,7 @@ def check(x: _T, y: _T) -> bool: [case testReversibleOpOnTypeVarProtocol] # https://github.com/python/mypy/issues/18203 from typing import Protocol, TypeVar, Union -from typing_extensions import Self +from typing_extensions import Self, runtime_checkable class A(Protocol): def __add__(self, other: Union[int, Self]) -> Self: ... @@ -742,9 +742,48 @@ class A(Protocol): AT = TypeVar("AT", bound=Union[int, A]) -def f(a: AT, _b: AT) -> None: - a + a -[builtins fixtures/ops.pyi] +def f(a: AT, b: AT) -> None: + reveal_type(a + a) # N: Revealed type is "Union[builtins.int, AT`-1]" + reveal_type(a + b) # N: Revealed type is "Union[builtins.int, AT`-1]" + if isinstance(a, int): + reveal_type(a) # N: Revealed type is "AT`-1" + reveal_type(a + a) # N: Revealed type is "builtins.int" + reveal_type(a + b) # N: Revealed type is "Union[builtins.int, AT`-1]" + reveal_type(b + a) # N: Revealed type is "Union[builtins.int, AT`-1]" + +@runtime_checkable +class B(Protocol): + def __radd__(self, other: Union[int, Self]) -> Self: ... + +BT = TypeVar("BT", bound=Union[int, B]) + +def g(a: BT, b: BT) -> None: + reveal_type(a + a) # E: Unsupported left operand type for + ("BT") \ + # N: Both left and right operands are unions \ + # N: Revealed type is "Union[builtins.int, BT`-1, Any]" + reveal_type(a + b) # E: Unsupported left operand type for + ("BT") \ + # N: Both left and right operands are unions \ + # N: Revealed type is "Union[builtins.int, BT`-1, Any]" + if isinstance(a, int): + reveal_type(a) # N: Revealed type is "BT`-1" + reveal_type(0 + a) # N: Revealed type is "builtins.int" + reveal_type(a + 0) # N: Revealed type is "builtins.int" + reveal_type(a + a) # N: Revealed type is "builtins.int" + reveal_type(a + b) # N: Revealed type is "Union[builtins.int, BT`-1]" + reveal_type(b + a) # E: Unsupported left operand type for + ("BT") \ + # N: Left operand is of type "BT" \ + # N: Revealed type is "Union[builtins.int, Any]" + if isinstance(a, B): + reveal_type(a) # N: Revealed type is "BT`-1" + reveal_type(0 + a) # N: Revealed type is "BT`-1" + reveal_type(a + 0) # E: Unsupported left operand type for + ("BT") \ + # N: Revealed type is "Any" + reveal_type(a + a) # E: Unsupported left operand type for + ("BT") \ + # N: Revealed type is "Any" + reveal_type(a + b) # E: Unsupported left operand type for + ("BT") \ + # N: Right operand is of type "BT" \ + # N: Revealed type is "Any" +[builtins fixtures/isinstance.pyi] [case testErrorContextAndBinaryOperators]