Skip to content

Commit 21ee452

Browse files
Czakitlambert03
andauthored
fix: prevent dupe calls, alternative (#546)
* use variant generation to simplify discover callbacks * add test checking that order is not important * remove obsolete file * fix tests * fix type_registered --------- Co-authored-by: Talley Lambert <[email protected]>
1 parent edf00f7 commit 21ee452

File tree

4 files changed

+157
-60
lines changed

4 files changed

+157
-60
lines changed

src/magicgui/type_map/_type_map.py

Lines changed: 78 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import datetime
55
import inspect
6+
import itertools
67
import os
78
import pathlib
89
import sys
@@ -366,6 +367,65 @@ def _validate_return_callback(func: Callable) -> None:
366367
_T = TypeVar("_T", bound=type)
367368

368369

370+
def _register_type_callback(
371+
resolved_type: _T,
372+
return_callback: ReturnCallback | None = None,
373+
) -> list[type]:
374+
modified_callbacks = []
375+
if return_callback is None:
376+
return []
377+
_validate_return_callback(return_callback)
378+
# if the type is a Union, add the callback to all of the types in the union
379+
# (except NoneType)
380+
if get_origin(resolved_type) is Union:
381+
for type_per in _generate_union_variants(resolved_type):
382+
if return_callback not in _RETURN_CALLBACKS[type_per]:
383+
_RETURN_CALLBACKS[type_per].append(return_callback)
384+
modified_callbacks.append(type_per)
385+
386+
for t in get_args(resolved_type):
387+
if not _is_none_type(t) and return_callback not in _RETURN_CALLBACKS[t]:
388+
_RETURN_CALLBACKS[t].append(return_callback)
389+
modified_callbacks.append(t)
390+
elif return_callback not in _RETURN_CALLBACKS[resolved_type]:
391+
_RETURN_CALLBACKS[resolved_type].append(return_callback)
392+
modified_callbacks.append(resolved_type)
393+
return modified_callbacks
394+
395+
396+
def _register_widget(
397+
resolved_type: _T,
398+
widget_type: WidgetRef | None = None,
399+
**options: Any,
400+
) -> WidgetTuple | None:
401+
_options = cast(dict, options)
402+
403+
previous_widget = _TYPE_DEFS.get(resolved_type)
404+
405+
if "choices" in _options:
406+
_TYPE_DEFS[resolved_type] = (widgets.ComboBox, _options)
407+
if widget_type is not None:
408+
warnings.warn(
409+
"Providing `choices` overrides `widget_type`. Categorical widget "
410+
f"will be used for type {resolved_type}",
411+
stacklevel=2,
412+
)
413+
elif widget_type is not None:
414+
if not isinstance(widget_type, (str, WidgetProtocol)) and not (
415+
inspect.isclass(widget_type) and issubclass(widget_type, widgets.Widget)
416+
):
417+
raise TypeError(
418+
'"widget_type" must be either a string, WidgetProtocol, or '
419+
"Widget subclass"
420+
)
421+
_TYPE_DEFS[resolved_type] = (widget_type, _options)
422+
elif "bind" in _options:
423+
# if we're binding a value to this parameter, it doesn't matter what type
424+
# of ValueWidget is used... it usually won't be shown
425+
_TYPE_DEFS[resolved_type] = (widgets.EmptyWidget, _options)
426+
return previous_widget
427+
428+
369429
@overload
370430
def register_type(
371431
type_: _T,
@@ -435,43 +495,11 @@ def register_type(
435495
"must be provided."
436496
)
437497

438-
def _deco(type_: _T) -> _T:
439-
resolved_type = resolve_single_type(type_)
440-
if return_callback is not None:
441-
_validate_return_callback(return_callback)
442-
# if the type is a Union, add the callback to all of the types in the union
443-
# (except NoneType)
444-
if get_origin(resolved_type) is Union:
445-
for t in get_args(resolved_type):
446-
if not _is_none_type(t):
447-
_RETURN_CALLBACKS[t].append(return_callback)
448-
else:
449-
_RETURN_CALLBACKS[resolved_type].append(return_callback)
450-
451-
_options = cast(dict, options)
452-
453-
if "choices" in _options:
454-
_TYPE_DEFS[resolved_type] = (widgets.ComboBox, _options)
455-
if widget_type is not None:
456-
warnings.warn(
457-
"Providing `choices` overrides `widget_type`. Categorical widget "
458-
f"will be used for type {resolved_type}",
459-
stacklevel=2,
460-
)
461-
elif widget_type is not None:
462-
if not isinstance(widget_type, (str, WidgetProtocol)) and not (
463-
inspect.isclass(widget_type) and issubclass(widget_type, widgets.Widget)
464-
):
465-
raise TypeError(
466-
'"widget_type" must be either a string, WidgetProtocol, or '
467-
"Widget subclass"
468-
)
469-
_TYPE_DEFS[resolved_type] = (widget_type, _options)
470-
elif "bind" in _options:
471-
# if we're binding a value to this parameter, it doesn't matter what type
472-
# of ValueWidget is used... it usually won't be shown
473-
_TYPE_DEFS[resolved_type] = (widgets.EmptyWidget, _options)
474-
return type_
498+
def _deco(type__: _T) -> _T:
499+
resolved_type = resolve_single_type(type__)
500+
_register_type_callback(resolved_type, return_callback)
501+
_register_widget(resolved_type, widget_type, **options)
502+
return type__
475503

476504
return _deco if type_ is None else _deco(type_)
477505

@@ -507,23 +535,19 @@ def type_registered(
507535
"""
508536
resolved_type = resolve_single_type(type_)
509537

510-
# check if return_callback is already registered
511-
rc_was_present = return_callback in _RETURN_CALLBACKS.get(resolved_type, [])
512538
# store any previous widget_type and options for this type
513-
prev_type_def: WidgetTuple | None = _TYPE_DEFS.get(resolved_type, None)
514-
resolved_type = register_type(
515-
resolved_type,
516-
widget_type=widget_type,
517-
return_callback=return_callback,
518-
**options,
519-
)
539+
540+
revert_list = _register_type_callback(resolved_type, return_callback)
541+
prev_type_def = _register_widget(resolved_type, widget_type, **options)
542+
520543
new_type_def: WidgetTuple | None = _TYPE_DEFS.get(resolved_type, None)
521544
try:
522545
yield
523546
finally:
524547
# restore things to before the context
525-
if return_callback is not None and not rc_was_present:
526-
_RETURN_CALLBACKS[resolved_type].remove(return_callback)
548+
if return_callback is not None: # this if is only for mypy
549+
for return_callback_type in revert_list:
550+
_RETURN_CALLBACKS[return_callback_type].remove(return_callback)
527551

528552
if _TYPE_DEFS.get(resolved_type, None) is not new_type_def:
529553
warnings.warn("Type definition changed during context", stacklevel=2)
@@ -537,9 +561,6 @@ def type_registered(
537561
def type2callback(type_: type) -> list[ReturnCallback]:
538562
"""Return any callbacks that have been registered for ``type_``.
539563
540-
Note that if the return type is X, then the callbacks registered for Optional[X]
541-
will be returned also be returned.
542-
543564
Parameters
544565
----------
545566
type_ : type
@@ -555,7 +576,7 @@ def type2callback(type_: type) -> list[ReturnCallback]:
555576

556577
# look for direct hits ...
557578
# if it's an Optional, we need to look for the type inside the Optional
558-
_, type_ = _is_optional(resolve_single_type(type_))
579+
type_ = resolve_single_type(type_)
559580
if type_ in _RETURN_CALLBACKS:
560581
return _RETURN_CALLBACKS[type_]
561582

@@ -566,10 +587,8 @@ def type2callback(type_: type) -> list[ReturnCallback]:
566587
return []
567588

568589

569-
def _is_optional(type_: Any) -> tuple[bool, type]:
570-
# TODO: this function is too similar to _type_optional above... need to combine
571-
if get_origin(type_) is Union:
572-
args = get_args(type_)
573-
if len(args) == 2 and any(_is_none_type(i) for i in args):
574-
return True, next(i for i in args if not _is_none_type(i))
575-
return False, type_
590+
def _generate_union_variants(type_: Any) -> Iterator[type]:
591+
type_args = get_args(type_)
592+
for i in range(2, len(type_args) + 1):
593+
for per in itertools.combinations(type_args, i):
594+
yield cast(type, Union[per])

tests/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,12 @@ def always_qapp(qapp):
1717
for w in qapp.topLevelWidgets():
1818
w.close()
1919
w.deleteLater()
20+
21+
22+
@pytest.fixture(autouse=True, scope="function")
23+
def _clean_return_callbacks():
24+
from magicgui.type_map._type_map import _RETURN_CALLBACKS
25+
26+
yield
27+
28+
_RETURN_CALLBACKS.clear()

tests/test_magicgui.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import inspect
66
from enum import Enum
7-
from typing import NewType, Optional
7+
from typing import NewType, Optional, Union
88
from unittest.mock import Mock
99

1010
import pytest
@@ -901,3 +901,37 @@ def func_optional(a: bool) -> ReturnType:
901901
mock.reset_mock()
902902
func_optional(a=False)
903903
mock.assert_called_once_with(func_optional, None, ReturnType)
904+
905+
906+
@pytest.mark.parametrize("optional", [True, False])
907+
def test_no_duplication_call(optional):
908+
mock = Mock()
909+
mock2 = Mock()
910+
911+
NewInt = NewType("NewInt", int)
912+
register_type(Optional[NewInt], return_callback=mock)
913+
register_type(NewInt, return_callback=mock)
914+
register_type(NewInt, return_callback=mock2)
915+
ReturnType = Optional[NewInt] if optional else NewInt
916+
917+
@magicgui
918+
def func() -> ReturnType:
919+
return NewInt(1)
920+
921+
func()
922+
923+
mock.assert_called_once()
924+
assert mock2.call_count == (not optional)
925+
926+
927+
def test_no_order():
928+
mock = Mock()
929+
930+
register_type(Union[int, None], return_callback=mock)
931+
932+
@magicgui
933+
def func() -> Union[None, int]:
934+
return 1
935+
936+
func()
937+
mock.assert_called_once()

tests/test_types.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,41 @@ def test_type_registered_warns():
189189
assert isinstance(widgets.create_widget(annotation=Path), widgets.FileEdit)
190190

191191

192+
def test_type_registered_optional_callbacks():
193+
assert not _RETURN_CALLBACKS[int]
194+
assert not _RETURN_CALLBACKS[Optional[int]]
195+
196+
@magicgui
197+
def func1(a: int) -> int:
198+
return a
199+
200+
@magicgui
201+
def func2(a: int) -> Optional[int]:
202+
return a
203+
204+
mock1 = Mock()
205+
mock2 = Mock()
206+
mock3 = Mock()
207+
208+
register_type(int, return_callback=mock2)
209+
210+
with type_registered(Optional[int], return_callback=mock1):
211+
func1(1)
212+
mock1.assert_called_once_with(func1, 1, int)
213+
mock1.reset_mock()
214+
func2(2)
215+
mock1.assert_called_once_with(func2, 2, Optional[int])
216+
mock1.reset_mock()
217+
mock2.assert_called_once_with(func1, 1, int)
218+
assert _RETURN_CALLBACKS[int] == [mock2, mock1]
219+
assert _RETURN_CALLBACKS[Optional[int]] == [mock1]
220+
register_type(Optional[int], return_callback=mock3)
221+
assert _RETURN_CALLBACKS[Optional[int]] == [mock1, mock3]
222+
223+
assert _RETURN_CALLBACKS[Optional[int]] == [mock3]
224+
assert _RETURN_CALLBACKS[int] == [mock2, mock3]
225+
226+
192227
def test_pick_widget_literal():
193228
from typing import Literal
194229

0 commit comments

Comments
 (0)