@@ -125,6 +125,17 @@ def ray_execute_bundle(
125125 output_buffers [expected_outputs [output .transform_id ]].append (output .data )
126126
127127 result : beam_fn_api_pb2 .InstructionResponse = result_future .get ()
128+
129+ if result .process_bundle .requires_finalization :
130+ finalize_request = beam_fn_api_pb2 .InstructionRequest (
131+ finalize_bundle = beam_fn_api_pb2 .FinalizeBundleRequest (
132+ instruction_id = process_bundle_id
133+ )
134+ )
135+ finalize_response = worker_handler .control_conn .push (finalize_request ).get ()
136+ if finalize_response .error :
137+ raise RuntimeError (finalize_response .error )
138+
128139 returns = [result .SerializeToString ()]
129140
130141 returns .append (len (output_buffers ))
@@ -155,6 +166,46 @@ def ray_execute_bundle(
155166 yield ret
156167
157168
169+ def _get_source_transform_name (
170+ process_bundle_descriptor : beam_fn_api_pb2 .ProcessBundleDescriptor ,
171+ transform_id : str ,
172+ input_id : str ,
173+ ) -> str :
174+ """Find the name of the source PTransform that feeds into the given
175+ (transform_id, input_id)."""
176+ input_pcoll = process_bundle_descriptor .transforms [transform_id ].inputs [input_id ]
177+ for ptransform_id , ptransform in process_bundle_descriptor .transforms .items ():
178+ # The GrpcRead is directly followed by the SDF/Process.
179+ if (
180+ ptransform .spec .urn == bundle_processor .DATA_INPUT_URN
181+ and input_pcoll in ptransform .outputs .values ()
182+ ):
183+ return ptransform_id
184+
185+ # The GrpcRead is followed by SDF/Truncate -> SDF/Process.
186+ # We need to traverse the TRUNCATE_SIZED_RESTRICTION node in order
187+ # to find the original source PTransform.
188+ if (
189+ ptransform .spec .urn
190+ == common_urns .sdf_components .TRUNCATE_SIZED_RESTRICTION .urn
191+ and input_pcoll in ptransform .outputs .values ()
192+ ):
193+ input_pcoll_ = translations .only_element (
194+ process_bundle_descriptor .transforms [ptransform_id ].inputs .values ()
195+ )
196+ for (
197+ ptransform_id_2 ,
198+ ptransform_2 ,
199+ ) in process_bundle_descriptor .transforms .items ():
200+ if (
201+ ptransform_2 .spec .urn == bundle_processor .DATA_INPUT_URN
202+ and input_pcoll_ in ptransform_2 .outputs .values ()
203+ ):
204+ return ptransform_id_2
205+
206+ raise RuntimeError ("No IO transform feeds %s" % transform_id )
207+
208+
158209def _retrieve_delayed_applications (
159210 bundle_result : beam_fn_api_pb2 .InstructionResponse ,
160211 process_bundle_descriptor : beam_fn_api_pb2 .ProcessBundleDescriptor ,
@@ -170,22 +221,15 @@ def _retrieve_delayed_applications(
170221 for delayed_application in bundle_result .process_bundle .residual_roots :
171222 # TODO(pabloem): Time delay needed for streaming. For now we'll ignore it.
172223 # time_delay = delayed_application.requested_time_delay
173- transform = process_bundle_descriptor .transforms [
174- delayed_application .application .transform_id
175- ]
176- pcoll_name = transform .inputs [delayed_application .application .input_id ]
177-
178- consumer_transform = translations .only_element (
179- [
180- read_id
181- for read_id , proto in process_bundle_descriptor .transforms .items ()
182- if proto .spec .urn == bundle_processor .DATA_INPUT_URN
183- and pcoll_name in proto .outputs .values ()
184- ]
224+ source_transform = _get_source_transform_name (
225+ process_bundle_descriptor ,
226+ delayed_application .application .transform_id ,
227+ delayed_application .application .input_id ,
185228 )
186- if consumer_transform not in delayed_bundles :
187- delayed_bundles [consumer_transform ] = []
188- delayed_bundles [consumer_transform ].append (
229+
230+ if source_transform not in delayed_bundles :
231+ delayed_bundles [source_transform ] = []
232+ delayed_bundles [source_transform ].append (
189233 delayed_application .application .element
190234 )
191235
0 commit comments