Skip to content

Commit afc3f1f

Browse files
committed
Fix passing type intersections to overloaded functions.
1 parent 840845f commit afc3f1f

File tree

2 files changed

+81
-6
lines changed

2 files changed

+81
-6
lines changed

gel/_internal/_typing_dispatch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from gel._internal import _typing_inspect
3939
from gel._internal import _typing_parametric
4040
from gel._internal._utils import type_repr
41+
from gel._internal._qbmodel._abstract._methods import BaseGelModelIntersection
4142

4243
_P = ParamSpec("_P")
4344
_R_co = TypeVar("_R_co", covariant=True)
@@ -66,6 +67,10 @@ def _issubclass(lhs: Any, tp: Any, fn: Any) -> bool:
6667
# subtypes of the variable bounds.
6768
# This lets us handle cases like:
6869
# std.array[Object] <: std.array[_T_anytype].
70+
71+
if issubclass(lhs, BaseGelModelIntersection):
72+
return any(_issubclass(c, tp, fn) for c in (lhs.lhs, lhs.rhs))
73+
6974
if _typing_inspect.is_generic_alias(tp):
7075
origin = typing.get_origin(tp)
7176
args = typing.get_args(tp)

tests/test_qb.py

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1695,9 +1695,7 @@ def test_qb_is_type_basic_07(self):
16951695
# Link TypeIntersection
16961696
from models.orm_qb import default
16971697

1698-
result = self.client.query(
1699-
default.Link_Inh_A.l.is_(default.Inh_B)
1700-
)
1698+
result = self.client.query(default.Link_Inh_A.l.is_(default.Inh_B))
17011699

17021700
self._assertObjectsWithFields(
17031701
result,
@@ -1900,9 +1898,9 @@ def test_qb_is_type_for_01(self):
19001898
from models.orm_qb import default, std
19011899

19021900
result = self.client.query(
1903-
std.for_(
1904-
default.Inh_A.is_(default.Inh_B), lambda x: x
1905-
).select(a=True)
1901+
std.for_(default.Inh_A.is_(default.Inh_B), lambda x: x).select(
1902+
a=True
1903+
)
19061904
)
19071905

19081906
self._assertObjectsWithFields(
@@ -2014,6 +2012,78 @@ def test_qb_is_type_for_03(self):
20142012
excluded_fields={'b', 'c', 'ab', 'ac', 'bc', 'abc', 'ab_ac'},
20152013
)
20162014

2015+
def test_qb_is_type_as_function_arg_01(self):
2016+
# Test that type exprs produced by is_ can be passed as function args
2017+
from models.orm_qb import default, std
2018+
2019+
result = self.client.query(
2020+
std.distinct(default.Inh_A.is_(default.Inh_B)).select('*')
2021+
)
2022+
2023+
self._assertObjectsWithFields(
2024+
result,
2025+
"a",
2026+
[
2027+
(
2028+
default.Inh_AB,
2029+
{
2030+
"a": 4,
2031+
"b": 5,
2032+
},
2033+
),
2034+
(
2035+
default.Inh_ABC,
2036+
{
2037+
"a": 13,
2038+
"b": 14,
2039+
},
2040+
),
2041+
(
2042+
default.Inh_AB_AC,
2043+
{
2044+
"a": 17,
2045+
"b": 18,
2046+
},
2047+
),
2048+
],
2049+
excluded_fields={'c', 'ab', 'ac', 'bc', 'abc', 'ab_ac'},
2050+
)
2051+
2052+
def test_qb_is_type_as_function_arg_02(self):
2053+
# Test that complex type exprs produced by is_ can be passed as
2054+
# function args
2055+
from models.orm_qb import default, std
2056+
2057+
result = self.client.query(
2058+
std.distinct(
2059+
default.Inh_A.is_(default.Inh_B).is_(default.Inh_C)
2060+
).select('*')
2061+
)
2062+
2063+
self._assertObjectsWithFields(
2064+
result,
2065+
"a",
2066+
[
2067+
(
2068+
default.Inh_ABC,
2069+
{
2070+
"a": 13,
2071+
"b": 14,
2072+
"c": 15,
2073+
},
2074+
),
2075+
(
2076+
default.Inh_AB_AC,
2077+
{
2078+
"a": 17,
2079+
"b": 18,
2080+
"c": 19,
2081+
},
2082+
),
2083+
],
2084+
excluded_fields={'ab', 'ac', 'bc', 'abc', 'ab_ac'},
2085+
)
2086+
20172087

20182088
class TestQueryBuilderModify(tb.ModelTestCase):
20192089
"""This test suite is for data manipulation using QB."""

0 commit comments

Comments
 (0)