19
19
This file implements the full Beam pipeline for TFRecorder.
20
20
"""
21
21
22
- from typing import Any , Dict , Generator , Union
22
+ from typing import Any , Callable , Dict , Generator , List , Optional , Union
23
23
24
24
import functools
25
25
import logging
26
26
import os
27
27
28
28
import apache_beam as beam
29
+ from apache_beam import pvalue
29
30
import pandas as pd
30
31
import tensorflow_transform as tft
31
32
from tensorflow_transform import beam as tft_beam
32
33
33
34
from tfrecorder import beam_image
34
35
from tfrecorder import common
35
36
from tfrecorder import constants
37
+ from tfrecorder import types
36
38
37
39
38
40
def _get_job_name (job_label : str = None ) -> str :
@@ -138,7 +140,7 @@ def _get_write_to_tfrecord(output_dir: str,
138
140
num_shards = num_shards ,
139
141
)
140
142
141
- def _preprocessing_fn (inputs , integer_label : bool = False ):
143
+ def _preprocessing_fn (inputs : Dict [ str , Any ] , integer_label : bool = False ):
142
144
"""TensorFlow Transform preprocessing function."""
143
145
144
146
outputs = inputs .copy ()
@@ -166,7 +168,7 @@ def __init__(self):
166
168
# pylint: disable=arguments-differ
167
169
def process (
168
170
self ,
169
- element : Dict [str , Any ]
171
+ element : List [str ],
170
172
) -> Generator [Dict [str , Any ], None , None ]:
171
173
"""Loads image and creates image features.
172
174
@@ -178,6 +180,43 @@ def process(
178
180
yield element
179
181
180
182
183
+ def get_split_counts (df : pd .DataFrame ):
184
+ """Returns number of rows for each data split type given dataframe."""
185
+ assert constants .SPLIT_KEY in df .columns
186
+ return df [constants .SPLIT_KEY ].value_counts ().to_dict ()
187
+
188
+
189
+ def _transform_and_write_tfr (
190
+ dataset : pvalue .PCollection ,
191
+ tfr_writer : Callable = None ,
192
+ preprocessing_fn : Optional [Callable ] = None ,
193
+ transform_fn : Optional [types .TransformFn ] = None ,
194
+ label : str = 'data' ):
195
+ """Applies TF Transform to dataset and outputs it as TFRecords."""
196
+
197
+ dataset_metadata = (dataset , constants .RAW_METADATA )
198
+
199
+ if transform_fn :
200
+ transformed_dataset , transformed_metadata = (
201
+ (dataset_metadata , transform_fn )
202
+ | f'Transform{ label } ' >> tft_beam .TransformDataset ())
203
+ else :
204
+ if not preprocessing_fn :
205
+ preprocessing_fn = lambda x : x
206
+ (transformed_dataset , transformed_metadata ), transform_fn = (
207
+ dataset_metadata
208
+ | f'AnalyzeAndTransform{ label } ' >>
209
+ tft_beam .AnalyzeAndTransformDataset (preprocessing_fn ))
210
+
211
+ transformed_data_coder = tft .coders .ExampleProtoCoder (
212
+ transformed_metadata .schema )
213
+ _ = (
214
+ transformed_dataset
215
+ | f'Encode{ label } ' >> beam .Map (transformed_data_coder .encode )
216
+ | f'Write{ label } ' >> tfr_writer (prefix = label .lower ()))
217
+
218
+ return transform_fn
219
+
181
220
182
221
# pylint: disable=too-many-arguments
183
222
# pylint: disable=too-many-locals
@@ -246,71 +285,49 @@ def build_pipeline(
246
285
| 'ReadImage' >> beam .ParDo (extract_images_fn )
247
286
)
248
287
249
- # Split dataset into train and validation.
288
+ # Note: This will not always reflect actual number of samples per dataset
289
+ # written as TFRecords. The succeeding `Partition` operation may mark
290
+ # additional samples from other splits as discarded. If a split has all
291
+ # its samples discarded, the pipeline will still generate a TFRecord
292
+ # file for that split, albeit empty.
293
+ split_counts = get_split_counts (df )
294
+
295
+ # Require training set to be available in the input data. The transform_fn
296
+ # and transformed_metadata will be generated from the training set and
297
+ # applied to the other datasets, if any
298
+ assert 'TRAIN' in split_counts
299
+
250
300
train_data , val_data , test_data , discard_data = (
251
301
image_csv_data | 'SplitDataset' >> beam .Partition (
252
302
_partition_fn , len (constants .SPLIT_VALUES ))
253
303
)
254
304
255
- train_dataset = (train_data , constants .RAW_METADATA )
256
- val_dataset = (val_data , constants .RAW_METADATA )
257
- test_dataset = (test_data , constants .RAW_METADATA )
258
-
259
- # TensorFlow Transform applied to all datasets.
260
305
preprocessing_fn = functools .partial (
261
306
_preprocessing_fn ,
262
307
integer_label = integer_label )
263
- transformed_train_dataset , transform_fn = (
264
- train_dataset
265
- | 'AnalyzeAndTransformTrain' >> tft_beam .AnalyzeAndTransformDataset (
266
- preprocessing_fn ))
267
-
268
- transformed_train_data , transformed_metadata = transformed_train_dataset
269
- transformed_data_coder = tft .coders .ExampleProtoCoder (
270
- transformed_metadata .schema )
271
-
272
- transformed_val_data , _ = (
273
- (val_dataset , transform_fn )
274
- | 'TransformVal' >> tft_beam .TransformDataset ()
275
- )
276
308
277
- transformed_test_data , _ = (
278
- (test_dataset , transform_fn )
279
- | 'TransformTest' >> tft_beam .TransformDataset ()
280
- )
309
+ tfr_writer = functools .partial (
310
+ _get_write_to_tfrecord , output_dir = job_dir , compress = compression ,
311
+ num_shards = num_shards )
312
+ transform_fn = _transform_and_write_tfr (
313
+ train_data , tfr_writer , preprocessing_fn = preprocessing_fn ,
314
+ label = 'Train' )
281
315
282
- # Sinks for TFRecords and metadata.
283
- tfr_writer = functools .partial (_get_write_to_tfrecord ,
284
- output_dir = job_dir ,
285
- compress = compression ,
286
- num_shards = num_shards )
316
+ if 'VALIDATION' in split_counts :
317
+ _transform_and_write_tfr (
318
+ val_data , tfr_writer , transform_fn = transform_fn , label = 'Validation' )
287
319
288
- _ = (
289
- transformed_train_data
290
- | 'EncodeTrainData' >> beam .Map (transformed_data_coder .encode )
291
- | 'WriteTrainData' >> tfr_writer (prefix = 'train' ))
292
-
293
- _ = (
294
- transformed_val_data
295
- | 'EncodeValData' >> beam .Map (transformed_data_coder .encode )
296
- | 'WriteValData' >> tfr_writer (prefix = 'val' ))
297
-
298
- _ = (
299
- transformed_test_data
300
- | 'EncodeTestData' >> beam .Map (transformed_data_coder .encode )
301
- | 'WriteTestData' >> tfr_writer (prefix = 'test' ))
320
+ if 'TEST' in split_counts :
321
+ _transform_and_write_tfr (
322
+ test_data , tfr_writer , transform_fn = transform_fn , label = 'Test' )
302
323
303
324
_ = (
304
325
discard_data
305
- | 'DiscardDataWriter ' >> beam .io .WriteToText (
326
+ | 'WriteDiscardedData ' >> beam .io .WriteToText (
306
327
os .path .join (job_dir , 'discarded-data' )))
307
328
308
- # Output transform function and metadata
329
+ # Note: `transform_fn` already contains the transformed metadata
309
330
_ = (transform_fn | 'WriteTransformFn' >> tft_beam .WriteTransformFn (
310
331
job_dir ))
311
332
312
- # Output metadata schema
313
- _ = (transformed_metadata | 'WriteMetadata' >> tft_beam .WriteMetadata (
314
- job_dir , pipeline = p ))
315
-
316
333
return p
0 commit comments