Skip to content

Commit d50a24e

Browse files
Robustly handle signatures (#199)
Co-authored-by: Bernát Gábor <[email protected]>
1 parent aa97f01 commit d50a24e

File tree

3 files changed

+100
-91
lines changed

3 files changed

+100
-91
lines changed

src/sphinx_autodoc_typehints/__init__.py

Lines changed: 83 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import textwrap
66
import typing
77
from ast import FunctionDef, Module, stmt
8+
from functools import partial
89
from typing import Any, AnyStr, NewType, TypeVar, get_type_hints
910

1011
from sphinx.application import Sphinx
@@ -16,28 +17,19 @@
1617

1718
from .version import version as __version__
1819

19-
logger = logging.getLogger(__name__)
20-
pydata_annotations = {"Any", "AnyStr", "Callable", "ClassVar", "Literal", "NoReturn", "Optional", "Tuple", "Union"}
21-
22-
__all__ = [
23-
"__version__",
24-
]
20+
_LOGGER = logging.getLogger(__name__)
21+
_PYDATA_ANNOTATIONS = {"Any", "AnyStr", "Callable", "ClassVar", "Literal", "NoReturn", "Optional", "Tuple", "Union"}
2522

2623

2724
def get_annotation_module(annotation: Any) -> str:
28-
# Special cases
2925
if annotation is None:
3026
return "builtins"
31-
3227
if sys.version_info >= (3, 10) and isinstance(annotation, NewType): # type: ignore # isinstance NewType is Callable
3328
return "typing"
34-
3529
if hasattr(annotation, "__module__"):
3630
return annotation.__module__ # type: ignore # deduced Any
37-
3831
if hasattr(annotation, "__origin__"):
3932
return annotation.__origin__.__module__ # type: ignore # deduced Any
40-
4133
raise ValueError(f"Cannot determine the module of {annotation}")
4234

4335

@@ -124,7 +116,7 @@ def format_annotation(annotation: Any, fully_qualified: bool = False, simplify_o
124116

125117
full_name = f"{module}.{class_name}" if module != "builtins" else class_name
126118
prefix = "" if fully_qualified or full_name == class_name else "~"
127-
role = "data" if class_name in pydata_annotations else "class"
119+
role = "data" if class_name in _PYDATA_ANNOTATIONS else "class"
128120
args_format = "\\[{}]"
129121
formatted_args = ""
130122

@@ -232,7 +224,7 @@ def _is_dataclass(name: str, what: str, qualname: str) -> bool:
232224
return False
233225

234226
if "<locals>" in obj.__qualname__ and not _is_dataclass(name, what, obj.__qualname__):
235-
logger.warning('Cannot treat a function defined as a local function: "%s" (use @functools.wraps)', name)
227+
_LOGGER.warning('Cannot treat a function defined as a local function: "%s" (use @functools.wraps)', name)
236228
return None
237229

238230
if parameters:
@@ -287,7 +279,7 @@ def get_all_type_hints(obj: Any, name: str) -> dict[str, Any]:
287279
if isinstance(exc, TypeError) and _future_annotations_imported(obj) and "unsupported operand type" in str(exc):
288280
rv = obj.__annotations__
289281
except NameError as exc:
290-
logger.warning('Cannot resolve forward reference in type annotations of "%s": %s', name, exc)
282+
_LOGGER.warning('Cannot resolve forward reference in type annotations of "%s": %s', name, exc)
291283
rv = obj.__annotations__
292284

293285
if rv:
@@ -305,7 +297,7 @@ def get_all_type_hints(obj: Any, name: str) -> dict[str, Any]:
305297
except (AttributeError, TypeError):
306298
pass
307299
except NameError as exc:
308-
logger.warning('Cannot resolve forward reference in type annotations of "%s": %s', name, exc)
300+
_LOGGER.warning('Cannot resolve forward reference in type annotations of "%s": %s', name, exc)
309301
rv = obj.__annotations__
310302

311303
return rv
@@ -327,7 +319,7 @@ def _one_child(module: Module) -> stmt | None:
327319
children = module.body # use the body to ignore type comments
328320

329321
if len(children) != 1:
330-
logger.warning('Did not get exactly one node from AST for "%s", got %s', name, len(children))
322+
_LOGGER.warning('Did not get exactly one node from AST for "%s", got %s', name, len(children))
331323
return None
332324

333325
return children[0]
@@ -353,7 +345,7 @@ def _one_child(module: Module) -> stmt | None:
353345
try:
354346
comment_args_str, comment_returns = type_comment.split(" -> ")
355347
except ValueError:
356-
logger.warning('Unparseable type hint comment for "%s": Expected to contain ` -> `', name)
348+
_LOGGER.warning('Unparseable type hint comment for "%s": Expected to contain ` -> `', name)
357349
return {}
358350

359351
rv = {}
@@ -368,7 +360,7 @@ def _one_child(module: Module) -> stmt | None:
368360
comment_args.insert(0, None) # self/cls may be omitted in type comments, insert blank
369361

370362
if len(args) != len(comment_args):
371-
logger.warning('Not enough type comments found on "%s"', name)
363+
_LOGGER.warning('Not enough type comments found on "%s"', name)
372364
return rv
373365

374366
for at, arg in enumerate(args):
@@ -442,80 +434,71 @@ def process_docstring(
442434
app: Sphinx, what: str, name: str, obj: Any, options: Options | None, lines: list[str] # noqa: U100
443435
) -> None:
444436
original_obj = obj
445-
if isinstance(obj, property):
446-
obj = obj.fget
447-
448-
if callable(obj):
449-
if inspect.isclass(obj):
450-
obj = obj.__init__
437+
obj = obj.fget if isinstance(obj, property) else obj
438+
if not callable(obj):
439+
return
440+
obj = obj.__init__ if inspect.isclass(obj) else obj
441+
obj = inspect.unwrap(obj)
451442

452-
obj = inspect.unwrap(obj)
443+
try:
453444
signature = sphinx_signature(obj)
454-
type_hints = get_all_type_hints(obj, name)
455-
456-
for arg_name, annotation in type_hints.items():
457-
if arg_name == "return":
458-
continue # this is handled separately later
459-
default = signature.parameters[arg_name].default
460-
if arg_name.endswith("_"):
461-
arg_name = f"{arg_name[:-1]}\\_"
462-
463-
formatted_annotation = format_annotation(
464-
annotation,
465-
fully_qualified=app.config.typehints_fully_qualified,
466-
simplify_optional_unions=app.config.simplify_optional_unions,
467-
)
468-
469-
search_for = [f":{field} {arg_name}:" for field in ("param", "parameter", "arg", "argument")]
470-
insert_index = None
471-
472-
for i, line in enumerate(lines):
473-
if any(line.startswith(search_string) for search_string in search_for):
474-
insert_index = i
475-
break
476-
477-
if insert_index is None and app.config.always_document_param_types:
478-
lines.append(f":param {arg_name}:")
479-
insert_index = len(lines)
480-
481-
if insert_index is not None:
482-
type_annotation = f":type {arg_name}: {formatted_annotation}"
483-
if app.config.typehints_defaults:
484-
formatted_default = format_default(app, default)
485-
if formatted_default:
486-
if app.config.typehints_defaults.endswith("after"):
487-
lines[insert_index] += formatted_default
488-
else: # add to last param doc line
489-
type_annotation += formatted_default
490-
lines.insert(insert_index, type_annotation)
491-
492-
if "return" in type_hints and not inspect.isclass(original_obj):
493-
# This avoids adding a return type for data class __init__ methods
494-
if what == "method" and name.endswith(".__init__"):
495-
return
496-
497-
formatted_annotation = format_annotation(
498-
type_hints["return"],
499-
fully_qualified=app.config.typehints_fully_qualified,
500-
simplify_optional_unions=app.config.simplify_optional_unions,
501-
)
502-
445+
except (ValueError, TypeError):
446+
signature = None
447+
type_hints = get_all_type_hints(obj, name)
448+
449+
formatter = partial(
450+
format_annotation,
451+
fully_qualified=app.config.typehints_fully_qualified,
452+
simplify_optional_unions=app.config.simplify_optional_unions,
453+
)
454+
for arg_name, annotation in type_hints.items():
455+
if arg_name == "return":
456+
continue # this is handled separately later
457+
default = inspect.Parameter.empty if signature is None else signature.parameters[arg_name].default
458+
if arg_name.endswith("_"):
459+
arg_name = f"{arg_name[:-1]}\\_"
460+
461+
formatted_annotation = formatter(annotation)
462+
463+
search_for = {f":{field} {arg_name}:" for field in ("param", "parameter", "arg", "argument")}
464+
insert_index = None
465+
for at, line in enumerate(lines):
466+
if any(line.startswith(search_string) for search_string in search_for):
467+
insert_index = at
468+
break
469+
470+
if insert_index is None and app.config.always_document_param_types:
471+
lines.append(f":param {arg_name}:")
503472
insert_index = len(lines)
504-
for i, line in enumerate(lines):
505-
if line.startswith(":rtype:"):
506-
insert_index = None
507-
break
508-
elif line.startswith(":return:") or line.startswith(":returns:"):
509-
insert_index = i
510-
511-
if insert_index is not None and app.config.typehints_document_rtype:
512-
if insert_index == len(lines):
513-
# Ensure that :rtype: doesn't get joined with a paragraph of text, which
514-
# prevents it being interpreted.
515-
lines.append("")
516-
insert_index += 1
517473

518-
lines.insert(insert_index, f":rtype: {formatted_annotation}")
474+
if insert_index is not None:
475+
type_annotation = f":type {arg_name}: {formatted_annotation}"
476+
if app.config.typehints_defaults:
477+
formatted_default = format_default(app, default)
478+
if formatted_default:
479+
if app.config.typehints_defaults.endswith("after"):
480+
lines[insert_index] += formatted_default
481+
else: # add to last param doc line
482+
type_annotation += formatted_default
483+
lines.insert(insert_index, type_annotation)
484+
485+
if "return" in type_hints and not inspect.isclass(original_obj):
486+
if what == "method" and name.endswith(".__init__"): # avoid adding a return type for data class __init__
487+
return
488+
formatted_annotation = formatter(type_hints["return"])
489+
insert_index = len(lines)
490+
for at, line in enumerate(lines):
491+
if line.startswith(":rtype:"):
492+
insert_index = None
493+
break
494+
elif line.startswith(":return:") or line.startswith(":returns:"):
495+
insert_index = at
496+
497+
if insert_index is not None and app.config.typehints_document_rtype:
498+
if insert_index == len(lines): # ensure that :rtype: doesn't get joined with a paragraph of text
499+
lines.append("")
500+
insert_index += 1
501+
lines.insert(insert_index, f":rtype: {formatted_annotation}")
519502

520503

521504
def builder_ready(app: Sphinx) -> None:
@@ -541,3 +524,15 @@ def setup(app: Sphinx) -> dict[str, bool]:
541524
app.connect("autodoc-process-signature", process_signature)
542525
app.connect("autodoc-process-docstring", process_docstring)
543526
return {"parallel_read_safe": True}
527+
528+
529+
__all__ = [
530+
"__version__",
531+
"format_annotation",
532+
"get_annotation_args",
533+
"get_annotation_class_name",
534+
"get_annotation_module",
535+
"normalize_source_lines",
536+
"process_docstring",
537+
"process_signature",
538+
]

tests/test_sphinx_autodoc_typehints.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import re
55
import sys
66
import typing
7+
from functools import cmp_to_key
78
from io import StringIO
89
from textwrap import dedent, indent
910
from types import ModuleType
@@ -24,10 +25,12 @@
2425
TypeVar,
2526
Union,
2627
)
27-
from unittest.mock import patch
28+
from unittest.mock import create_autospec, patch
2829

2930
import pytest
3031
import typing_extensions
32+
from sphinx.application import Sphinx
33+
from sphinx.config import Config
3134
from sphinx.testing.util import SphinxTestApp
3235
from sphobjinv import Inventory
3336

@@ -255,13 +258,14 @@ def test_format_annotation_both_libs(library: ModuleType, annotation: str, param
255258

256259
def test_process_docstring_slot_wrapper() -> None:
257260
lines: list[str] = []
258-
process_docstring(None, "class", "SlotWrapper", Slotted, None, lines) # type: ignore # first argument is not Sphinx
261+
config = create_autospec(Config, typehints_fully_qualified=False, simplify_optional_unions=False)
262+
app: Sphinx = create_autospec(Sphinx, config=config)
263+
process_docstring(app, "class", "SlotWrapper", Slotted, None, lines)
259264
assert not lines
260265

261266

262267
def set_python_path() -> None:
263268
test_path = pathlib.Path(__file__).parent
264-
265269
# Add test directory to sys.path to allow imports of dummy module.
266270
if str(test_path) not in sys.path:
267271
sys.path.insert(0, str(test_path))
@@ -715,3 +719,12 @@ def __init__(bound_args): # noqa: N805
715719
"""
716720

717721
assert normalize_source_lines(dedent(source)) == dedent(expected)
722+
723+
724+
@pytest.mark.parametrize("obj", [cmp_to_key, 1])
725+
def test_default_no_signature(obj: Any) -> None:
726+
config = create_autospec(Config, typehints_fully_qualified=False, simplify_optional_unions=False)
727+
app: Sphinx = create_autospec(Sphinx, config=config)
728+
lines: list[str] = []
729+
process_docstring(app, "what", "name", obj, None, lines)
730+
assert lines == []

whitelist.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ exc
1212
fget
1313
fmt
1414
fn
15+
formatter
1516
func
1617
getmodule
1718
getsource

0 commit comments

Comments
 (0)