Skip to content

Commit e22e03c

Browse files
committed
Make tests pass.
1 parent 2d25f02 commit e22e03c

34 files changed

+2616
-88
lines changed

check_shapes/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -451,9 +451,9 @@
451451
452452
* Python built-in scalars: ``bool``, ``int``, ``float`` and ``str``.
453453
* Python built-in sequences: ``tuple`` and ``list``.
454-
* NumPy ``ndarray``\ s.
455-
* TensorFlow ``Tensor``\ s and ``Variable``\ s.
456-
* TensorFlow Probability ``DeferredTensor``\ s, including ``TransformedVariable`` and
454+
* NumPy ``ndarray``\\ s.
455+
* TensorFlow ``Tensor``\\ s and ``Variable``\\ s.
456+
* TensorFlow Probability ``DeferredTensor``\\ s, including ``TransformedVariable`` and
457457
:class:`gpflow.Parameter`.
458458
459459
@@ -490,6 +490,8 @@
490490
from .inheritance import inherit_check_shapes
491491
from .shapes import get_shape, register_get_shape
492492

493+
__version__ = "0.1.0"
494+
493495
__all__ = [
494496
"Dimension",
495497
"DocstringFormat",

check_shapes/accessors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def maybe_get_check_shapes(func: Callable[..., Any]) -> Optional[Callable[[C], C
3434
:returns: The ``check_shapes`` that is wrapping ``func``, and ``None`` if no ``check_shapes`` is
3535
not wrapping ``func``.
3636
"""
37-
return getattr(func, "__check_shapes__", None)
37+
return getattr(func, "__check_shapes__", None) # type: ignore[no-any-return]
3838

3939

4040
def get_check_shapes(func: Callable[..., Any]) -> Callable[[C], C]:

check_shapes/argument_ref.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def __repr__(self) -> str:
9696

9797
@dataclass(frozen=True) # type: ignore[misc]
9898
class DelegatingArgumentRef(ArgumentRef):
99-
""" Abstract base class for :class:`ArgumentRef`\ s that delegates to a source. """
99+
""" Abstract base class for :class:`ArgumentRef`\\ s that delegates to a source. """
100100

101101
source: ArgumentRef
102102

check_shapes/base_types.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
"""
1515
Definitions of commonly used types.
1616
"""
17-
from typing import Any, Callable, Optional, Tuple, TypeVar
17+
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, TypeVar, Union
18+
19+
import numpy as np
20+
21+
from .type_flags import GENERIC_NP_ARRAYS, NP_TYPE_CHECKING
1822

1923
C = TypeVar("C", bound=Callable[..., Any])
2024

@@ -33,3 +37,14 @@
3337
3438
Raise an exception if objects of that type can never have a shape.
3539
"""
40+
41+
if TYPE_CHECKING and (not NP_TYPE_CHECKING): # pragma: no cover
42+
AnyNDArray = Any
43+
else:
44+
if GENERIC_NP_ARRAYS:
45+
# It would be nice to use something more interesting than `Any` here, but it looks like
46+
# the infrastructure in the rest of the ecosystem isn't really set up for this
47+
# yet. Maybe when we get Python 3.11?
48+
AnyNDArray = np.ndarray[Any, Any] # type: ignore[misc]
49+
else:
50+
AnyNDArray = Union[np.ndarray] # type: ignore[misc]

check_shapes/bool_specs.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020
from typing import Any, Callable, Mapping, Tuple, cast
2121

2222
from .argument_ref import ArgumentRef
23-
from .error_contexts import ErrorContext, ObjectValueContext, ParallelContext, StackContext
23+
from .error_contexts import (
24+
ErrorContext,
25+
ObjectValueContext,
26+
ParallelContext,
27+
StackContext,
28+
)
2429

2530

2631
class ParsedBoolSpec(ABC):

check_shapes/checker.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,19 @@
1515
Class responsible for remembering and checking shapes.
1616
"""
1717
from dataclasses import dataclass, field
18-
from typing import Any, DefaultDict, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union
18+
from typing import (
19+
Any,
20+
DefaultDict,
21+
Dict,
22+
Iterable,
23+
List,
24+
Optional,
25+
Set,
26+
Tuple,
27+
TypeVar,
28+
Union,
29+
)
1930

20-
from ..utils import experimental
2131
from .base_types import Dimension, Shape
2232
from .config import get_enable_check_shapes
2333
from .error_contexts import (
@@ -227,7 +237,6 @@ class ShapeChecker:
227237
:dedent:
228238
"""
229239

230-
@experimental
231240
def __init__(self) -> None:
232241
self._variables: DefaultDict[str, _VariableState] = DefaultDict(_VariableState)
233242
self._observed_shapes: List[Tuple[Shape, ParsedTensorSpec, ErrorContext]] = []

check_shapes/decorator.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
import tensorflow as tf
2222

23-
from ..utils import experimental
2423
from .accessors import set_check_shapes
2524
from .argument_ref import RESULT_TOKEN
2625
from .base_types import C
@@ -50,7 +49,6 @@ def null_check_shapes(func: C) -> C:
5049
return func
5150

5251

53-
@experimental
5452
def check_shapes(*specs: str, tf_decorator: bool = False) -> Callable[[C], C]:
5553
"""
5654
Decorator that checks the shapes of tensor arguments.

check_shapes/error_contexts.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
The :class:`ErrorContext` is a reusable bit of information about where/why an error occurred that
2323
can be written to a :class:`MessageBuilder`.
2424
25-
:class:`ErrorContext`\ s can be composed using the :class:`StackContext` and
25+
:class:`ErrorContext`\\ s can be composed using the :class:`StackContext` and
2626
:class:`ParallelContext`.
2727
2828
This allows reusable error messages in a consistent format.
@@ -59,7 +59,7 @@
5959
_UNKNOWN_LINE = "<Unknown line>"
6060
_DISABLED_FILE_AND_LINE = (
6161
f"{_UNKNOWN_FILE}:{_UNKNOWN_LINE}"
62-
" (Disabled. Call gpflow.experimental.check_shapes.set_enable_function_call_precompute(True) to"
62+
" (Disabled. Call check_shapes.set_enable_function_call_precompute(True) to"
6363
" see this. (Slow.))"
6464
)
6565
_NONE_SHAPE = "<Tensor is None or has unknown shape>"
@@ -225,7 +225,9 @@ def _split_head(context: StackContext) -> Tuple[ErrorContext, ErrorContext]:
225225
by_head: Dict[ErrorContext, List[ErrorContext]] = {}
226226
for child in flat:
227227
if isinstance(child, StackContext):
228-
head, body = _split_head(child)
228+
split_child = _split_head(child)
229+
head: ErrorContext = split_child[0]
230+
body: Optional[ErrorContext] = split_child[1]
229231
else:
230232
head = child
231233
body = None

check_shapes/inheritance.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,12 @@
1717
import inspect
1818
from typing import Callable, Optional, cast
1919

20-
from ..utils import experimental
2120
from .accessors import maybe_get_check_shapes
2221
from .base_types import C
2322
from .config import get_enable_check_shapes
2423
from .decorator import null_check_shapes
2524

2625

27-
@experimental
2826
def inherit_check_shapes(func: C) -> C:
2927
"""
3028
Decorator that inherits the :func:`check_shapes` decoration from any overridden method in a

check_shapes/mypy_flags.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright 2022 The GPflow Contributors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Code for printing flags for mypy, depending on library versions.
16+
"""
17+
from .type_flags import compute_mypy_flags
18+
19+
20+
def print_mypy_flags() -> None:
21+
print(compute_mypy_flags())
22+
23+
24+
if __name__ == "__main__":
25+
print_mypy_flags()

0 commit comments

Comments
 (0)