-
Notifications
You must be signed in to change notification settings - Fork 10
Sparse MatMul Support in TorchJD for Seamless compatibility with torch.vmap() #388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
🚀 New features to boost your workflow:
|
Hi Emile! Thanks a lot for all the work! We didn't know that torchjd didn't work for sparse tensors, and we never thought about the graph neural network use-case, so it's really good that you bring this to our attention. I also went through your code and learned a lot from it. We already had a number of problems in the past due to This means that another way to run the code example that you provided in your PR message is the following: import torch
from torchjd import backward
from torchjd.aggregation import UPGrad
A = torch.sparse_coo_tensor(
indices=[[0, 1], [1, 0]],
values=[1., 1.],
size=(2, 2)
).coalesce() # fixed graph structure during training, no grads
p = torch.randn(2, requires_grad=True)
y1 = (A @ p.unsqueeze(1)).sum() # works out-of-the-box
y2 = (p ** 2).sum()
backward([y1, y2], UPGrad(), parallel_chunk_size=1) # <- Avoid using vmap by setting parallel_chunk_size=1
print("p.grad =", p.grad) Since this seems to work, I don't think we want to add so much extra code in torchjd just to make it possible to do this in parallel (using We would definitely benefit from:
What do you think of this? Would you be up to the task? Also, do you think it could happen that people want to differentiate with respect to a sparse tensor (e.g. have a sparse parameter tensor in their model)? In this case, we may also want to test that, to at least know if it works. |
Awesome ! Turning off batching to drop back to serial
I agree the root cause is upstream:
Happy to help with that ofc ! Just let me know how you'd like me to proceed or I can improvise some tests in the meantime. |
Yep, there are definitely practical cases where the values stored in a sparse tensor need to be learnable parameters. Core PyTorch can only back-propagate through COO tensors (values only). The recommended pattern by the PyG maintainter is:
Provided we avoid I've used a small snippet for quick testing of sparse autograd, I get this : import torch, torchjd
from torchjd import backward
from torchjd.aggregation import Sum
# fixed COO indices
idx = torch.tensor([[0, 1], [1, 0]]) # two off-diagonal entries
vals = torch.nn.Parameter(torch.tensor([1., 1.])) # learnable
def forward():
W = torch.sparse_coo_tensor(idx, vals, (2, 2)).coalesce() # sparse weight
y = torch.sparse.mm(W, torch.eye(2)) # dense output
return y.sum()
loss = forward()
backward([loss], Sum(), parallel_chunk_size=1) # serial JD
print(vals.grad) Output:
However, if we feed multiple losses at once and leave I could add the COO-values test to Let me know if I answered your question well and how you'd like to proceed. |
I agree. We're currently working on the engine (#387), meaning that the way we use vmap will likely change, so we'll wait for that before documenting more heavily how to find vmap workarounds.
I think you could just start with I think that will be a good start, and we can already see from there if things seem to work / when do they work, etc. If all goes well, I guess that later steps could be:
Very interesting, thanks you. I think it's a nice workaround when dealing with sparse inputs. I tried a slightly different variation of your code, with an explicitly sparse input: import torch
from torchjd import backward
from torchjd.aggregation import UPGrad
# fixed COO indices
idx = torch.tensor([[0, 1], [1, 0]]) # two off-diagonal entries
vals = torch.tensor([1., 1.]) # learnable
sparse_tensor = torch.sparse_coo_tensor(idx, vals, requires_grad=True)
def forward():
y = torch.sparse.mm(sparse_tensor, torch.eye(2)) # dense output
return y.sum()
loss = forward()
backward([loss], UPGrad(), parallel_chunk_size=1) # serial JD
print(vals.grad) It does not work:
I think that it will be too much effort to make TorchJD work seamlessly with sparse inputs, unless pytorch itself starts supporting more operations on sparse tensors. So for now, let's focus on sparse intermediate tensors, i.e. new test in |
Hi folks, it's me again,
I had noticed a few months back that TorchJD bugged when trying to input sparse matrices into the custom JD backward pass due to incompatibility with
torch.vmap()
.Indeed, many graph-based models store their adjacency matrices in sparse format for efficiency. Feeding such tensors through TorchJD raised an error inside the custom backward transform before this patch.
In this PR, I'm suggesting a minimalistic implementation that may help tackle this problem efficiently.
These new addons aim at letting users pass sparse adjacency matrices (COO, CSR, or
torch_sparse.SparseTensor
) into Jacobian-Descent workflows without any code changes:What's in the patch:
torchjd.sparse._autograd.SparseMatMul
A vmap-aware autograd Function with custom
forward
,backward
, andvmap
implementations.torchjd.sparse.enable_seamless_sparse()
Monkey-patches:
torch.sparse.mm
Tensor.__matmul__
torch_sparse.SparseTensor.matmul
Automatic patch application
It runs on first import of
torchjd.sparse
, so users rarely need to invoke it directly.torchjd.sparse._utils.to_coalesced_coo
It converts
torch_sparse
and SciPy matrices to coalesced PyTorch COO format.15 unit tests added in a nutshell
Cheers,
Emile