Skip to content

Commit a400f08

Browse files
Updated the internal links with external links in customize_input_pipeline.md
and In earlier read_custom_datasets.md, the format is misaligned. Now it is modified. PiperOrigin-RevId: 639599840
1 parent 15cafce commit a400f08

File tree

2 files changed

+394
-8
lines changed

2 files changed

+394
-8
lines changed
Lines changed: 370 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,370 @@
1+
# Customize Input Pipeline
2+
3+
4+
5+
6+
7+
8+
## Overview
9+
10+
11+
A task is a class that encapsulates the logic of loading data, building models,
12+
performing one-step training and validation, etc. It connects all components
13+
together and is called by the base
14+
[Trainer](https://github.com/tensorflow/models/blob/master/official/core/base_trainer.py).
15+
You can create your own task by inheriting from base
16+
[Task](https://github.com/tensorflow/models/blob/master/official/core/base_task.py),
17+
or from one of the
18+
[tasks](https://github.com/tensorflow/models/tree/master/official/vision/tasks)
19+
we already defined, if most of the operations can be reused. An `ExampleTask`
20+
inheriting from
21+
[ImageClassificationTask](https://github.com/tensorflow/models/blob/master/official/vision/tasks/image_classification.py#L31)
22+
can be found
23+
[here](https://github.com/tensorflow/models/blob/master/official/vision/examples/starter/example_task.py).
24+
25+
26+
In a task class, the `build_inputs` method is responsible for building the input
27+
pipeline for training and evaluation. Specifically, it will instantiate a
28+
Decoder object and a Parser object, which are used to create an `InputReader`
29+
that will generate a `tf.data.Dataset` object.
30+
31+
32+
Here's an example code snippet that demonstrates how to create a custom
33+
`build_inputs` method:
34+
35+
36+
```python
37+
def build_inputs(
38+
self,
39+
params: exp_cfg.DataConfig,
40+
input_context: Optional[tf.distribute.InputContext] = None
41+
) -> tf.data.Dataset:
42+
....
43+
44+
45+
decoder = sample_input.Decoder()
46+
parser = sample_input.Parser(
47+
output_size=..., num_classes=...)
48+
reader = input_reader_factory.input_reader_generator(
49+
params,
50+
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
51+
decoder_fn=decoder.decode,
52+
parser_fn=parser.parse_fn(params.is_training))
53+
....
54+
55+
56+
dataset = reader.read(input_context=input_context)
57+
return dataset
58+
```
59+
60+
61+
The class being responsible for building the input pipeline is
62+
[InputReader](https://github.com/tensorflow/models/blob/b1a7752c5137822a32bd0dd70a0cb96e807ea411/official/core/input_reader.py#L214)
63+
with interface
64+
65+
```python
66+
class InputReader:
67+
"""Input reader that returns a tf.data.Dataset instance."""
68+
69+
def __init__(
70+
self,
71+
params: cfg.DataConfig,
72+
dataset_fn=tf.data.TFRecordDataset,
73+
decoder_fn: Optional[Callable[..., Any]] = None,
74+
combine_fn: Optional[Callable[..., Any]] = None,
75+
sample_fn: Optional[Callable[..., Any]] = None,
76+
parser_fn: Optional[Callable[..., Any]] = None,
77+
filter_fn: Optional[Callable[..., tf.Tensor]] = None,
78+
transform_and_batch_fn: Optional[
79+
Callable[
80+
[tf.data.Dataset, Optional[tf.distribute.InputContext]],
81+
tf.data.Dataset,
82+
]
83+
] = None,
84+
postprocess_fn: Optional[Callable[..., Any]] = None,
85+
):
86+
....
87+
88+
def read(self,
89+
input_context: Optional[tf.distribute.InputContext] = None,
90+
dataset: Optional[tf.data.Dataset] = None) -> tf.data.Dataset:
91+
"""Generates a tf.data.Dataset object."""
92+
if dataset is None:
93+
dataset = self._read_data_source(self._matched_files, self._dataset_fn,
94+
input_context)
95+
dataset = self._decode_and_parse_dataset(dataset, self._global_batch_size,
96+
input_context)
97+
dataset = _maybe_map_fn(dataset, self._postprocess_fn)
98+
if not (self._enable_shared_tf_data_service_between_parallel_trainers and
99+
self._apply_tf_data_service_before_batching):
100+
dataset = self._maybe_apply_data_service(dataset, input_context)
101+
102+
if self._deterministic is not None:
103+
options = tf.data.Options()
104+
options.deterministic = self._deterministic
105+
dataset = dataset.with_options(options)
106+
if self._autotune_algorithm:
107+
options = tf.data.Options()
108+
options.autotune.autotune_algorithm = (
109+
tf.data.experimental.AutotuneAlgorithm[self._autotune_algorithm])
110+
dataset = dataset.with_options(options)
111+
return dataset.prefetch(self._prefetch_buffer_size)
112+
```
113+
114+
Therefore, customizing the input pipeline is equivalent to having customized
115+
versions of `dataset_fn`, `decoder_fn`, etc. The execution order is generally
116+
as:
117+
118+
```
119+
dataset_fn -> decoder_fn -> combine_fn -> parser_fn -> filter_fn ->
120+
transform_and_batch_fn -> postprocess_fn
121+
```
122+
123+
The `transform_and_batch_fn` is an optional function that merges multiple
124+
examples into a batch and its default behavior to `dataset.batch` if not
125+
specified. In this workflow, the functions before `transform_and_batch_fn`, e.g.
126+
`dataset_fn`, `decoder_fn`, consume tensors without the batch dimension, while
127+
`postprocess_fn` will consume tensors with the batch dimension.
128+
129+
We have essentially covered
130+
[decoder_fn](https://github.com/tensorflow/models/blob/master/official/vision/docs/read_custom_datasets.md#decoder),
131+
and `parser_fn` is another very important one that takes the decoded raw tensors
132+
dict and parses them into a dictionary of tensors that can be consumed by the
133+
model. It will be executed after decoder_fn.
134+
135+
It is also worth noting that optimizing of the input pipeline through
136+
batching, shuffling and prefetching is also implemented in this class.
137+
138+
## Parser
139+
140+
A custom data loader can also be useful if you want to take advantage of
141+
features such as data augmentation.
142+
143+
Customizing preprocessing is useful because it allows the user to tailor the
144+
preprocessing steps to suit the specific requirements of the task. While there
145+
are standard preprocessing techniques that are commonly used, different
146+
applications may require different preprocessing steps. Additionally, custom
147+
preprocessing can also improve the efficiency and accuracy of the model by
148+
removing unnecessary steps, reducing computational resources or adding steps
149+
that are important to the specific task being addressed.
150+
151+
For example, tasks such as object detection or segmentation may require
152+
additional preprocessing steps such as resizing, cropping, or data augmentation
153+
to improve the robustness of the model. Below are some essential steps to
154+
customize a parser.
155+
156+
### Instructions
157+
158+
* **Create a Subclass**
159+
<br>
160+
161+
<dd><dl>
162+
163+
Like Decoder, create `class Parser(parser.Parser)` in the same file.The
164+
`Parser` class should be a childclass of the
165+
[generic parser interface](https://github.com/tensorflow/models/blob/master/official/vision/dataloaders/parser.py)
166+
and must implement all the abstract methods. It should have the implementation
167+
of abstract methods `_parse_train_data` and `_parse_eval_data`, to generate
168+
images and labels for model training and evaluation respectively. The below example
169+
takes only two arguments but one can freely add as many arguments as needed.
170+
171+
```python
172+
class Parser(parser.Parser):
173+
174+
def __init__(self, output_size: List[int], num_classes: float):
175+
176+
self._output_size = output_size
177+
self._num_classes = num_classes
178+
self._dtype = tf.float32
179+
180+
....
181+
```
182+
183+
<br>
184+
185+
Refer to the data parser and processing [class](https://github.com/tensorflow/models/blob/master/official/vision/dataloaders/maskrcnn_input.py) for Mask R-CNN for more complex cases. The class has multiple parameters related to data augmentation, masking, anchor boxes, data type of output image and more.
186+
187+
</dd></dl>
188+
189+
<br>
190+
191+
* **Complete Abstract Methods**<br>
192+
193+
<dd><dl>
194+
195+
To define your own Parser, the user should override abstract functions
196+
[_parse_train_data](https://github.com/tensorflow/models/blob/b1a7752c5137822a32bd0dd70a0cb96e807ea411/official/vision/dataloaders/parser.py#L26)
197+
and
198+
[_parse_eval_data](https://github.com/tensorflow/models/blob/b1a7752c5137822a32bd0dd70a0cb96e807ea411/official/vision/dataloaders/parser.py#L39)
199+
of the
200+
[parser](https://github.com/tensorflow/models/blob/master/official/vision/dataloaders/parser.py)
201+
interface in the subclass, where decoded tensors are parsed with pre-processing
202+
steps for training and evaluation respectively. The output from the two
203+
functions can be any structure like a tuple, list or dictionary.
204+
205+
```python
206+
@abc.abstractmethod
207+
def _parse_train_data(self, decoded_tensors):
208+
"""Generates images and labels that are usable for model training.
209+
210+
Args:
211+
decoded_tensors: a dict of Tensors produced by the decoder.
212+
213+
Returns:
214+
images: the image tensor.
215+
labels: a dict of Tensors that contains labels.
216+
"""
217+
pass
218+
219+
@abc.abstractmethod
220+
def _parse_eval_data(self, decoded_tensors):
221+
"""Generates images and labels that are usable for model evaluation.
222+
223+
Args:
224+
decoded_tensors: a dict of Tensors produced by the decoder.
225+
226+
Returns:
227+
images: the image tensor.
228+
labels: a dict of Tensors that contains labels.
229+
"""
230+
pass
231+
232+
```
233+
234+
The input of `_parse_train_data` and `_parse_eval_data` is a dict of Tensors
235+
produced by the decoder; the output of these two functions is typically a tuple
236+
of (processe_image, processed_label). The user may perform any processing steps
237+
in these two functions as long as the interface is aligned. Note that the
238+
processing steps in `_parse_train_data` and `_parse_eval_data` are typically
239+
different since data augmentation is usually only applied to training. For
240+
Example, refer to the
241+
[Data parser](https://github.com/tensorflow/models/blob/b1a7752c5137822a32bd0dd70a0cb96e807ea411/official/vision/dataloaders/classification_input.py#L166)
242+
and processing steps for classification. We can observe that
243+
244+
<dd><dl>
245+
246+
-For `_parse_train_data`, the following steps are performed</dd></dl>
247+
248+
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;- &nbsp;&nbsp;&nbsp;Image decoding<br>
249+
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;- &nbsp;&nbsp;&nbsp;Random cropping<br>
250+
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;- &nbsp;&nbsp;&nbsp;Random flipping<br>
251+
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;- &nbsp;&nbsp;&nbsp;Color jittering<br>
252+
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;- &nbsp;&nbsp;&nbsp;Image resizing<br>
253+
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;- &nbsp;&nbsp;&nbsp;Auto-augmentation with autoaug, randaug etc.<br>
254+
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;- &nbsp;&nbsp;&nbsp;Image normalization<br>
255+
256+
<dd><dl><dd><dl>
257+
258+
-For `_parse_eval_data`, the following steps are performed</dd></dl></dd></dl>
259+
260+
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;- &nbsp;&nbsp;&nbsp;
261+
Image decoding<br>
262+
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;- &nbsp;&nbsp;&nbsp;Center cropping<br>
263+
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;- &nbsp;&nbsp;&nbsp;Image resizing<br>
264+
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;- &nbsp;&nbsp;&nbsp;Image normalization<br>
265+
266+
</dd></dl>
267+
268+
**Additional Methods**
269+
270+
The subclass (say sample_input.py) must include implementations for all of the
271+
abstract methods defined in the Interface
272+
[Decoder](https://github.com/tensorflow/models/blob/master/official/vision/dataloaders/decoder.py)
273+
and
274+
[Parser](https://github.com/tensorflow/models/blob/master/official/vision/dataloaders/parser.py)
275+
, as well as any additional methods that are necessary for the subclass's
276+
functionality.
277+
278+
For Example, In
279+
[object detection](https://github.com/tensorflow/models/blob/b1a7752c5137822a32bd0dd70a0cb96e807ea411/official/vision/dataloaders/tf_example_decoder.py#L72),
280+
the decoder will take the serialized example and output a dictionary of tensors
281+
with multiple fields that process and analyze to detect objects and determine
282+
their location and orientation in the image. Separate methods for each of the
283+
above fields can make the code easier to read and maintain, especially when the
284+
class contains a large number of methods.
285+
286+
Refer
287+
[Data parser](https://github.com/tensorflow/models/blob/master/official/vision/dataloaders/retinanet_input.py)
288+
for Object Detection here.
289+
290+
### Example
291+
292+
Creating a Parser is an optional step and it varies with the use case. Below are
293+
some use cases where we have included the Decoder and Parser based on the
294+
requirements.
295+
296+
Use case | Decoder/Parser |
297+
-------------------------------------------------------------------------------------------------------------------------------------------------------- | ----
298+
[Classification](https://github.com/tensorflow/models/blob/master/official/vision/dataloaders/classification_input.py) | Both Decoder and Parser
299+
[Segmentation](https://github.com/tensorflow/models/blob/master/official/vision/dataloaders/retinanet_input.py) | Only Parser
300+
301+
## Input Pipeline
302+
303+
Decoder and Parser discussed previously define how to decode and parse per data
304+
point e.g. an image. However a complete input pipeline would need to handle
305+
reading data from files in a distributed system, applying random perturbations,
306+
batching etc. You may find more details about these concepts
307+
[here](https://www.tensorflow.org/guide/data_performance#optimize_performance).
308+
309+
We have established a well tuned input pipeline as defined in the [InputReader](https://github.com/tensorflow/models/blob/b1a7752c5137822a32bd0dd70a0cb96e807ea411/official/core/input_reader.py#L214) class, such that the user won’t need to modify it in most cases. The input pipeline roughly follows<br>
310+
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;- &nbsp;&nbsp;&nbsp;Shuffling the files<br>
311+
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;- &nbsp;&nbsp;&nbsp;Decoding<br>
312+
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;- &nbsp;&nbsp;&nbsp;Parsing<br>
313+
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;- &nbsp;&nbsp;&nbsp;Caching<br>
314+
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;- &nbsp;&nbsp;&nbsp;If training: repeat and shuffle<br>
315+
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;- &nbsp;&nbsp;&nbsp;Batching<br>
316+
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;- &nbsp;&nbsp;&nbsp;Prefetching<br>
317+
318+
For the rest of this section, we will discuss one particular use case that
319+
requires the modification of the typical pipeline by maybe creating a subclass
320+
of the
321+
[InputReader](https://github.com/tensorflow/models/blob/b1a7752c5137822a32bd0dd70a0cb96e807ea411/official/core/input_reader.py#L214).
322+
323+
### Combines multiple datasets
324+
325+
Create a custom InputReader by subclassing
326+
[InputReader](https://github.com/tensorflow/models/blob/b1a7752c5137822a32bd0dd70a0cb96e807ea411/official/core/input_reader.py#L214)
327+
interface. Custom InputReader class allows the user to combine multiple
328+
datasets, helps in mixing a labeled and pseudo-labeled dataset etc. The business
329+
logic is implemented in the `read()` method which finally generates a
330+
`tf.data.Dataset` object.
331+
332+
The exact implementation of an InputReader can vary depending on the specific
333+
requirements of your task and the type of input data you're working with, data
334+
format, and preprocessing requirements.
335+
336+
Here is an example of how to create a custom InputReader by subclassing
337+
[InputReader](https://github.com/tensorflow/models/blob/b1a7752c5137822a32bd0dd70a0cb96e807ea411/official/core/input_reader.py#L214)
338+
interface:
339+
340+
```python
341+
class CustomInputReader(input_reader.InputReader):
342+
343+
def __init__(self,
344+
params: cfg.DataConfig,
345+
dataset_fn=tf.data.TFRecordDataset,
346+
pseudo_label_dataset_fn=tf.data.TFRecordDataset,
347+
....):
348+
349+
def read(
350+
self,
351+
input_context: Optional[tf.distribute.InputContext] = None
352+
) -> tf.data.Dataset:
353+
354+
355+
labeled_dataset = ....
356+
pseudo_labeled_dataset = ....
357+
dataset_concat = tf.data.Dataset.zip(
358+
(labeled_dataset, pseudo_labeled_dataset))
359+
....
360+
361+
return dataset_concat.prefetch(tf.data.experimental.AUTOTUNE)
362+
363+
```
364+
365+
### Example
366+
367+
Refer to the
368+
[InputReader](https://github.com/tensorflow/models/blob/b1a7752c5137822a32bd0dd70a0cb96e807ea411/official/vision/dataloaders/input_reader.py#L124)
369+
for vision in TFM. The `CombinationDatasetInputReader` class mixes a labeled and
370+
pseudo-labeled dataset and returns a `tf.data.Dataset` instance.

0 commit comments

Comments
 (0)