Skip to content
144 changes: 144 additions & 0 deletions example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import abc
import sys
from collections.abc import Iterable, Iterator, Sequence
from dataclasses import dataclass
from itertools import count, islice, product
from typing import TypeVar, cast

from useq._mda_event import MDAEvent

T = TypeVar("T")


class AxisIterator(Iterable[T]):
INFINITE = -1

@property
@abc.abstractmethod
def axis_key(self) -> str:
"""A string id representing the axis."""

def __iter__(self) -> Iterator[T]:
"""Iterate over the axis."""

def length(self) -> int:
"""Return the number of axis values.

If the axis is infinite, return -1.
"""
return self.INFINITE

@abc.abstractmethod
def create_event_kwargs(cls, val: T) -> dict: ...

def should_skip(cls, kwargs: dict) -> bool:
return False


class TimePlan(AxisIterator[float]):
def __init__(self, tpoints: Sequence[float]) -> None:
self._tpoints = tpoints

axis_key = "t"

def __iter__(self) -> Iterator[float]:
yield from self._tpoints

def length(self) -> int:
return len(self._tpoints)

def create_event_kwargs(cls, val: float) -> dict:
return {"min_start_time": val}


class ZPlan(AxisIterator[int]):
def __init__(self, stop: int | None = None) -> None:
self._stop = stop
self.acquire_every = 2

axis_key = "z"

def __iter__(self) -> Iterator[int]:
if self._stop is not None:
return iter(range(self._stop))
return count()

def length(self) -> int:
return self._stop or self.INFINITE

def create_event_kwargs(cls, val: int) -> dict:
return {"z_pos": val}

def should_skip(self, event: dict) -> bool:
index = event["index"]
if "t" in index and index["t"] % self.acquire_every:
return True
return False


@dataclass
class MySequence:
axes: tuple[AxisIterator, ...]
order: tuple[str, ...]
chunk_size = 1000

@property
def is_infinite(self) -> bool:
"""Return `True` if the sequence is infinite."""
return any(ax.length() == ax.INFINITE for ax in self.axes)

def _enumerate_ax(
self, key: str, ax: Iterable[T], start: int = 0
) -> Iterable[tuple[str, int, T]]:
"""Return the key for an enumerated axis."""
for idx, val in enumerate(ax, start):
yield key, idx, val

def __iter__(self) -> MDAEvent:
ax_map: dict[str, type[AxisIterator]] = {ax.axis_key: ax for ax in self.axes}
for item in self._iter_inner():
event: dict = {"index": {}}
for axis_key, index, value in item:
ax_type = ax_map[axis_key]
event["index"][axis_key] = index
event.update(ax_type.create_event_kwargs(value))

if not any(ax_type.should_skip(event) for ax_type in ax_map.values()):
yield MDAEvent(**event)

def _iter_inner(self) -> Iterator[tuple[str, int, T]]:
"""Iterate over the sequence."""
ax_map = {ax.axis_key: ax for ax in self.axes}
sorted_axes = [ax_map[key] for key in self.order]
if not self.is_infinite:
iterators = (self._enumerate_ax(ax.axis_key, ax) for ax in sorted_axes)
yield from product(*iterators)
else:
idx = 0
while True:
yield from self._iter_infinite_slice(sorted_axes, idx, self.chunk_size)
idx += self.chunk_size

def _iter_infinite_slice(
self, sorted_axes: list[AxisIterator], start: int, chunk_size: int
) -> Iterator[tuple[str, T]]:
"""Iterate over a slice of an infinite sequence."""
iterators = []
for ax in sorted_axes:
if ax.length() is not ax.INFINITE:
iterator, begin = cast("Iterable", ax), 0
else:
# use islice to avoid calling product with infinite iterators
iterator, begin = islice(ax, start, start + chunk_size), start
iterators.append(self._enumerate_ax(ax.axis_key, iterator, begin))

return product(*iterators)


if __name__ == "__main__":
seq = MySequence(axes=(TimePlan((0, 1, 2, 3, 4)), ZPlan(3)), order=("t", "z"))
if seq.is_infinite:
print("Infinite sequence")
sys.exit(0)
for event in seq:
print(event)
77 changes: 77 additions & 0 deletions src/useq/_axis_iterable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from __future__ import annotations

from collections.abc import Iterator, Sized
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
NamedTuple,
Protocol,
TypeVar,
runtime_checkable,
)

from pydantic import BaseModel

if TYPE_CHECKING:
from useq._iter_sequence import MDAEventDict


# ------ Protocol that can be used as a field annotation in a Pydantic model ------

T = TypeVar("T")


class IterItem(NamedTuple):
"""An item in an iteration sequence."""

axis_key: str
axis_index: int
value: Any
axis_iterable: AxisIterable


@runtime_checkable
class AxisIterable(Protocol[T]):
@property
def axis_key(self) -> str:
"""A string id representing the axis. Prefer lowercase."""

def __iter__(self) -> Iterator[T]:
"""Iterate over the axis."""

def create_event_kwargs(self, val: T) -> MDAEventDict:
"""Convert a value from the iterator to kwargs for an MDAEvent."""

def length(self) -> int:
"""Return the number of axis values.

If the axis is infinite, return -1.
"""

def should_skip(self, kwargs: dict[str, IterItem]) -> bool:
"""Return True if the event should be skipped."""
return False


# ------- concrete base class/mixin that implements the above protocol -------


class AxisIterableBase(BaseModel):
axis_key: ClassVar[str]

def create_event_kwargs(self, val: T) -> MDAEventDict:
"""Convert a value from the iterator to kwargs for an MDAEvent."""
raise NotImplementedError

def length(self) -> int:
"""Return the number of axis values.

If the axis is infinite, return -1.
"""
if isinstance(self, Sized):
return len(self)
raise NotImplementedError

def should_skip(self, kwargs: dict[str, IterItem]) -> bool:
return False
62 changes: 60 additions & 2 deletions src/useq/_channel.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from typing import Optional
from typing import TYPE_CHECKING, Any, ClassVar, Optional, cast

from pydantic import Field
from pydantic import Field, RootModel, model_validator

from useq._axis_iterable import AxisIterableBase, IterItem
from useq._base_model import FrozenModel
from useq._utils import Axis

if TYPE_CHECKING:
from useq._iter_sequence import MDAEventDict
from useq._z import ZPlan


class Channel(FrozenModel):
Expand Down Expand Up @@ -38,3 +44,55 @@ class Channel(FrozenModel):
z_offset: float = 0.0
acquire_every: int = Field(default=1, gt=0) # acquire every n frames
camera: Optional[str] = None

@model_validator(mode="before")
def _validate_model(cls, value: Any) -> Any:
if isinstance(value, str):
return {"config": value}
return value


class Channels(RootModel, AxisIterableBase):
root: tuple[Channel, ...]
axis_key: ClassVar[str] = "c"

def __iter__(self):
return iter(self.root)

def __getitem__(self, item):
return self.root[item]

def create_event_kwargs(self, val: Channel) -> "MDAEventDict":
"""Convert a value from the iterator to kwargs for an MDAEvent."""
from useq._mda_event import Channel

d: MDAEventDict = {"channel": Channel(config=val.config, group=val.group)}
if val.z_offset:
d["z_pos_rel"] = val.z_offset
return d

def length(self) -> int:
"""Return the number of axis values.

If the axis is infinite, return -1.
"""
return len(self.root)

def should_skip(self, kwargs: dict[str, IterItem]) -> bool:
if Axis.CHANNEL not in kwargs:
return False
channel = cast("Channel", kwargs[Axis.CHANNEL].value)

if Axis.TIME in kwargs:
if kwargs[Axis.TIME].axis_index % channel.acquire_every:
return True

# only acquire on the middle plane:
if not channel.do_stack:
if Axis.Z in kwargs:
z_plan = cast("ZPlan", kwargs[Axis.Z].axis_iterable)
z_index = kwargs[Axis.Z].axis_index
if z_index != z_plan.num_positions() // 2:
return True

return False
6 changes: 6 additions & 0 deletions src/useq/_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

if TYPE_CHECKING:
from matplotlib.axes import Axes
from typing_extensions import Self, TypeAlias

PointGenerator: TypeAlias = Callable[
[np.random.RandomState, int, float, float], Iterable[tuple[float, float]]
Expand Down Expand Up @@ -76,6 +77,11 @@ class _GridPlan(_MultiPointPlan[PositionT]):
Engines MAY override this even if provided.
"""

@property
def axis_key(self) -> str:
"""A string id representing the axis. Prefer lowercase."""
return "g"

overlap: tuple[float, float] = Field((0.0, 0.0), frozen=True)
mode: OrderMode = Field(OrderMode.row_wise_snake, frozen=True)

Expand Down
34 changes: 10 additions & 24 deletions src/useq/_iter_sequence.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from functools import cache
from itertools import product
from typing import TYPE_CHECKING, Any, cast

Expand All @@ -16,7 +15,7 @@
from collections.abc import Iterator

from useq._mda_sequence import MDASequence
from useq._position import Position, PositionBase, RelativePosition
from useq._position import Position, RelativePosition


class MDAEventDict(TypedDict, total=False):
Expand All @@ -26,8 +25,11 @@ class MDAEventDict(TypedDict, total=False):
min_start_time: float | None
pos_name: str | None
x_pos: float | None
x_pos_rel: float | None
y_pos: float | None
y_pos_rel: float | None
z_pos: float | None
z_pos_rel: float | None
sequence: MDASequence | None
# properties: list[tuple] | None
metadata: dict
Expand All @@ -40,21 +42,6 @@ class PositionDict(TypedDict, total=False):
z_pos: float


@cache
def _iter_axis(seq: MDASequence, ax: str) -> tuple[Channel | float | PositionBase, ...]:
return tuple(seq.iter_axis(ax))


@cache
def _sizes(seq: MDASequence) -> dict[str, int]:
return {k: len(list(_iter_axis(seq, k))) for k in seq.axis_order}


@cache
def _used_axes(seq: MDASequence) -> str:
return "".join(k for k in seq.axis_order if _sizes(seq)[k])


def iter_sequence(sequence: MDASequence) -> Iterator[MDAEvent]:
"""Iterate over all events in the MDA sequence.'.

Expand Down Expand Up @@ -144,9 +131,8 @@ def _iter_sequence(
MDAEvent
Each event in the MDA sequence.
"""
order = _used_axes(sequence)
# this needs to be tuple(...) to work for mypyc
axis_iterators = tuple(enumerate(_iter_axis(sequence, ax)) for ax in order)
order = sequence.used_axes
axis_iterators = (enumerate(sequence.iter_axis(ax)) for ax in order)
for item in product(*axis_iterators):
if not item: # the case with no events
continue # pragma: no cover
Expand Down Expand Up @@ -267,11 +253,11 @@ def _position_offsets(
def _parse_axes(
event: zip[tuple[str, Any]],
) -> tuple[
dict[str, int],
dict[str, int], # index
float | None, # time
Position | None,
RelativePosition | None,
Channel | None,
Position | None, # position
RelativePosition | None, # grid
Channel | None, # channel
float | None, # z
]:
"""Parse an individual event from the product of axis iterators.
Expand Down
Loading
Loading