Skip to content

Commit 1c59e82

Browse files
jd7-trfacebook-github-bot
authored andcommitted
Remove the stride_per_key_per_rank dim check (#3150)
Summary: Pull Request resolved: #3150 Fix test failures Created from CodeHub with https://fburl.com/edit-in-codehub Reviewed By: spmex Differential Revision: D77636705 fbshipit-source-id: b44e07dbd592561856fb2cc4ffc360dba3bdcfa2
1 parent da486e3 commit 1c59e82

File tree

1 file changed

+0
-5
lines changed

1 file changed

+0
-5
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1767,11 +1767,6 @@ def __init__(
17671767
# does not take List[List[int]]
17681768
assert not isinstance(stride_per_key_per_rank, list)
17691769

1770-
if isinstance(stride_per_key_per_rank, torch.IntTensor):
1771-
assert (
1772-
stride_per_key_per_rank.dim() == 2
1773-
), f"Expect 2D tensor with shape [len(keys), len(ranks)] for stride_per_key_per_rank, but got tensor with shape: {stride_per_key_per_rank.shape}"
1774-
17751770
self._stride_per_key_per_rank: Optional[torch.IntTensor] = (
17761771
torch.IntTensor(stride_per_key_per_rank, device="cpu")
17771772
if isinstance(stride_per_key_per_rank, list)

0 commit comments

Comments
 (0)