3
3
4
4
import datetime
5
5
import inspect
6
+ import itertools
6
7
import os
7
8
import pathlib
8
9
import sys
@@ -366,6 +367,65 @@ def _validate_return_callback(func: Callable) -> None:
366
367
_T = TypeVar ("_T" , bound = type )
367
368
368
369
370
+ def _register_type_callback (
371
+ resolved_type : _T ,
372
+ return_callback : ReturnCallback | None = None ,
373
+ ) -> list [type ]:
374
+ modified_callbacks = []
375
+ if return_callback is None :
376
+ return []
377
+ _validate_return_callback (return_callback )
378
+ # if the type is a Union, add the callback to all of the types in the union
379
+ # (except NoneType)
380
+ if get_origin (resolved_type ) is Union :
381
+ for type_per in _generate_union_variants (resolved_type ):
382
+ if return_callback not in _RETURN_CALLBACKS [type_per ]:
383
+ _RETURN_CALLBACKS [type_per ].append (return_callback )
384
+ modified_callbacks .append (type_per )
385
+
386
+ for t in get_args (resolved_type ):
387
+ if not _is_none_type (t ) and return_callback not in _RETURN_CALLBACKS [t ]:
388
+ _RETURN_CALLBACKS [t ].append (return_callback )
389
+ modified_callbacks .append (t )
390
+ elif return_callback not in _RETURN_CALLBACKS [resolved_type ]:
391
+ _RETURN_CALLBACKS [resolved_type ].append (return_callback )
392
+ modified_callbacks .append (resolved_type )
393
+ return modified_callbacks
394
+
395
+
396
+ def _register_widget (
397
+ resolved_type : _T ,
398
+ widget_type : WidgetRef | None = None ,
399
+ ** options : Any ,
400
+ ) -> WidgetTuple | None :
401
+ _options = cast (dict , options )
402
+
403
+ previous_widget = _TYPE_DEFS .get (resolved_type )
404
+
405
+ if "choices" in _options :
406
+ _TYPE_DEFS [resolved_type ] = (widgets .ComboBox , _options )
407
+ if widget_type is not None :
408
+ warnings .warn (
409
+ "Providing `choices` overrides `widget_type`. Categorical widget "
410
+ f"will be used for type { resolved_type } " ,
411
+ stacklevel = 2 ,
412
+ )
413
+ elif widget_type is not None :
414
+ if not isinstance (widget_type , (str , WidgetProtocol )) and not (
415
+ inspect .isclass (widget_type ) and issubclass (widget_type , widgets .Widget )
416
+ ):
417
+ raise TypeError (
418
+ '"widget_type" must be either a string, WidgetProtocol, or '
419
+ "Widget subclass"
420
+ )
421
+ _TYPE_DEFS [resolved_type ] = (widget_type , _options )
422
+ elif "bind" in _options :
423
+ # if we're binding a value to this parameter, it doesn't matter what type
424
+ # of ValueWidget is used... it usually won't be shown
425
+ _TYPE_DEFS [resolved_type ] = (widgets .EmptyWidget , _options )
426
+ return previous_widget
427
+
428
+
369
429
@overload
370
430
def register_type (
371
431
type_ : _T ,
@@ -435,43 +495,11 @@ def register_type(
435
495
"must be provided."
436
496
)
437
497
438
- def _deco (type_ : _T ) -> _T :
439
- resolved_type = resolve_single_type (type_ )
440
- if return_callback is not None :
441
- _validate_return_callback (return_callback )
442
- # if the type is a Union, add the callback to all of the types in the union
443
- # (except NoneType)
444
- if get_origin (resolved_type ) is Union :
445
- for t in get_args (resolved_type ):
446
- if not _is_none_type (t ):
447
- _RETURN_CALLBACKS [t ].append (return_callback )
448
- else :
449
- _RETURN_CALLBACKS [resolved_type ].append (return_callback )
450
-
451
- _options = cast (dict , options )
452
-
453
- if "choices" in _options :
454
- _TYPE_DEFS [resolved_type ] = (widgets .ComboBox , _options )
455
- if widget_type is not None :
456
- warnings .warn (
457
- "Providing `choices` overrides `widget_type`. Categorical widget "
458
- f"will be used for type { resolved_type } " ,
459
- stacklevel = 2 ,
460
- )
461
- elif widget_type is not None :
462
- if not isinstance (widget_type , (str , WidgetProtocol )) and not (
463
- inspect .isclass (widget_type ) and issubclass (widget_type , widgets .Widget )
464
- ):
465
- raise TypeError (
466
- '"widget_type" must be either a string, WidgetProtocol, or '
467
- "Widget subclass"
468
- )
469
- _TYPE_DEFS [resolved_type ] = (widget_type , _options )
470
- elif "bind" in _options :
471
- # if we're binding a value to this parameter, it doesn't matter what type
472
- # of ValueWidget is used... it usually won't be shown
473
- _TYPE_DEFS [resolved_type ] = (widgets .EmptyWidget , _options )
474
- return type_
498
+ def _deco (type__ : _T ) -> _T :
499
+ resolved_type = resolve_single_type (type__ )
500
+ _register_type_callback (resolved_type , return_callback )
501
+ _register_widget (resolved_type , widget_type , ** options )
502
+ return type__
475
503
476
504
return _deco if type_ is None else _deco (type_ )
477
505
@@ -507,23 +535,19 @@ def type_registered(
507
535
"""
508
536
resolved_type = resolve_single_type (type_ )
509
537
510
- # check if return_callback is already registered
511
- rc_was_present = return_callback in _RETURN_CALLBACKS .get (resolved_type , [])
512
538
# store any previous widget_type and options for this type
513
- prev_type_def : WidgetTuple | None = _TYPE_DEFS .get (resolved_type , None )
514
- resolved_type = register_type (
515
- resolved_type ,
516
- widget_type = widget_type ,
517
- return_callback = return_callback ,
518
- ** options ,
519
- )
539
+
540
+ revert_list = _register_type_callback (resolved_type , return_callback )
541
+ prev_type_def = _register_widget (resolved_type , widget_type , ** options )
542
+
520
543
new_type_def : WidgetTuple | None = _TYPE_DEFS .get (resolved_type , None )
521
544
try :
522
545
yield
523
546
finally :
524
547
# restore things to before the context
525
- if return_callback is not None and not rc_was_present :
526
- _RETURN_CALLBACKS [resolved_type ].remove (return_callback )
548
+ if return_callback is not None : # this if is only for mypy
549
+ for return_callback_type in revert_list :
550
+ _RETURN_CALLBACKS [return_callback_type ].remove (return_callback )
527
551
528
552
if _TYPE_DEFS .get (resolved_type , None ) is not new_type_def :
529
553
warnings .warn ("Type definition changed during context" , stacklevel = 2 )
@@ -537,9 +561,6 @@ def type_registered(
537
561
def type2callback (type_ : type ) -> list [ReturnCallback ]:
538
562
"""Return any callbacks that have been registered for ``type_``.
539
563
540
- Note that if the return type is X, then the callbacks registered for Optional[X]
541
- will be returned also be returned.
542
-
543
564
Parameters
544
565
----------
545
566
type_ : type
@@ -555,7 +576,7 @@ def type2callback(type_: type) -> list[ReturnCallback]:
555
576
556
577
# look for direct hits ...
557
578
# if it's an Optional, we need to look for the type inside the Optional
558
- _ , type_ = _is_optional ( resolve_single_type (type_ ) )
579
+ type_ = resolve_single_type (type_ )
559
580
if type_ in _RETURN_CALLBACKS :
560
581
return _RETURN_CALLBACKS [type_ ]
561
582
@@ -566,10 +587,8 @@ def type2callback(type_: type) -> list[ReturnCallback]:
566
587
return []
567
588
568
589
569
- def _is_optional (type_ : Any ) -> tuple [bool , type ]:
570
- # TODO: this function is too similar to _type_optional above... need to combine
571
- if get_origin (type_ ) is Union :
572
- args = get_args (type_ )
573
- if len (args ) == 2 and any (_is_none_type (i ) for i in args ):
574
- return True , next (i for i in args if not _is_none_type (i ))
575
- return False , type_
590
+ def _generate_union_variants (type_ : Any ) -> Iterator [type ]:
591
+ type_args = get_args (type_ )
592
+ for i in range (2 , len (type_args ) + 1 ):
593
+ for per in itertools .combinations (type_args , i ):
594
+ yield cast (type , Union [per ])
0 commit comments