Skip to content
This repository was archived by the owner on Jul 31, 2023. It is now read-only.

Commit 3f26f8a

Browse files
authored
Merge pull request #23 from google/dev
Merging Dev for release 0.1.1
2 parents f989f23 + d696fae commit 3f26f8a

File tree

9 files changed

+72
-32
lines changed

9 files changed

+72
-32
lines changed

README.md

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,22 @@ df.tensorflow.to_tfr(output_dir='gs://my/bucket')
3535

3636
##### Running on Cloud Dataflow
3737

38+
Google Cloud Platform Dataflow workers need to be supplied with the tfrecorder
39+
package that you would like to run remotely. To do so first download or build
40+
the package (a python wheel file) and then specify the path the the file when
41+
tfrecorder is called.
42+
43+
Step 1: Download or create the wheel file.
44+
45+
To download the wheel from pip:
46+
`pip download tfrecorder --no-deps`
47+
48+
To build from source/git:
49+
`python setup.py sdist`
50+
51+
Step 2:
52+
Specify the project, region, and path to the tfrecorder wheel for remote execution.
53+
3854
```python
3955
import pandas as pd
4056
import tfrecorder
@@ -44,9 +60,11 @@ df.tensorflow.to_tfr(
4460
output_dir='gs://my/bucket',
4561
runner='DataFlowRunner',
4662
project='my-project',
47-
region='us-central1')
63+
region='us-central1'
64+
tfrecorder_wheel='/path/to/my/tfrecorder.whl')
4865
```
4966

67+
5068
#### From CSV
5169

5270
Using Python interpreter:

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@ pylint >= 2.5.3
99
fire >= 0.3.1
1010
jupyter >= 1.0.0
1111
tensorflow >= 2.2.0
12-
pyarrow < 0.17
13-
frozendict >= 1.2
12+
pyarrow >= 0.17
13+
frozendict >= 1.2

tfrecorder/accessor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def to_tfr(
4343
runner: str = 'DirectRunner',
4444
project: Optional[str] = None,
4545
region: Optional[str] = None,
46+
tfrecorder_wheel: Optional[str] = None,
4647
dataflow_options: Union[Dict[str, Any], None] = None,
4748
job_label: str = 'to-tfr',
4849
compression: Optional[str] = 'gzip',
@@ -66,6 +67,9 @@ def to_tfr(
6667
runner: Beam runner. Can be DirectRunner or DataFlowRunner.
6768
project: GCP project name (Required if DataFlowRunner).
6869
region: GCP region name (Required if DataFlowRunner).
70+
tfrecorder_wheel: Path to the tfrecorder wheel DataFlow will run.
71+
(create with 'python setup.py sdist' or
72+
'pip download tfrecorder --no-deps')
6973
dataflow_options: Optional dictionary containing DataFlow options.
7074
job_label: User supplied description for the beam job name.
7175
compression: Can be 'gzip' or None for no compression.
@@ -84,6 +88,7 @@ def to_tfr(
8488
runner=runner,
8589
project=project,
8690
region=region,
91+
tfrecorder_wheel=tfrecorder_wheel,
8792
dataflow_options=dataflow_options,
8893
job_label=job_label,
8994
compression=compression,

tfrecorder/beam_image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ def load(image_uri):
6565
try:
6666
with tf.io.gfile.GFile(image_uri, 'rb') as f:
6767
return Image.open(f)
68-
except tf.python.framework.errors_impl.NotFoundError:
69-
raise OSError('File {} was not found.'.format(image_uri))
68+
except tf.python.framework.errors_impl.NotFoundError as e:
69+
raise OSError('File {} was not found.'.format(image_uri)) from e
7070

7171

7272
# pylint: disable=abstract-method

tfrecorder/beam_pipeline.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,6 @@
3535
from tfrecorder import constants
3636

3737

38-
def _get_setup_py_filepath() -> str:
39-
"""Returns the file path to the setup.py file.
40-
41-
The location of the setup.py file is needed to run Dataflow jobs.
42-
"""
43-
44-
return os.path.join(
45-
os.path.dirname(os.path.abspath(__file__)), '..', 'setup.py')
46-
47-
4838
def _get_job_name(job_label: str = None) -> str:
4939
"""Returns Beam runner job name.
5040
@@ -76,6 +66,7 @@ def _get_pipeline_options(
7666
job_dir: str,
7767
project: str,
7868
region: str,
69+
tfrecorder_wheel: str,
7970
dataflow_options: Union[Dict[str, Any], None]
8071
) -> beam.pipeline.PipelineOptions:
8172
"""Returns Beam pipeline options."""
@@ -95,7 +86,7 @@ def _get_pipeline_options(
9586
if region:
9687
options_dict['region'] = region
9788
if runner == 'DataflowRunner':
98-
options_dict['setup_file'] = _get_setup_py_filepath()
89+
options_dict['extra_packages'] = tfrecorder_wheel
9990
if dataflow_options:
10091
options_dict.update(dataflow_options)
10192

@@ -199,6 +190,7 @@ def build_pipeline(
199190
output_dir: str,
200191
compression: str,
201192
num_shards: int,
193+
tfrecorder_wheel: str,
202194
dataflow_options: dict,
203195
integer_label: bool) -> beam.Pipeline:
204196
"""Runs TFRecorder Beam Pipeline.
@@ -212,6 +204,7 @@ def build_pipeline(
212204
output_dir: GCS or Local Path for output.
213205
compression: gzip or None.
214206
num_shards: Number of shards.
207+
tfrecorder_wheel: Path to TFRecorder wheel for DataFlow
215208
dataflow_options: Dataflow Runner Options (optional)
216209
integer_label: Flags if label is already an integer.
217210
@@ -229,6 +222,7 @@ def build_pipeline(
229222
job_dir,
230223
project,
231224
region,
225+
tfrecorder_wheel,
232226
dataflow_options)
233227

234228
#with beam.Pipeline(runner, options=options) as p:

tfrecorder/beam_pipeline_test.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
"""Tests for beam_pipeline."""
1818

19-
import os
2019
import unittest
2120
from unittest import mock
2221

@@ -78,12 +77,6 @@ def test_partition_fn(self):
7877
index, i,
7978
'{} should be index {} but was index {}'.format(part, i, index))
8079

81-
def test_get_setup_py_filepath(self):
82-
"""Tests `_get_setup_py_filepath`."""
83-
filepath = beam_pipeline._get_setup_py_filepath()
84-
self.assertTrue(os.path.isfile(filepath))
85-
self.assertTrue(os.path.isabs(filepath))
86-
8780

8881
if __name__ == '__main__':
8982
unittest.main()

tfrecorder/client.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def _validate_runner(
5454
df: pd.DataFrame,
5555
runner: str,
5656
project: str,
57-
region: str):
57+
region: str,
58+
tfrecorder_wheel: str):
5859
"""Validates an appropriate beam runner is chosen."""
5960
if runner not in ['DataflowRunner', 'DirectRunner']:
6061
raise AttributeError('Runner {} is not supported.'.format(runner))
@@ -70,6 +71,9 @@ def _validate_runner(
7071
'DataflowRunner requires valid `project` and `region` to be specified.'
7172
'The `project` is {} and `region` is {}'.format(project, region))
7273

74+
if (runner == 'DataflowRunner') & (not tfrecorder_wheel):
75+
raise AttributeError(
76+
'DataflowRunner requires a tfrecorder whl file for remote execution.')
7377
# def read_image_directory(dirpath) -> pd.DataFrame:
7478
# """Reads image data from a directory into a Pandas DataFrame."""
7579
#
@@ -164,6 +168,7 @@ def create_tfrecords(
164168
runner: str = 'DirectRunner',
165169
project: Optional[str] = None,
166170
region: Optional[str] = None,
171+
tfrecorder_wheel: Optional[str] = None,
167172
dataflow_options: Optional[Dict[str, Any]] = None,
168173
job_label: str = 'create-tfrecords',
169174
compression: Optional[str] = 'gzip',
@@ -190,6 +195,7 @@ def create_tfrecords(
190195
runner: Beam runner. Can be 'DirectRunner' or 'DataFlowRunner'
191196
project: GCP project name (Required if DataflowRunner)
192197
region: GCP region name (Required if DataflowRunner)
198+
tfrecorder_wheel: Required for GCP Runs, path to the tfrecorder whl.
193199
dataflow_options: Options dict for DataflowRunner
194200
job_label: User supplied description for the Beam job name.
195201
compression: Can be 'gzip' or None for no compression.
@@ -206,7 +212,7 @@ def create_tfrecords(
206212
df = to_dataframe(input_data, header, names)
207213

208214
_validate_data(df)
209-
_validate_runner(df, runner, project, region)
215+
_validate_runner(df, runner, project, region, tfrecorder_wheel)
210216

211217
logfile = os.path.join('/tmp', constants.LOGFILE)
212218
_configure_logging(logfile)
@@ -222,6 +228,7 @@ def create_tfrecords(
222228
output_dir=output_dir,
223229
compression=compression,
224230
num_shards=num_shards,
231+
tfrecorder_wheel=tfrecorder_wheel,
225232
dataflow_options=dataflow_options,
226233
integer_label=integer_label)
227234

tfrecorder/client_test.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def setUp(self):
3838
self.test_df = test_utils.get_test_df()
3939
self.test_region = 'us-central1'
4040
self.test_project = 'foo'
41+
self.test_wheel = '/my/path/wheel.whl'
4142

4243
@mock.patch('tfrecorder.client.beam_pipeline')
4344
def test_create_tfrecords_direct_runner(self, mock_beam):
@@ -71,7 +72,8 @@ def test_create_tfrecords_dataflow_runner(self, mock_beam):
7172
runner='DataflowRunner',
7273
output_dir=outdir,
7374
region=self.test_region,
74-
project=self.test_project)
75+
project=self.test_project,
76+
tfrecorder_wheel=self.test_wheel)
7577
self.assertEqual(r, expected)
7678

7779

@@ -84,6 +86,7 @@ def setUp(self):
8486
self.test_df = test_utils.get_test_df()
8587
self.test_region = 'us-central1'
8688
self.test_project = 'foo'
89+
self.test_wheel = '/my/path/wheel.whl'
8790

8891
def test_valid_dataframe(self):
8992
"""Tests valid DataFrame input."""
@@ -126,7 +129,8 @@ def test_valid_runner(self):
126129
self.test_df,
127130
runner='DirectRunner',
128131
project=self.test_project,
129-
region=self.test_region))
132+
region=self.test_region,
133+
tfrecorder_wheel=None))
130134

131135
def test_invalid_runner(self):
132136
"""Tests invalid runner."""
@@ -135,7 +139,8 @@ def test_invalid_runner(self):
135139
self.test_df,
136140
runner='FooRunner',
137141
project=self.test_project,
138-
region=self.test_region)
142+
region=self.test_region,
143+
tfrecorder_wheel=None)
139144

140145
def test_local_path_with_dataflow_runner(self):
141146
"""Tests DataflowRunner conflict with local path."""
@@ -144,7 +149,8 @@ def test_local_path_with_dataflow_runner(self):
144149
self.df_test,
145150
runner='DataflowRunner',
146151
project=self.test_project,
147-
region=self.test_region)
152+
region=self.test_region,
153+
tfrecorder_wheel=self.test_wheel)
148154

149155
def test_gcs_path_with_dataflow_runner(self):
150156
"""Tests DataflowRunner with GCS path."""
@@ -155,7 +161,8 @@ def test_gcs_path_with_dataflow_runner(self):
155161
df2,
156162
runner='DataflowRunner',
157163
project=self.test_project,
158-
region=self.test_region))
164+
region=self.test_region,
165+
tfrecorder_wheel=self.test_wheel))
159166

160167
def test_gcs_path_with_dataflow_runner_missing_param(self):
161168
"""Tests DataflowRunner with missing required parameter."""
@@ -168,11 +175,27 @@ def test_gcs_path_with_dataflow_runner_missing_param(self):
168175
df2,
169176
runner='DataflowRunner',
170177
project=p,
171-
region=r)
178+
region=r,
179+
tfrecorder_wheel=self.test_wheel)
172180
self.assertTrue('DataflowRunner requires valid `project` and `region`'
173181
in repr(context.exception))
174182

175183

184+
def test_gcs_path_with_dataflow_runner_missing_wheel(self):
185+
"""Tests DataflowRunner with missing required whl path."""
186+
df2 = self.test_df.copy()
187+
df2[constants.IMAGE_URI_KEY] = 'gs://' + df2[constants.IMAGE_URI_KEY]
188+
with self.assertRaises(AttributeError) as context:
189+
client._validate_runner(
190+
df2,
191+
runner='DataflowRunner',
192+
project=self.test_project,
193+
region=self.test_region,
194+
tfrecorder_wheel=None)
195+
self.assertTrue('requires a tfrecorder whl file for remote execution.'
196+
in repr(context.exception))
197+
198+
176199
def _make_csv_tempfile(data: List[List[str]]) -> tempfile.NamedTemporaryFile:
177200
"""Returns `NamedTemporaryFile` representing an image CSV."""
178201

tfrecorder/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,4 @@ def copy_logfile_to_gcs(logfile: str, output_dir: str):
3939
gcs_logfile.write(log)
4040
except FileNotFoundError as e:
4141
raise FileNotFoundError("Unable to copy log file {} to gcs.".format(
42-
e.filename))
42+
e.filename)) from e

0 commit comments

Comments
 (0)