Skip to content

Commit 21aef92

Browse files
authored
Refactor Frame and FrameBatch into their own source file (#247)
1 parent b65882e commit 21aef92

File tree

10 files changed

+101
-72
lines changed

10 files changed

+101
-72
lines changed

docs/source/api_ref_decoders.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,4 @@ torchcodec.decoders
2020
:nosignatures:
2121
:template: dataclass.rst
2222

23-
Frame
24-
FrameBatch
2523
VideoStreamMetadata

docs/source/api_ref_torchcodec.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
.. _torchcodec:
2+
3+
===================
4+
torchcodec
5+
===================
6+
7+
.. currentmodule:: torchcodec
8+
9+
10+
.. autosummary::
11+
:toctree: generated/
12+
:nosignatures:
13+
:template: dataclass.rst
14+
15+
Frame
16+
FrameBatch

docs/source/index.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ We achieve these capabilities through:
4545
.. grid-item-card:: :octicon:`file-code;1em`
4646
API Reference
4747
:img-top: _static/img/card-background.svg
48-
:link: api_ref_decoders.html
48+
:link: api_ref_torchcodec.html
4949
:link-type: url
5050

5151
The API reference for TorchCodec
@@ -73,4 +73,5 @@ We achieve these capabilities through:
7373
:caption: API Reference
7474
:hidden:
7575

76+
api_ref_torchcodec
7677
api_ref_decoders

examples/basic_example.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
121121
# This can be achieved using the
122122
# :meth:`~torchcodec.decoders.VideoDecoder.get_frame_at` and
123123
# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_at` methods, which
124-
# will return a :class:`~torchcodec.decoders.Frame` and
125-
# :class:`~torchcodec.decoders.FrameBatch` objects respectively.
124+
# will return a :class:`~torchcodec.Frame` and
125+
# :class:`~torchcodec.FrameBatch` objects respectively.
126126

127127
last_frame = decoder.get_frame_at(len(decoder) - 1)
128128
print(f"{type(last_frame) = }")
@@ -138,12 +138,12 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
138138
plot(middle_frames.data, "Middle frames")
139139

140140
# %%
141-
# Both :class:`~torchcodec.decoders.Frame` and
142-
# :class:`~torchcodec.decoders.FrameBatch` have a ``data`` field, which contains
141+
# Both :class:`~torchcodec.Frame` and
142+
# :class:`~torchcodec.FrameBatch` have a ``data`` field, which contains
143143
# the decoded tensor data. They also have the ``pts_seconds`` and
144144
# ``duration_seconds`` fields which are single ints for
145-
# :class:`~torchcodec.decoders.Frame`, and 1-D :class:`torch.Tensor` for
146-
# :class:`~torchcodec.decoders.FrameBatch` (one value per frame in the batch).
145+
# :class:`~torchcodec.Frame`, and 1-D :class:`torch.Tensor` for
146+
# :class:`~torchcodec.FrameBatch` (one value per frame in the batch).
147147

148148
# %%
149149
# Using time-based indexing
@@ -153,7 +153,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
153153
# frames based on *when* they are displayed with
154154
# :meth:`~torchcodec.decoders.VideoDecoder.get_frame_displayed_at` and
155155
# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_displayed_at`, which
156-
# also returns :class:`~torchcodec.decoders.Frame` and :class:`~torchcodec.decoders.FrameBatch`
156+
# also returns :class:`~torchcodec.Frame` and :class:`~torchcodec.FrameBatch`
157157
# respectively.
158158

159159
frame_at_2_seconds = decoder.get_frame_displayed_at(seconds=2)

src/torchcodec/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from . import decoders, samplers # noqa # noqa
7+
# Note: usort wants to put Frame and FrameBatch after decoders and samplers,
8+
# but that results in circular import.
9+
from ._frame import Frame, FrameBatch # usort:skip # noqa
10+
from . import decoders, samplers # noqa
811

912
__version__ = "0.0.4.dev"

src/torchcodec/_frame.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import dataclasses
8+
from dataclasses import dataclass
9+
from typing import Iterable, Iterator, Union
10+
11+
from torch import Tensor
12+
13+
14+
def _frame_repr(self):
15+
# Utility to replace Frame and FrameBatch __repr__ method. This prints the
16+
# shape of the .data tensor rather than printing the (potentially very long)
17+
# data tensor itself.
18+
s = self.__class__.__name__ + ":\n"
19+
spaces = " "
20+
for field in dataclasses.fields(self):
21+
field_name = field.name
22+
field_val = getattr(self, field_name)
23+
if field_name == "data":
24+
field_name = "data (shape)"
25+
field_val = field_val.shape
26+
s += f"{spaces}{field_name}: {field_val}\n"
27+
return s
28+
29+
30+
@dataclass
31+
class Frame(Iterable):
32+
"""A single video frame with associated metadata."""
33+
34+
data: Tensor
35+
"""The frame data as (3-D ``torch.Tensor``)."""
36+
pts_seconds: float
37+
"""The :term:`pts` of the frame, in seconds (float)."""
38+
duration_seconds: float
39+
"""The duration of the frame, in seconds (float)."""
40+
41+
def __iter__(self) -> Iterator[Union[Tensor, float]]:
42+
for field in dataclasses.fields(self):
43+
yield getattr(self, field.name)
44+
45+
def __repr__(self):
46+
return _frame_repr(self)
47+
48+
49+
@dataclass
50+
class FrameBatch(Iterable):
51+
"""Multiple video frames with associated metadata."""
52+
53+
data: Tensor
54+
"""The frames data as (4-D ``torch.Tensor``)."""
55+
pts_seconds: Tensor
56+
"""The :term:`pts` of the frame, in seconds (1-D ``torch.Tensor`` of floats)."""
57+
duration_seconds: Tensor
58+
"""The duration of the frame, in seconds (1-D ``torch.Tensor`` of floats)."""
59+
60+
def __iter__(self) -> Iterator[Union[Tensor, float]]:
61+
for field in dataclasses.fields(self):
62+
yield getattr(self, field.name)
63+
64+
def __repr__(self):
65+
return _frame_repr(self)

src/torchcodec/decoders/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from ._core import VideoStreamMetadata
8-
from ._video_decoder import Frame, FrameBatch, VideoDecoder # noqa
8+
from ._video_decoder import VideoDecoder # noqa
99

1010
SimpleVideoDecoder = VideoDecoder

src/torchcodec/decoders/_video_decoder.py

Lines changed: 2 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -4,71 +4,15 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import dataclasses
87
import numbers
9-
from dataclasses import dataclass
108
from pathlib import Path
11-
from typing import Iterable, Iterator, Literal, Tuple, Union
9+
from typing import Literal, Tuple, Union
1210

1311
from torch import Tensor
1412

13+
from torchcodec import Frame, FrameBatch
1514
from torchcodec.decoders import _core as core
1615

17-
18-
def _frame_repr(self):
19-
# Utility to replace Frame and FrameBatch __repr__ method. This prints the
20-
# shape of the .data tensor rather than printing the (potentially very long)
21-
# data tensor itself.
22-
s = self.__class__.__name__ + ":\n"
23-
spaces = " "
24-
for field in dataclasses.fields(self):
25-
field_name = field.name
26-
field_val = getattr(self, field_name)
27-
if field_name == "data":
28-
field_name = "data (shape)"
29-
field_val = field_val.shape
30-
s += f"{spaces}{field_name}: {field_val}\n"
31-
return s
32-
33-
34-
@dataclass
35-
class Frame(Iterable):
36-
"""A single video frame with associated metadata."""
37-
38-
data: Tensor
39-
"""The frame data as (3-D ``torch.Tensor``)."""
40-
pts_seconds: float
41-
"""The :term:`pts` of the frame, in seconds (float)."""
42-
duration_seconds: float
43-
"""The duration of the frame, in seconds (float)."""
44-
45-
def __iter__(self) -> Iterator[Union[Tensor, float]]:
46-
for field in dataclasses.fields(self):
47-
yield getattr(self, field.name)
48-
49-
def __repr__(self):
50-
return _frame_repr(self)
51-
52-
53-
@dataclass
54-
class FrameBatch(Iterable):
55-
"""Multiple video frames with associated metadata."""
56-
57-
data: Tensor
58-
"""The frames data as (4-D ``torch.Tensor``)."""
59-
pts_seconds: Tensor
60-
"""The :term:`pts` of the frame, in seconds (1-D ``torch.Tensor`` of floats)."""
61-
duration_seconds: Tensor
62-
"""The duration of the frame, in seconds (1-D ``torch.Tensor`` of floats)."""
63-
64-
def __iter__(self) -> Iterator[Union[Tensor, float]]:
65-
for field in dataclasses.fields(self):
66-
yield getattr(self, field.name)
67-
68-
def __repr__(self):
69-
return _frame_repr(self)
70-
71-
7216
_ERROR_REPORTING_INSTRUCTIONS = """
7317
This should never happen. Please report an issue following the steps in
7418
https://github.com/pytorch/torchcodec/issues/new?assignees=&labels=&projects=&template=bug-report.yml.

src/torchcodec/samplers/_implem.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
import torch
44

5-
from torchcodec.decoders import Frame, FrameBatch, VideoDecoder
5+
from torchcodec import Frame, FrameBatch
6+
from torchcodec.decoders import VideoDecoder
67

78

89
def _validate_params(

test/samplers/test_samplers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import pytest
66
import torch
77

8-
from torchcodec.decoders import FrameBatch, VideoDecoder
8+
from torchcodec import FrameBatch
9+
from torchcodec.decoders import VideoDecoder
910
from torchcodec.samplers import clips_at_random_indices, clips_at_regular_indices
1011
from torchcodec.samplers._implem import _build_all_clips_indices, _POLICY_FUNCTIONS
1112

0 commit comments

Comments
 (0)