11"""Extension of torchrec.dataset.utils.Batch to cover any dataset.
22"""
33# flake8: noqa
4- from __future__ import annotations
5- from typing import Dict
4+ from __future__ import (
5+ annotations ,
6+ )
7+
68import abc
7- from dataclasses import dataclass
89import dataclasses
10+ from collections import (
11+ UserDict ,
12+ )
13+ from dataclasses import (
14+ dataclass ,
15+ )
16+ from typing import (
17+ Any ,
18+ Dict ,
19+ List ,
20+ TypeVar ,
21+ )
922
1023import torch
11- from torchrec .streamable import Pipelineable
24+ from torchrec .streamable import (
25+ Pipelineable ,
26+ )
27+
28+ _KT = TypeVar ("_KT" ) # key type
29+ _VT = TypeVar ("_VT" ) # value type
1230
1331
1432class BatchBase (Pipelineable , abc .ABC ):
1533 @abc .abstractmethod
16- def as_dict (self ) -> Dict :
34+ def as_dict (self ) -> Dict [ str , Any ] :
1735 raise NotImplementedError
1836
19- def to (self , device : torch .device , non_blocking : bool = False ):
37+ def to (self , device : torch .device , non_blocking : bool = False ) -> BatchBase :
2038 args = {}
2139 for feature_name , feature_value in self .as_dict ().items ():
2240 args [feature_name ] = feature_value .to (device = device , non_blocking = non_blocking )
@@ -26,14 +44,14 @@ def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
2644 for feature_value in self .as_dict ().values ():
2745 feature_value .record_stream (stream )
2846
29- def pin_memory (self ):
47+ def pin_memory (self ) -> BatchBase :
3048 args = {}
3149 for feature_name , feature_value in self .as_dict ().items ():
3250 args [feature_name ] = feature_value .pin_memory ()
3351 return self .__class__ (** args )
3452
3553 def __repr__ (self ) -> str :
36- def obj2str (v ) :
54+ def obj2str (v : Any ) -> str :
3755 return f"{ v .size ()} " if hasattr (v , "size" ) else f"{ v .length_per_key ()} "
3856
3957 return "\n " .join ([f"{ k } : { obj2str (v )} ," for k , v in self .as_dict ().items ()])
@@ -52,18 +70,18 @@ def batch_size(self) -> int:
5270@dataclass
5371class DataclassBatch (BatchBase ):
5472 @classmethod
55- def feature_names (cls ):
73+ def feature_names (cls ) -> List [ str ] :
5674 return list (cls .__dataclass_fields__ .keys ())
5775
58- def as_dict (self ):
76+ def as_dict (self ) -> Dict [ str , Any ] :
5977 return {
6078 feature_name : getattr (self , feature_name )
6179 for feature_name in self .feature_names ()
6280 if hasattr (self , feature_name )
6381 }
6482
6583 @staticmethod
66- def from_schema (name : str , schema ) :
84+ def from_schema (name : str , schema : Any ) -> type :
6785 """Instantiates a custom batch subclass if all columns can be represented as a torch.Tensor."""
6886 return dataclasses .make_dataclass (
6987 cls_name = name ,
@@ -72,14 +90,14 @@ def from_schema(name: str, schema):
7290 )
7391
7492 @staticmethod
75- def from_fields (name : str , fields : dict ) :
93+ def from_fields (name : str , fields : Dict [ str , Any ]) -> type :
7694 return dataclasses .make_dataclass (
7795 cls_name = name ,
7896 fields = [(_name , _type , dataclasses .field (default = None )) for _name , _type in fields .items ()],
7997 bases = (DataclassBatch ,),
8098 )
8199
82100
83- class DictionaryBatch (BatchBase , dict ):
84- def as_dict (self ) -> Dict :
101+ class DictionaryBatch (BatchBase , UserDict [ _KT , _VT ] ):
102+ def as_dict (self ) -> Dict [ str , Any ] :
85103 return self
0 commit comments