Skip to content

Commit 6a31f33

Browse files
Add test workflow with unittest (#189)
* Add `.gitignore` * Add github test workflow * Add push trigger * Add expected failures and skips * Also test on python version 3.11 * Remove python version 3.11 due to dependency error * Add note on how to run tests * Add `tf-keras` as a dependency * Add expect failure for several tests and test classes all suffering from `AttributeError: module 'tensorflow_model_analysis' has no attribute 'EvalConfig'` These are to be addressed in a future PR * Remove import to nonexistant modules * Install `libprotobuf-c-dev` for unit tests in CI * Temporarily remove `expectedFailure`s * Fix `EvalConfig` imports * Add more `expectedFailure`s * Remove `unexpectedFailure` from unexpected success * Add Python 3.11 to CI tests * Remove `libprotobuf-c-dev` * Remove unnecessary import * Add `expectedFailure` * Use `skip` instead of `expectedFailure` * Fix code comment * Run for all users --------- Co-authored-by: Peyton Murray <[email protected]>
1 parent a349d15 commit 6a31f33

17 files changed

+120
-36
lines changed

.github/workflows/ci-test.yml

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Github action definitions for unit-tests with PRs.
2+
3+
name: tfma-unit-tests
4+
on:
5+
push:
6+
pull_request:
7+
branches: [ master ]
8+
paths-ignore:
9+
- '**.md'
10+
- 'docs/**'
11+
workflow_dispatch:
12+
13+
jobs:
14+
unit-tests:
15+
runs-on: ubuntu-latest
16+
17+
strategy:
18+
matrix:
19+
python-version: ['3.9', '3.10', '3.11']
20+
21+
steps:
22+
- name: Checkout repository
23+
uses: actions/checkout@v4
24+
25+
- name: Set up Python ${{ matrix.python-version }}
26+
uses: actions/setup-python@v5
27+
with:
28+
python-version: ${{ matrix.python-version }}
29+
cache: 'pip'
30+
cache-dependency-path: |
31+
setup.py
32+
33+
- name: Install dependencies
34+
run: |
35+
sudo apt update
36+
sudo apt install -y protobuf-compiler
37+
pip install .
38+
39+
- name: Run unit tests
40+
shell: bash
41+
run: |
42+
python -m unittest discover -p "*_test.py"

README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,16 @@ cd dist
8585
pip3 install tensorflow_model_analysis-<version>-py3-none-any.whl
8686
```
8787

88+
### Running tests
89+
90+
To run tests, run
91+
92+
```
93+
python -m unittest discover -p *_test.py
94+
```
95+
96+
from the root project directory.
97+
8898
### Jupyter Lab
8999

90100
As of writing, because of https://github.com/pypa/pip/issues/9187, `pip install`

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,7 @@ def select_constraint(default, nightly=None, git_master=None):
342342
nightly='>=1.18.0.dev',
343343
git_master='@git+https://github.com/tensorflow/tfx-bsl@master',
344344
),
345+
'tf-keras',
345346
],
346347
'extras_require': {
347348
'all': [*_make_extra_packages_tfjs(), *_make_docs_packages()],

tensorflow_model_analysis/api/model_eval_lib_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import json
1717
import os
1818
import tempfile
19+
import unittest
1920

2021
from absl.testing import absltest
2122
from absl.testing import parameterized
@@ -1122,6 +1123,9 @@ def testRunModelAnalysisWithQueryBasedMetrics(self):
11221123
for k in expected_metrics[group]:
11231124
self.assertIn(k, got_metrics[group])
11241125

1126+
# PR 189: Remove the `skip` mark if the test passes for all supported versions
1127+
# of python
1128+
@unittest.skip('Fails for some versions of Python, including 3.9')
11251129
def testRunModelAnalysisWithUncertainty(self):
11261130
examples = [
11271131
self._makeExample(age=3.0, language='english', label=1.0),
@@ -1391,6 +1395,8 @@ def testRunModelAnalysisWithSchema(self):
13911395
self.assertEqual(1.0, got_buckets[1]['lowerThresholdInclusive'])
13921396
self.assertEqual(2.0, got_buckets[-2]['upperThresholdExclusive'])
13931397

1398+
# PR 189: Remove the `expectedFailure` mark if the test passes
1399+
@unittest.expectedFailure
13941400
def testLoadValidationResult(self):
13951401
result = validation_result_pb2.ValidationResult(validation_ok=True)
13961402
path = os.path.join(absltest.get_default_test_tmpdir(), 'results.tfrecord')
@@ -1399,6 +1405,8 @@ def testLoadValidationResult(self):
13991405
loaded_result = model_eval_lib.load_validation_result(path)
14001406
self.assertTrue(loaded_result.validation_ok)
14011407

1408+
# PR 189: Remove the `expectedFailure` mark if the test passes
1409+
@unittest.expectedFailure
14021410
def testLoadValidationResultDir(self):
14031411
result = validation_result_pb2.ValidationResult(validation_ok=True)
14041412
path = os.path.join(
@@ -1409,6 +1417,8 @@ def testLoadValidationResultDir(self):
14091417
loaded_result = model_eval_lib.load_validation_result(os.path.dirname(path))
14101418
self.assertTrue(loaded_result.validation_ok)
14111419

1420+
# PR 189: Remove the `expectedFailure` mark if the test passes
1421+
@unittest.expectedFailure
14121422
def testLoadValidationResultEmptyFile(self):
14131423
path = os.path.join(
14141424
absltest.get_default_test_tmpdir(), constants.VALIDATIONS_KEY

tensorflow_model_analysis/export_only/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,3 @@ def eval_input_receiver_fn():
2929
tfma_export.export.export_eval_saved_model(...)
3030
"""
3131

32-
from tensorflow_model_analysis.eval_saved_model import export
33-
from tensorflow_model_analysis.eval_saved_model import exporter

tensorflow_model_analysis/extractors/inference_base_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
from tensorflow_serving.apis import logging_pb2
3535
from tensorflow_serving.apis import prediction_log_pb2
3636

37+
import unittest
38+
3739

3840
class TfxBslPredictionsExtractorTest(testutil.TensorflowModelAnalysisTest):
3941

@@ -70,6 +72,8 @@ def _create_tfxio_and_feature_extractor(
7072
)
7173
return tfx_io, feature_extractor
7274

75+
# PR 189: Remove the `expectedFailure` mark if the test passes
76+
@unittest.expectedFailure
7377
def testIsValidConfigForBulkInferencePass(self):
7478
saved_model_proto = text_format.Parse(
7579
"""
@@ -129,6 +133,8 @@ def testIsValidConfigForBulkInferencePass(self):
129133
)
130134
)
131135

136+
# PR 189: Remove the `expectedFailure` mark if the test passes
137+
@unittest.expectedFailure
132138
def testIsValidConfigForBulkInferencePassDefaultSignatureLookUp(self):
133139
saved_model_proto = text_format.Parse(
134140
"""
@@ -184,6 +190,8 @@ def testIsValidConfigForBulkInferencePassDefaultSignatureLookUp(self):
184190
)
185191
)
186192

193+
# PR 189: Remove the `expectedFailure` mark if the test passes
194+
@unittest.expectedFailure
187195
def testIsValidConfigForBulkInferenceFailNoSignatureFound(self):
188196
saved_model_proto = text_format.Parse(
189197
"""
@@ -239,6 +247,8 @@ def testIsValidConfigForBulkInferenceFailNoSignatureFound(self):
239247
)
240248
)
241249

250+
# PR 189: Remove the `expectedFailure` mark if the test passes
251+
@unittest.expectedFailure
242252
def testIsValidConfigForBulkInferenceFailKerasModel(self):
243253
saved_model_proto = text_format.Parse(
244254
"""
@@ -296,6 +306,8 @@ def testIsValidConfigForBulkInferenceFailKerasModel(self):
296306
)
297307
)
298308

309+
# PR 189: Remove the `expectedFailure` mark if the test passes
310+
@unittest.expectedFailure
299311
def testIsValidConfigForBulkInferenceFailWrongInputType(self):
300312
saved_model_proto = text_format.Parse(
301313
"""

tensorflow_model_analysis/metrics/bleu_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import numpy as np
2020
import tensorflow as tf
2121
import tensorflow_model_analysis as tfma
22+
from tensorflow_model_analysis.proto import config_pb2
2223
from tensorflow_model_analysis import constants
2324
from tensorflow_model_analysis.evaluators import metrics_plots_and_validations_evaluator
2425
from tensorflow_model_analysis.metrics import bleu
@@ -573,7 +574,7 @@ def test_bleu_end_2_end(self):
573574
}
574575
}
575576
""",
576-
tfma.EvalConfig(),
577+
config_pb2.EvalConfig(),
577578
)
578579

579580
example1 = {

tensorflow_model_analysis/metrics/example_count_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import numpy as np
2020
import tensorflow as tf
2121
import tensorflow_model_analysis as tfma
22+
from tensorflow_model_analysis.proto import config_pb2
2223
from tensorflow_model_analysis.metrics import example_count
2324
from tensorflow_model_analysis.metrics import metric_types
2425
from tensorflow_model_analysis.metrics import metric_util
@@ -109,7 +110,7 @@ def testExampleCountsWithoutLabelPredictions(self):
109110
}
110111
}
111112
""",
112-
tfma.EvalConfig(),
113+
config_pb2.EvalConfig(),
113114
)
114115
name_list = ['example_count']
115116
expected_results = [0.6]

tensorflow_model_analysis/metrics/object_detection_confusion_matrix_metrics_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
from apache_beam.testing import util
1919
import numpy as np
2020
import tensorflow_model_analysis as tfma
21+
from tensorflow_model_analysis.proto import config_pb2
2122
from tensorflow_model_analysis.metrics import metric_types
2223
from google.protobuf import text_format
2324

24-
2525
class ObjectDetectionConfusionMatrixMetricsTest(parameterized.TestCase):
2626

2727
@parameterized.named_parameters(('_max_recall',
@@ -41,7 +41,7 @@ class ObjectDetectionConfusionMatrixMetricsTest(parameterized.TestCase):
4141
'"max_num_detections":100, "name":"maxrecall"'
4242
}
4343
}
44-
""", tfma.EvalConfig()), ['maxrecall'], [2 / 3]),
44+
""", config_pb2.EvalConfig()), ['maxrecall'], [2 / 3]),
4545
('_precision_at_recall',
4646
text_format.Parse(
4747
"""
@@ -59,7 +59,7 @@ class ObjectDetectionConfusionMatrixMetricsTest(parameterized.TestCase):
5959
'"max_num_detections":100, "name":"precisionatrecall"'
6060
}
6161
}
62-
""", tfma.EvalConfig()), ['precisionatrecall'], [3 / 5]),
62+
""", config_pb2.EvalConfig()), ['precisionatrecall'], [3 / 5]),
6363
('_recall',
6464
text_format.Parse(
6565
"""
@@ -77,7 +77,7 @@ class ObjectDetectionConfusionMatrixMetricsTest(parameterized.TestCase):
7777
'"max_num_detections":100, "name":"recall"'
7878
}
7979
}
80-
""", tfma.EvalConfig()), ['recall'], [2 / 3]), ('_precision',
80+
""", config_pb2.EvalConfig()), ['recall'], [2 / 3]), ('_precision',
8181
text_format.Parse(
8282
"""
8383
model_specs {
@@ -94,7 +94,7 @@ class ObjectDetectionConfusionMatrixMetricsTest(parameterized.TestCase):
9494
'"max_num_detections":100, "name":"precision"'
9595
}
9696
}
97-
""", tfma.EvalConfig()), ['precision'], [0.5]), ('_threshold_at_recall',
97+
""", config_pb2.EvalConfig()), ['precision'], [0.5]), ('_threshold_at_recall',
9898
text_format.Parse(
9999
"""
100100
model_specs {
@@ -111,7 +111,7 @@ class ObjectDetectionConfusionMatrixMetricsTest(parameterized.TestCase):
111111
'"max_num_detections":100, "name":"thresholdatrecall"'
112112
}
113113
}
114-
""", tfma.EvalConfig()), ['thresholdatrecall'], [0.3]))
114+
""", config_pb2.EvalConfig()), ['thresholdatrecall'], [0.3]))
115115
def testObjectDetectionMetrics(self, eval_config, name_list,
116116
expected_results):
117117

tensorflow_model_analysis/metrics/object_detection_confusion_matrix_plot_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from apache_beam.testing import util
1919
import numpy as np
2020
import tensorflow_model_analysis as tfma
21+
from tensorflow_model_analysis.proto import config_pb2
2122
from tensorflow_model_analysis.metrics import metric_types
2223
from tensorflow_model_analysis.utils import test_util
2324

@@ -45,7 +46,7 @@ def testConfusionMatrixPlot(self):
4546
'"max_num_detections":100, "name":"iou0.5"'
4647
}
4748
}
48-
""", tfma.EvalConfig())
49+
""", config_pb2.EvalConfig())
4950
extracts = [
5051
# The match at iou_threshold = 0.5 is
5152
# gt_matches: [[0]] dt_matches: [[0, -1]]

0 commit comments

Comments
 (0)