Skip to content

Commit 15c2e5d

Browse files
committed
add options for a signal aliases
adapt to SignalRelay
1 parent 025104e commit 15c2e5d

12 files changed

+962
-78
lines changed

src/psygnal/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"EventedModel",
3030
"get_evented_namespace",
3131
"is_evented",
32+
"PSYGNAL_METADATA",
3233
"Signal",
3334
"SignalGroup",
3435
"SignalGroupDescriptor",
@@ -48,6 +49,7 @@
4849
stacklevel=2,
4950
)
5051

52+
from ._dataclass_utils import PSYGNAL_METADATA
5153
from ._evented_decorator import evented
5254
from ._exceptions import EmitLoopError
5355
from ._group import EmissionInfo, SignalGroup

src/psygnal/_dataclass_utils.py

Lines changed: 262 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,32 @@
44
import dataclasses
55
import sys
66
import types
7-
from typing import TYPE_CHECKING, Any, Iterator, List, Protocol, cast, overload
7+
from dataclasses import dataclass, fields
8+
from typing import (
9+
TYPE_CHECKING,
10+
Any,
11+
Callable,
12+
Iterator,
13+
List,
14+
Mapping,
15+
Protocol,
16+
cast,
17+
overload,
18+
)
819

920
if TYPE_CHECKING:
21+
from dataclasses import Field
22+
1023
import attrs
1124
import msgspec
1225
from pydantic import BaseModel
1326
from typing_extensions import TypeGuard # py310
1427

1528

29+
EqOperator = Callable[[Any, Any], bool]
30+
PSYGNAL_METADATA = "__psygnal_metadata"
31+
32+
1633
class _DataclassParams(Protocol):
1734
init: bool
1835
repr: bool
@@ -29,6 +46,9 @@ class AttrsType:
2946
__attrs_attrs__: tuple[attrs.Attribute, ...]
3047

3148

49+
KW_ONLY = object()
50+
with contextlib.suppress(ImportError):
51+
from dataclasses import KW_ONLY
3252
_DATACLASS_PARAMS = "__dataclass_params__"
3353
with contextlib.suppress(ImportError):
3454
from dataclasses import _DATACLASS_PARAMS # type: ignore
@@ -171,8 +191,8 @@ def iter_fields(
171191
yield field_name, p_field.annotation
172192
else:
173193
for p_field in cls.__fields__.values(): # type: ignore [attr-defined]
174-
if p_field.field_info.allow_mutation or not exclude_frozen: # type: ignore
175-
yield p_field.name, p_field.outer_type_ # type: ignore
194+
if p_field.field_info.allow_mutation or not exclude_frozen: # type: ignore [attr-defined]
195+
yield p_field.name, p_field.outer_type_ # type: ignore [attr-defined]
176196
return
177197

178198
if (attrs_fields := getattr(cls, "__attrs_attrs__", None)) is not None:
@@ -185,3 +205,242 @@ def iter_fields(
185205
type_ = cls.__annotations__.get(m_field, None)
186206
yield m_field, type_
187207
return
208+
209+
210+
@dataclass
211+
class FieldOptions:
212+
name: str
213+
type_: type | None = None
214+
# set KW_ONLY value for compatibility with python < 3.10
215+
_: KW_ONLY = KW_ONLY # type: ignore [valid-type]
216+
alias: str | None = None
217+
skip: bool | None = None
218+
eq: EqOperator | None = None
219+
disable_setattr: bool | None = None
220+
221+
222+
def is_kw_only(f: Field) -> bool:
223+
if hasattr(f, "kw_only"):
224+
return cast(bool, f.kw_only)
225+
# for python < 3.10
226+
if f.name not in ["name", "type_"]:
227+
return True
228+
return False
229+
230+
231+
def sanitize_field_options_dict(d: Mapping) -> dict[str, Any]:
232+
field_options_kws = [f.name for f in fields(FieldOptions) if is_kw_only(f)]
233+
return {k: v for k, v in d.items() if k in field_options_kws}
234+
235+
236+
def get_msgspec_metadata(
237+
cls: type[msgspec.Struct],
238+
m_field: str,
239+
) -> tuple[type | None, dict[str, Any]]:
240+
# Look for type in cls and super classes
241+
type_: type | None = None
242+
for super_cls in cls.__mro__:
243+
if not hasattr(super_cls, "__annotations__"):
244+
continue
245+
type_ = super_cls.__annotations__.get(m_field, None)
246+
if type_ is not None:
247+
break
248+
249+
msgspec = sys.modules.get("msgspec", None)
250+
if msgspec is None:
251+
return type_, {}
252+
253+
metadata_list = getattr(type_, "__metadata__", [])
254+
255+
metadata: dict[str, Any] = {}
256+
for meta in metadata_list:
257+
if not isinstance(meta, msgspec.Meta):
258+
continue
259+
single_meta: dict[str, Any] = getattr(meta, "extra", {}).get(
260+
PSYGNAL_METADATA, {}
261+
)
262+
metadata.update(single_meta)
263+
264+
return type_, metadata
265+
266+
267+
def iter_fields_with_options(
268+
cls: type, exclude_frozen: bool = True
269+
) -> Iterator[FieldOptions]:
270+
"""Iterate over all fields in the class, return a field description.
271+
272+
This function recognizes dataclasses, attrs classes, msgspec Structs, and pydantic
273+
models.
274+
275+
Parameters
276+
----------
277+
cls : type
278+
The class to iterate over.
279+
exclude_frozen : bool, optional
280+
If True, frozen fields will be excluded. By default True.
281+
282+
Yields
283+
------
284+
FieldOptions
285+
A dataclass instance with the name, type and metadata of each field.
286+
"""
287+
# Add metadata for dataclasses.dataclass
288+
dclass_fields = getattr(cls, "__dataclass_fields__", None)
289+
if dclass_fields is not None:
290+
"""
291+
Example
292+
-------
293+
from dataclasses import dataclass, field
294+
295+
296+
@dataclass
297+
class Foo:
298+
bar: int = field(metadata={"alias": "bar_alias"})
299+
300+
assert (
301+
Foo.__dataclass_fields__["bar"].metadata ==
302+
{"__psygnal_metadata": {"alias": "bar_alias"}}
303+
)
304+
305+
"""
306+
for d_field in dclass_fields.values():
307+
if d_field._field_type is dataclasses._FIELD: # type: ignore [attr-defined]
308+
metadata = getattr(d_field, "metadata", {}).get(PSYGNAL_METADATA, {})
309+
metadata = sanitize_field_options_dict(metadata)
310+
options = FieldOptions(d_field.name, d_field.type, **metadata)
311+
yield options
312+
return
313+
314+
# Add metadata for pydantic dataclass
315+
if is_pydantic_model(cls):
316+
"""
317+
Example
318+
-------
319+
from typing import Annotated
320+
321+
from pydantic import BaseModel, Field
322+
323+
324+
# Only works with Pydantic v2
325+
class Foo(BaseModel):
326+
bar: Annotated[
327+
str,
328+
{'__psygnal_metadata': {"alias": "bar_alias"}}
329+
] = Field(...)
330+
331+
# Working with Pydantic v2 and partially with v1
332+
# Alternative, using Field `json_schema_extra` keyword argument
333+
class Bar(BaseModel):
334+
bar: str = Field(
335+
json_schema_extra={PSYGNAL_METADATA: {"alias": "bar_alias"}}
336+
)
337+
338+
339+
assert (
340+
Foo.model_fields["bar"].metadata[0] ==
341+
{"__psygnal_metadata": {"alias": "bar_alias"}}
342+
)
343+
assert (
344+
Bar.model_fields["bar"].json_schema_extra ==
345+
{"__psygnal_metadata": {"alias": "bar_alias"}}
346+
)
347+
348+
"""
349+
if hasattr(cls, "model_fields"):
350+
# Pydantic v2
351+
for field_name, p_field in cls.model_fields.items():
352+
# skip frozen field
353+
if exclude_frozen and p_field.frozen:
354+
continue
355+
metadata_list = getattr(p_field, "metadata", [])
356+
metadata = {}
357+
for field in metadata_list:
358+
metadata.update(field.get(PSYGNAL_METADATA, {}))
359+
# Compat with using Field `json_schema_extra` keyword argument
360+
if isinstance(getattr(p_field, "json_schema_extra", None), Mapping):
361+
meta_dict = cast(Mapping, p_field.json_schema_extra)
362+
metadata.update(meta_dict.get(PSYGNAL_METADATA, {}))
363+
metadata = sanitize_field_options_dict(metadata)
364+
options = FieldOptions(field_name, p_field.annotation, **metadata)
365+
yield options
366+
return
367+
368+
else:
369+
# Pydantic v1, metadata is not always working
370+
for pv1_field in cls.__fields__.values(): # type: ignore [attr-defined]
371+
# skip frozen field
372+
if exclude_frozen and not pv1_field.field_info.allow_mutation:
373+
continue
374+
meta_dict = getattr(pv1_field.field_info, "extra", {}).get(
375+
"json_schema_extra", {}
376+
)
377+
metadata = meta_dict.get(PSYGNAL_METADATA, {})
378+
379+
metadata = sanitize_field_options_dict(metadata)
380+
options = FieldOptions(
381+
pv1_field.name,
382+
pv1_field.outer_type_,
383+
**metadata,
384+
)
385+
yield options
386+
return
387+
388+
# Add metadata for attrs dataclass
389+
attrs_fields = getattr(cls, "__attrs_attrs__", None)
390+
if attrs_fields is not None:
391+
"""
392+
Example
393+
-------
394+
from attrs import define, field
395+
396+
397+
@define
398+
class Foo:
399+
bar: int = field(metadata={"alias": "bar_alias"})
400+
401+
assert (
402+
Foo.__attrs_attrs__.bar.metadata ==
403+
{"__psygnal_metadata": {"alias": "bar_alias"}}
404+
)
405+
406+
"""
407+
for a_field in attrs_fields:
408+
metadata = getattr(a_field, "metadata", {}).get(PSYGNAL_METADATA, {})
409+
metadata = sanitize_field_options_dict(metadata)
410+
options = FieldOptions(a_field.name, a_field.type, **metadata)
411+
yield options
412+
return
413+
414+
# Add metadata for attrs dataclass
415+
if is_msgspec_struct(cls):
416+
"""
417+
Example
418+
-------
419+
from typing import Annotated
420+
421+
from msgspec import Meta, Struct
422+
423+
424+
class Foo(Struct):
425+
bar: Annotated[
426+
str,
427+
Meta(extra={"__psygnal_metadata": {"alias": "bar_alias"}))
428+
] = ""
429+
430+
431+
print(Foo.__annotations__["bar"].__metadata__[0].extra)
432+
# {"__psygnal_metadata": {"alias": "bar_alias"}}
433+
434+
"""
435+
for m_field in cls.__struct_fields__:
436+
try:
437+
type_, metadata = get_msgspec_metadata(cls, m_field)
438+
metadata = sanitize_field_options_dict(metadata)
439+
except AttributeError:
440+
msg = f"Cannot parse field metadata for {m_field}: {type_}"
441+
# logger.exception(msg)
442+
print(msg)
443+
type_, metadata = None, {}
444+
options = FieldOptions(m_field, type_, **metadata)
445+
yield options
446+
return

0 commit comments

Comments
 (0)