Skip to content

Commit a20b69f

Browse files
author
Samuel Boïté
committed
pre commit
1 parent e69f989 commit a20b69f

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

ot/backend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1952,13 +1952,13 @@ def forward(ctx, val, grads, *inputs):
19521952
def backward(ctx, grad_output):
19531953
# the gradients are grad
19541954
return (None, None) + tuple(g * grad_output for g in ctx.grads)
1955-
1955+
19561956
# define a differentiable SPD matrix sqrt
19571957
# with closed-form VJP
19581958
class MatrixSqrtFunction(Function):
19591959
@staticmethod
19601960
def forward(ctx, a):
1961-
a_sym = .5 * (a + a.transpose(-2, -1))
1961+
a_sym = 0.5 * (a + a.transpose(-2, -1))
19621962
L, V = torch.linalg.eigh(a_sym)
19631963
s = L.clamp_min(0).sqrt()
19641964
y = (V * s.unsqueeze(-2)) @ V.transpose(-2, -1)
@@ -1969,7 +1969,7 @@ def forward(ctx, a):
19691969
@once_differentiable
19701970
def backward(ctx, g):
19711971
s, V = ctx.saved_tensors
1972-
g_sym = .5 * (g + g.transpose(-2, -1))
1972+
g_sym = 0.5 * (g + g.transpose(-2, -1))
19731973
ghat = V.transpose(-2, -1) @ g_sym @ V
19741974
d = s.unsqueeze(-1) + s.unsqueeze(-2)
19751975
xhat = ghat / d

test/test_backend.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -823,16 +823,16 @@ def fun(a, b, d):
823823

824824

825825
def test_sqrtm_backward_torch():
826-
if not torch:
827-
pytest.skip("Torch not available")
828-
nx = ot.backend.TorchBackend()
829-
torch.manual_seed(42)
830-
d = 5
831-
A = torch.randn(d, d, dtype=torch.float64, device="cpu")
832-
A = A @ A.T
833-
A.requires_grad_(True)
834-
func = lambda x: nx.sqrtm(x).sum()
835-
assert torch.autograd.gradcheck(func, (A,), atol=1e-4, rtol=1e-4)
826+
if not torch:
827+
pytest.skip("Torch not available")
828+
nx = ot.backend.TorchBackend()
829+
torch.manual_seed(42)
830+
d = 5
831+
A = torch.randn(d, d, dtype=torch.float64, device="cpu")
832+
A = A @ A.T
833+
A.requires_grad_(True)
834+
func = lambda x: nx.sqrtm(x).sum()
835+
assert torch.autograd.gradcheck(func, (A,), atol=1e-4, rtol=1e-4)
836836

837837

838838
def test_get_backend_none():

0 commit comments

Comments
 (0)