Skip to content

Commit 5dd1eb4

Browse files
authored
Complete user-initiated SDF functionality (#52)
* correctly set is_drain parameter * enable passing runner tests * Support deferred applications in drain mode * implement bundle finalization
1 parent 3339959 commit 5dd1eb4

File tree

3 files changed

+67
-25
lines changed

3 files changed

+67
-25
lines changed

ray_beam_runner/portability/execution.py

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,17 @@ def ray_execute_bundle(
125125
output_buffers[expected_outputs[output.transform_id]].append(output.data)
126126

127127
result: beam_fn_api_pb2.InstructionResponse = result_future.get()
128+
129+
if result.process_bundle.requires_finalization:
130+
finalize_request = beam_fn_api_pb2.InstructionRequest(
131+
finalize_bundle=beam_fn_api_pb2.FinalizeBundleRequest(
132+
instruction_id=process_bundle_id
133+
)
134+
)
135+
finalize_response = worker_handler.control_conn.push(finalize_request).get()
136+
if finalize_response.error:
137+
raise RuntimeError(finalize_response.error)
138+
128139
returns = [result.SerializeToString()]
129140

130141
returns.append(len(output_buffers))
@@ -155,6 +166,46 @@ def ray_execute_bundle(
155166
yield ret
156167

157168

169+
def _get_source_transform_name(
170+
process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor,
171+
transform_id: str,
172+
input_id: str,
173+
) -> str:
174+
"""Find the name of the source PTransform that feeds into the given
175+
(transform_id, input_id)."""
176+
input_pcoll = process_bundle_descriptor.transforms[transform_id].inputs[input_id]
177+
for ptransform_id, ptransform in process_bundle_descriptor.transforms.items():
178+
# The GrpcRead is directly followed by the SDF/Process.
179+
if (
180+
ptransform.spec.urn == bundle_processor.DATA_INPUT_URN
181+
and input_pcoll in ptransform.outputs.values()
182+
):
183+
return ptransform_id
184+
185+
# The GrpcRead is followed by SDF/Truncate -> SDF/Process.
186+
# We need to traverse the TRUNCATE_SIZED_RESTRICTION node in order
187+
# to find the original source PTransform.
188+
if (
189+
ptransform.spec.urn
190+
== common_urns.sdf_components.TRUNCATE_SIZED_RESTRICTION.urn
191+
and input_pcoll in ptransform.outputs.values()
192+
):
193+
input_pcoll_ = translations.only_element(
194+
process_bundle_descriptor.transforms[ptransform_id].inputs.values()
195+
)
196+
for (
197+
ptransform_id_2,
198+
ptransform_2,
199+
) in process_bundle_descriptor.transforms.items():
200+
if (
201+
ptransform_2.spec.urn == bundle_processor.DATA_INPUT_URN
202+
and input_pcoll_ in ptransform_2.outputs.values()
203+
):
204+
return ptransform_id_2
205+
206+
raise RuntimeError("No IO transform feeds %s" % transform_id)
207+
208+
158209
def _retrieve_delayed_applications(
159210
bundle_result: beam_fn_api_pb2.InstructionResponse,
160211
process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor,
@@ -170,22 +221,15 @@ def _retrieve_delayed_applications(
170221
for delayed_application in bundle_result.process_bundle.residual_roots:
171222
# TODO(pabloem): Time delay needed for streaming. For now we'll ignore it.
172223
# time_delay = delayed_application.requested_time_delay
173-
transform = process_bundle_descriptor.transforms[
174-
delayed_application.application.transform_id
175-
]
176-
pcoll_name = transform.inputs[delayed_application.application.input_id]
177-
178-
consumer_transform = translations.only_element(
179-
[
180-
read_id
181-
for read_id, proto in process_bundle_descriptor.transforms.items()
182-
if proto.spec.urn == bundle_processor.DATA_INPUT_URN
183-
and pcoll_name in proto.outputs.values()
184-
]
224+
source_transform = _get_source_transform_name(
225+
process_bundle_descriptor,
226+
delayed_application.application.transform_id,
227+
delayed_application.application.input_id,
185228
)
186-
if consumer_transform not in delayed_bundles:
187-
delayed_bundles[consumer_transform] = []
188-
delayed_bundles[consumer_transform].append(
229+
230+
if source_transform not in delayed_bundles:
231+
delayed_bundles[source_transform] = []
232+
delayed_bundles[source_transform].append(
189233
delayed_application.application.element
190234
)
191235

ray_beam_runner/portability/ray_fn_runner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,18 +121,21 @@ def _pipeline_checks(
121121
class RayFnApiRunner(runner.PipelineRunner):
122122
def __init__(
123123
self,
124+
is_drain=False,
124125
) -> None:
125126

126127
"""Creates a new Ray Runner instance.
127128
128129
Args:
129130
progress_request_frequency: The frequency (in seconds) that the runner
130131
waits before requesting progress from the SDK.
132+
is_drain: identify whether expand the sdf graph in the drain mode.
131133
"""
132134
super().__init__()
133135
# TODO: figure out if this is necessary (probably, later)
134136
self._progress_frequency = None
135137
self._cache_token_generator = fn_runner.FnApiRunner.get_cache_token_generator()
138+
self._is_drain = is_drain
136139

137140
@staticmethod
138141
def supported_requirements():
@@ -183,7 +186,7 @@ def run_pipeline(
183186
]
184187
),
185188
use_state_iterables=False,
186-
is_drain=False,
189+
is_drain=self._is_drain,
187190
)
188191
return self.execute_pipeline(stage_context, stages)
189192

ray_beam_runner/portability/ray_runner_test.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,9 @@ def tearDown(self) -> None:
102102

103103
def create_pipeline(self, is_drain=False):
104104
return beam.Pipeline(
105-
runner=ray_beam_runner.portability.ray_fn_runner.RayFnApiRunner()
105+
runner=ray_beam_runner.portability.ray_fn_runner.RayFnApiRunner(
106+
is_drain=is_drain
107+
)
106108
)
107109

108110
def test_assert_that(self):
@@ -684,7 +686,6 @@ def process(
684686
actual = p | beam.Create(data) | beam.ParDo(ExpandingStringsDoFn())
685687
assert_that(actual, equal_to(list("".join(data))))
686688

687-
@unittest.skip("SDF not yet supported")
688689
def test_sdf_with_dofn_as_watermark_estimator(self):
689690
class ExpandingStringsDoFn(beam.DoFn, beam.WatermarkEstimatorProvider):
690691
def initial_estimator_state(self, element, restriction):
@@ -758,11 +759,9 @@ def process(
758759
def test_sdf_with_sdf_initiated_checkpointing(self):
759760
self.run_sdf_initiated_checkpointing(is_drain=False)
760761

761-
@unittest.skip("SDF not yet supported")
762762
def test_draining_sdf_with_sdf_initiated_checkpointing(self):
763763
self.run_sdf_initiated_checkpointing(is_drain=True)
764764

765-
@unittest.skip("SDF not yet supported")
766765
def test_sdf_default_truncate_when_bounded(self):
767766
class SimleSDF(beam.DoFn):
768767
def process(
@@ -782,7 +781,6 @@ def process(
782781
actual = p | beam.Create([10]) | beam.ParDo(SimleSDF())
783782
assert_that(actual, equal_to(range(10)))
784783

785-
@unittest.skip("SDF not yet supported")
786784
def test_sdf_default_truncate_when_unbounded(self):
787785
class SimleSDF(beam.DoFn):
788786
def process(
@@ -802,7 +800,6 @@ def process(
802800
actual = p | beam.Create([10]) | beam.ParDo(SimleSDF())
803801
assert_that(actual, equal_to([]))
804802

805-
@unittest.skip("SDF not yet supported")
806803
def test_sdf_with_truncate(self):
807804
class SimleSDF(beam.DoFn):
808805
def process(
@@ -1042,7 +1039,6 @@ def process(self, element, bundle_finalizer=beam.DoFn.BundleFinalizerParam):
10421039
)
10431040
assert_that(res, equal_to(["1", "2"]))
10441041

1045-
@unittest.skip("SDF not yet supported")
10461042
def test_register_finalizations(self):
10471043
event_recorder = EventRecorder(tempfile.gettempdir())
10481044

@@ -1086,7 +1082,6 @@ def process(
10861082

10871083
event_recorder.cleanup()
10881084

1089-
@unittest.skip("Combiners not yet supported")
10901085
def test_sdf_synthetic_source(self):
10911086
common_attrs = {
10921087
"key_size": 1,
@@ -1188,7 +1183,7 @@ def expand(self, pcoll):
11881183
any(re.match(packed_step_name_regex, s) for s in step_names)
11891184
)
11901185

1191-
@unittest.skip("Combiners not yet supported")
1186+
@unittest.skip("Metrics not yet supported")
11921187
def test_pack_combiners(self):
11931188
self._test_pack_combiners(assert_using_counter_names=True)
11941189

0 commit comments

Comments
 (0)