Skip to content

Commit ff04e67

Browse files
zoyahavtfx-copybara
authored andcommitted
Showcase tft_unit in simple example, and validate the expected transformed data.
PiperOrigin-RevId: 422775631
1 parent cd42be7 commit ff04e67

File tree

2 files changed

+72
-30
lines changed

2 files changed

+72
-30
lines changed

examples/simple_example.py

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,41 +22,51 @@
2222
from tensorflow_transform.tf_metadata import dataset_metadata
2323
from tensorflow_transform.tf_metadata import schema_utils
2424

25+
_RAW_DATA_METADATA = dataset_metadata.DatasetMetadata(
26+
schema_utils.schema_from_feature_spec({
27+
's': tf.io.FixedLenFeature([], tf.string),
28+
'y': tf.io.FixedLenFeature([], tf.float32),
29+
'x': tf.io.FixedLenFeature([], tf.float32),
30+
}))
31+
32+
_RAW_DATA = [{
33+
'x': 1,
34+
'y': 1,
35+
's': 'hello'
36+
}, {
37+
'x': 2,
38+
'y': 2,
39+
's': 'world'
40+
}, {
41+
'x': 3,
42+
'y': 3,
43+
's': 'hello'
44+
}]
2545

26-
def main():
27-
def preprocessing_fn(inputs):
28-
"""Preprocess input columns into transformed columns."""
29-
x = inputs['x']
30-
y = inputs['y']
31-
s = inputs['s']
32-
x_centered = x - tft.mean(x)
33-
y_normalized = tft.scale_to_0_1(y)
34-
s_integerized = tft.compute_and_apply_vocabulary(s)
35-
x_centered_times_y_normalized = (x_centered * y_normalized)
36-
return {
37-
'x_centered': x_centered,
38-
'y_normalized': y_normalized,
39-
'x_centered_times_y_normalized': x_centered_times_y_normalized,
40-
's_integerized': s_integerized
41-
}
4246

43-
raw_data = [
44-
{'x': 1, 'y': 1, 's': 'hello'},
45-
{'x': 2, 'y': 2, 's': 'world'},
46-
{'x': 3, 'y': 3, 's': 'hello'}
47-
]
47+
def _preprocessing_fn(inputs):
48+
"""Preprocess input columns into transformed columns."""
49+
x = inputs['x']
50+
y = inputs['y']
51+
s = inputs['s']
52+
x_centered = x - tft.mean(x)
53+
y_normalized = tft.scale_to_0_1(y)
54+
s_integerized = tft.compute_and_apply_vocabulary(s)
55+
x_centered_times_y_normalized = (x_centered * y_normalized)
56+
return {
57+
'x_centered': x_centered,
58+
'y_normalized': y_normalized,
59+
'x_centered_times_y_normalized': x_centered_times_y_normalized,
60+
's_integerized': s_integerized
61+
}
4862

49-
raw_data_metadata = dataset_metadata.DatasetMetadata(
50-
schema_utils.schema_from_feature_spec({
51-
's': tf.io.FixedLenFeature([], tf.string),
52-
'y': tf.io.FixedLenFeature([], tf.float32),
53-
'x': tf.io.FixedLenFeature([], tf.float32),
54-
}))
63+
64+
def main():
5565

5666
with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
5767
transformed_dataset, transform_fn = ( # pylint: disable=unused-variable
58-
(raw_data, raw_data_metadata) | tft_beam.AnalyzeAndTransformDataset(
59-
preprocessing_fn))
68+
(_RAW_DATA, _RAW_DATA_METADATA)
69+
| tft_beam.AnalyzeAndTransformDataset(_preprocessing_fn))
6070

6171
transformed_data, transformed_metadata = transformed_dataset # pylint: disable=unused-variable
6272

examples/simple_example_test.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,42 @@
1414
"""Tests for simple_example."""
1515

1616
import tensorflow as tf
17+
from tensorflow_transform.beam import tft_unit
1718
import simple_example
1819

1920

20-
class SimpleExampleTest(tf.test.TestCase):
21+
_EXPECTED_TRANSFORMED_OUTPUT = [
22+
{
23+
'x_centered': 1.0,
24+
'y_normalized': 1.0,
25+
'x_centered_times_y_normalized': 1.0,
26+
's_integerized': 0,
27+
},
28+
{
29+
'x_centered': 0.0,
30+
'y_normalized': 0.5,
31+
'x_centered_times_y_normalized': 0.0,
32+
's_integerized': 1,
33+
},
34+
{
35+
'x_centered': -1.0,
36+
'y_normalized': 0.0,
37+
'x_centered_times_y_normalized': -0.0,
38+
's_integerized': 0,
39+
},
40+
]
41+
42+
43+
class SimpleExampleTest(tft_unit.TransformTestCase):
44+
45+
def test_preprocessing_fn(self):
46+
self.assertAnalyzeAndTransformResults(simple_example._RAW_DATA,
47+
simple_example._RAW_DATA_METADATA,
48+
simple_example._preprocessing_fn,
49+
_EXPECTED_TRANSFORMED_OUTPUT)
50+
51+
52+
class SimpleMainTest(tf.test.TestCase):
2153

2254
def testMainDoesNotCrash(self):
2355
simple_example.main()

0 commit comments

Comments
 (0)