Skip to content

Commit 60fd710

Browse files
committed
Flatten the implementation of the pipeline() decorator.
It seems easier to follow in a single function rather than being spread out in multiple helpers. (It also makes the signature of `pipeline()` explicit without having to duplicate it.)
1 parent 8669c26 commit 60fd710

File tree

1 file changed

+27
-98
lines changed

1 file changed

+27
-98
lines changed

slicerator/__init__.py

Lines changed: 27 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import collections.abc
55
import itertools
6-
from functools import wraps
6+
from functools import partial, wraps
77
from copy import copy
88
import inspect
99

@@ -503,7 +503,7 @@ def __setstate__(self, data_as_list):
503503
return self.__init__(lambda x: x, data_as_list)
504504

505505

506-
def pipeline(func=None, **kwargs):
506+
def pipeline(func=None, *, retain_doc=False, ancestor_count=1):
507507
"""Decorator to enable lazy evaluation of a function.
508508
509509
When the function is applied to a Slicerator or Pipeline object, it
@@ -540,8 +540,8 @@ def pipeline(func=None, **kwargs):
540540
Apply the pipeline decorator to your image processing function.
541541
542542
>>> @pipeline
543-
... def color_channel(image, channel):
544-
... return image[channel, :, :]
543+
... def color_channel(image, channel):
544+
... return image[channel, :, :]
545545
...
546546
547547
@@ -583,94 +583,19 @@ def pipeline(func=None, **kwargs):
583583
... def sum_offset(img1, img2, offset):
584584
... return img1 + img2 + offset
585585
"""
586-
def wrapper(f):
587-
return _pipeline(f, **kwargs)
588-
589586
if func is None:
590-
return wrapper
591-
else:
592-
return wrapper(func)
587+
return partial(
588+
pipeline, retain_doc=retain_doc, ancestor_count=ancestor_count)
593589

590+
if ancestor_count == 'all':
591+
ancestor_count = len(
592+
p for p in inspect.signature(func).parameters
593+
if p.kind.name in ["POSITIONAL_ONLY", "POSITIONAL_OR_KEYWORD"])
594594

595-
def _pipeline(func_or_class, **kwargs):
596595
try:
597-
is_class = issubclass(func_or_class, Pipeline)
596+
is_class = issubclass(func, Pipeline)
598597
except TypeError:
599598
is_class = False
600-
if is_class:
601-
return _pipeline_fromclass(func_or_class, **kwargs)
602-
else:
603-
return _pipeline_fromfunc(func_or_class, **kwargs)
604-
605-
606-
def _pipeline_fromclass(cls, retain_doc=False, ancestor_count=1):
607-
"""Actual `pipeline` implementation
608-
609-
Parameters
610-
----------
611-
func : class
612-
Class for lazy evaluation
613-
retain_doc : bool
614-
If True, don't modify `func`'s doc string to say that it has been
615-
made lazy
616-
ancestor_count : int or 'all', optional
617-
Number of inputs to the pipeline. Defaults to 1.
618-
619-
Returns
620-
-------
621-
Pipeline
622-
Lazy function evaluation :py:class:`Pipeline` for `func`.
623-
"""
624-
if ancestor_count == 'all':
625-
# subtract 1 for `self`
626-
ancestor_count = len(inspect.getfullargspec(cls).args) - 1
627-
628-
@wraps(cls)
629-
def process(*args, **kwargs):
630-
ancestors = args[:ancestor_count]
631-
args = args[ancestor_count:]
632-
all_pipe = all(hasattr(a, '_slicerator_flag') or
633-
isinstance(a, Slicerator) or
634-
isinstance(a, Pipeline) for a in ancestors)
635-
if all_pipe:
636-
return cls(*(ancestors + args), **kwargs)
637-
else:
638-
# Fall back on normal behavior of func, interpreting input
639-
# as a single image.
640-
return cls(*(tuple([a] for a in ancestors) + args), **kwargs)[0]
641-
642-
if not retain_doc:
643-
if process.__doc__ is None:
644-
process.__doc__ = ''
645-
process.__doc__ = ("This function has been made lazy. When passed\n"
646-
"a Slicerator, it will return a \n"
647-
"Pipeline of the results. When passed \n"
648-
"any other objects, its behavior is "
649-
"unchanged.\n\n") + process.__doc__
650-
process.__name__ = cls.__name__
651-
return process
652-
653-
654-
def _pipeline_fromfunc(func, retain_doc=False, ancestor_count=1):
655-
"""Actual `pipeline` implementation
656-
657-
Parameters
658-
----------
659-
func : callable
660-
Function for lazy evaluation
661-
retain_doc : bool
662-
If True, don't modify `func`'s doc string to say that it has been
663-
made lazy
664-
ancestor_count : int or 'all', optional
665-
Number of inputs to the pipeline. Defaults to 1.
666-
667-
Returns
668-
-------
669-
Pipeline
670-
Lazy function evaluation :py:class:`Pipeline` for `func`.
671-
"""
672-
if ancestor_count == 'all':
673-
ancestor_count = len(inspect.getfullargspec(func).args)
674599

675600
@wraps(func)
676601
def process(*args, **kwargs):
@@ -679,24 +604,28 @@ def process(*args, **kwargs):
679604
all_pipe = all(hasattr(a, '_slicerator_flag') or
680605
isinstance(a, Slicerator) or
681606
isinstance(a, Pipeline) for a in ancestors)
682-
if all_pipe:
683-
def proc_func(*x):
684-
return func(*(x + args), **kwargs)
685607

686-
return Pipeline(proc_func, *ancestors)
608+
if is_class:
609+
return (func(*ancestors, *args, **kwargs)
610+
if all_pipe else
611+
# Fall back on normal behavior of func, interpreting input
612+
# as a single image.
613+
func(*[[a] for a in ancestors], *args, **kwargs)[0])
614+
687615
else:
688-
# Fall back on normal behavior of func, interpreting input
689-
# as a single image.
690-
return func(*(ancestors + args), **kwargs)
616+
return (Pipeline(lambda *x: func(*x, *args, **kwargs), *ancestors)
617+
if all_pipe else
618+
# Fall back on normal behavior of func, interpreting input
619+
# as a single image.
620+
func(*ancestors, *args, **kwargs))
691621

692622
if not retain_doc:
693623
if process.__doc__ is None:
694624
process.__doc__ = ''
695-
process.__doc__ = ("This function has been made lazy. When passed\n"
696-
"a Slicerator, it will return a \n"
697-
"Pipeline of the results. When passed \n"
698-
"any other objects, its behavior is "
699-
"unchanged.\n\n") + process.__doc__
625+
process.__doc__ = (
626+
"This function has been made lazy. When passed a Slicerator, it \n"
627+
"will return a Pipeline of the results. When passed any other \n"
628+
"objects, its behavior is unchanged.\n\n" + process.__doc__)
700629
process.__name__ = func.__name__
701630
return process
702631

0 commit comments

Comments
 (0)