Skip to content

Commit 9cb9230

Browse files
committed
Fix union and coalesce expressions not decoding to the correct type.
1 parent ca84f90 commit 9cb9230

File tree

5 files changed

+242
-12
lines changed

5 files changed

+242
-12
lines changed

gel/_internal/_codegen/_models/_pydantic.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6157,6 +6157,45 @@ def resolve(
61576157
f"# type: ignore [assignment, misc, unused-ignore]"
61586158
)
61596159

6160+
if function.schemapath in {
6161+
SchemaPath('std', 'UNION'),
6162+
SchemaPath('std', 'IF'),
6163+
SchemaPath('std', '??'),
6164+
}:
6165+
# Special case for the UNION, IF and ?? operators
6166+
# Produce a union type instead of just taking the first
6167+
# valid type.
6168+
#
6169+
# See gel: edb.compiler.func.compile_operator
6170+
create_union = self.import_name(
6171+
BASE_IMPL, "create_optional_union"
6172+
)
6173+
6174+
tvars: list[str] = []
6175+
for param, path in sources:
6176+
if (
6177+
param.name in required_generic_params
6178+
or param.name in optional_generic_params
6179+
):
6180+
pn = param_vars[param.name]
6181+
tvar = f"__t_{pn}__"
6182+
6183+
resolve(pn, path, tvar)
6184+
tvars.append(tvar)
6185+
6186+
self.write(
6187+
f"{gtvar} = {tvars[0]} "
6188+
f"# type: ignore [assignment, misc, unused-ignore]"
6189+
)
6190+
for tvar in tvars[1:]:
6191+
self.write(
6192+
f"{gtvar} = {create_union}({gtvar}, {tvar}) "
6193+
f"# type: ignore ["
6194+
f"assignment, misc, unused-ignore]"
6195+
)
6196+
6197+
continue
6198+
61606199
# Try to infer generic type from required params first
61616200
for param, path in sources:
61626201
if param.name in required_generic_params:

gel/_internal/_qbmodel/_abstract/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@
6868
from ._methods import (
6969
BaseGelModel,
7070
BaseGelModelIntersection,
71+
BaseGelModelUnion,
72+
create_optional_union,
73+
create_union,
7174
)
7275

7376

@@ -138,6 +141,7 @@
138141
"ArrayMeta",
139142
"BaseGelModel",
140143
"BaseGelModelIntersection",
144+
"BaseGelModelUnion",
141145
"ComputedLinkSet",
142146
"ComputedLinkWithPropsSet",
143147
"ComputedMultiLinkDescriptor",
@@ -181,6 +185,8 @@
181185
"TupleMeta",
182186
"UUIDImpl",
183187
"copy_or_ref_lprops",
188+
"create_optional_union",
189+
"create_union",
184190
"empty_set_if_none",
185191
"field_descriptor",
186192
"get_base_scalars_backed_by_py_type",

gel/_internal/_qbmodel/_abstract/_methods.py

Lines changed: 191 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818

1919
from gel._internal import _qb
2020
from gel._internal._schemapath import (
21-
TypeNameIntersection,
2221
TypeNameExpr,
22+
TypeNameIntersection,
23+
TypeNameUnion,
2324
)
2425
from gel._internal import _type_expression
2526
from gel._internal._xmethod import classonlymethod
@@ -270,6 +271,25 @@ class BaseGelModelIntersectionBacklinks(
270271
rhs: ClassVar[type[AbstractGelObjectBacklinksModel]]
271272

272273

274+
class BaseGelModelUnion(
275+
BaseGelModel,
276+
_type_expression.Union,
277+
Generic[_T_Lhs, _T_Rhs],
278+
):
279+
__gel_type_class__: ClassVar[type]
280+
281+
lhs: ClassVar[type[AbstractGelModel]]
282+
rhs: ClassVar[type[AbstractGelModel]]
283+
284+
285+
class BaseGelModelUnionBacklinks(
286+
AbstractGelObjectBacklinksModel,
287+
_type_expression.Intersection,
288+
):
289+
lhs: ClassVar[type[AbstractGelObjectBacklinksModel]]
290+
rhs: ClassVar[type[AbstractGelObjectBacklinksModel]]
291+
292+
273293
T = TypeVar('T')
274294
U = TypeVar('U')
275295

@@ -318,6 +338,17 @@ def combine_dicts(
318338
return result
319339

320340

341+
def _order_base_types(lhs: type, rhs: type) -> tuple[type, ...]:
342+
if lhs == rhs:
343+
return (lhs,)
344+
elif issubclass(lhs, rhs):
345+
return (lhs, rhs)
346+
elif issubclass(rhs, lhs):
347+
return (rhs, lhs)
348+
else:
349+
return (lhs, rhs)
350+
351+
321352
_type_intersection_cache: weakref.WeakKeyDictionary[
322353
type[AbstractGelModel],
323354
weakref.WeakKeyDictionary[
@@ -430,17 +461,6 @@ def object(
430461
return result
431462

432463

433-
def _order_base_types(lhs: type, rhs: type) -> tuple[type, ...]:
434-
if lhs == rhs:
435-
return (lhs,)
436-
elif issubclass(lhs, rhs):
437-
return (lhs, rhs)
438-
elif issubclass(rhs, lhs):
439-
return (rhs, lhs)
440-
else:
441-
return (lhs, rhs)
442-
443-
444464
def create_intersection_backlinks(
445465
lhs_backlinks: type[AbstractGelObjectBacklinksModel],
446466
rhs_backlinks: type[AbstractGelObjectBacklinksModel],
@@ -500,3 +520,162 @@ def create_intersection_backlinks(
500520
)
501521

502522
return backlinks
523+
524+
525+
_type_union_cache: weakref.WeakKeyDictionary[
526+
type[AbstractGelModel],
527+
weakref.WeakKeyDictionary[
528+
type[AbstractGelModel],
529+
type[BaseGelModelUnion[AbstractGelModel, AbstractGelModel]],
530+
],
531+
] = weakref.WeakKeyDictionary()
532+
533+
534+
def create_optional_union(
535+
lhs: type[_T_Lhs] | None,
536+
rhs: type[_T_Rhs] | None,
537+
) -> type[BaseGelModelUnion[_T_Lhs, _T_Rhs] | AbstractGelModel] | None:
538+
if lhs is None:
539+
return rhs
540+
elif rhs is None:
541+
return lhs
542+
else:
543+
return create_union(lhs, rhs)
544+
545+
546+
def create_union(
547+
lhs: type[_T_Lhs],
548+
rhs: type[_T_Rhs],
549+
) -> type[BaseGelModelUnion[_T_Lhs, _T_Rhs]]:
550+
"""Create a runtime union type which acts like a GelModel."""
551+
552+
if (lhs_entry := _type_union_cache.get(lhs)) and (
553+
rhs_entry := lhs_entry.get(rhs)
554+
):
555+
return rhs_entry # type: ignore[return-value]
556+
557+
# Combine pointer reflections from args
558+
ptr_reflections: dict[str, _qb.GelPointerReflection] = {
559+
p_name: p_refl
560+
for p_name, p_refl in lhs.__gel_reflection__.pointers.items()
561+
if p_name in rhs.__gel_reflection__.pointers
562+
}
563+
564+
# Create type reflection for union type
565+
class __gel_reflection__(_qb.GelObjectTypeExprMetadata.__gel_reflection__): # noqa: N801
566+
expr_object_types: set[type[AbstractGelModel]] = getattr(
567+
lhs.__gel_reflection__, 'expr_object_types', {lhs}
568+
) | getattr(rhs.__gel_reflection__, 'expr_object_types', {rhs})
569+
570+
type_name = TypeNameUnion(
571+
args=(
572+
lhs.__gel_reflection__.type_name,
573+
rhs.__gel_reflection__.type_name,
574+
)
575+
)
576+
577+
pointers = ptr_reflections
578+
579+
@classmethod
580+
def object(
581+
cls,
582+
) -> Any:
583+
raise NotImplementedError(
584+
"Type expressions schema objects are inaccessible"
585+
)
586+
587+
# Create the resulting union type
588+
result = type(
589+
f"({lhs.__name__} | {rhs.__name__})",
590+
(BaseGelModelUnion,),
591+
{
592+
'lhs': lhs,
593+
'rhs': rhs,
594+
'__gel_reflection__': __gel_reflection__,
595+
"__gel_proxied_dunders__": frozenset(
596+
{
597+
"__backlinks__",
598+
}
599+
),
600+
},
601+
)
602+
603+
# Generate field descriptors.
604+
descriptors: dict[str, ModelFieldDescriptor] = {
605+
p_name: field_descriptor(result, p_name, l_path_alias.__gel_origin__)
606+
for p_name, p_refl in lhs.__gel_reflection__.pointers.items()
607+
if (
608+
hasattr(lhs, p_name)
609+
and (l_path_alias := getattr(lhs, p_name, None)) is not None
610+
and isinstance(l_path_alias, _qb.PathAlias)
611+
)
612+
if (
613+
hasattr(rhs, p_name)
614+
and (r_path_alias := getattr(rhs, p_name, None)) is not None
615+
and isinstance(r_path_alias, _qb.PathAlias)
616+
)
617+
}
618+
for p_name, descriptor in descriptors.items():
619+
setattr(result, p_name, descriptor)
620+
621+
if lhs not in _type_union_cache:
622+
_type_union_cache[lhs] = weakref.WeakKeyDictionary()
623+
_type_union_cache[lhs][rhs] = result
624+
625+
return result
626+
627+
628+
def create_union_backlinks(
629+
lhs_backlinks: type[AbstractGelObjectBacklinksModel],
630+
rhs_backlinks: type[AbstractGelObjectBacklinksModel],
631+
result: type[BaseGelModelIntersection[Any, Any]],
632+
result_type_name: TypeNameExpr,
633+
) -> type[AbstractGelObjectBacklinksModel]:
634+
reflection = type(
635+
"__gel_reflection__",
636+
_order_base_types(
637+
lhs_backlinks.__gel_reflection__,
638+
rhs_backlinks.__gel_reflection__,
639+
),
640+
{
641+
"name": result_type_name,
642+
"type_name": result_type_name,
643+
"pointers": {
644+
p_name: p_refl
645+
for p_name, p_refl in lhs_backlinks.__gel_reflection__.pointers
646+
if p_name in rhs_backlinks.__gel_reflection__.pointers
647+
},
648+
},
649+
)
650+
651+
# Generate field descriptors for backlinks.
652+
field_descriptors: dict[str, ModelFieldDescriptor] = {
653+
p_name: field_descriptor(result, p_name, l_path_alias.__gel_origin__)
654+
for p_name in lhs_backlinks.__gel_reflection__.pointers
655+
if (
656+
hasattr(lhs_backlinks, p_name)
657+
and (l_path_alias := getattr(lhs_backlinks, p_name, None))
658+
is not None
659+
and isinstance(l_path_alias, _qb.PathAlias)
660+
)
661+
if (
662+
hasattr(rhs_backlinks, p_name)
663+
and (r_path_alias := getattr(rhs_backlinks, p_name, None))
664+
is not None
665+
and isinstance(r_path_alias, _qb.PathAlias)
666+
)
667+
}
668+
669+
backlinks = type(
670+
f"__{result_type_name.name}_backlinks__",
671+
(BaseGelModelUnionBacklinks,),
672+
{
673+
'lhs': lhs_backlinks,
674+
'rhs': rhs_backlinks,
675+
'__gel_reflection__': reflection,
676+
'__module__': __name__,
677+
**field_descriptors,
678+
},
679+
)
680+
681+
return backlinks

gel/_internal/_typing_dispatch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ def _issubclass(lhs: Any, tp: Any, fn: Any) -> bool:
7070

7171
if issubclass(lhs, _type_expression.Intersection):
7272
return any(_issubclass(c, tp, fn) for c in (lhs.lhs, lhs.rhs))
73+
elif issubclass(lhs, _type_expression.Union):
74+
return all(_issubclass(c, tp, fn) for c in (lhs.lhs, lhs.rhs))
7375

7476
if _typing_inspect.is_generic_alias(tp):
7577
origin = typing.get_origin(tp)

gel/models/pydantic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@
7676
PyTypeScalarConstraint,
7777
RangeMeta,
7878
UUIDImpl,
79+
create_optional_union,
80+
create_union,
7981
empty_set_if_none,
8082
)
8183

@@ -215,6 +217,8 @@
215217
"classonlymethod",
216218
"computed_field",
217219
"construct_infix_op_chain",
220+
"create_optional_union",
221+
"create_union",
218222
"dispatch_overload",
219223
"empty_set_if_none",
220224
)

0 commit comments

Comments
 (0)