|
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
7 | | -import dataclasses |
8 | 7 | import numbers |
9 | | -from dataclasses import dataclass |
10 | 8 | from pathlib import Path |
11 | | -from typing import Iterable, Iterator, Literal, Tuple, Union |
| 9 | +from typing import Literal, Tuple, Union |
12 | 10 |
|
13 | 11 | from torch import Tensor |
14 | 12 |
|
| 13 | +from torchcodec import Frame, FrameBatch |
15 | 14 | from torchcodec.decoders import _core as core |
16 | 15 |
|
17 | | - |
18 | | -def _frame_repr(self): |
19 | | - # Utility to replace Frame and FrameBatch __repr__ method. This prints the |
20 | | - # shape of the .data tensor rather than printing the (potentially very long) |
21 | | - # data tensor itself. |
22 | | - s = self.__class__.__name__ + ":\n" |
23 | | - spaces = " " |
24 | | - for field in dataclasses.fields(self): |
25 | | - field_name = field.name |
26 | | - field_val = getattr(self, field_name) |
27 | | - if field_name == "data": |
28 | | - field_name = "data (shape)" |
29 | | - field_val = field_val.shape |
30 | | - s += f"{spaces}{field_name}: {field_val}\n" |
31 | | - return s |
32 | | - |
33 | | - |
34 | | -@dataclass |
35 | | -class Frame(Iterable): |
36 | | - """A single video frame with associated metadata.""" |
37 | | - |
38 | | - data: Tensor |
39 | | - """The frame data as (3-D ``torch.Tensor``).""" |
40 | | - pts_seconds: float |
41 | | - """The :term:`pts` of the frame, in seconds (float).""" |
42 | | - duration_seconds: float |
43 | | - """The duration of the frame, in seconds (float).""" |
44 | | - |
45 | | - def __iter__(self) -> Iterator[Union[Tensor, float]]: |
46 | | - for field in dataclasses.fields(self): |
47 | | - yield getattr(self, field.name) |
48 | | - |
49 | | - def __repr__(self): |
50 | | - return _frame_repr(self) |
51 | | - |
52 | | - |
53 | | -@dataclass |
54 | | -class FrameBatch(Iterable): |
55 | | - """Multiple video frames with associated metadata.""" |
56 | | - |
57 | | - data: Tensor |
58 | | - """The frames data as (4-D ``torch.Tensor``).""" |
59 | | - pts_seconds: Tensor |
60 | | - """The :term:`pts` of the frame, in seconds (1-D ``torch.Tensor`` of floats).""" |
61 | | - duration_seconds: Tensor |
62 | | - """The duration of the frame, in seconds (1-D ``torch.Tensor`` of floats).""" |
63 | | - |
64 | | - def __iter__(self) -> Iterator[Union[Tensor, float]]: |
65 | | - for field in dataclasses.fields(self): |
66 | | - yield getattr(self, field.name) |
67 | | - |
68 | | - def __repr__(self): |
69 | | - return _frame_repr(self) |
70 | | - |
71 | | - |
72 | 16 | _ERROR_REPORTING_INSTRUCTIONS = """ |
73 | 17 | This should never happen. Please report an issue following the steps in |
74 | 18 | https://github.com/pytorch/torchcodec/issues/new?assignees=&labels=&projects=&template=bug-report.yml. |
|
0 commit comments