diff --git a/src/useq/_utils.py b/src/useq/_utils.py index 9a0cc875..2d768a8b 100644 --- a/src/useq/_utils.py +++ b/src/useq/_utils.py @@ -137,7 +137,7 @@ def _estimate_simple_sequence_duration(seq: useq.MDASequence) -> TimeEstimate: phases = tplan.phases if hasattr(tplan, "phases") else [tplan] tot_duration = 0.0 for phase in phases: - phase_duration, exceeded = _time_phase_duration(phase, s_per_timepoint) + phase_duration, exceeded = _time_phase_duration(phase, s_per_timepoint, seq) tot_duration += phase_duration t_interval_exceeded = t_interval_exceeded or exceeded else: @@ -147,7 +147,7 @@ def _estimate_simple_sequence_duration(seq: useq.MDASequence) -> TimeEstimate: def _time_phase_duration( - phase: SinglePhaseTimePlan, s_per_timepoint: float + phase: SinglePhaseTimePlan, s_per_timepoint: float, seq: useq.MDASequence ) -> tuple[float, bool]: """Calculate duration for a single time plan phase.""" time_interval_s = phase.interval.total_seconds() @@ -158,6 +158,28 @@ def _time_phase_duration( # to actually acquire the data time_interval_s = s_per_timepoint + axis = list(seq.axis_order) + # if there are no position and grid axes, then the time interval is not + # exceeded + if Axis.POSITION not in axis and Axis.GRID not in axis: + time_interval_exceeded = False + # if there are both position and grid axes, then the time interval, is not + # exceeded if the position and grid axes are before the time axis + elif Axis.POSITION in axis and Axis.GRID in axis: + if axis.index(Axis.POSITION) < axis.index(Axis.TIME) and axis.index( + Axis.GRID + ) < axis.index(Axis.TIME): + time_interval_exceeded = False + # if there is only one of position or grid axes, then the time interval is + # not exceeded if that axis is before the time axis + elif Axis.POSITION in axis: + if axis.index(Axis.POSITION) < axis.index(Axis.TIME): + time_interval_exceeded = False + elif axis.index(Axis.GRID) < axis.index(Axis.TIME): + time_interval_exceeded = False + + # TODO: add cases with a single pos or a single fov grid + tot_duration = (phase.num_timepoints() - 1) * time_interval_s + s_per_timepoint return tot_duration, time_interval_exceeded diff --git a/tests/test_sequence.py b/tests/test_sequence.py index bd8e180f..f3fd0465 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -429,3 +429,35 @@ def test_z_plan_num_position(): def test_channel_str(): assert MDAEvent(channel="DAPI") == MDAEvent(channel={"config": "DAPI"}) + + +def test_time_interval_exceeded(): + main_seq = MDASequence( + axis_order="tc", + channels=[{"config": "DAPI", "exposure": 100}], + time_plan=TIntervalLoops(loops=10, interval=0), + ) + assert not main_seq.estimate_duration().time_interval_exceeded + + p_seq = main_seq.replace(axis_order="tpc", stage_position=[(1, 2, 3)]) + assert p_seq.estimate_duration().time_interval_exceeded + + p_seq = p_seq.replace(axis_order="ptc") + assert not p_seq.estimate_duration().time_interval_exceeded + + g_seq = main_seq.replace( + axis_order="tgc", grid_plan=GridRelative(rows=1, columns=2) + ) + assert g_seq.estimate_duration().time_interval_exceeded + + g_seq = g_seq.replace(axis_order="gtc") + assert not g_seq.estimate_duration().time_interval_exceeded + + pg_seq = g_seq.replace(axis_order="tpgc", stage_position=[(1, 2, 3)]) + assert pg_seq.estimate_duration().time_interval_exceeded + + pg_seq = pg_seq.replace(axis_order="ptcg") + assert pg_seq.estimate_duration().time_interval_exceeded + + pg_seq = pg_seq.replace(axis_order="pgtc") + assert not pg_seq.estimate_duration().time_interval_exceeded