|
18 | 18 |
|
19 | 19 | from gel._internal import _qb |
20 | 20 | from gel._internal._schemapath import ( |
21 | | - TypeNameIntersection, |
22 | 21 | TypeNameExpr, |
| 22 | + TypeNameIntersection, |
| 23 | + TypeNameUnion, |
23 | 24 | ) |
24 | 25 | from gel._internal import _type_expression |
25 | 26 | from gel._internal._xmethod import classonlymethod |
@@ -270,6 +271,25 @@ class BaseGelModelIntersectionBacklinks( |
270 | 271 | rhs: ClassVar[type[AbstractGelObjectBacklinksModel]] |
271 | 272 |
|
272 | 273 |
|
| 274 | +class BaseGelModelUnion( |
| 275 | + BaseGelModel, |
| 276 | + _type_expression.Union, |
| 277 | + Generic[_T_Lhs, _T_Rhs], |
| 278 | +): |
| 279 | + __gel_type_class__: ClassVar[type] |
| 280 | + |
| 281 | + lhs: ClassVar[type[AbstractGelModel]] |
| 282 | + rhs: ClassVar[type[AbstractGelModel]] |
| 283 | + |
| 284 | + |
| 285 | +class BaseGelModelUnionBacklinks( |
| 286 | + AbstractGelObjectBacklinksModel, |
| 287 | + _type_expression.Intersection, |
| 288 | +): |
| 289 | + lhs: ClassVar[type[AbstractGelObjectBacklinksModel]] |
| 290 | + rhs: ClassVar[type[AbstractGelObjectBacklinksModel]] |
| 291 | + |
| 292 | + |
273 | 293 | T = TypeVar('T') |
274 | 294 | U = TypeVar('U') |
275 | 295 |
|
@@ -318,6 +338,17 @@ def combine_dicts( |
318 | 338 | return result |
319 | 339 |
|
320 | 340 |
|
| 341 | +def _order_base_types(lhs: type, rhs: type) -> tuple[type, ...]: |
| 342 | + if lhs == rhs: |
| 343 | + return (lhs,) |
| 344 | + elif issubclass(lhs, rhs): |
| 345 | + return (lhs, rhs) |
| 346 | + elif issubclass(rhs, lhs): |
| 347 | + return (rhs, lhs) |
| 348 | + else: |
| 349 | + return (lhs, rhs) |
| 350 | + |
| 351 | + |
321 | 352 | _type_intersection_cache: weakref.WeakKeyDictionary[ |
322 | 353 | type[AbstractGelModel], |
323 | 354 | weakref.WeakKeyDictionary[ |
@@ -430,17 +461,6 @@ def object( |
430 | 461 | return result |
431 | 462 |
|
432 | 463 |
|
433 | | -def _order_base_types(lhs: type, rhs: type) -> tuple[type, ...]: |
434 | | - if lhs == rhs: |
435 | | - return (lhs,) |
436 | | - elif issubclass(lhs, rhs): |
437 | | - return (lhs, rhs) |
438 | | - elif issubclass(rhs, lhs): |
439 | | - return (rhs, lhs) |
440 | | - else: |
441 | | - return (lhs, rhs) |
442 | | - |
443 | | - |
444 | 464 | def create_intersection_backlinks( |
445 | 465 | lhs_backlinks: type[AbstractGelObjectBacklinksModel], |
446 | 466 | rhs_backlinks: type[AbstractGelObjectBacklinksModel], |
@@ -500,3 +520,162 @@ def create_intersection_backlinks( |
500 | 520 | ) |
501 | 521 |
|
502 | 522 | return backlinks |
| 523 | + |
| 524 | + |
| 525 | +_type_union_cache: weakref.WeakKeyDictionary[ |
| 526 | + type[AbstractGelModel], |
| 527 | + weakref.WeakKeyDictionary[ |
| 528 | + type[AbstractGelModel], |
| 529 | + type[BaseGelModelUnion[AbstractGelModel, AbstractGelModel]], |
| 530 | + ], |
| 531 | +] = weakref.WeakKeyDictionary() |
| 532 | + |
| 533 | + |
| 534 | +def create_optional_union( |
| 535 | + lhs: type[_T_Lhs] | None, |
| 536 | + rhs: type[_T_Rhs] | None, |
| 537 | +) -> type[BaseGelModelUnion[_T_Lhs, _T_Rhs] | AbstractGelModel] | None: |
| 538 | + if lhs is None: |
| 539 | + return rhs |
| 540 | + elif rhs is None: |
| 541 | + return lhs |
| 542 | + else: |
| 543 | + return create_union(lhs, rhs) |
| 544 | + |
| 545 | + |
| 546 | +def create_union( |
| 547 | + lhs: type[_T_Lhs], |
| 548 | + rhs: type[_T_Rhs], |
| 549 | +) -> type[BaseGelModelUnion[_T_Lhs, _T_Rhs]]: |
| 550 | + """Create a runtime union type which acts like a GelModel.""" |
| 551 | + |
| 552 | + if (lhs_entry := _type_union_cache.get(lhs)) and ( |
| 553 | + rhs_entry := lhs_entry.get(rhs) |
| 554 | + ): |
| 555 | + return rhs_entry # type: ignore[return-value] |
| 556 | + |
| 557 | + # Combine pointer reflections from args |
| 558 | + ptr_reflections: dict[str, _qb.GelPointerReflection] = { |
| 559 | + p_name: p_refl |
| 560 | + for p_name, p_refl in lhs.__gel_reflection__.pointers.items() |
| 561 | + if p_name in rhs.__gel_reflection__.pointers |
| 562 | + } |
| 563 | + |
| 564 | + # Create type reflection for union type |
| 565 | + class __gel_reflection__(_qb.GelObjectTypeExprMetadata.__gel_reflection__): # noqa: N801 |
| 566 | + expr_object_types: set[type[AbstractGelModel]] = getattr( |
| 567 | + lhs.__gel_reflection__, 'expr_object_types', {lhs} |
| 568 | + ) | getattr(rhs.__gel_reflection__, 'expr_object_types', {rhs}) |
| 569 | + |
| 570 | + type_name = TypeNameUnion( |
| 571 | + args=( |
| 572 | + lhs.__gel_reflection__.type_name, |
| 573 | + rhs.__gel_reflection__.type_name, |
| 574 | + ) |
| 575 | + ) |
| 576 | + |
| 577 | + pointers = ptr_reflections |
| 578 | + |
| 579 | + @classmethod |
| 580 | + def object( |
| 581 | + cls, |
| 582 | + ) -> Any: |
| 583 | + raise NotImplementedError( |
| 584 | + "Type expressions schema objects are inaccessible" |
| 585 | + ) |
| 586 | + |
| 587 | + # Create the resulting union type |
| 588 | + result = type( |
| 589 | + f"({lhs.__name__} | {rhs.__name__})", |
| 590 | + (BaseGelModelUnion,), |
| 591 | + { |
| 592 | + 'lhs': lhs, |
| 593 | + 'rhs': rhs, |
| 594 | + '__gel_reflection__': __gel_reflection__, |
| 595 | + "__gel_proxied_dunders__": frozenset( |
| 596 | + { |
| 597 | + "__backlinks__", |
| 598 | + } |
| 599 | + ), |
| 600 | + }, |
| 601 | + ) |
| 602 | + |
| 603 | + # Generate field descriptors. |
| 604 | + descriptors: dict[str, ModelFieldDescriptor] = { |
| 605 | + p_name: field_descriptor(result, p_name, l_path_alias.__gel_origin__) |
| 606 | + for p_name, p_refl in lhs.__gel_reflection__.pointers.items() |
| 607 | + if ( |
| 608 | + hasattr(lhs, p_name) |
| 609 | + and (l_path_alias := getattr(lhs, p_name, None)) is not None |
| 610 | + and isinstance(l_path_alias, _qb.PathAlias) |
| 611 | + ) |
| 612 | + if ( |
| 613 | + hasattr(rhs, p_name) |
| 614 | + and (r_path_alias := getattr(rhs, p_name, None)) is not None |
| 615 | + and isinstance(r_path_alias, _qb.PathAlias) |
| 616 | + ) |
| 617 | + } |
| 618 | + for p_name, descriptor in descriptors.items(): |
| 619 | + setattr(result, p_name, descriptor) |
| 620 | + |
| 621 | + if lhs not in _type_union_cache: |
| 622 | + _type_union_cache[lhs] = weakref.WeakKeyDictionary() |
| 623 | + _type_union_cache[lhs][rhs] = result |
| 624 | + |
| 625 | + return result |
| 626 | + |
| 627 | + |
| 628 | +def create_union_backlinks( |
| 629 | + lhs_backlinks: type[AbstractGelObjectBacklinksModel], |
| 630 | + rhs_backlinks: type[AbstractGelObjectBacklinksModel], |
| 631 | + result: type[BaseGelModelIntersection[Any, Any]], |
| 632 | + result_type_name: TypeNameExpr, |
| 633 | +) -> type[AbstractGelObjectBacklinksModel]: |
| 634 | + reflection = type( |
| 635 | + "__gel_reflection__", |
| 636 | + _order_base_types( |
| 637 | + lhs_backlinks.__gel_reflection__, |
| 638 | + rhs_backlinks.__gel_reflection__, |
| 639 | + ), |
| 640 | + { |
| 641 | + "name": result_type_name, |
| 642 | + "type_name": result_type_name, |
| 643 | + "pointers": { |
| 644 | + p_name: p_refl |
| 645 | + for p_name, p_refl in lhs_backlinks.__gel_reflection__.pointers |
| 646 | + if p_name in rhs_backlinks.__gel_reflection__.pointers |
| 647 | + }, |
| 648 | + }, |
| 649 | + ) |
| 650 | + |
| 651 | + # Generate field descriptors for backlinks. |
| 652 | + field_descriptors: dict[str, ModelFieldDescriptor] = { |
| 653 | + p_name: field_descriptor(result, p_name, l_path_alias.__gel_origin__) |
| 654 | + for p_name in lhs_backlinks.__gel_reflection__.pointers |
| 655 | + if ( |
| 656 | + hasattr(lhs_backlinks, p_name) |
| 657 | + and (l_path_alias := getattr(lhs_backlinks, p_name, None)) |
| 658 | + is not None |
| 659 | + and isinstance(l_path_alias, _qb.PathAlias) |
| 660 | + ) |
| 661 | + if ( |
| 662 | + hasattr(rhs_backlinks, p_name) |
| 663 | + and (r_path_alias := getattr(rhs_backlinks, p_name, None)) |
| 664 | + is not None |
| 665 | + and isinstance(r_path_alias, _qb.PathAlias) |
| 666 | + ) |
| 667 | + } |
| 668 | + |
| 669 | + backlinks = type( |
| 670 | + f"__{result_type_name.name}_backlinks__", |
| 671 | + (BaseGelModelUnionBacklinks,), |
| 672 | + { |
| 673 | + 'lhs': lhs_backlinks, |
| 674 | + 'rhs': rhs_backlinks, |
| 675 | + '__gel_reflection__': reflection, |
| 676 | + '__module__': __name__, |
| 677 | + **field_descriptors, |
| 678 | + }, |
| 679 | + ) |
| 680 | + |
| 681 | + return backlinks |
0 commit comments