Skip to content

Commit 294512f

Browse files
author
The TensorFlow Datasets Authors
committed
Support huggingface Video feature
PiperOrigin-RevId: 770677201
1 parent 72f3778 commit 294512f

File tree

3 files changed

+29
-9
lines changed

3 files changed

+29
-9
lines changed

tensorflow_datasets/core/features/video_feature.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class Video(sequence_feature.Sequence):
9393

9494
def __init__(
9595
self,
96-
shape: Sequence[Optional[int]],
96+
shape: Sequence[Optional[int]] | None = None,
9797
encoding_format: str = 'png',
9898
ffmpeg_extra_args: Sequence[str] = (),
9999
use_colormap: bool = False,
@@ -103,8 +103,8 @@ def __init__(
103103
"""Initializes the connector.
104104
105105
Args:
106-
shape: tuple of ints, the shape of the video (num_frames, height, width,
107-
channels), where channels is 1 or 3.
106+
shape: The shape of the video (num_frames, height, width, channels), where
107+
channels is 1 or 3.
108108
encoding_format: The video is stored as a sequence of encoded images. You
109109
can use any encoding format supported by image_feature.Feature.
110110
ffmpeg_extra_args: A sequence of additional args to be passed to the
@@ -121,19 +121,22 @@ def __init__(
121121
ValueError: If the shape is invalid
122122
"""
123123
dtype = tf.dtypes.as_dtype(dtype)
124-
shape = tuple(shape)
125-
if len(shape) != 4:
126-
raise ValueError('Video shape should be of rank 4')
124+
frame_shape = None
125+
if shape:
126+
shape = tuple(shape)
127+
if len(shape) != 4:
128+
raise ValueError('Video shape should be of rank 4')
129+
frame_shape = shape[1:]
127130
self._encoding_format = encoding_format
128131
self._extra_ffmpeg_args = list(ffmpeg_extra_args or [])
129132
super(Video, self).__init__(
130133
image_feature.Image(
131-
shape=shape[1:],
134+
shape=frame_shape,
132135
dtype=dtype,
133136
encoding_format=encoding_format,
134137
use_colormap=use_colormap,
135138
),
136-
length=shape[0],
139+
length=shape[0] if shape else None,
137140
)
138141

139142
def _ffmpeg_decode(self, path_or_fobj):

tensorflow_datasets/core/features/video_feature_test.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,22 @@ def test_video_numpy(self):
4848
test_attributes=dict(_encoding_format='png', _extra_ffmpeg_args=[]),
4949
)
5050

51+
def test_video_with_none_shape(self):
52+
np_video = np.random.randint(256, size=(128, 64, 64, 3), dtype=np.uint8)
53+
54+
self.assertFeature(
55+
feature=features.Video(shape=None),
56+
shape=(None, None, None, 3),
57+
dtype=tf.uint8,
58+
tests=[
59+
testing.FeatureExpectationItem(
60+
value=np_video,
61+
expected=np_video,
62+
),
63+
],
64+
test_attributes=dict(_encoding_format='png', _extra_ffmpeg_args=[]),
65+
)
66+
5167
def test_video_concatenated_frames(self):
5268
video_shape = (None, 400, 640, 3)
5369
lsun_examples_path = os.path.join(self._test_data_path, 'lsun_examples')
@@ -119,6 +135,5 @@ def read(self, *args, **kwargs):
119135
],
120136
)
121137

122-
123138
if __name__ == '__main__':
124139
testing.test_main()

tensorflow_datasets/core/utils/huggingface_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ def convert_hf_features(hf_features) -> feature_lib.FeatureConnector:
119119
sample_rate=hf_features.sampling_rate,
120120
dtype=np.int32,
121121
)
122+
case hf_datasets.Video():
123+
return feature_lib.Video()
122124

123125
raise TypeError(f'Type {type(hf_features)} is not supported.')
124126

0 commit comments

Comments
 (0)