From 2ab6bdb778f306b5703916eefcc8aa63c1859d6c Mon Sep 17 00:00:00 2001 From: fdrgsp Date: Tue, 3 Oct 2023 16:40:18 -0400 Subject: [PATCH 1/4] fix: time_interval_exceeded --- src/useq/_utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/useq/_utils.py b/src/useq/_utils.py index 8b98701f..6467e197 100644 --- a/src/useq/_utils.py +++ b/src/useq/_utils.py @@ -136,7 +136,9 @@ def _estimate_simple_sequence_duration(seq: useq.MDASequence) -> TimeEstimate: phases = tplan.phases if hasattr(tplan, "phases") else [tplan] tot_duration = 0.0 for phase in phases: - phase_duration, exceeded = _time_phase_duration(phase, s_per_timepoint) + phase_duration, exceeded = _time_phase_duration( + phase, s_per_timepoint, seq.axis_order + ) tot_duration += phase_duration t_interval_exceeded = t_interval_exceeded or exceeded else: @@ -146,7 +148,7 @@ def _estimate_simple_sequence_duration(seq: useq.MDASequence) -> TimeEstimate: def _time_phase_duration( - phase: SinglePhaseTimePlan, s_per_timepoint: float + phase: SinglePhaseTimePlan, s_per_timepoint: float, axis_order: tuple[str, ...] ) -> tuple[float, bool]: """Calculate duration for a single time plan phase.""" time_interval_s = phase.interval.total_seconds() @@ -157,6 +159,10 @@ def _time_phase_duration( # to actually acquire the data time_interval_s = s_per_timepoint + # if p axes is before t axes, then the time interval is not exceeded + if list(axis_order).index("p") < list(axis_order).index("t"): + time_interval_exceeded = False + tot_duration = (phase.num_timepoints() - 1) * time_interval_s + s_per_timepoint return tot_duration, time_interval_exceeded From fc89ab02bd8394e5211f9ba415e9c2a4c838e210 Mon Sep 17 00:00:00 2001 From: fdrgsp Date: Tue, 3 Oct 2023 17:03:41 -0400 Subject: [PATCH 2/4] fix: update _time_phase_duration --- src/useq/_utils.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/useq/_utils.py b/src/useq/_utils.py index 6467e197..a909f3ee 100644 --- a/src/useq/_utils.py +++ b/src/useq/_utils.py @@ -159,8 +159,24 @@ def _time_phase_duration( # to actually acquire the data time_interval_s = s_per_timepoint - # if p axes is before t axes, then the time interval is not exceeded - if list(axis_order).index("p") < list(axis_order).index("t"): + axis = list(axis_order) + # if there are no position and grid axes, then the time interval is not + # exceeded + if Axis.POSITION not in axis and Axis.GRID not in axis: + time_interval_exceeded = False + # if there are both position and grid axes, then the time interval, is not + # exceeded if the position and grid axes are before the time axis + elif Axis.POSITION in axis and Axis.GRID in axis: + if axis.index(Axis.POSITION) < axis.index(Axis.TIME) and axis.index( + Axis.GRID + ) < axis.index(Axis.TIME): + time_interval_exceeded = False + # if there is only one of position or grid axes, then the time interval is + # not exceeded if that axis is before the time axis + elif Axis.POSITION in axis: + if axis.index(Axis.POSITION) < axis.index(Axis.TIME): + time_interval_exceeded = False + elif axis.index(Axis.GRID) < axis.index(Axis.TIME): time_interval_exceeded = False tot_duration = (phase.num_timepoints() - 1) * time_interval_s + s_per_timepoint From 3eddccfe9c00b4b09624bef1e01e15112fe9cf9c Mon Sep 17 00:00:00 2001 From: fdrgsp Date: Tue, 3 Oct 2023 17:22:32 -0400 Subject: [PATCH 3/4] test: add test --- tests/test_sequence.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 04ebd3b8..49b68866 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -439,3 +439,35 @@ def test_z_plan_num_position(): def test_channel_str(): assert MDAEvent(channel="DAPI") == MDAEvent(channel={"config": "DAPI"}) + + +def test_time_interval_exceeded(): + main_seq = MDASequence( + axis_order="tc", + channels=[{"config": "DAPI", "exposure": 100}], + time_plan=TIntervalLoops(loops=10, interval=0), + ) + assert not main_seq.estimate_duration().time_interval_exceeded + + p_seq = main_seq.replace(axis_order="tpc", stage_position=[(1, 2, 3)]) + assert p_seq.estimate_duration().time_interval_exceeded + + p_seq = p_seq.replace(axis_order="ptc") + assert not p_seq.estimate_duration().time_interval_exceeded + + g_seq = main_seq.replace( + axis_order="tgc", grid_plan=GridRelative(rows=1, columns=2) + ) + assert g_seq.estimate_duration().time_interval_exceeded + + g_seq = g_seq.replace(axis_order="gtc") + assert not g_seq.estimate_duration().time_interval_exceeded + + pg_seq = g_seq.replace(axis_order="tpgc", stage_position=[(1, 2, 3)]) + assert pg_seq.estimate_duration().time_interval_exceeded + + pg_seq = pg_seq.replace(axis_order="ptcg") + assert pg_seq.estimate_duration().time_interval_exceeded + + pg_seq = pg_seq.replace(axis_order="pgtc") + assert not pg_seq.estimate_duration().time_interval_exceeded From 1127cf63685170f3cdb4fb5d92d7565af7956a75 Mon Sep 17 00:00:00 2001 From: fdrgsp Date: Tue, 3 Oct 2023 20:17:21 -0400 Subject: [PATCH 4/4] fix: TODO --- src/useq/_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/useq/_utils.py b/src/useq/_utils.py index a909f3ee..2b8d2964 100644 --- a/src/useq/_utils.py +++ b/src/useq/_utils.py @@ -136,9 +136,7 @@ def _estimate_simple_sequence_duration(seq: useq.MDASequence) -> TimeEstimate: phases = tplan.phases if hasattr(tplan, "phases") else [tplan] tot_duration = 0.0 for phase in phases: - phase_duration, exceeded = _time_phase_duration( - phase, s_per_timepoint, seq.axis_order - ) + phase_duration, exceeded = _time_phase_duration(phase, s_per_timepoint, seq) tot_duration += phase_duration t_interval_exceeded = t_interval_exceeded or exceeded else: @@ -148,7 +146,7 @@ def _estimate_simple_sequence_duration(seq: useq.MDASequence) -> TimeEstimate: def _time_phase_duration( - phase: SinglePhaseTimePlan, s_per_timepoint: float, axis_order: tuple[str, ...] + phase: SinglePhaseTimePlan, s_per_timepoint: float, seq: useq.MDASequence ) -> tuple[float, bool]: """Calculate duration for a single time plan phase.""" time_interval_s = phase.interval.total_seconds() @@ -159,7 +157,7 @@ def _time_phase_duration( # to actually acquire the data time_interval_s = s_per_timepoint - axis = list(axis_order) + axis = list(seq.axis_order) # if there are no position and grid axes, then the time interval is not # exceeded if Axis.POSITION not in axis and Axis.GRID not in axis: @@ -179,6 +177,8 @@ def _time_phase_duration( elif axis.index(Axis.GRID) < axis.index(Axis.TIME): time_interval_exceeded = False + # TODO: add cases with a single pos or a single fov grid + tot_duration = (phase.num_timepoints() - 1) * time_interval_s + s_per_timepoint return tot_duration, time_interval_exceeded