Skip to content

Commit 2d8bda9

Browse files
authored
Update per_sample_grads.py
1 parent 9cce023 commit 2d8bda9

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

intermediate_source/per_sample_grads.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def compute_loss(params, buffers, sample, target):
181181
ft_per_sample_grad = ft_per_sample_grad.view(per_sample_grad.shape)
182182

183183
# Use a higher tolerance for comparison
184-
assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=1e-2, rtol=1e-2), \
184+
assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=2e-2, rtol=2e-2), \
185185
f"Mismatch in {name}: max diff {(per_sample_grad - ft_per_sample_grad).abs().max().item()}"
186186

187187

0 commit comments

Comments
 (0)