Skip to content

Commit f5d4fa4

Browse files
committed
add options for a signal aliases
adapt to SignalRelay
1 parent f6cebcb commit f5d4fa4

File tree

5 files changed

+496
-23
lines changed

5 files changed

+496
-23
lines changed

src/psygnal/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"EventedModel",
3030
"get_evented_namespace",
3131
"is_evented",
32+
"PSYGNAL_METADATA",
3233
"Signal",
3334
"SignalGroup",
3435
"SignalGroupDescriptor",
@@ -48,6 +49,7 @@
4849
stacklevel=2,
4950
)
5051

52+
from ._dataclass_utils import PSYGNAL_METADATA
5153
from ._evented_decorator import evented
5254
from ._exceptions import EmitLoopError
5355
from ._group import EmissionInfo, SignalGroup

src/psygnal/_dataclass_utils.py

Lines changed: 263 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,30 @@
44
import dataclasses
55
import sys
66
import types
7-
from typing import TYPE_CHECKING, Any, Iterator, List, Protocol, cast, overload
7+
from dataclasses import dataclass, fields
8+
from typing import (
9+
TYPE_CHECKING,
10+
Any,
11+
Callable,
12+
Iterator,
13+
List,
14+
Mapping,
15+
Protocol,
16+
cast,
17+
overload,
18+
)
819

920
if TYPE_CHECKING:
21+
from dataclasses import Field
22+
1023
import attrs
1124
import msgspec
1225
from pydantic import BaseModel
13-
from typing_extensions import TypeGuard # py310
26+
from typing_extensions import TypeAlias, TypeGuard # py310
27+
28+
EqOperator: TypeAlias = Callable[[Any, Any], bool]
29+
30+
PSYGNAL_METADATA = "__psygnal_metadata"
1431

1532

1633
class _DataclassParams(Protocol):
@@ -29,12 +46,11 @@ class AttrsType:
2946
__attrs_attrs__: tuple[attrs.Attribute, ...]
3047

3148

32-
_DATACLASS_PARAMS = "__dataclass_params__"
49+
KW_ONLY = object()
3350
with contextlib.suppress(ImportError):
34-
from dataclasses import _DATACLASS_PARAMS # type: ignore
51+
from dataclasses import KW_ONLY # py310
52+
_DATACLASS_PARAMS = "__dataclass_params__"
3553
_DATACLASS_FIELDS = "__dataclass_fields__"
36-
with contextlib.suppress(ImportError):
37-
from dataclasses import _DATACLASS_FIELDS # type: ignore
3854

3955

4056
class DataClassType:
@@ -171,8 +187,8 @@ def iter_fields(
171187
yield field_name, p_field.annotation
172188
else:
173189
for p_field in cls.__fields__.values(): # type: ignore [attr-defined]
174-
if p_field.field_info.allow_mutation or not exclude_frozen: # type: ignore
175-
yield p_field.name, p_field.outer_type_ # type: ignore
190+
if p_field.field_info.allow_mutation or not exclude_frozen: # type: ignore [attr-defined]
191+
yield p_field.name, p_field.outer_type_ # type: ignore [attr-defined]
176192
return
177193

178194
if (attrs_fields := getattr(cls, "__attrs_attrs__", None)) is not None:
@@ -185,3 +201,242 @@ def iter_fields(
185201
type_ = cls.__annotations__.get(m_field, None)
186202
yield m_field, type_
187203
return
204+
205+
206+
@dataclass
207+
class FieldOptions:
208+
name: str
209+
type_: type | None = None
210+
# set KW_ONLY value for compatibility with python < 3.10
211+
_: KW_ONLY = KW_ONLY # type: ignore [valid-type]
212+
alias: str | None = None
213+
skip: bool | None = None
214+
eq: EqOperator | None = None
215+
disable_setattr: bool | None = None
216+
217+
218+
def is_kw_only(f: Field) -> bool:
219+
if hasattr(f, "kw_only"):
220+
return cast(bool, f.kw_only)
221+
# for python < 3.10
222+
if f.name not in ["name", "type_"]:
223+
return True
224+
return False
225+
226+
227+
def sanitize_field_options_dict(d: Mapping) -> dict[str, Any]:
228+
field_options_kws = [f.name for f in fields(FieldOptions) if is_kw_only(f)]
229+
return {k: v for k, v in d.items() if k in field_options_kws}
230+
231+
232+
def get_msgspec_metadata(
233+
cls: type[msgspec.Struct],
234+
m_field: str,
235+
) -> tuple[type | None, dict[str, Any]]:
236+
# Look for type in cls and super classes
237+
type_: type | None = None
238+
for super_cls in cls.__mro__:
239+
if not hasattr(super_cls, "__annotations__"):
240+
continue
241+
type_ = super_cls.__annotations__.get(m_field, None)
242+
if type_ is not None:
243+
break
244+
245+
msgspec = sys.modules.get("msgspec", None)
246+
if msgspec is None:
247+
return type_, {}
248+
249+
metadata_list = getattr(type_, "__metadata__", [])
250+
251+
metadata: dict[str, Any] = {}
252+
for meta in metadata_list:
253+
if not isinstance(meta, msgspec.Meta):
254+
continue
255+
single_meta: dict[str, Any] = getattr(meta, "extra", {}).get(
256+
PSYGNAL_METADATA, {}
257+
)
258+
metadata.update(single_meta)
259+
260+
return type_, metadata
261+
262+
263+
def iter_fields_with_options(
264+
cls: type, exclude_frozen: bool = True
265+
) -> Iterator[FieldOptions]:
266+
"""Iterate over all fields in the class, return a field description.
267+
268+
This function recognizes dataclasses, attrs classes, msgspec Structs, and pydantic
269+
models.
270+
271+
Parameters
272+
----------
273+
cls : type
274+
The class to iterate over.
275+
exclude_frozen : bool, optional
276+
If True, frozen fields will be excluded. By default True.
277+
278+
Yields
279+
------
280+
FieldOptions
281+
A dataclass instance with the name, type and metadata of each field.
282+
"""
283+
# Add metadata for dataclasses.dataclass
284+
dclass_fields = getattr(cls, "__dataclass_fields__", None)
285+
if dclass_fields is not None:
286+
"""
287+
Example
288+
-------
289+
from dataclasses import dataclass, field
290+
291+
292+
@dataclass
293+
class Foo:
294+
bar: int = field(metadata={"alias": "bar_alias"})
295+
296+
assert (
297+
Foo.__dataclass_fields__["bar"].metadata ==
298+
{"__psygnal_metadata": {"alias": "bar_alias"}}
299+
)
300+
301+
"""
302+
for d_field in dclass_fields.values():
303+
if d_field._field_type is dataclasses._FIELD: # type: ignore [attr-defined]
304+
metadata = getattr(d_field, "metadata", {}).get(PSYGNAL_METADATA, {})
305+
metadata = sanitize_field_options_dict(metadata)
306+
options = FieldOptions(d_field.name, d_field.type, **metadata)
307+
yield options
308+
return
309+
310+
# Add metadata for pydantic dataclass
311+
if is_pydantic_model(cls):
312+
"""
313+
Example
314+
-------
315+
from typing import Annotated
316+
317+
from pydantic import BaseModel, Field
318+
319+
320+
# Only works with Pydantic v2
321+
class Foo(BaseModel):
322+
bar: Annotated[
323+
str,
324+
{'__psygnal_metadata': {"alias": "bar_alias"}}
325+
] = Field(...)
326+
327+
# Working with Pydantic v2 and partially with v1
328+
# Alternative, using Field `json_schema_extra` keyword argument
329+
class Bar(BaseModel):
330+
bar: str = Field(
331+
json_schema_extra={PSYGNAL_METADATA: {"alias": "bar_alias"}}
332+
)
333+
334+
335+
assert (
336+
Foo.model_fields["bar"].metadata[0] ==
337+
{"__psygnal_metadata": {"alias": "bar_alias"}}
338+
)
339+
assert (
340+
Bar.model_fields["bar"].json_schema_extra ==
341+
{"__psygnal_metadata": {"alias": "bar_alias"}}
342+
)
343+
344+
"""
345+
if hasattr(cls, "model_fields"):
346+
# Pydantic v2
347+
for field_name, p_field in cls.model_fields.items():
348+
# skip frozen field
349+
if exclude_frozen and p_field.frozen:
350+
continue
351+
metadata_list = getattr(p_field, "metadata", [])
352+
metadata = {}
353+
for field in metadata_list:
354+
metadata.update(field.get(PSYGNAL_METADATA, {}))
355+
# Compat with using Field `json_schema_extra` keyword argument
356+
if isinstance(getattr(p_field, "json_schema_extra", None), Mapping):
357+
meta_dict = cast(Mapping, p_field.json_schema_extra)
358+
metadata.update(meta_dict.get(PSYGNAL_METADATA, {}))
359+
metadata = sanitize_field_options_dict(metadata)
360+
options = FieldOptions(field_name, p_field.annotation, **metadata)
361+
yield options
362+
return
363+
364+
else:
365+
# Pydantic v1, metadata is not always working
366+
for pv1_field in cls.__fields__.values(): # type: ignore [attr-defined]
367+
# skip frozen field
368+
if exclude_frozen and not pv1_field.field_info.allow_mutation:
369+
continue
370+
meta_dict = getattr(pv1_field.field_info, "extra", {}).get(
371+
"json_schema_extra", {}
372+
)
373+
metadata = meta_dict.get(PSYGNAL_METADATA, {})
374+
375+
metadata = sanitize_field_options_dict(metadata)
376+
options = FieldOptions(
377+
pv1_field.name,
378+
pv1_field.outer_type_,
379+
**metadata,
380+
)
381+
yield options
382+
return
383+
384+
# Add metadata for attrs dataclass
385+
attrs_fields = getattr(cls, "__attrs_attrs__", None)
386+
if attrs_fields is not None:
387+
"""
388+
Example
389+
-------
390+
from attrs import define, field
391+
392+
393+
@define
394+
class Foo:
395+
bar: int = field(metadata={"alias": "bar_alias"})
396+
397+
assert (
398+
Foo.__attrs_attrs__.bar.metadata ==
399+
{"__psygnal_metadata": {"alias": "bar_alias"}}
400+
)
401+
402+
"""
403+
for a_field in attrs_fields:
404+
metadata = getattr(a_field, "metadata", {}).get(PSYGNAL_METADATA, {})
405+
metadata = sanitize_field_options_dict(metadata)
406+
options = FieldOptions(a_field.name, a_field.type, **metadata)
407+
yield options
408+
return
409+
410+
# Add metadata for attrs dataclass
411+
if is_msgspec_struct(cls):
412+
"""
413+
Example
414+
-------
415+
from typing import Annotated
416+
417+
from msgspec import Meta, Struct
418+
419+
420+
class Foo(Struct):
421+
bar: Annotated[
422+
str,
423+
Meta(extra={"__psygnal_metadata": {"alias": "bar_alias"}))
424+
] = ""
425+
426+
427+
print(Foo.__annotations__["bar"].__metadata__[0].extra)
428+
# {"__psygnal_metadata": {"alias": "bar_alias"}}
429+
430+
"""
431+
for m_field in cls.__struct_fields__:
432+
try:
433+
type_, metadata = get_msgspec_metadata(cls, m_field)
434+
metadata = sanitize_field_options_dict(metadata)
435+
except AttributeError:
436+
msg = f"Cannot parse field metadata for {m_field}: {type_}"
437+
# logger.exception(msg)
438+
print(msg)
439+
type_, metadata = None, {}
440+
options = FieldOptions(m_field, type_, **metadata)
441+
yield options
442+
return

src/psygnal/_evented_decorator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
from psygnal._group_descriptor import SignalGroupDescriptor
66

77
if TYPE_CHECKING:
8-
from psygnal._group_descriptor import EqOperator, FieldAliasFunc
8+
from psygnal._group_descriptor import ( # type: ignore[attr-defined]
9+
EqOperator,
10+
FieldAliasFunc,
11+
)
912

1013
__all__ = ["evented"]
1114

0 commit comments

Comments
 (0)