diff --git a/ray_beam_runner/portability/execution.py b/ray_beam_runner/portability/execution.py index 4db8940..311b08f 100644 --- a/ray_beam_runner/portability/execution.py +++ b/ray_beam_runner/portability/execution.py @@ -46,10 +46,15 @@ from apache_beam.runners.worker import bundle_processor from ray_beam_runner.portability.state import RayStateManager +from ray_beam_runner.portability.translations import StageTags _LOGGER = logging.getLogger(__name__) +# TODO(pabloem): Stop hardcoding the number of blocks per task +BLOCKS_PER_TASK = 10 + + @ray.remote(num_returns="dynamic") def ray_execute_bundle( runner_context: "RayRunnerExecutionContext", @@ -59,11 +64,22 @@ def ray_execute_bundle( stage_timers: Mapping[translations.TimerFamilyId, bytes], instruction_request_repr: Mapping[str, typing.Any], dry_run=False, + stage_tags=None, ) -> Generator: - # generator returns: - # (serialized InstructionResponse, ouputs, - # repeat of pcoll, data, - # delayed applications, repeat of pcoll, data) + """Execute a Beam bundle as a ray task. + + :returns A `Generator` with the following values: + - serialized InstructionResponse, + - dictionary of timers + - dictionary of delayed applications + - count of output pcollections, + - repeat of + - pcoll name + - pcoll block count + - repeat of pcoll block + """ + + stage_tags = stage_tags or set() instruction_request = beam_fn_api_pb2.InstructionRequest( instruction_id=instruction_request_repr["instruction_id"], @@ -74,9 +90,19 @@ def ray_execute_bundle( cache_tokens=[instruction_request_repr["cache_token"]], ), ) + + # TODO(pabloem): CHECK THIS TO MAKE SURE IS GOOD + expects_group = any(k.startswith('group') for k in expected_outputs.keys()) + output_buffers: Mapping[ - typing.Union[str, translations.TimerFamilyId], list - ] = collections.defaultdict(list) + str, typing.Union[KeyBlockBasedDataBuffer, RandomBlockBasedDataBuffer] + ] = collections.defaultdict( + KeyBlockBasedDataBuffer if expects_group else RandomBlockBasedDataBuffer + ) + + output_timer_buffers: Mapping[ + translations.TimerFamilyId, list] = collections.defaultdict(list) + process_bundle_id = instruction_request.instruction_id worker_handler = _get_worker_handler( @@ -99,26 +125,26 @@ def ray_execute_bundle( data_out = worker_handler.data_conn.output_stream( process_bundle_id, transform_id ) - for byte_stream in elements: - data_out.write(byte_stream) - data_out.close() - - expect_reads: List[typing.Union[str, translations.TimerFamilyId]] = list( - expected_outputs.keys() - ) - expect_reads.extend(list(stage_timers.keys())) + try: + for byte_stream in elements: + data_out.write(byte_stream) + data_out.close() + except: + # raise + import ray + ray.util.pdb.set_trace() result_future = worker_handler.control_conn.push(instruction_request) for output in worker_handler.data_conn.input_elements( process_bundle_id, - expect_reads, + list(stage_timers.keys()) + list(expected_outputs.keys()), abort_callback=lambda: ( result_future.is_done() and bool(result_future.get().error) ), ): if isinstance(output, beam_fn_api_pb2.Elements.Timers) and not dry_run: - output_buffers[ + output_timer_buffers[ stage_timers[(output.transform_id, output.timer_family_id)] ].append(output.timers) if isinstance(output, beam_fn_api_pb2.Elements.Data) and not dry_run: @@ -138,10 +164,8 @@ def ray_execute_bundle( returns = [result.SerializeToString()] - returns.append(len(output_buffers)) - for pcoll, buffer in output_buffers.items(): - returns.append(pcoll) - returns.append(buffer) + # We pass output timers as a single full object, as these are smaller data + returns.append(output_timer_buffers) # Now we collect all the deferred inputs remaining from bundle execution. # Deferred inputs can be: @@ -157,15 +181,46 @@ def ray_execute_bundle( runner_context, ) - returns.append(len(delayed_applications)) - for pcoll, buffer in delayed_applications.items(): + # We pass delayed applications as a single full object, as these are smaller data + returns.append(delayed_applications) + + returns.append(len(output_buffers)) + for pcoll, buffer in output_buffers.items(): returns.append(pcoll) - returns.append(buffer) + returns.append(buffer.num_blocks()) + for i in range(buffer.num_blocks()): + returns.append(buffer.blocks[i]) for ret in returns: yield ret +class RandomBlockBasedDataBuffer: + def __init__(self): + self._num_blocks = BLOCKS_PER_TASK + self.blocks = [[] for _ in range(self._num_blocks)] + self._total_data = 0 + + def num_blocks(self): + return min(self._total_data, self._num_blocks) + + def append(self, data): + self.blocks[self._total_data % len(self.blocks)].append(data) + self._total_data +=1 + + +class KeyBlockBasedDataBuffer: + def __init__(self): + self._num_blocks = 1 + self.blocks = [[] for _ in range(self._num_blocks)] + + def num_blocks(self): + return 1 + + def append(self, data): + # TODO: Figure out how to get the Key for the data. + self.blocks[0].append(data) + def _get_source_transform_name( process_bundle_descriptor: beam_fn_api_pb2.ProcessBundleDescriptor, transform_id: str, @@ -263,6 +318,7 @@ def _fetch_decode_data( data_references: List[ray.ObjectRef], ): """Fetch a PCollection's data and decode it.""" + logging.warning("pabloem - Buffer is %s" % buffer_id) if buffer_id.startswith(b"group"): _, pcoll_id = translations.split_buffer_id(buffer_id) transform = runner_context.pipeline_components.transforms[pcoll_id] @@ -288,15 +344,10 @@ def _fetch_decode_data( windowing=apache_beam.Windowing.from_runner_api(windowing_strategy, None), ) else: - buffer = fn_execution.ListBuffer( - coder_impl=runner_context.pipeline_context.coders[coder_id].get_impl() - ) + buffer = [] for block in ray.get(data_references): - # TODO(pabloem): Stop using ListBuffer, and use different - # buffers to pass data to Beam. - for elm in block: - buffer.append(elm) + buffer.extend(block) return buffer diff --git a/ray_beam_runner/portability/ray_fn_runner.py b/ray_beam_runner/portability/ray_fn_runner.py index d5694d1..18df309 100644 --- a/ray_beam_runner/portability/ray_fn_runner.py +++ b/ray_beam_runner/portability/ray_fn_runner.py @@ -19,6 +19,7 @@ # pytype: skip-file # mypy: check-untyped-defs import collections +import concurrent.futures import copy import logging import typing @@ -48,6 +49,7 @@ import ray from ray_beam_runner.portability.context_management import RayBundleContextManager from ray_beam_runner.portability.execution import Bundle, _get_input_id +from ray_beam_runner.portability import translations as ray_translations from ray_beam_runner.portability.execution import ( ray_execute_bundle, merge_stage_results, @@ -170,7 +172,8 @@ def run_pipeline( translations.pack_combiners, translations.lift_combiners, translations.expand_sdf, - translations.expand_gbk, + ray_translations.expand_gbk, + ray_translations.expand_reshuffle, translations.sink_flattens, translations.greedily_fuse, translations.read_to_impulse, @@ -183,6 +186,7 @@ def run_pipeline( [ common_urns.primitives.FLATTEN.urn, common_urns.primitives.GROUP_BY_KEY.urn, + common_urns.composites.RESHUFFLE.urn, ] ), use_state_iterables=False, @@ -248,6 +252,7 @@ def _run_stage( final_result = None # type: Optional[beam_fn_api_pb2.InstructionResponse] + logging.warning("Executing stage %s", bundle_context_manager.stage.name) while True: ( last_result, @@ -315,57 +320,82 @@ def _run_bundle( process_bundle_id = "bundle_%s" % process_bundle_descriptor.id pbd_id = process_bundle_descriptor.id - result_generator_ref = ray_execute_bundle.remote( - runner_execution_context, - input_bundle, - transform_to_buffer_coder, - data_output, - stage_timers, - instruction_request_repr={ - "instruction_id": process_bundle_id, - "process_descriptor_id": pbd_id, - "cache_token": next(cache_token_generator), - }, - ) - result_generator = iter(ray.get(result_generator_ref)) - result = beam_fn_api_pb2.InstructionResponse.FromString( - ray.get(next(result_generator)) - ) - output = [] - num_outputs = ray.get(next(result_generator)) - for _ in range(num_outputs): - pcoll = ray.get(next(result_generator)) - data_ref = next(result_generator) - output.append(pcoll) - runner_execution_context.pcollection_buffers.put(pcoll, [data_ref]) - - delayed_applications = {} - num_delayed_applications = ray.get(next(result_generator)) - for _ in range(num_delayed_applications): - pcoll = ray.get(next(result_generator)) - data_ref = next(result_generator) - delayed_applications[pcoll] = data_ref - runner_execution_context.pcollection_buffers.put(pcoll, [data_ref]) + input_data = input_bundle.input_data + result_generator_futures = [] + if len(input_data) > 1: + raise RuntimeError( + "pabloem - Stage has multiple main input PCollections " + "which is unusual: %s" + % bundle_context_manager.stage.name) + + input_id, obj_refs = list(input_data.items())[0] + logging.warning("pabloem - Running stage in PARALLEL AS WE HOPED - %d blocks", len(obj_refs)) + # TODO(pabloem): This is an awful hack. HOW DO WE FREAKIN KEEP KEYED DATA TOGETHER?! + # TODO(pableom): DO GROUPING PER KEY. + if 'GroupByKey/Read' in input_id: + obj_refs = [obj_refs] + for i, obj_ref in enumerate(obj_refs): + result_generator_futures.append(ray_execute_bundle.remote( + runner_execution_context, + Bundle(input_timers=input_bundle.input_timers if i == 0 else {}, + input_data={input_id: [obj_ref] if not isinstance(obj_ref, list) else obj_ref}), + transform_to_buffer_coder, + data_output, + stage_timers, + instruction_request_repr={ + "instruction_id": process_bundle_id, + "process_descriptor_id": pbd_id, + "cache_token": next(cache_token_generator), + }, + stage_tags=getattr(bundle_context_manager.stage, "tags", None) + )) + + final_result = None + final_output = set() + while True: + ready_results, result_generator_futures = ray.wait(result_generator_futures) + for ready_res in ready_results: + new_result, new_output, new_delayed_applications = self._fetch_execution_output(runner_execution_context, ready_res) + final_result = merge_stage_results(final_result, new_result) if final_result else new_result + final_output = final_output.union(new_output) + if not result_generator_futures: + break ( watermarks_by_transform_and_timer_family, newly_set_timers, ) = self._collect_written_timers(bundle_context_manager) - # TODO(pabloem): Add support for splitting of results. + # TODO: Set delayed applications somehow + return final_result, newly_set_timers, new_delayed_applications, final_output + + def _fetch_execution_output(self, runner_execution_context: RayRunnerExecutionContext, result_generator_ref): + result_generator = iter(ray.get(result_generator_ref)) + response_str = ray.get(next(result_generator)) + result = beam_fn_api_pb2.InstructionResponse.FromString( + response_str + ) + + output_timers = ray.get(next(result_generator)) + delayed_applications = ray.get(next(result_generator)) + + for timer_id, timer_data in output_timers.items(): + runner_execution_context.pcollection_buffers.put(timer_id, timer_data) + for pcoll, data_ref in delayed_applications.items(): + runner_execution_context.pcollection_buffers.put(pcoll, [data_ref]) - # After collecting deferred inputs, we 'pad' the structure with empty - # buffers for other expected inputs. - # if deferred_inputs or newly_set_timers: - # # The worker will be waiting on these inputs as well. - # for other_input in data_input: - # if other_input not in deferred_inputs: - # deferred_inputs[other_input] = ListBuffer( - # coder_impl=bundle_context_manager.get_input_coder_impl( - # other_input)) + output = [] + num_outputs = ray.get(next(result_generator)) + for _1 in range(num_outputs): + pcoll = ray.get(next(result_generator)) + output.append(pcoll) + blocks_per_pcoll = ray.get(next(result_generator)) + for _2 in range(blocks_per_pcoll): + data_ref = next(result_generator) + runner_execution_context.pcollection_buffers.put(pcoll, [data_ref]) - return result, newly_set_timers, delayed_applications, output + return result, output, delayed_applications @staticmethod def _collect_written_timers( diff --git a/ray_beam_runner/portability/translations.py b/ray_beam_runner/portability/translations.py new file mode 100644 index 0000000..e51c266 --- /dev/null +++ b/ray_beam_runner/portability/translations.py @@ -0,0 +1,116 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Set of utilities that perform static analysis for a Beam graph ahead of execution.""" +import logging +import typing + +from apache_beam.portability import common_urns +from apache_beam.portability.api import beam_runner_api_pb2 +from apache_beam.runners.portability.fn_api_runner import translations as beam_translations +from apache_beam.runners.portability.fn_api_runner.translations import Stage +from apache_beam.runners.portability.fn_api_runner.translations import TransformContext +from apache_beam.runners.worker import bundle_processor + + +def expand_reshuffle(stages: typing.Iterable[Stage], pipeline_context: TransformContext) -> typing.Iterator[Stage]: + for s in stages: + t = beam_translations.only_transform(s.transforms) + if t.spec.urn == common_urns.composites.RESHUFFLE.urn: + reshuffle_buffer = beam_translations.create_buffer_id(s.name) + reshuffle_write = Stage( + t.unique_name + '/Write', + [ + beam_runner_api_pb2.PTransform( + unique_name=t.unique_name + '/Write', + inputs=t.inputs, + spec=beam_runner_api_pb2.FunctionSpec( + urn=bundle_processor.DATA_OUTPUT_URN, + payload=reshuffle_buffer)) + ], + downstream_side_inputs=frozenset(), + must_follow=s.must_follow, + ) + yield reshuffle_write + + yield Stage( + t.unique_name + '/Read', + [ + beam_runner_api_pb2.PTransform( + unique_name=t.unique_name + '/Read', + outputs=t.outputs, + spec=beam_runner_api_pb2.FunctionSpec( + urn=bundle_processor.DATA_INPUT_URN, + payload=reshuffle_buffer)) + ], + downstream_side_inputs=s.downstream_side_inputs, + must_follow=beam_translations.union(frozenset([reshuffle_write]), s.must_follow)) + else: + yield s + + +def expand_gbk(stages, pipeline_context): + # type: (Iterable[Stage], TransformContext) -> Iterator[Stage] + + """Transforms each GBK into a write followed by a read.""" + for stage in stages: + transform = beam_translations.only_transform(stage.transforms) + if transform.spec.urn == common_urns.primitives.GROUP_BY_KEY.urn: + for pcoll_id in transform.inputs.values(): + pipeline_context.length_prefix_pcoll_coders(pcoll_id) + for pcoll_id in transform.outputs.values(): + if pipeline_context.use_state_iterables: + pipeline_context.components.pcollections[ + pcoll_id].coder_id = pipeline_context.with_state_iterables( + pipeline_context.components.pcollections[pcoll_id].coder_id) + pipeline_context.length_prefix_pcoll_coders(pcoll_id) + + # This is used later to correlate the read and write. + transform_id = stage.name + if transform != pipeline_context.components.transforms.get(transform_id): + transform_id = beam_translations.unique_name( + pipeline_context.components.transforms, stage.name) + pipeline_context.components.transforms[transform_id].CopyFrom(transform) + gbk_buffer = beam_translations.create_buffer_id(transform_id, kind='group') + gbk_write = Stage( + transform.unique_name + '/Write', + [ + beam_runner_api_pb2.PTransform( + unique_name=transform.unique_name + '/Write', + inputs=transform.inputs, + spec=beam_runner_api_pb2.FunctionSpec( + urn=bundle_processor.DATA_OUTPUT_URN, + payload=gbk_buffer)) + ], + downstream_side_inputs=frozenset(), + must_follow=stage.must_follow) + yield gbk_write + + yield Stage( + transform.unique_name + '/Read', + [ + beam_runner_api_pb2.PTransform( + unique_name=transform.unique_name + '/Read', + outputs=transform.outputs, + spec=beam_runner_api_pb2.FunctionSpec( + urn=bundle_processor.DATA_INPUT_URN, + payload=gbk_buffer)) + ], + downstream_side_inputs=stage.downstream_side_inputs, + must_follow=beam_translations.union(frozenset([gbk_write]), stage.must_follow)) + else: + yield stage