Skip to content

Commit f08e5e1

Browse files
author
Samuel Boïté
committed
stable matrix sqrt using closed-form diff
1 parent be211ac commit f08e5e1

File tree

1 file changed

+26
-6
lines changed

1 file changed

+26
-6
lines changed

ot/backend.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1938,6 +1938,7 @@ def __init__(self):
19381938
self.rng_cuda_ = torch.Generator("cpu")
19391939

19401940
from torch.autograd import Function
1941+
from torch.autograd.function import once_differentiable
19411942

19421943
# define a function that takes inputs val and grads
19431944
# ad returns a val tensor with proper gradients
@@ -1951,8 +1952,32 @@ def forward(ctx, val, grads, *inputs):
19511952
def backward(ctx, grad_output):
19521953
# the gradients are grad
19531954
return (None, None) + tuple(g * grad_output for g in ctx.grads)
1955+
1956+
# define a differentiable SPD matrix sqrt
1957+
# with closed-form VJP
1958+
class MatrixSqrtFunction(Function):
1959+
@staticmethod
1960+
def forward(ctx, a):
1961+
a_sym = .5 * (a + a.transpose(-2, -1))
1962+
L, V = torch.linalg.eigh(a_sym)
1963+
s = L.clamp_min(0).sqrt()
1964+
y = (V * s.unsqueeze(-2)) @ V.transpose(-2, -1)
1965+
ctx.save_for_backward(s, V)
1966+
return y
1967+
1968+
@staticmethod
1969+
@once_differentiable
1970+
def backward(ctx, g):
1971+
s, V = ctx.saved_tensors
1972+
g_sym = .5 * (g + g.transpose(-2, -1))
1973+
ghat = V.transpose(-2, -1) @ g_sym @ V
1974+
d = s.unsqueeze(-1) + s.unsqueeze(-2)
1975+
xhat = ghat / d
1976+
xhat = xhat.masked_fill(d == 0, 0)
1977+
return V @ xhat @ V.transpose(-2, -1)
19541978

19551979
self.ValFunction = ValFunction
1980+
self.MatrixSqrtFunction = MatrixSqrtFunction
19561981

19571982
def _to_numpy(self, a):
19581983
if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray):
@@ -2395,12 +2420,7 @@ def pinv(self, a, hermitian=False):
23952420
return torch.linalg.pinv(a, hermitian=hermitian)
23962421

23972422
def sqrtm(self, a):
2398-
L, V = torch.linalg.eigh(a)
2399-
L = torch.sqrt(L)
2400-
# Q[...] = V[...] @ diag(L[...])
2401-
Q = torch.einsum("...jk,...k->...jk", V, L)
2402-
# R[...] = Q[...] @ V[...].T
2403-
return torch.einsum("...jk,...kl->...jl", Q, torch.transpose(V, -1, -2))
2423+
return self.MatrixSqrtFunction.apply(a)
24042424

24052425
def eigh(self, a):
24062426
return torch.linalg.eigh(a)

0 commit comments

Comments
 (0)