@@ -38,6 +38,7 @@ def setUp(self):
38
38
self .test_df = test_utils .get_test_df ()
39
39
self .test_region = 'us-central1'
40
40
self .test_project = 'foo'
41
+ self .test_wheel = '/my/path/wheel.whl'
41
42
42
43
@mock .patch ('tfrecorder.client.beam_pipeline' )
43
44
def test_create_tfrecords_direct_runner (self , mock_beam ):
@@ -71,7 +72,8 @@ def test_create_tfrecords_dataflow_runner(self, mock_beam):
71
72
runner = 'DataflowRunner' ,
72
73
output_dir = outdir ,
73
74
region = self .test_region ,
74
- project = self .test_project )
75
+ project = self .test_project ,
76
+ tfrecorder_wheel = self .test_wheel )
75
77
self .assertEqual (r , expected )
76
78
77
79
@@ -84,6 +86,7 @@ def setUp(self):
84
86
self .test_df = test_utils .get_test_df ()
85
87
self .test_region = 'us-central1'
86
88
self .test_project = 'foo'
89
+ self .test_wheel = '/my/path/wheel.whl'
87
90
88
91
def test_valid_dataframe (self ):
89
92
"""Tests valid DataFrame input."""
@@ -126,7 +129,8 @@ def test_valid_runner(self):
126
129
self .test_df ,
127
130
runner = 'DirectRunner' ,
128
131
project = self .test_project ,
129
- region = self .test_region ))
132
+ region = self .test_region ,
133
+ tfrecorder_wheel = None ))
130
134
131
135
def test_invalid_runner (self ):
132
136
"""Tests invalid runner."""
@@ -135,7 +139,8 @@ def test_invalid_runner(self):
135
139
self .test_df ,
136
140
runner = 'FooRunner' ,
137
141
project = self .test_project ,
138
- region = self .test_region )
142
+ region = self .test_region ,
143
+ tfrecorder_wheel = None )
139
144
140
145
def test_local_path_with_dataflow_runner (self ):
141
146
"""Tests DataflowRunner conflict with local path."""
@@ -144,7 +149,8 @@ def test_local_path_with_dataflow_runner(self):
144
149
self .df_test ,
145
150
runner = 'DataflowRunner' ,
146
151
project = self .test_project ,
147
- region = self .test_region )
152
+ region = self .test_region ,
153
+ tfrecorder_wheel = self .test_wheel )
148
154
149
155
def test_gcs_path_with_dataflow_runner (self ):
150
156
"""Tests DataflowRunner with GCS path."""
@@ -155,7 +161,8 @@ def test_gcs_path_with_dataflow_runner(self):
155
161
df2 ,
156
162
runner = 'DataflowRunner' ,
157
163
project = self .test_project ,
158
- region = self .test_region ))
164
+ region = self .test_region ,
165
+ tfrecorder_wheel = self .test_wheel ))
159
166
160
167
def test_gcs_path_with_dataflow_runner_missing_param (self ):
161
168
"""Tests DataflowRunner with missing required parameter."""
@@ -168,11 +175,27 @@ def test_gcs_path_with_dataflow_runner_missing_param(self):
168
175
df2 ,
169
176
runner = 'DataflowRunner' ,
170
177
project = p ,
171
- region = r )
178
+ region = r ,
179
+ tfrecorder_wheel = self .test_wheel )
172
180
self .assertTrue ('DataflowRunner requires valid `project` and `region`'
173
181
in repr (context .exception ))
174
182
175
183
184
+ def test_gcs_path_with_dataflow_runner_missing_wheel (self ):
185
+ """Tests DataflowRunner with missing required whl path."""
186
+ df2 = self .test_df .copy ()
187
+ df2 [constants .IMAGE_URI_KEY ] = 'gs://' + df2 [constants .IMAGE_URI_KEY ]
188
+ with self .assertRaises (AttributeError ) as context :
189
+ client ._validate_runner (
190
+ df2 ,
191
+ runner = 'DataflowRunner' ,
192
+ project = self .test_project ,
193
+ region = self .test_region ,
194
+ tfrecorder_wheel = None )
195
+ self .assertTrue ('requires a tfrecorder whl file for remote execution.'
196
+ in repr (context .exception ))
197
+
198
+
176
199
def _make_csv_tempfile (data : List [List [str ]]) -> tempfile .NamedTemporaryFile :
177
200
"""Returns `NamedTemporaryFile` representing an image CSV."""
178
201
0 commit comments