diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 51f85d898..ebdce6acb 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -98,15 +98,6 @@ def _maybe_compute_lengths( return lengths -@torch.fx.wrap -def _maybe_compute_max_length(lengths: torch.Tensor, max_length: Optional[int]) -> int: - if max_length is None: - if lengths.numel() == 0: - return 0 - max_length = int(lengths.max().item()) - return max_length - - def _maybe_compute_offsets( lengths: Optional[torch.Tensor], offsets: Optional[torch.Tensor] ) -> torch.Tensor: @@ -590,7 +581,7 @@ class JaggedTensor(Pipelineable, metaclass=JaggedTensorMeta): offsets. """ - _fields = ["_values", "_weights", "_lengths", "_offsets", "_max_length"] + _fields = ["_values", "_weights", "_lengths", "_offsets"] def __init__( self, @@ -598,7 +589,6 @@ def __init__( weights: Optional[torch.Tensor] = None, lengths: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, - max_length: Optional[int] = None, ) -> None: self._values: torch.Tensor = values @@ -610,7 +600,6 @@ def __init__( _assert_tensor_has_no_elements_or_has_integers(lengths, "lengths") self._lengths: Optional[torch.Tensor] = lengths self._offsets: Optional[torch.Tensor] = offsets - self._max_length: Optional[int] = max_length @staticmethod def empty( @@ -641,7 +630,6 @@ def empty( offsets=torch.empty(0, dtype=lengths_dtype, device=device), lengths=torch.empty(0, dtype=lengths_dtype, device=device), weights=weights, - max_length=0, ) @staticmethod @@ -924,26 +912,6 @@ def lengths_or_none(self) -> Optional[torch.Tensor]: """ return self._lengths - def max_length(self) -> int: - """ - Get the maximum length of the JaggedTensor. - - Returns: - int: the maximum length of the JaggedTensor. - """ - _max_length = _maybe_compute_max_length(self.lengths(), self._max_length) - self._max_length = _max_length - return _max_length - - def max_length_or_none(self) -> Optional[int]: - """ - Get the maximum length of the JaggedTensor. If not computed, return None. - - Returns: - Optional[int]: the maximum length of the JaggedTensor. - """ - return self._max_length - def offsets(self) -> torch.Tensor: """ Get JaggedTensor offsets. If not computed, compute it from lengths. @@ -1005,7 +973,6 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "JaggedTensor" weights = self._weights lengths = self._lengths offsets = self._offsets - max_length = self._max_length return JaggedTensor( values=self._values.to(device, non_blocking=non_blocking), weights=( @@ -1023,7 +990,6 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "JaggedTensor" if offsets is not None else None ), - max_length=max_length, ) @torch.jit.unused diff --git a/torchrec/sparse/tests/test_jagged_tensor.py b/torchrec/sparse/tests/test_jagged_tensor.py index b6273915e..1f15cbeaf 100644 --- a/torchrec/sparse/tests/test_jagged_tensor.py +++ b/torchrec/sparse/tests/test_jagged_tensor.py @@ -568,14 +568,6 @@ def test_length_vs_offset(self) -> None: self.assertTrue(torch.equal(j_offset.lengths(), j_lens.lengths())) self.assertTrue(torch.equal(j_offset.offsets(), j_lens.offsets().int())) - def test_max_length(self) -> None: - values = torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) - offsets = torch.IntTensor([0, 2, 2, 3, 4, 5, 8]) - jt = JaggedTensor(values=values, offsets=offsets) - self.assertIsNone(jt.max_length_or_none()) - self.assertEqual(jt.max_length(), 3) - self.assertEqual(jt.max_length_or_none(), 3) - def test_empty(self) -> None: jt = JaggedTensor.empty(values_dtype=torch.int64)