diff --git a/ChangeLog b/ChangeLog index cccaa0242..bbf99cb4a 100644 --- a/ChangeLog +++ b/ChangeLog @@ -60,6 +60,11 @@ Release date: TBA refs #2789 +* Add support for type parameter defaults added in Python 3.13. + +* Improve ``as_string()`` representation for ``TypeVar``, ``ParamSpec`` and ``TypeVarTuple`` nodes, as well as + type parameter in ``ClassDef``, ``FuncDef`` and ``TypeAlias`` nodes (PEP 695). + What's New in astroid 3.3.11? ============================= diff --git a/astroid/nodes/as_string.py b/astroid/nodes/as_string.py index f1f206b44..7d04b5eb6 100644 --- a/astroid/nodes/as_string.py +++ b/astroid/nodes/as_string.py @@ -176,18 +176,27 @@ def visit_call(self, node: nodes.Call) -> str: args.extend(keywords) return f"{expr_str}({', '.join(args)})" + def _handle_type_params( + self, type_params: list[nodes.TypeVar | nodes.ParamSpec | nodes.TypeVarTuple] + ) -> str: + return ( + f"[{', '.join(tp.accept(self) for tp in type_params)}]" + if type_params + else "" + ) + def visit_classdef(self, node: nodes.ClassDef) -> str: """return an astroid.ClassDef node as string""" decorate = node.decorators.accept(self) if node.decorators else "" + type_params = self._handle_type_params(node.type_params) args = [n.accept(self) for n in node.bases] if node._metaclass and not node.has_metaclass_hack(): args.append("metaclass=" + node._metaclass.accept(self)) args += [n.accept(self) for n in node.keywords] args_str = f"({', '.join(args)})" if args else "" docs = self._docs_dedent(node.doc_node) - # TODO: handle type_params - return "\n\n{}class {}{}:{}\n{}\n".format( - decorate, node.name, args_str, docs, self._stmt_list(node.body) + return "\n\n{}class {}{}{}:{}\n{}\n".format( + decorate, node.name, type_params, args_str, docs, self._stmt_list(node.body) ) def visit_compare(self, node: nodes.Compare) -> str: @@ -336,17 +345,18 @@ def visit_formattedvalue(self, node: nodes.FormattedValue) -> str: def handle_functiondef(self, node: nodes.FunctionDef, keyword: str) -> str: """return a (possibly async) function definition node as string""" decorate = node.decorators.accept(self) if node.decorators else "" + type_params = self._handle_type_params(node.type_params) docs = self._docs_dedent(node.doc_node) trailer = ":" if node.returns: return_annotation = " -> " + node.returns.as_string() trailer = return_annotation + ":" - # TODO: handle type_params - def_format = "\n%s%s %s(%s)%s%s\n%s" + def_format = "\n%s%s %s%s(%s)%s%s\n%s" return def_format % ( decorate, keyword, node.name, + type_params, node.args.accept(self), trailer, docs, @@ -455,7 +465,10 @@ def visit_nonlocal(self, node: nodes.Nonlocal) -> str: def visit_paramspec(self, node: nodes.ParamSpec) -> str: """return an astroid.ParamSpec node as string""" - return node.name.accept(self) + default_value_str = ( + f" = {node.default_value.accept(self)}" if node.default_value else "" + ) + return f"**{node.name.accept(self)}{default_value_str}" def visit_pass(self, node: nodes.Pass) -> str: """return an astroid.Pass node as string""" @@ -545,15 +558,23 @@ def visit_tuple(self, node: nodes.Tuple) -> str: def visit_typealias(self, node: nodes.TypeAlias) -> str: """return an astroid.TypeAlias node as string""" - return node.name.accept(self) if node.name else "_" + type_params = self._handle_type_params(node.type_params) + return f"type {node.name.accept(self)}{type_params} = {node.value.accept(self)}" def visit_typevar(self, node: nodes.TypeVar) -> str: """return an astroid.TypeVar node as string""" - return node.name.accept(self) if node.name else "_" + bound_str = f": {node.bound.accept(self)}" if node.bound else "" + default_value_str = ( + f" = {node.default_value.accept(self)}" if node.default_value else "" + ) + return f"{node.name.accept(self)}{bound_str}{default_value_str}" def visit_typevartuple(self, node: nodes.TypeVarTuple) -> str: """return an astroid.TypeVarTuple node as string""" - return "*" + node.name.accept(self) if node.name else "" + default_value_str = ( + f" = {node.default_value.accept(self)}" if node.default_value else "" + ) + return f"*{node.name.accept(self)}{default_value_str}" def visit_unaryop(self, node: nodes.UnaryOp) -> str: """return an astroid.UnaryOp node as string""" diff --git a/astroid/nodes/node_classes.py b/astroid/nodes/node_classes.py index 372ada7a6..e713da503 100644 --- a/astroid/nodes/node_classes.py +++ b/astroid/nodes/node_classes.py @@ -3383,9 +3383,9 @@ class ParamSpec(_base_nodes.AssignTypeNode): """ - _astroid_fields = ("name",) - + _astroid_fields = ("name", "default_value") name: AssignName + default_value: NodeNG | None def __init__( self, @@ -3404,8 +3404,9 @@ def __init__( parent=parent, ) - def postinit(self, *, name: AssignName) -> None: + def postinit(self, *, name: AssignName, default_value: NodeNG | None) -> None: self.name = name + self.default_value = default_value def _infer( self, context: InferenceContext | None = None, **kwargs: Any @@ -4141,10 +4142,10 @@ class TypeVar(_base_nodes.AssignTypeNode): """ - _astroid_fields = ("name", "bound") - + _astroid_fields = ("name", "bound", "default_value") name: AssignName bound: NodeNG | None + default_value: NodeNG | None def __init__( self, @@ -4163,9 +4164,16 @@ def __init__( parent=parent, ) - def postinit(self, *, name: AssignName, bound: NodeNG | None) -> None: + def postinit( + self, + *, + name: AssignName, + bound: NodeNG | None, + default_value: NodeNG | None = None, + ) -> None: self.name = name self.bound = bound + self.default_value = default_value def _infer( self, context: InferenceContext | None = None, **kwargs: Any @@ -4187,9 +4195,9 @@ class TypeVarTuple(_base_nodes.AssignTypeNode): """ - _astroid_fields = ("name",) - + _astroid_fields = ("name", "default_value") name: AssignName + default_value: NodeNG | None def __init__( self, @@ -4208,8 +4216,11 @@ def __init__( parent=parent, ) - def postinit(self, *, name: AssignName) -> None: + def postinit( + self, *, name: AssignName, default_value: NodeNG | None = None + ) -> None: self.name = name + self.default_value = default_value def _infer( self, context: InferenceContext | None = None, **kwargs: Any diff --git a/astroid/rebuilder.py b/astroid/rebuilder.py index 104d9a416..5814679fb 100644 --- a/astroid/rebuilder.py +++ b/astroid/rebuilder.py @@ -18,7 +18,7 @@ from astroid import nodes from astroid._ast import ParserModule, get_parser_module, parse_function_type_comment -from astroid.const import PY312_PLUS, Context +from astroid.const import PY312_PLUS, PY313_PLUS, Context from astroid.nodes.utils import Position from astroid.typing import InferenceResult @@ -1483,7 +1483,12 @@ def visit_paramspec( ) # Add AssignName node for 'node.name' # https://bugs.python.org/issue43994 - newnode.postinit(name=self.visit_assignname(node, newnode, node.name)) + newnode.postinit( + name=self.visit_assignname(node, newnode, node.name), + default_value=( + self.visit(node.default_value, newnode) if PY313_PLUS else None + ), + ) return newnode def visit_pass(self, node: ast.Pass, parent: nodes.NodeNG) -> nodes.Pass: @@ -1679,6 +1684,9 @@ def visit_typevar(self, node: ast.TypeVar, parent: nodes.NodeNG) -> nodes.TypeVa newnode.postinit( name=self.visit_assignname(node, newnode, node.name), bound=self.visit(node.bound, newnode), + default_value=( + self.visit(node.default_value, newnode) if PY313_PLUS else None + ), ) return newnode @@ -1695,7 +1703,12 @@ def visit_typevartuple( ) # Add AssignName node for 'node.name' # https://bugs.python.org/issue43994 - newnode.postinit(name=self.visit_assignname(node, newnode, node.name)) + newnode.postinit( + name=self.visit_assignname(node, newnode, node.name), + default_value=( + self.visit(node.default_value, newnode) if PY313_PLUS else None + ), + ) return newnode def visit_unaryop(self, node: ast.UnaryOp, parent: nodes.NodeNG) -> nodes.UnaryOp: diff --git a/tests/test_nodes.py b/tests/test_nodes.py index d46c46ebb..9ed2022fb 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -32,6 +32,7 @@ IS_PYPY, PY311_PLUS, PY312_PLUS, + PY313_PLUS, PY314_PLUS, Context, ) @@ -332,15 +333,34 @@ def test_recursion_error_trapped() -> None: class AsStringTypeParamNodes(unittest.TestCase): @staticmethod def test_as_string_type_alias() -> None: - ast = abuilder.string_build("type Point = tuple[float, float]") - type_alias = ast.body[0] - assert type_alias.as_string().strip() == "Point" + ast1 = abuilder.string_build("type Point = tuple[float, float]") + type_alias1 = ast1.body[0] + assert type_alias1.as_string().strip() == "type Point = tuple[float, float]" + ast2 = abuilder.string_build( + "type Point[T, **P] = tuple[float, T, Callable[P, None]]" + ) + type_alias2 = ast2.body[0] + assert ( + type_alias2.as_string().strip() + == "type Point[T, **P] = tuple[float, T, Callable[P, None]]" + ) @staticmethod def test_as_string_type_var() -> None: - ast = abuilder.string_build("type Point[T] = tuple[float, float]") + ast = abuilder.string_build("type Point[T: int | str] = tuple[float, float]") + type_var = ast.body[0].type_params[0] + assert type_var.as_string().strip() == "T: int | str" + + @staticmethod + @pytest.mark.skipif( + not PY313_PLUS, reason="Type parameter defaults were added in 313" + ) + def test_as_string_type_var_default() -> None: + ast = abuilder.string_build( + "type Point[T: int | str = int] = tuple[float, float]" + ) type_var = ast.body[0].type_params[0] - assert type_var.as_string().strip() == "T" + assert type_var.as_string().strip() == "T: int | str = int" @staticmethod def test_as_string_type_var_tuple() -> None: @@ -348,11 +368,41 @@ def test_as_string_type_var_tuple() -> None: type_var_tuple = ast.body[0].type_params[0] assert type_var_tuple.as_string().strip() == "*Ts" + @staticmethod + @pytest.mark.skipif( + not PY313_PLUS, reason="Type parameter defaults were added in 313" + ) + def test_as_string_type_var_tuple_defaults() -> None: + ast = abuilder.string_build("type Alias[*Ts = tuple[int, str]] = tuple[*Ts]") + type_var_tuple = ast.body[0].type_params[0] + assert type_var_tuple.as_string().strip() == "*Ts = tuple[int, str]" + @staticmethod def test_as_string_param_spec() -> None: ast = abuilder.string_build("type Alias[**P] = Callable[P, int]") param_spec = ast.body[0].type_params[0] - assert param_spec.as_string().strip() == "P" + assert param_spec.as_string().strip() == "**P" + + @staticmethod + @pytest.mark.skipif( + not PY313_PLUS, reason="Type parameter defaults were added in 313" + ) + def test_as_string_param_spec_defaults() -> None: + ast = abuilder.string_build("type Alias[**P = [str, int]] = Callable[P, int]") + param_spec = ast.body[0].type_params[0] + assert param_spec.as_string().strip() == "**P = [str, int]" + + @staticmethod + def test_as_string_class_type_params() -> None: + code = abuilder.string_build("class A[T, **P]: ...") + cls_node = code.body[0] + assert cls_node.as_string().strip() == "class A[T, **P]:\n ..." + + @staticmethod + def test_as_string_function_type_params() -> None: + code = abuilder.string_build("def func[T, **P](): ...") + func_node = code.body[0] + assert func_node.as_string().strip() == "def func[T, **P]():\n ..." class _NodeTest(unittest.TestCase): diff --git a/tests/test_type_params.py b/tests/test_type_params.py index 6398f78ad..021aa9a28 100644 --- a/tests/test_type_params.py +++ b/tests/test_type_params.py @@ -5,11 +5,14 @@ import pytest from astroid import extract_node -from astroid.const import PY312_PLUS +from astroid.const import PY312_PLUS, PY313_PLUS from astroid.nodes import ( AssignName, + List, + Name, ParamSpec, Subscript, + Tuple, TypeAlias, TypeVar, TypeVarTuple, @@ -26,6 +29,7 @@ def test_type_alias() -> None: assert isinstance(node.type_params[0].name, AssignName) assert node.type_params[0].name.name == "T" assert node.type_params[0].bound is None + assert node.type_params[0].default_value is None assert isinstance(node.value, Subscript) assert node.value.value.name == "list" @@ -41,12 +45,46 @@ def test_type_alias() -> None: assert assigned is node.value +def test_type_var() -> None: + node = extract_node("type Point[T: int] = T") + param = node.type_params[0] + assert isinstance(param, TypeVar) + assert isinstance(param.bound, Name) + assert param.bound.name == "int" + assert param.default_value is None + + +@pytest.mark.skipif(not PY313_PLUS, reason="Type parameter defaults were added in 313") +def test_type_var_defaults() -> None: + node = extract_node("type Point[T: int = int] = T") + param = node.type_params[0] + assert isinstance(param, TypeVar) + assert isinstance(param.bound, Name) + assert param.bound.name == "int" + assert isinstance(param.default_value, Name) + assert param.default_value.name == "int" + + def test_type_param_spec() -> None: node = extract_node("type Alias[**P] = Callable[P, int]") params = node.type_params[0] assert isinstance(params, ParamSpec) assert isinstance(params.name, AssignName) assert params.name.name == "P" + assert params.default_value is None + + assert node.inferred()[0] is node + + +@pytest.mark.skipif(not PY313_PLUS, reason="Type parameter defaults were added in 313") +def test_type_param_spec_defaults() -> None: + node = extract_node("type Alias[**P = [int, str]] = Callable[P, int]") + params = node.type_params[0] + assert isinstance(params, ParamSpec) + assert isinstance(params.name, AssignName) + assert params.name.name == "P" + assert isinstance(params.default_value, List) + assert len(params.default_value.elts) == 2 assert node.inferred()[0] is node @@ -57,6 +95,23 @@ def test_type_var_tuple() -> None: assert isinstance(params, TypeVarTuple) assert isinstance(params.name, AssignName) assert params.name.name == "Ts" + assert params.default_value is None + + assert node.inferred()[0] is node + + +@pytest.mark.skipif(not PY313_PLUS, reason="Type parameter defaults were added in 313") +def test_type_var_tuple_defaults() -> None: + node = extract_node("type Alias[*Ts = tuple[int, str]] = tuple[*Ts]") + params = node.type_params[0] + assert isinstance(params, TypeVarTuple) + assert isinstance(params.name, AssignName) + assert params.name.name == "Ts" + assert isinstance(params.default_value, Subscript) + assert isinstance(params.default_value.value, Name) + assert params.default_value.value.name == "tuple" + assert isinstance(params.default_value.slice, Tuple) + assert len(params.default_value.slice.elts) == 2 assert node.inferred()[0] is node