From bf55bdad2aee7f9f3452fe98e3b0d58a12dca81a Mon Sep 17 00:00:00 2001 From: Talley Lambert Date: Fri, 22 Mar 2024 10:34:24 -0400 Subject: [PATCH] feat: allow explicit location to module --- pybind11_stubgen/__init__.py | 42 ++++++++++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 4 deletions(-) diff --git a/pybind11_stubgen/__init__.py b/pybind11_stubgen/__init__.py index f3a92152..7f266237 100644 --- a/pybind11_stubgen/__init__.py +++ b/pybind11_stubgen/__init__.py @@ -1,10 +1,13 @@ from __future__ import annotations import importlib +import importlib.util import logging import re -from argparse import ArgumentParser, Namespace +import sys +from argparse import ArgumentParser, ArgumentTypeError, Namespace from pathlib import Path +from types import ModuleType from pybind11_stubgen.parser.interface import IParser from pybind11_stubgen.parser.mixins.error_handlers import ( @@ -77,6 +80,7 @@ class CLIArgs(Namespace): dry_run: bool stub_extension: str module_name: str + location: Path | None def arg_parser() -> ArgumentParser: @@ -215,6 +219,22 @@ def regex_colon_path(regex_path: str) -> tuple[re.Pattern, str]: "Must be 'pyi' (default) or 'py'", ) + def existing_file(path: str) -> Path | None: + if path is None: + return None + try: + return Path(path).resolve(strict=True) + except FileNotFoundError: + raise ArgumentTypeError(f"Path {path!r} does not exist.") + + parser.add_argument( + "--location", + type=existing_file, + default=None, + dest="location", + help="Explicit filesytem location for module", + ) + parser.add_argument( "module_name", metavar="MODULE_NAME", @@ -324,6 +344,7 @@ def main(): sub_dir=sub_dir, dry_run=args.dry_run, writer=Writer(stub_ext=args.stub_extension), + location=args.location, ) @@ -345,6 +366,16 @@ def to_output_and_subdir( return out_dir.joinpath(*module_path[:-1]), sub_dir +def import_module_from_path(module_name: str, location: Path) -> ModuleType: + spec = importlib.util.spec_from_file_location(module_name, location) + if not (spec and spec.loader): + raise ImportError(f"Can't import {module_name} from {location}") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + sys.modules[module_name] = module + return module + + def run( parser: IParser, printer: Printer, @@ -353,10 +384,13 @@ def run( sub_dir: Path | None, dry_run: bool, writer: Writer, + location: Path | None, ): - module = parser.handle_module( - QualifiedName.from_str(module_name), importlib.import_module(module_name) - ) + if location: + pymodule = import_module_from_path(module_name, location) + else: + pymodule = importlib.import_module(module_name) + module = parser.handle_module(QualifiedName.from_str(module_name), pymodule) parser.finalize() if module is None: