2424import  itertools 
2525import  logging 
2626import  random 
27+ import  threading 
28+ import  time 
2729import  typing 
28- from  typing  import  List 
30+ from  typing  import  List ,  MutableMapping 
2931from  typing  import  Mapping 
3032from  typing  import  Optional 
3133from  typing  import  Generator 
@@ -57,6 +59,7 @@ def ray_execute_bundle(
5759    transform_buffer_coder : Mapping [str , typing .Tuple [bytes , str ]],
5860    expected_outputs : translations .DataOutput ,
5961    stage_timers : Mapping [translations .TimerFamilyId , bytes ],
62+     split_manager ,
6063    instruction_request_repr : Mapping [str , typing .Any ],
6164    dry_run = False ,
6265) ->  Generator :
@@ -83,8 +86,6 @@ def ray_execute_bundle(
8386        runner_context , instruction_request_repr ["process_descriptor_id" ]
8487    )
8588
86-     _send_timers (worker_handler , input_bundle , stage_timers , process_bundle_id )
87- 
8889    input_data  =  {
8990        k : _fetch_decode_data (
9091            runner_context ,
@@ -95,19 +96,30 @@ def ray_execute_bundle(
9596        for  k , objrefs  in  input_bundle .input_data .items ()
9697    }
9798
98-     for  transform_id , elements  in  input_data .items ():
99-         data_out  =  worker_handler .data_conn .output_stream (
100-             process_bundle_id , transform_id 
101-         )
102-         for  byte_stream  in  elements :
103-             data_out .write (byte_stream )
104-         data_out .close ()
105- 
10699    expect_reads : List [typing .Union [str , translations .TimerFamilyId ]] =  list (
107100        expected_outputs .keys ()
108101    )
109102    expect_reads .extend (list (stage_timers .keys ()))
110103
104+     split_results  =  []
105+     split_manager_thread  =  None 
106+     if  split_manager :
107+         split_manager_thread  =  threading .Thread (
108+             target = _run_split_manager ,
109+             args = (
110+                 runner_context ,
111+                 worker_handler ,
112+                 split_manager ,
113+                 input_data ,
114+                 transform_buffer_coder ,
115+                 instruction_request ,
116+                 split_results ,
117+             ),
118+         )
119+         split_manager_thread .start ()
120+ 
121+     _send_timers (worker_handler , input_bundle , stage_timers , process_bundle_id )
122+     _send_input_data (worker_handler , input_data , process_bundle_id )
111123    result_future  =  worker_handler .control_conn .push (instruction_request )
112124
113125    for  output  in  worker_handler .data_conn .input_elements (
@@ -125,6 +137,8 @@ def ray_execute_bundle(
125137            output_buffers [expected_outputs [output .transform_id ]].append (output .data )
126138
127139    result : beam_fn_api_pb2 .InstructionResponse  =  result_future .get ()
140+     if  split_manager_thread :
141+         split_manager_thread .join ()
128142
129143    if  result .process_bundle .requires_finalization :
130144        finalize_request  =  beam_fn_api_pb2 .InstructionRequest (
@@ -151,14 +165,27 @@ def ray_execute_bundle(
151165    process_bundle_descriptor  =  runner_context .worker_manager .process_bundle_descriptor (
152166        instruction_request_repr ["process_descriptor_id" ]
153167    )
154-     delayed_applications  =  _retrieve_delayed_applications (
168+ 
169+     deferred_inputs  =  {}
170+ 
171+     _add_delayed_applications_to_deferred_inputs (
155172        result ,
156173        process_bundle_descriptor ,
157174        runner_context ,
175+         deferred_inputs ,
158176    )
159177
160-     returns .append (len (delayed_applications ))
161-     for  pcoll , buffer  in  delayed_applications .items ():
178+     _add_residuals_and_channel_splits_to_deferred_inputs (
179+         runner_context ,
180+         input_bundle .input_data ,
181+         transform_buffer_coder ,
182+         process_bundle_descriptor ,
183+         split_results ,
184+         deferred_inputs ,
185+     )
186+ 
187+     returns .append (len (deferred_inputs ))
188+     for  pcoll , buffer  in  deferred_inputs .items ():
162189        returns .append (pcoll )
163190        returns .append (buffer )
164191
@@ -206,37 +233,101 @@ def _get_source_transform_name(
206233    raise  RuntimeError ("No IO transform feeds %s"  %  transform_id )
207234
208235
209- def  _retrieve_delayed_applications (
236+ def  _add_delayed_application_to_deferred_inputs (
237+     process_bundle_descriptor : beam_fn_api_pb2 .ProcessBundleDescriptor ,
238+     delayed_application : beam_fn_api_pb2 .DelayedBundleApplication ,
239+     deferred_inputs : MutableMapping [str , List [bytes ]],
240+ ):
241+     # TODO(pabloem): Time delay needed for streaming. For now we'll ignore it. 
242+     # time_delay = delayed_application.requested_time_delay 
243+     source_transform  =  _get_source_transform_name (
244+         process_bundle_descriptor ,
245+         delayed_application .application .transform_id ,
246+         delayed_application .application .input_id ,
247+     )
248+ 
249+     if  source_transform  not  in   deferred_inputs :
250+         deferred_inputs [source_transform ] =  []
251+     deferred_inputs [source_transform ].append (delayed_application .application .element )
252+ 
253+ 
254+ def  _add_delayed_applications_to_deferred_inputs (
210255    bundle_result : beam_fn_api_pb2 .InstructionResponse ,
211256    process_bundle_descriptor : beam_fn_api_pb2 .ProcessBundleDescriptor ,
212257    runner_context : "RayRunnerExecutionContext" ,
258+     deferred_inputs : MutableMapping [str , List [bytes ]],
213259):
214260    """Extract delayed applications from a bundle run. 
215261
216262    A delayed application represents a user-initiated checkpoint, where user code 
217263    delays the consumption of a data element to checkpoint the previous elements 
218264    in a bundle. 
219265    """ 
220-     delayed_bundles  =  {}
221266    for  delayed_application  in  bundle_result .process_bundle .residual_roots :
222-         # TODO(pabloem): Time delay needed for streaming. For now we'll ignore it. 
223-         # time_delay = delayed_application.requested_time_delay 
224-         source_transform  =  _get_source_transform_name (
267+         _add_delayed_application_to_deferred_inputs (
225268            process_bundle_descriptor ,
226-             delayed_application . application . transform_id ,
227-             delayed_application . application . input_id ,
269+             delayed_application ,
270+             deferred_inputs ,
228271        )
229272
230-         if  source_transform  not  in   delayed_bundles :
231-             delayed_bundles [source_transform ] =  []
232-         delayed_bundles [source_transform ].append (
233-             delayed_application .application .element 
234-         )
235273
236-     for  consumer , data  in  delayed_bundles .items ():
237-         delayed_bundles [consumer ] =  [data ]
274+ def  _add_residuals_and_channel_splits_to_deferred_inputs (
275+     runner_context : "RayRunnerExecutionContext" ,
276+     raw_inputs : Mapping [str , List [ray .ObjectRef ]],
277+     transform_buffer_coder : Mapping [str , typing .Tuple [bytes , str ]],
278+     process_bundle_descriptor : beam_fn_api_pb2 .ProcessBundleDescriptor ,
279+     splits : List [beam_fn_api_pb2 .ProcessBundleSplitResponse ],
280+     deferred_inputs : MutableMapping [str , List [bytes ]],
281+ ):
282+     prev_split_point  =  {}  # transform id -> first residual offset 
283+     for  split  in  splits :
284+         for  delayed_application  in  split .residual_roots :
285+             _add_delayed_application_to_deferred_inputs (
286+                 process_bundle_descriptor ,
287+                 delayed_application ,
288+                 deferred_inputs ,
289+             )
290+         for  channel_split  in  split .channel_splits :
291+             # Decode all input elements 
292+             byte_stream  =  b"" .join (
293+                 (
294+                     element 
295+                     for  block  in  ray .get (raw_inputs [channel_split .transform_id ])
296+                     for  element  in  block 
297+                 )
298+             )
299+             input_coder_id  =  transform_buffer_coder [channel_split .transform_id ][1 ]
300+             input_coder  =  runner_context .pipeline_context .coders [input_coder_id ]
301+ 
302+             buffer_id  =  transform_buffer_coder [channel_split .transform_id ][0 ]
303+             if  buffer_id .startswith (b"group:" ):
304+                 coder_impl  =  coders .WindowedValueCoder (
305+                     coders .TupleCoder (
306+                         (
307+                             input_coder .wrapped_value_coder ._coders [0 ],
308+                             input_coder .wrapped_value_coder ._coders [1 ]._elem_coder ,
309+                         )
310+                     ),
311+                     input_coder .window_coder ,
312+                 ).get_impl ()
313+             else :
314+                 coder_impl  =  input_coder .get_impl ()
315+ 
316+             all_elements  =  list (coder_impl .decode_all (byte_stream ))
317+ 
318+             # split at first_residual_element index 
319+             end  =  prev_split_point .get (channel_split .transform_id , len (all_elements ))
320+             residual_elements  =  all_elements [channel_split .first_residual_element  : end ]
321+             prev_split_point [
322+                 channel_split .transform_id 
323+             ] =  channel_split .first_residual_element 
238324
239-     return  delayed_bundles 
325+             if  residual_elements :
326+                 encoded_residual  =  coder_impl .encode_all (residual_elements )
327+ 
328+                 if  channel_split .transform_id  not  in   deferred_inputs :
329+                     deferred_inputs [channel_split .transform_id ] =  []
330+                 deferred_inputs [channel_split .transform_id ].append (encoded_residual )
240331
241332
242333def  _get_input_id (buffer_id , transform_name ):
@@ -316,6 +407,94 @@ def _send_timers(
316407        timer_out .close ()
317408
318409
410+ def  _send_input_data (
411+     worker_handler : worker_handlers .WorkerHandler ,
412+     input_data : Mapping [str , fn_execution .PartitionableBuffer ],
413+     process_bundle_id ,
414+ ):
415+     for  transform_id , elements  in  input_data .items ():
416+         data_out  =  worker_handler .data_conn .output_stream (
417+             process_bundle_id , transform_id 
418+         )
419+         for  byte_stream  in  elements :
420+             data_out .write (byte_stream )
421+         data_out .close ()
422+ 
423+ 
424+ def  _run_split_manager (
425+     runner_context : "RayRunnerExecutionContext" ,
426+     worker_handler : worker_handlers .WorkerHandler ,
427+     split_manager ,
428+     inputs : Mapping [str , fn_execution .PartitionableBuffer ],
429+     transform_buffer_coder : Mapping [str , typing .Tuple [bytes , str ]],
430+     instruction_request ,
431+     split_results_buf : List [beam_fn_api_pb2 .ProcessBundleSplitResponse ],
432+ ):
433+     read_transform_id , buffer_data  =  translations .only_element (inputs .items ())
434+     byte_stream  =  b"" .join (buffer_data  or  [])
435+     coder_id  =  transform_buffer_coder [read_transform_id ][1 ]
436+     coder_impl  =  runner_context .pipeline_context .coders [coder_id ].get_impl ()
437+     num_elements  =  len (list (coder_impl .decode_all (byte_stream )))
438+ 
439+     # Start the split manager in case it wants to set any breakpoints. 
440+     split_manager_generator  =  split_manager (num_elements )
441+     try :
442+         split_fraction  =  next (split_manager_generator )
443+         done  =  False 
444+     except  StopIteration :
445+         split_fraction  =  None 
446+         done  =  True 
447+ 
448+     assert  worker_handler  is  not   None 
449+ 
450+     # Execute the requested splits. 
451+     while  not  done :
452+         if  split_fraction  is  None :
453+             split_result  =  None 
454+         else :
455+             DesiredSplit  =  beam_fn_api_pb2 .ProcessBundleSplitRequest .DesiredSplit 
456+             split_request  =  beam_fn_api_pb2 .InstructionRequest (
457+                 process_bundle_split = beam_fn_api_pb2 .ProcessBundleSplitRequest (
458+                     instruction_id = instruction_request .instruction_id ,
459+                     desired_splits = {
460+                         read_transform_id : DesiredSplit (
461+                             fraction_of_remainder = split_fraction ,
462+                             estimated_input_elements = num_elements ,
463+                         )
464+                     },
465+                 )
466+             )
467+             split_response  =  worker_handler .control_conn .push (
468+                 split_request 
469+             ).get ()  # type: beam_fn_api_pb2.InstructionResponse 
470+             for  t  in  (0.05 , 0.1 , 0.2 ):
471+                 if  (
472+                     "Unknown process bundle"  in  split_response .error 
473+                     or  split_response .process_bundle_split 
474+                     ==  beam_fn_api_pb2 .ProcessBundleSplitResponse ()
475+                 ):
476+                     time .sleep (t )
477+                     split_response  =  worker_handler .control_conn .push (
478+                         split_request 
479+                     ).get ()
480+             if  (
481+                 "Unknown process bundle"  in  split_response .error 
482+                 or  split_response .process_bundle_split 
483+                 ==  beam_fn_api_pb2 .ProcessBundleSplitResponse ()
484+             ):
485+                 # It may have finished too fast. 
486+                 split_result  =  None 
487+             elif  split_response .error :
488+                 raise  RuntimeError (split_response .error )
489+             else :
490+                 split_result  =  split_response .process_bundle_split 
491+                 split_results_buf .append (split_result )
492+         try :
493+             split_fraction  =  split_manager_generator .send (split_result )
494+         except  StopIteration :
495+             break 
496+ 
497+ 
319498@ray .remote  
320499class  _RayRunnerStats :
321500    def  __init__ (self ):
0 commit comments