4242from  apache_beam .runners .portability .fn_api_runner  import  translations 
4343from  apache_beam .runners .portability .fn_api_runner .execution  import  ListBuffer 
4444from  apache_beam .transforms  import  environments 
45- from  apache_beam .utils  import  proto_utils 
45+ from  apache_beam .utils  import  proto_utils ,  timestamp 
4646
4747import  ray 
4848from  ray_beam_runner .portability .context_management  import  RayBundleContextManager 
@@ -227,7 +227,9 @@ def _run_stage(
227227          bundle_context_manager (execution.BundleContextManager): A description of 
228228            the stage to execute, and its context. 
229229        """ 
230+ 
230231        bundle_context_manager .setup ()
232+ 
231233        runner_execution_context .worker_manager .register_process_bundle_descriptor (
232234            bundle_context_manager .process_bundle_descriptor 
233235        )
@@ -246,6 +248,8 @@ def _run_stage(
246248            for  k  in  bundle_context_manager .transform_to_buffer_coder 
247249        }
248250
251+         watermark_manager  =  runner_execution_context .watermark_manager 
252+ 
249253        final_result  =  None   # type: Optional[beam_fn_api_pb2.InstructionResponse] 
250254
251255        while  True :
@@ -262,19 +266,26 @@ def _run_stage(
262266
263267            final_result  =  merge_stage_results (final_result , last_result )
264268            if  not  delayed_applications  and  not  fired_timers :
269+                 # Processing has completed; marking all outputs as completed 
270+                 for  output_pc  in  bundle_outputs :
271+                     _ , update_output_pc  =  translations .split_buffer_id (output_pc )
272+                     watermark_manager .set_pcoll_produced_watermark .remote (
273+                         update_output_pc , timestamp .MAX_TIMESTAMP 
274+                     )
265275                break 
266276            else :
267-                 # TODO: Enable following assertion after watermarking is implemented 
268-                 # assert (ray.get( 
269-                 # runner_execution_context.watermark_manager 
270-                 # .get_stage_node.remote( 
271-                 #     bundle_context_manager.stage.name)).output_watermark() 
272-                 #         < timestamp.MAX_TIMESTAMP), ( 
273-                 #     'wrong timestamp for %s. ' 
274-                 #     % ray.get( 
275-                 #     runner_execution_context.watermark_manager 
276-                 #     .get_stage_node.remote( 
277-                 #     bundle_context_manager.stage.name))) 
277+                 assert  (
278+                     ray .get (
279+                         watermark_manager .get_stage_node .remote (
280+                             bundle_context_manager .stage .name 
281+                         )
282+                     ).output_watermark ()
283+                     <  timestamp .MAX_TIMESTAMP 
284+                 ), "wrong timestamp for %s. "  %  ray .get (
285+                     watermark_manager .get_stage_node .remote (
286+                         bundle_context_manager .stage .name 
287+                     )
288+                 )
278289                input_data  =  delayed_applications 
279290                input_timers  =  fired_timers 
280291
@@ -288,6 +299,20 @@ def _run_stage(
288299        # TODO(pabloem): Make sure that side inputs are being stored somewhere. 
289300        # runner_execution_context.commit_side_inputs_to_state(data_side_input) 
290301
302+         # assert that the output watermark was correctly set for this stage 
303+         stage_node  =  ray .get (
304+             runner_execution_context .watermark_manager .get_stage_node .remote (
305+                 bundle_context_manager .stage .name 
306+             )
307+         )
308+         assert  (
309+             stage_node .output_watermark () ==  timestamp .MAX_TIMESTAMP 
310+         ), "wrong output watermark for %s. Expected %s, but got %s."  %  (
311+             stage_node ,
312+             timestamp .MAX_TIMESTAMP ,
313+             stage_node .output_watermark (),
314+         )
315+ 
291316        return  final_result 
292317
293318    def  _run_bundle (
@@ -346,6 +371,21 @@ def _run_bundle(
346371        #           coder_impl=bundle_context_manager.get_input_coder_impl( 
347372        #               other_input)) 
348373
374+         # TODO: replace placeholder sets when timers are implemented 
375+         watermark_updates  =  fn_runner .FnApiRunner ._build_watermark_updates (
376+             runner_execution_context ,
377+             transform_to_buffer_coder .keys (),
378+             set (),  # expected_timers 
379+             set (),  # pcolls_with_da 
380+             delayed_applications .keys (),
381+             set (),  # watermarks_by_transform_and_timer_family 
382+         )
383+ 
384+         for  pc_name , watermark  in  watermark_updates .items ():
385+             runner_execution_context .watermark_manager .set_pcoll_watermark .remote (
386+                 pc_name , watermark 
387+             )
388+ 
349389        newly_set_timers  =  {}
350390        return  result , newly_set_timers , delayed_applications , output 
351391
0 commit comments