Skip to content

Commit 116b92b

Browse files
authored
More detailed checking of type objects in stubtest (#18251)
This uses `checkmember.type_object_type` and context to produce better types of type objects.
1 parent 15b8ca9 commit 116b92b

File tree

2 files changed

+84
-10
lines changed

2 files changed

+84
-10
lines changed

mypy/stubtest.py

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
from typing_extensions import get_origin, is_typeddict
3535

3636
import mypy.build
37+
import mypy.checkexpr
38+
import mypy.checkmember
39+
import mypy.erasetype
3740
import mypy.modulefinder
3841
import mypy.nodes
3942
import mypy.state
@@ -792,7 +795,11 @@ def _verify_arg_default_value(
792795
"has a default value but stub parameter does not"
793796
)
794797
else:
795-
runtime_type = get_mypy_type_of_runtime_value(runtime_arg.default)
798+
type_context = stub_arg.variable.type
799+
runtime_type = get_mypy_type_of_runtime_value(
800+
runtime_arg.default, type_context=type_context
801+
)
802+
796803
# Fallback to the type annotation type if var type is missing. The type annotation
797804
# is an UnboundType, but I don't know enough to know what the pros and cons here are.
798805
# UnboundTypes have ugly question marks following them, so default to var type.
@@ -1247,7 +1254,7 @@ def verify_var(
12471254
):
12481255
yield Error(object_path, "is read-only at runtime but not in the stub", stub, runtime)
12491256

1250-
runtime_type = get_mypy_type_of_runtime_value(runtime)
1257+
runtime_type = get_mypy_type_of_runtime_value(runtime, type_context=stub.type)
12511258
if (
12521259
runtime_type is not None
12531260
and stub.type is not None
@@ -1832,7 +1839,18 @@ def is_subtype_helper(left: mypy.types.Type, right: mypy.types.Type) -> bool:
18321839
return mypy.subtypes.is_subtype(left, right)
18331840

18341841

1835-
def get_mypy_type_of_runtime_value(runtime: Any) -> mypy.types.Type | None:
1842+
def get_mypy_node_for_name(module: str, type_name: str) -> mypy.nodes.SymbolNode | None:
1843+
stub = get_stub(module)
1844+
if stub is None:
1845+
return None
1846+
if type_name not in stub.names:
1847+
return None
1848+
return stub.names[type_name].node
1849+
1850+
1851+
def get_mypy_type_of_runtime_value(
1852+
runtime: Any, type_context: mypy.types.Type | None = None
1853+
) -> mypy.types.Type | None:
18361854
"""Returns a mypy type object representing the type of ``runtime``.
18371855
18381856
Returns None if we can't find something that works.
@@ -1893,14 +1911,45 @@ def anytype() -> mypy.types.AnyType:
18931911
is_ellipsis_args=True,
18941912
)
18951913

1896-
# Try and look up a stub for the runtime object
1897-
stub = get_stub(type(runtime).__module__)
1898-
if stub is None:
1899-
return None
1900-
type_name = type(runtime).__name__
1901-
if type_name not in stub.names:
1914+
skip_type_object_type = False
1915+
if type_context:
1916+
# Don't attempt to process the type object when context is generic
1917+
# This is related to issue #3737
1918+
type_context = mypy.types.get_proper_type(type_context)
1919+
# Callable types with a generic return value
1920+
if isinstance(type_context, mypy.types.CallableType):
1921+
if isinstance(type_context.ret_type, mypy.types.TypeVarType):
1922+
skip_type_object_type = True
1923+
# Type[x] where x is generic
1924+
if isinstance(type_context, mypy.types.TypeType):
1925+
if isinstance(type_context.item, mypy.types.TypeVarType):
1926+
skip_type_object_type = True
1927+
1928+
if isinstance(runtime, type) and not skip_type_object_type:
1929+
1930+
def _named_type(name: str) -> mypy.types.Instance:
1931+
parts = name.rsplit(".", maxsplit=1)
1932+
node = get_mypy_node_for_name(parts[0], parts[1])
1933+
assert isinstance(node, nodes.TypeInfo)
1934+
any_type = mypy.types.AnyType(mypy.types.TypeOfAny.special_form)
1935+
return mypy.types.Instance(node, [any_type] * len(node.defn.type_vars))
1936+
1937+
# Try and look up a stub for the runtime object itself
1938+
# The logic here is similar to ExpressionChecker.analyze_ref_expr
1939+
type_info = get_mypy_node_for_name(runtime.__module__, runtime.__name__)
1940+
if isinstance(type_info, nodes.TypeInfo):
1941+
result: mypy.types.Type | None = None
1942+
result = mypy.typeops.type_object_type(type_info, _named_type)
1943+
if mypy.checkexpr.is_type_type_context(type_context):
1944+
# This is the type in a type[] expression, so substitute type
1945+
# variables with Any.
1946+
result = mypy.erasetype.erase_typevars(result)
1947+
return result
1948+
1949+
# Try and look up a stub for the runtime object's type
1950+
type_info = get_mypy_node_for_name(type(runtime).__module__, type(runtime).__name__)
1951+
if type_info is None:
19021952
return None
1903-
type_info = stub.names[type_name].node
19041953
if isinstance(type_info, nodes.Var):
19051954
return type_info.type
19061955
if not isinstance(type_info, nodes.TypeInfo):

mypy/test/teststubtest.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2636,6 +2636,31 @@ class _X1: ...
26362636
error=None,
26372637
)
26382638

2639+
@collect_cases
2640+
def test_type_default_protocol(self) -> Iterator[Case]:
2641+
yield Case(
2642+
stub="""
2643+
from typing import Protocol
2644+
2645+
class _FormatterClass(Protocol):
2646+
def __call__(self, *, prog: str) -> HelpFormatter: ...
2647+
2648+
class ArgumentParser:
2649+
def __init__(self, formatter_class: _FormatterClass = ...) -> None: ...
2650+
2651+
class HelpFormatter:
2652+
def __init__(self, prog: str, indent_increment: int = 2) -> None: ...
2653+
""",
2654+
runtime="""
2655+
class HelpFormatter:
2656+
def __init__(self, prog, indent_increment=2) -> None: ...
2657+
2658+
class ArgumentParser:
2659+
def __init__(self, formatter_class=HelpFormatter): ...
2660+
""",
2661+
error=None,
2662+
)
2663+
26392664

26402665
def remove_color_code(s: str) -> str:
26412666
return re.sub("\\x1b.*?m", "", s) # this works!

0 commit comments

Comments
 (0)