Skip to content

Commit e69f989

Browse files
author
Samuel Boïté
committed
torch matrix sqrt gradcheck
1 parent f08e5e1 commit e69f989

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

test/test_backend.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,19 @@ def fun(a, b, d):
822822
assert nx.allclose(dl_db, b)
823823

824824

825+
def test_sqrtm_backward_torch():
826+
if not torch:
827+
pytest.skip("Torch not available")
828+
nx = ot.backend.TorchBackend()
829+
torch.manual_seed(42)
830+
d = 5
831+
A = torch.randn(d, d, dtype=torch.float64, device="cpu")
832+
A = A @ A.T
833+
A.requires_grad_(True)
834+
func = lambda x: nx.sqrtm(x).sum()
835+
assert torch.autograd.gradcheck(func, (A,), atol=1e-4, rtol=1e-4)
836+
837+
825838
def test_get_backend_none():
826839
a, b = np.zeros((2, 3)), None
827840
nx = get_backend(a, b)

0 commit comments

Comments
 (0)