Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 55 additions & 6 deletions ibis/common/annotations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import dataclasses
import functools
import inspect
import types
Expand Down Expand Up @@ -301,7 +302,13 @@ class Signature(inspect.Signature):
Primarily used in the implementation of `ibis.common.grounds.Annotable`.
"""

__slots__ = ()
__slots__ = ('_patterns', '_dataclass')

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# prebuild dict of patterns to avoid slow retrieval via property&MappingProxyType
self._patterns = {k: param.annotation.pattern for k, param in self.parameters.items() if hasattr(param.annotation, 'pattern')}
self._dataclass = self.to_dataclass()

@classmethod
def merge(cls, *signatures, **annotations):
Expand Down Expand Up @@ -509,15 +516,27 @@ def validate(self, func, args, kwargs):

return this

def validate_fast(self, func, args, kwargs):
"""Faster validation using internal dataclass to bind args/kwargs to names instead of Signature.bind."""
try:
instance = self._dataclass(*args, **kwargs)
except TypeError as err:
raise SignatureValidationError(
"{call} {cause}\n\nExpected signature: {sig}",
sig=self,
func=func,
args=args,
kwargs=kwargs,
) from err

return self.validate_nobind(func, instance.__dict__)

def validate_nobind(self, func, kwargs):
"""Validate the arguments against the signature without binding."""
this, errors = {}, []
for name, param in self.parameters.items():
value = kwargs.get(name, param.default)
if value is EMPTY:
raise TypeError(f"missing required argument `{name!r}`")
for name, pattern in self._patterns.items():
value = kwargs[name]

pattern = param.annotation.pattern
result = pattern.match(value, this)
if result is NoMatch:
errors.append((name, value, pattern))
Expand Down Expand Up @@ -565,6 +584,36 @@ def validate_return(self, func, value):

return result

def to_dataclass(self, cls_name: str = 'SignatureDataclass') -> type:
"""Create a dataclass from this signature.

Later, instantiating a dataclass from arg+kwargs and accessing the resulting __dict__
is much faster (~10-20x) than using Signature.bind
"""
fields = []
for k, v in self.parameters.items():
if v.default is inspect.Parameter.empty:
fields.append((k, v.annotation))
elif v.annotation.__hash__ is None:
# unhashable types (e.g. list) cannot be used as default values
# in dataclasses, so we use a default factory instead
fields.append((k, v.annotation, dataclasses.field(default_factory=DefaultFactory(v.default))))
else:
fields.append((k, v.annotation, dataclasses.field(default=v.default)))
return dataclasses.make_dataclass(cls_name, fields)


class DefaultFactory:
"""Helper to create default factories for dataclass fields."""

__slots__ = ("value",)

def __init__(self, value):
self.value = value

def __call__(self):
return self.value


def annotated(_1=None, _2=None, _3=None, **kwargs):
"""Create functions with arguments validated at runtime.
Expand Down
22 changes: 13 additions & 9 deletions ibis/common/grounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def __new__(metacls, clsname, bases, dct, **kwargs):
def __or__(self, other):
# required to support `dt.Numeric | dt.Floating` annotation for python<3.10
return Union[self, other]

__call__ = type.__call__


@dataclass_transform()
Expand All @@ -113,21 +115,22 @@ class Annotable(Abstract, metaclass=AnnotableMeta):
__match_args__: ClassVar[tuple[str, ...]]
"""Names of the arguments to be used for pattern matching."""

@classmethod
def __create__(cls, *args: Any, **kwargs: Any) -> Self:
# construct the instance by passing only validated keyword arguments
kwargs = cls.__signature__.validate(cls, args, kwargs)
return super().__create__(**kwargs)
#@classmethod
#def __create__(cls, *args: Any, **kwargs: Any) -> Self:
# # construct the instance by passing only validated keyword arguments
# validated_kwargs = cls.__signature__.validate_fast(cls, args, kwargs)
# return super().__create__(**validated_kwargs)

@classmethod
def __recreate__(cls, kwargs: Any) -> Self:
# bypass signature binding by requiring keyword arguments only
kwargs = cls.__signature__.validate_nobind(cls, kwargs)
return super().__create__(**kwargs)

def __init__(self, **kwargs: Any) -> None:
def __init__(self, *args, **kwargs: Any) -> None:
validated_kwargs = self.__signature__.validate_fast(self.__class__, args, kwargs)
# set the already validated arguments
for name, value in kwargs.items():
for name, value in validated_kwargs.items():
object.__setattr__(self, name, value)
# initialize the remaining attributes
for name, field in self.__attributes__.items():
Expand Down Expand Up @@ -192,11 +195,12 @@ class Concrete(Immutable, Comparable, Annotable):

__slots__ = ("__args__", "__precomputed_hash__")

def __init__(self, **kwargs: Any) -> None:
def __init__(self, *args, **kwargs: Any) -> None:
validated_kwargs = self.__signature__.validate_fast(self.__class__, args, kwargs)
# collect and set the arguments in a single pass
args = []
for name in self.__argnames__:
value = kwargs[name]
value = validated_kwargs[name]
args.append(value)
object.__setattr__(self, name, value)

Expand Down
2 changes: 1 addition & 1 deletion ibis/expr/operations/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class ScalarParameter(Scalar):

shape = ds.scalar

def __init__(self, dtype, counter):
def __init__(self, dtype, counter=None):
if counter is None:
counter = next(self._counter)
super().__init__(dtype=dtype, counter=counter)
Expand Down
8 changes: 4 additions & 4 deletions ibis/expr/operations/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,13 @@ class Field(Value):
shape = ds.columnar

def __init__(self, rel, name):
if name not in rel.schema:
columns_formatted = ", ".join(map(repr, rel.schema.names))
super().__init__(rel=rel, name=name)
if self.name not in self.rel.schema:
columns_formatted = ", ".join(map(repr, self.rel.schema.names))
raise IbisTypeError(
f"Column {name!r} is not found in table. "
f"Column {self.name!r} is not found in table. "
f"Existing columns: {columns_formatted}."
)
super().__init__(rel=rel, name=name)

@attribute
def dtype(self):
Expand Down
4 changes: 4 additions & 0 deletions ibis/expr/types/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ def type(self) -> dt.DataType:
"""
return self.op().dtype

@property
def dtype(self) -> dt.DataType:
return self.type()

def hash(self) -> ir.IntegerValue:
"""Compute an integer hash value.

Expand Down
Loading