Skip to content

Conversation

EmileAydar
Copy link

@EmileAydar EmileAydar commented Jun 12, 2025

Hi folks, it's me again,

I had noticed a few months back that TorchJD bugged when trying to input sparse matrices into the custom JD backward pass due to incompatibility with torch.vmap().

Indeed, many graph-based models store their adjacency matrices in sparse format for efficiency. Feeding such tensors through TorchJD raised an error inside the custom backward transform before this patch.

In this PR, I'm suggesting a minimalistic implementation that may help tackle this problem efficiently.

These new addons aim at letting users pass sparse adjacency matrices (COO, CSR, or torch_sparse.SparseTensor) into Jacobian-Descent workflows without any code changes:

import torch
from torchjd import backward
from torchjd.aggregation import UPGrad 
from torchjd.sparse import sparse_mm          # import auto-patches torch
A = torch.sparse_coo_tensor(
        indices=[[0, 1], [1, 0]],
        values=[1., 1.],
        size=(2, 2)
).coalesce()    # fixed graph structure during training, no grads

p = torch.randn(2, requires_grad=True)

y1 = (A @ p.unsqueeze(1)).sum()        # works out-of-the-box
y2 = (p ** 2).sum()           

backward([y1, y2], UPGrad())
print("p.grad =", p.grad)

What's in the patch:

  • torchjd.sparse._autograd.SparseMatMul
    A vmap-aware autograd Function with custom forward, backward, and vmap implementations.

  • torchjd.sparse.enable_seamless_sparse()
    Monkey-patches:

    • torch.sparse.mm
    • Tensor.__matmul__
    • (optional) torch_sparse.SparseTensor.matmul
  • Automatic patch application
    It runs on first import of torchjd.sparse, so users rarely need to invoke it directly.

  • torchjd.sparse._utils.to_coalesced_coo
    It converts torch_sparse and SciPy matrices to coalesced PyTorch COO format.

  • 15 unit tests added in a nutshell

    • Forward/backward parity (dense vs sparse, batched and vmap paths)
    • Patch idempotency, warning branch, wrapper coverage
    • Overall 98% line coverage for the new "torchjd/sparse" module.

Cheers,
Emile

Copy link

codecov bot commented Jun 12, 2025

Codecov Report

Attention: Patch coverage is 98.01980% with 2 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/torchjd/sparse/_patch.py 94.28% 2 Missing ⚠️
Files with missing lines Coverage Δ
src/torchjd/_autojac/__init__.py 100.00% <100.00%> (ø)
src/torchjd/sparse/__init__.py 100.00% <100.00%> (ø)
src/torchjd/sparse/_autograd.py 100.00% <100.00%> (ø)
src/torchjd/sparse/_registry.py 100.00% <100.00%> (ø)
src/torchjd/sparse/_utils.py 100.00% <100.00%> (ø)
src/torchjd/sparse/_patch.py 94.28% <94.28%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ValerianRey
Copy link
Contributor

Hi Emile!

Thanks a lot for all the work! We didn't know that torchjd didn't work for sparse tensors, and we never thought about the graph neural network use-case, so it's really good that you bring this to our attention. I also went through your code and learned a lot from it.


We already had a number of problems in the past due to vmap, so we decided to add the possibility to simply avoid it by setting parallel_chunk_size=1 when calling backward or mtl_backward (see #221, and especially this change).

This means that another way to run the code example that you provided in your PR message is the following:

import torch
from torchjd import backward
from torchjd.aggregation import UPGrad
A = torch.sparse_coo_tensor(
        indices=[[0, 1], [1, 0]],
        values=[1., 1.],
        size=(2, 2)
).coalesce()    # fixed graph structure during training, no grads

p = torch.randn(2, requires_grad=True)

y1 = (A @ p.unsqueeze(1)).sum()        # works out-of-the-box
y2 = (p ** 2).sum()

backward([y1, y2], UPGrad(), parallel_chunk_size=1)  # <- Avoid using vmap by setting parallel_chunk_size=1
print("p.grad =", p.grad)

Since this seems to work, I don't think we want to add so much extra code in torchjd just to make it possible to do this in parallel (using vmap this time). Is my understanding correct or am I missing something else that your patch enables?
Also, it really seems that the fix you're proposing intends to fix an issue that torch itself has. To me, it would be more appropriate to try to fix it directly in torch (e.g. open an issue there). Of course we depend on torch, so their issues are heavily reflected on TorchJD, but our strategy so far is more to "dodge" torch's problems (for instance by having a way to avoid vmap altogether) than to try to fix them, unless it can be easily fixed by a PR on their side.

We would definitely benefit from:

  • New tests involving different kinds of sparse matrices (coo, csr, csc, bsr, bsc) and using parallel_chunk_size=1. I'm not sure exactly what will work and what will not, but this is exactly the point of those tests. We could also try with and without coalescing the tensors prior to applying the torchjd functions to them. I think we would like to test this in tests/unit/autojac/test_backward.py and tests/unit/autojac/test_mtl_backward.py.
  • A new usage example specifically made for sparse matrices, maybe from the viewpoint of somebody trying to do something with a graph neural network, or at least with an adjacency matrix, telling people to use parallel_chunk_size=1 when dealing with them and explaining why. For this, I would really like that we use the example that you provided in the message of your PR (and maybe slightly adapt it, but I think it's almost ready). As for the explanation about vmap, we could adapt the note provided at the end of the documentation page about RNNs (https://torchjd.org/stable/examples/rnn/).
  • In the warning note about vmap in https://torchjd.org/stable/docs/autojac/backward/ and https://torchjd.org/stable/docs/autojac/mtl_backward/, it would be nice to add "when some tensors are sparse", with a link to this PR, in the list of examples of things that break vmap.

What do you think of this? Would you be up to the task?

Also, do you think it could happen that people want to differentiate with respect to a sparse tensor (e.g. have a sparse parameter tensor in their model)? In this case, we may also want to test that, to at least know if it works.

@EmileAydar
Copy link
Author

EmileAydar commented Jun 14, 2025

We already had a number of problems in the past due to vmap, so we decided to add the possibility to simply avoid it by setting parallel_chunk_size=1 when calling backward or mtl_backward (see #221, and especially this change).

Since this seems to work, I don't think we want to add so much extra code in torchjd just to make it possible to do this in parallel (using vmap this time). Is my understanding correct or am I missing something else that your patch enables?

Awesome ! Turning off batching to drop back to serial grad calls is a clean escape hatch ! In the common “one-loss” or
“few-losses” scenario that’s probably all anyone needs.

Also, it really seems that the fix you're proposing intends to fix an issue that torch itself has. To me, it would be more appropriate to try to fix it directly in torch (e.g. open an issue there). Of course we depend on torch, so their issues are heavily reflected on TorchJD, but our strategy so far is more to "dodge" torch's problems (for instance by having a way to avoid 'vmap()' altogether) than to try to fix them, unless it can be easily fixed by a PR on their side.

I agree the root cause is upstream: _sparse_mm/_sparse_mv have no batching rule. I did consider opening an issue on torch, but I don't really provide a long-term solution with it and the issue is already widely known.
torch-sparse / torch-scatter sit on top of PyTorch, so the Functorch batching rule really has to live inside ATen.
What I did with my patch is really another type of workaround but JDs workaround is simpler and safer. Overall, this stays an open-problem I'd say. A suitable kernel rule would make my patch obsolete.
Might be worth mentioning in the docs that custom autograd Functions can be a viable workaround for vmap()-compatible code though for users who want to have control over their model code.

We would definitely benefit from:

  • New tests involving different kinds of sparse matrices (coo, csr, csc, bsr, bsc) and using parallel_chunk_size=1. I'm not sure exactly what will work and what will not, but this is exactly the point of those tests. We could also try with and without coalescing the tensors prior to applying the torchjd functions to them. I think we would like to test this in tests/unit/autojac/test_backward.py and tests/unit/autojac/test_mtl_backward.py.
  • A new usage example specifically made for sparse matrices, maybe from the viewpoint of somebody trying to do something with a graph neural network, or at least with an adjacency matrix, telling people to use parallel_chunk_size=1 when dealing with them and explaining why. For this, I would really like that we use the example that you provided in the message of your PR (and maybe slightly adapt it, but I think it's almost ready). As for the explanation about vmap, we could adapt the note provided at the end of the documentation page about RNNs (https://torchjd.org/stable/examples/rnn/).
  • In the warning note about vmap in https://torchjd.org/stable/docs/autojac/backward/ and https://torchjd.org/stable/docs/autojac/mtl_backward/, it would be nice to add "when some tensors are sparse", with a link to this PR, in the list of examples of things that break vmap.

What do you think of this? Would you be up to the task?

Happy to help with that ofc ! Just let me know how you'd like me to proceed or I can improvise some tests in the meantime.

@EmileAydar
Copy link
Author

EmileAydar commented Jun 14, 2025

Also, do you think it could happen that people want to differentiate with respect to a sparse tensor (e.g. have a sparse parameter tensor in their model)? In this case, we may also want to test that, to at least know if it works.

Yep, there are definitely practical cases where the values stored in a sparse tensor need to be learnable parameters.
Many GNN variants process edge weights (or even adjacency masks) as trainable params. In this case you would store the weights naturally in a sparse COO matrix rather than dense NxN tensors. There are probably other use-cases involving dynamic pruning or network compression schemes, but overall it's something relevant as a concept.
In these types of cases, you would need the partial derivatives of the losses for the sparse tensor.

Core PyTorch can only back-propagate through COO tensors (values only).
Gradients for integer indices are undefined, and gradients for CSR/CSC/BSR layouts aren't implemented yet either.

The recommended pattern by the PyG maintainter is:

  1. Keep the index tensor as fixed data
  2. Wrap the value vector in nn.Parameter
  3. Rebuild the COO tensor each forward pass
    (see the maintainer discussion here).

Provided we avoid vmap, this does work with TorchJD as well. I tested with parallel_chunk_size=1, and TorchJD passes gradients back to the value vector exactly as plain PyTorch does.

I've used a small snippet for quick testing of sparse autograd, I get this :

import torch, torchjd
from torchjd import backward
from torchjd.aggregation import Sum

# fixed COO indices
idx = torch.tensor([[0, 1], [1, 0]])  # two off-diagonal entries
vals = torch.nn.Parameter(torch.tensor([1., 1.]))  # learnable 

def forward():
    W = torch.sparse_coo_tensor(idx, vals, (2, 2)).coalesce()  # sparse weight
    y = torch.sparse.mm(W, torch.eye(2))  # dense output
    return y.sum()

loss = forward()
backward([loss], Sum(), parallel_chunk_size=1)  # serial JD
print(vals.grad) 

Output:

tensor([1., 1.])

However, if we feed multiple losses at once and leave parallel_chunk_size=None, TorchJD will use vmap, then Functorch tries to batch the sparse ops, and we hit the usual "no batching rule for aten::sparse_mask/sparsemm" error and come to the same conclusions about vmap() as the ones we drew above.

I could add the COO-values test to tests/unit/autojac/test_backward.py and extend the vmap-limitation
bullet in the docs.

Let me know if I answered your question well and how you'd like to proceed.

@ValerianRey
Copy link
Contributor

Might be worth mentioning in the docs that custom autograd Functions can be a viable workaround for vmap()-compatible code though for users who want to have control over their model code.

I agree. We're currently working on the engine (#387), meaning that the way we use vmap will likely change, so we'll wait for that before documenting more heavily how to find vmap workarounds.

Happy to help with that ofc ! Just let me know how you'd like me to proceed or I can improvise some tests in the meantime.

I think you could just start with test_backward.py, and create a new test specifically for sparse intermediate values (not inputs or outputs, just like in the example you gave at the beginning of the PR). This test could be based on the current test_value_is_correct, but with a fixed aggregator (UPGrad), shape, automatically specified inputs (i.e. you just use the default inputs when calling backward), and a chunk size of 1. J could be created similarly but cast to a "sparse" tensor afterwards (it won't really be sparse, but it will have a sparse layout, which is what we want to test). I'm thinking something like J = torch.randn(shape).to_sparse(layout=torch.sparse_coo).

I think that will be a good start, and we can already see from there if things seem to work / when do they work, etc.

If all goes well, I guess that later steps could be:

  • Parametrize this test on the kind of sparse layout that we want J to be cast to.
  • Parametrize this test on whether we coalesce or not, using a boolean.
  • Add a similar test in test_mtl_backward.

Yep, there are definitely practical cases where the values stored in a sparse tensor need to be learnable parameters.
[...]
Output:
tensor([1., 1.])

Very interesting, thanks you. I think it's a nice workaround when dealing with sparse inputs. I tried a slightly different variation of your code, with an explicitly sparse input:

import torch
from torchjd import backward
from torchjd.aggregation import UPGrad

# fixed COO indices
idx = torch.tensor([[0, 1], [1, 0]])  # two off-diagonal entries
vals = torch.tensor([1., 1.])  # learnable
sparse_tensor = torch.sparse_coo_tensor(idx, vals, requires_grad=True)

def forward():
    y = torch.sparse.mm(sparse_tensor, torch.eye(2))  # dense output
    return y.sum()

loss = forward()
backward([loss], UPGrad(), parallel_chunk_size=1)  # serial JD
print(vals.grad)

It does not work:
(Summary of trace):

../src/torchjd/_autojac/_transform/_jac.py:82: in _get_vjp
    return torch.concatenate([grad.reshape([-1]) for grad in grads])
                              ^^^^^^^^^^^^^^^^^^
RuntimeError: reshape is not implemented for sparse tensors

I think that it will be too much effort to make TorchJD work seamlessly with sparse inputs, unless pytorch itself starts supporting more operations on sparse tensors.

So for now, let's focus on sparse intermediate tensors, i.e. new test in test_backward.py, extend the vmap-limitation bullet, and possibly a new usage example if things seem to work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants