Skip to content

Commit 6790427

Browse files
committed
implement runner-initiated split path
1 parent 5dd1eb4 commit 6790427

File tree

3 files changed

+306
-80
lines changed

3 files changed

+306
-80
lines changed

ray_beam_runner/portability/execution.py

Lines changed: 208 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424
import itertools
2525
import logging
2626
import random
27+
import threading
28+
import time
2729
import typing
28-
from typing import List
30+
from typing import List, MutableMapping
2931
from typing import Mapping
3032
from typing import Optional
3133
from typing import Generator
@@ -57,6 +59,7 @@ def ray_execute_bundle(
5759
transform_buffer_coder: Mapping[str, typing.Tuple[bytes, str]],
5860
expected_outputs: translations.DataOutput,
5961
stage_timers: Mapping[translations.TimerFamilyId, bytes],
62+
split_manager,
6063
instruction_request_repr: Mapping[str, typing.Any],
6164
dry_run=False,
6265
) -> Generator:
@@ -83,8 +86,6 @@ def ray_execute_bundle(
8386
runner_context, instruction_request_repr["process_descriptor_id"]
8487
)
8588

86-
_send_timers(worker_handler, input_bundle, stage_timers, process_bundle_id)
87-
8889
input_data = {
8990
k: _fetch_decode_data(
9091
runner_context,
@@ -95,19 +96,30 @@ def ray_execute_bundle(
9596
for k, objrefs in input_bundle.input_data.items()
9697
}
9798

98-
for transform_id, elements in input_data.items():
99-
data_out = worker_handler.data_conn.output_stream(
100-
process_bundle_id, transform_id
101-
)
102-
for byte_stream in elements:
103-
data_out.write(byte_stream)
104-
data_out.close()
105-
10699
expect_reads: List[typing.Union[str, translations.TimerFamilyId]] = list(
107100
expected_outputs.keys()
108101
)
109102
expect_reads.extend(list(stage_timers.keys()))
110103

104+
split_results = []
105+
split_manager_thread = None
106+
if split_manager:
107+
split_manager_thread = threading.Thread(
108+
target=_run_split_manager,
109+
args=(
110+
runner_context,
111+
worker_handler,
112+
split_manager,
113+
input_data,
114+
transform_buffer_coder,
115+
instruction_request,
116+
split_results,
117+
),
118+
)
119+
split_manager_thread.start()
120+
121+
_send_timers(worker_handler, input_bundle, stage_timers, process_bundle_id)
122+
_send_input_data(worker_handler, input_data, process_bundle_id)
111123
result_future = worker_handler.control_conn.push(instruction_request)
112124

113125
for output in worker_handler.data_conn.input_elements(
@@ -125,6 +137,8 @@ def ray_execute_bundle(
125137
output_buffers[expected_outputs[output.transform_id]].append(output.data)
126138

127139
result: beam_fn_api_pb2.InstructionResponse = result_future.get()
140+
if split_manager_thread:
141+
split_manager_thread.join()
128142

129143
if result.process_bundle.requires_finalization:
130144
finalize_request = beam_fn_api_pb2.InstructionRequest(
@@ -151,14 +165,27 @@ def ray_execute_bundle(
151165
process_bundle_descriptor = runner_context.worker_manager.process_bundle_descriptor(
152166
instruction_request_repr["process_descriptor_id"]
153167
)
154-
delayed_applications = _retrieve_delayed_applications(
168+
169+
deferred_inputs = {}
170+
171+
_add_delayed_applications_to_deferred_inputs(
155172
result,
156173
process_bundle_descriptor,
157174
runner_context,
175+
deferred_inputs,
158176
)
159177

160-
returns.append(len(delayed_applications))
161-
for pcoll, buffer in delayed_applications.items():
178+
_add_residuals_and_channel_splits_to_deferred_inputs(
179+
runner_context,
180+
input_bundle.input_data,
181+
transform_buffer_coder,
182+
process_bundle_descriptor,
183+
split_results,
184+
deferred_inputs,
185+
)
186+
187+
returns.append(len(deferred_inputs))
188+
for pcoll, buffer in deferred_inputs.items():
162189
returns.append(pcoll)
163190
returns.append(buffer)
164191

@@ -206,37 +233,101 @@ def _get_source_transform_name(
206233
raise RuntimeError("No IO transform feeds %s" % transform_id)
207234

208235

209-
def _retrieve_delayed_applications(
236+
def _add_delayed_application_to_deferred_inputs(
237+
process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor,
238+
delayed_application: beam_fn_api_pb2.DelayedBundleApplication,
239+
deferred_inputs: MutableMapping[str, List[bytes]],
240+
):
241+
# TODO(pabloem): Time delay needed for streaming. For now we'll ignore it.
242+
# time_delay = delayed_application.requested_time_delay
243+
source_transform = _get_source_transform_name(
244+
process_bundle_descriptor,
245+
delayed_application.application.transform_id,
246+
delayed_application.application.input_id,
247+
)
248+
249+
if source_transform not in deferred_inputs:
250+
deferred_inputs[source_transform] = []
251+
deferred_inputs[source_transform].append(delayed_application.application.element)
252+
253+
254+
def _add_delayed_applications_to_deferred_inputs(
210255
bundle_result: beam_fn_api_pb2.InstructionResponse,
211256
process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor,
212257
runner_context: "RayRunnerExecutionContext",
258+
deferred_inputs: MutableMapping[str, List[bytes]],
213259
):
214260
"""Extract delayed applications from a bundle run.
215261
216262
A delayed application represents a user-initiated checkpoint, where user code
217263
delays the consumption of a data element to checkpoint the previous elements
218264
in a bundle.
219265
"""
220-
delayed_bundles = {}
221266
for delayed_application in bundle_result.process_bundle.residual_roots:
222-
# TODO(pabloem): Time delay needed for streaming. For now we'll ignore it.
223-
# time_delay = delayed_application.requested_time_delay
224-
source_transform = _get_source_transform_name(
267+
_add_delayed_application_to_deferred_inputs(
225268
process_bundle_descriptor,
226-
delayed_application.application.transform_id,
227-
delayed_application.application.input_id,
269+
delayed_application,
270+
deferred_inputs,
228271
)
229272

230-
if source_transform not in delayed_bundles:
231-
delayed_bundles[source_transform] = []
232-
delayed_bundles[source_transform].append(
233-
delayed_application.application.element
234-
)
235273

236-
for consumer, data in delayed_bundles.items():
237-
delayed_bundles[consumer] = [data]
274+
def _add_residuals_and_channel_splits_to_deferred_inputs(
275+
runner_context: "RayRunnerExecutionContext",
276+
raw_inputs: Mapping[str, List[ray.ObjectRef]],
277+
transform_buffer_coder: Mapping[str, typing.Tuple[bytes, str]],
278+
process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor,
279+
splits: List[beam_fn_api_pb2.ProcessBundleSplitResponse],
280+
deferred_inputs: MutableMapping[str, List[bytes]],
281+
):
282+
prev_split_point = {} # transform id -> first residual offset
283+
for split in splits:
284+
for delayed_application in split.residual_roots:
285+
_add_delayed_application_to_deferred_inputs(
286+
process_bundle_descriptor,
287+
delayed_application,
288+
deferred_inputs,
289+
)
290+
for channel_split in split.channel_splits:
291+
# Decode all input elements
292+
byte_stream = b"".join(
293+
(
294+
element
295+
for block in ray.get(raw_inputs[channel_split.transform_id])
296+
for element in block
297+
)
298+
)
299+
input_coder_id = transform_buffer_coder[channel_split.transform_id][1]
300+
input_coder = runner_context.pipeline_context.coders[input_coder_id]
301+
302+
buffer_id = transform_buffer_coder[channel_split.transform_id][0]
303+
if buffer_id.startswith(b"group:"):
304+
coder_impl = coders.WindowedValueCoder(
305+
coders.TupleCoder(
306+
(
307+
input_coder.wrapped_value_coder._coders[0],
308+
input_coder.wrapped_value_coder._coders[1]._elem_coder,
309+
)
310+
),
311+
input_coder.window_coder,
312+
).get_impl()
313+
else:
314+
coder_impl = input_coder.get_impl()
315+
316+
all_elements = list(coder_impl.decode_all(byte_stream))
317+
318+
# split at first_residual_element index
319+
end = prev_split_point.get(channel_split.transform_id, len(all_elements))
320+
residual_elements = all_elements[channel_split.first_residual_element : end]
321+
prev_split_point[
322+
channel_split.transform_id
323+
] = channel_split.first_residual_element
238324

239-
return delayed_bundles
325+
if residual_elements:
326+
encoded_residual = coder_impl.encode_all(residual_elements)
327+
328+
if channel_split.transform_id not in deferred_inputs:
329+
deferred_inputs[channel_split.transform_id] = []
330+
deferred_inputs[channel_split.transform_id].append(encoded_residual)
240331

241332

242333
def _get_input_id(buffer_id, transform_name):
@@ -316,6 +407,94 @@ def _send_timers(
316407
timer_out.close()
317408

318409

410+
def _send_input_data(
411+
worker_handler: worker_handlers.WorkerHandler,
412+
input_data: Mapping[str, fn_execution.PartitionableBuffer],
413+
process_bundle_id,
414+
):
415+
for transform_id, elements in input_data.items():
416+
data_out = worker_handler.data_conn.output_stream(
417+
process_bundle_id, transform_id
418+
)
419+
for byte_stream in elements:
420+
data_out.write(byte_stream)
421+
data_out.close()
422+
423+
424+
def _run_split_manager(
425+
runner_context: "RayRunnerExecutionContext",
426+
worker_handler: worker_handlers.WorkerHandler,
427+
split_manager,
428+
inputs: Mapping[str, fn_execution.PartitionableBuffer],
429+
transform_buffer_coder: Mapping[str, typing.Tuple[bytes, str]],
430+
instruction_request,
431+
split_results_buf: List[beam_fn_api_pb2.ProcessBundleSplitResponse],
432+
):
433+
read_transform_id, buffer_data = translations.only_element(inputs.items())
434+
byte_stream = b"".join(buffer_data or [])
435+
coder_id = transform_buffer_coder[read_transform_id][1]
436+
coder_impl = runner_context.pipeline_context.coders[coder_id].get_impl()
437+
num_elements = len(list(coder_impl.decode_all(byte_stream)))
438+
439+
# Start the split manager in case it wants to set any breakpoints.
440+
split_manager_generator = split_manager(num_elements)
441+
try:
442+
split_fraction = next(split_manager_generator)
443+
done = False
444+
except StopIteration:
445+
split_fraction = None
446+
done = True
447+
448+
assert worker_handler is not None
449+
450+
# Execute the requested splits.
451+
while not done:
452+
if split_fraction is None:
453+
split_result = None
454+
else:
455+
DesiredSplit = beam_fn_api_pb2.ProcessBundleSplitRequest.DesiredSplit
456+
split_request = beam_fn_api_pb2.InstructionRequest(
457+
process_bundle_split=beam_fn_api_pb2.ProcessBundleSplitRequest(
458+
instruction_id=instruction_request.instruction_id,
459+
desired_splits={
460+
read_transform_id: DesiredSplit(
461+
fraction_of_remainder=split_fraction,
462+
estimated_input_elements=num_elements,
463+
)
464+
},
465+
)
466+
)
467+
split_response = worker_handler.control_conn.push(
468+
split_request
469+
).get() # type: beam_fn_api_pb2.InstructionResponse
470+
for t in (0.05, 0.1, 0.2):
471+
if (
472+
"Unknown process bundle" in split_response.error
473+
or split_response.process_bundle_split
474+
== beam_fn_api_pb2.ProcessBundleSplitResponse()
475+
):
476+
time.sleep(t)
477+
split_response = worker_handler.control_conn.push(
478+
split_request
479+
).get()
480+
if (
481+
"Unknown process bundle" in split_response.error
482+
or split_response.process_bundle_split
483+
== beam_fn_api_pb2.ProcessBundleSplitResponse()
484+
):
485+
# It may have finished too fast.
486+
split_result = None
487+
elif split_response.error:
488+
raise RuntimeError(split_response.error)
489+
else:
490+
split_result = split_response.process_bundle_split
491+
split_results_buf.append(split_result)
492+
try:
493+
split_fraction = split_manager_generator.send(split_result)
494+
except StopIteration:
495+
break
496+
497+
319498
@ray.remote
320499
class _RayRunnerStats:
321500
def __init__(self):

0 commit comments

Comments
 (0)