|
22 | 22 | from tensorflow_transform.tf_metadata import dataset_metadata
|
23 | 23 | from tensorflow_transform.tf_metadata import schema_utils
|
24 | 24 |
|
| 25 | +_RAW_DATA_METADATA = dataset_metadata.DatasetMetadata( |
| 26 | + schema_utils.schema_from_feature_spec({ |
| 27 | + 's': tf.io.FixedLenFeature([], tf.string), |
| 28 | + 'y': tf.io.FixedLenFeature([], tf.float32), |
| 29 | + 'x': tf.io.FixedLenFeature([], tf.float32), |
| 30 | + })) |
| 31 | + |
| 32 | +_RAW_DATA = [{ |
| 33 | + 'x': 1, |
| 34 | + 'y': 1, |
| 35 | + 's': 'hello' |
| 36 | +}, { |
| 37 | + 'x': 2, |
| 38 | + 'y': 2, |
| 39 | + 's': 'world' |
| 40 | +}, { |
| 41 | + 'x': 3, |
| 42 | + 'y': 3, |
| 43 | + 's': 'hello' |
| 44 | +}] |
25 | 45 |
|
26 |
| -def main(): |
27 |
| - def preprocessing_fn(inputs): |
28 |
| - """Preprocess input columns into transformed columns.""" |
29 |
| - x = inputs['x'] |
30 |
| - y = inputs['y'] |
31 |
| - s = inputs['s'] |
32 |
| - x_centered = x - tft.mean(x) |
33 |
| - y_normalized = tft.scale_to_0_1(y) |
34 |
| - s_integerized = tft.compute_and_apply_vocabulary(s) |
35 |
| - x_centered_times_y_normalized = (x_centered * y_normalized) |
36 |
| - return { |
37 |
| - 'x_centered': x_centered, |
38 |
| - 'y_normalized': y_normalized, |
39 |
| - 'x_centered_times_y_normalized': x_centered_times_y_normalized, |
40 |
| - 's_integerized': s_integerized |
41 |
| - } |
42 | 46 |
|
43 |
| - raw_data = [ |
44 |
| - {'x': 1, 'y': 1, 's': 'hello'}, |
45 |
| - {'x': 2, 'y': 2, 's': 'world'}, |
46 |
| - {'x': 3, 'y': 3, 's': 'hello'} |
47 |
| - ] |
| 47 | +def _preprocessing_fn(inputs): |
| 48 | + """Preprocess input columns into transformed columns.""" |
| 49 | + x = inputs['x'] |
| 50 | + y = inputs['y'] |
| 51 | + s = inputs['s'] |
| 52 | + x_centered = x - tft.mean(x) |
| 53 | + y_normalized = tft.scale_to_0_1(y) |
| 54 | + s_integerized = tft.compute_and_apply_vocabulary(s) |
| 55 | + x_centered_times_y_normalized = (x_centered * y_normalized) |
| 56 | + return { |
| 57 | + 'x_centered': x_centered, |
| 58 | + 'y_normalized': y_normalized, |
| 59 | + 'x_centered_times_y_normalized': x_centered_times_y_normalized, |
| 60 | + 's_integerized': s_integerized |
| 61 | + } |
48 | 62 |
|
49 |
| - raw_data_metadata = dataset_metadata.DatasetMetadata( |
50 |
| - schema_utils.schema_from_feature_spec({ |
51 |
| - 's': tf.io.FixedLenFeature([], tf.string), |
52 |
| - 'y': tf.io.FixedLenFeature([], tf.float32), |
53 |
| - 'x': tf.io.FixedLenFeature([], tf.float32), |
54 |
| - })) |
| 63 | + |
| 64 | +def main(): |
55 | 65 |
|
56 | 66 | with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
|
57 | 67 | transformed_dataset, transform_fn = ( # pylint: disable=unused-variable
|
58 |
| - (raw_data, raw_data_metadata) | tft_beam.AnalyzeAndTransformDataset( |
59 |
| - preprocessing_fn)) |
| 68 | + (_RAW_DATA, _RAW_DATA_METADATA) |
| 69 | + | tft_beam.AnalyzeAndTransformDataset(_preprocessing_fn)) |
60 | 70 |
|
61 | 71 | transformed_data, transformed_metadata = transformed_dataset # pylint: disable=unused-variable
|
62 | 72 |
|
|
0 commit comments