Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 126 additions & 40 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,23 @@ Some aggregators may have additional dependencies. Please refer to the
[installation documentation](https://torchjd.org/stable/installation) for them.

## Usage
The main way to use TorchJD is to replace the usual call to `loss.backward()` by a call to
`torchjd.backward` or `torchjd.mtl_backward`, depending on the use-case.
There are two main ways to use TorchJD. The first one is to replace the usual call to
`loss.backward()` by a call to
[`torchjd.autojac.backward`](https://torchjd.org/stable/docs/autojac/backward/) or
[`torchjd.autojac.mtl_backward`](https://torchjd.org/stable/docs/autojac/mtl_backward/), depending
on the use-case. This will compute the Jacobian of the vector of losses with respect to the model
parameters, and aggregate it with the specified
[`Aggregator`](https://torchjd.org/stable/docs/aggregation/index.html#torchjd.aggregation.Aggregator).
Whenever you want to optimize the vector of per-sample losses, you should rather use the
[`torchjd.autogram.Engine`](https://torchjd.org/stable/docs/autogram/engine.html). Instead of
computing the full Jacobian at once, it computes the Gramian of this Jacobian, layer by layer, in a
memory-efficient way. A vector of weights (one per element of the batch) can then be extracted from
this Gramian, using a
[`Weighting`](https://torchjd.org/stable/docs/aggregation/index.html#torchjd.aggregation.Weighting),
and used to combine the losses of the batch. Assuming each element of the batch is
processed independently from the others, this approach is equivalent to
[`torchjd.autojac.backward`](https://torchjd.org/stable/docs/autojac/backward/) while being
generally much faster due to the lower memory usage.

The following example shows how to use TorchJD to train a multi-task model with Jacobian descent,
using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/).
Expand All @@ -66,7 +81,7 @@ using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/).
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

+ from torchjd import mtl_backward
+ from torchjd.autojac import mtl_backward
+ from torchjd.aggregation import UPGrad

shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
Expand Down Expand Up @@ -104,49 +119,120 @@ using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/).
> In this example, the Jacobian is only with respect to the shared parameters. The task-specific
> parameters are simply updated via the gradient of their task’s loss with respect to them.

More usage examples can be found [here](https://torchjd.org/stable/examples/).
The following example shows how to use TorchJD to minimize the vector of per-instance losses with
Jacobian descent using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/).

## Supported Aggregators
TorchJD provides many existing aggregators from the literature, listed in the following table.
```diff
import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

<!-- recommended aggregators first, then alphabetical order -->
| Aggregator | Publication |
|-----------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/) (recommended) | [Jacobian Descent For Multi-Objective Optimization](https://arxiv.org/pdf/2406.16232) |
| [AlignedMTL](https://torchjd.org/stable/docs/aggregation/aligned_mtl/) | [Independent Component Alignment for Multi-Task Learning](https://arxiv.org/pdf/2305.19000) |
| [CAGrad](https://torchjd.org/stable/docs/aggregation/cagrad/) | [Conflict-Averse Gradient Descent for Multi-task Learning](https://arxiv.org/pdf/2110.14048) |
| [ConFIG](https://torchjd.org/stable/docs/aggregation/config/) | [ConFIG: Towards Conflict-free Training of Physics Informed Neural Networks](https://arxiv.org/pdf/2408.11104) |
| [Constant](https://torchjd.org/stable/docs/aggregation/constant/) | - |
| [DualProj](https://torchjd.org/stable/docs/aggregation/dualproj/) | [Gradient Episodic Memory for Continual Learning](https://arxiv.org/pdf/1706.08840) |
| [GradDrop](https://torchjd.org/stable/docs/aggregation/graddrop/) | [Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout](https://arxiv.org/pdf/2010.06808) |
| [IMTL-G](https://torchjd.org/stable/docs/aggregation/imtl_g/) | [Towards Impartial Multi-task Learning](https://discovery.ucl.ac.uk/id/eprint/10120667/) |
| [Krum](https://torchjd.org/stable/docs/aggregation/krum/) | [Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent](https://proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-Paper.pdf) |
| [Mean](https://torchjd.org/stable/docs/aggregation/mean/) | - |
| [MGDA](https://torchjd.org/stable/docs/aggregation/mgda/) | [Multiple-gradient descent algorithm (MGDA) for multiobjective optimization](https://www.sciencedirect.com/science/article/pii/S1631073X12000738) |
| [Nash-MTL](https://torchjd.org/stable/docs/aggregation/nash_mtl/) | [Multi-Task Learning as a Bargaining Game](https://arxiv.org/pdf/2202.01017) |
| [PCGrad](https://torchjd.org/stable/docs/aggregation/pcgrad/) | [Gradient Surgery for Multi-Task Learning](https://arxiv.org/pdf/2001.06782) |
| [Random](https://torchjd.org/stable/docs/aggregation/random/) | [Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning](https://arxiv.org/pdf/2111.10603) |
| [Sum](https://torchjd.org/stable/docs/aggregation/sum/) | - |
| [Trimmed Mean](https://torchjd.org/stable/docs/aggregation/trimmed_mean/) | [Byzantine-Robust Distributed Learning: Towards Optimal Statistical Rates](https://proceedings.mlr.press/v80/yin18a/yin18a.pdf) |

The following example shows how to instantiate
[UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/) and aggregate a simple matrix `J` with
it.
```python
from torch import tensor
from torchjd.aggregation import UPGrad
+ from torchjd.autogram import Engine
+ from torchjd.aggregation import UPGradWeighting

model = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU(), Linear(3, 1), ReLU())

A = UPGrad()
J = tensor([[-4., 1., 1.], [6., 1., 1.]])
- loss_fn = MSELoss()
+ loss_fn = MSELoss(reduction="none")
optimizer = SGD(model.parameters(), lr=0.1)

A(J)
# Output: tensor([0.2929, 1.9004, 1.9004])
+ weighting = UPGradWeighting()
+ engine = Engine(model, batch_dim=0)

inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task

for input, target in zip(inputs, targets):
output = model(input).squeeze(dim=1) # shape [16]
- loss = loss_fn(output, target) # shape [1]
+ losses = loss_fn(output, target) # shape [16]

optimizer.zero_grad()
- loss.backward()
+ gramian = engine.compute_gramian(losses) # shape: [16, 16]
+ weights = weighting(gramian) # shape: [16]
+ losses.backward(weights)
optimizer.step()
```

> [!TIP]
> When using TorchJD, you generally don't have to use aggregators directly. You simply instantiate
> one and pass it to the backward function (`torchjd.backward` or `torchjd.mtl_backward`), which
> will in turn apply it to the Jacobian matrix that it will compute.
Lastly, you can even combine the two approaches by considering multiple tasks and each element of
the batch independently. We call that Instance-Wise Multitask Learning (IWMTL).

```python
import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd.aggregation import Flattening, UPGradWeighting
from torchjd.autogram import Engine

shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
task1_module = Linear(3, 1)
task2_module = Linear(3, 1)
params = [
*shared_module.parameters(),
*task1_module.parameters(),
*task2_module.parameters(),
]

optimizer = SGD(params, lr=0.1)
mse = MSELoss(reduction="none")
weighting = Flattening(UPGradWeighting())
engine = Engine(shared_module, batch_dim=0)

inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task
task2_targets = torch.randn(8, 16) # 8 batches of 16 targets for the second task

for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
features = shared_module(input) # shape: [16, 3]
out1 = task1_module(features).squeeze(1) # shape: [16]
out2 = task2_module(features).squeeze(1) # shape: [16]

# Compute the matrix of losses: one loss per element of the batch and per task
losses = torch.stack([mse(out1, target1), mse(out2, target2)], dim=1) # shape: [16, 2]

# Compute the gramian (inner products between pairs of gradients of the losses)
gramian = engine.compute_gramian(losses) # shape: [16, 2, 2, 16]

# Obtain the weights that lead to no conflict between reweighted gradients
weights = weighting(gramian) # shape: [16, 2]

optimizer.zero_grad()
# Do the standard backward pass, but weighted using the obtained weights
losses.backward(weights)
optimizer.step()
```

> [!NOTE]
> Here, because the losses are a matrix instead of a simple vector, we compute a *generalized
> Gramian* and we extract weights from it using a
> [GeneralizedWeighting](https://torchjd.org/docs/aggregation/index.html#torchjd.aggregation.GeneralizedWeighting).

More usage examples can be found [here](https://torchjd.org/stable/examples/).

## Supported Aggregators and Weightings
TorchJD provides many existing aggregators from the literature, listed in the following table.

<!-- recommended aggregators first, then alphabetical order -->
| Aggregator | Weighting | Publication |
|------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad.html#torchjd.aggregation.UPGrad) (recommended) | [UPGradWeighting](https://torchjd.org/stable/docs/aggregation/upgrad#torchjd.aggregation.UPGradWeighting) | [Jacobian Descent For Multi-Objective Optimization](https://arxiv.org/pdf/2406.16232) |
| [AlignedMTL](https://torchjd.org/stable/docs/aggregation/aligned_mtl#torchjd.aggregation.AlignedMTL) | [AlignedMTLWeighting](https://torchjd.org/stable/docs/aggregation/aligned_mtl#torchjd.aggregation.AlignedMTLWeighting) | [Independent Component Alignment for Multi-Task Learning](https://arxiv.org/pdf/2305.19000) |
| [CAGrad](https://torchjd.org/stable/docs/aggregation/cagrad#torchjd.aggregation.CAGrad) | [CAGradWeighting](https://torchjd.org/stable/docs/aggregation/cagrad#torchjd.aggregation.CAGradWeighting) | [Conflict-Averse Gradient Descent for Multi-task Learning](https://arxiv.org/pdf/2110.14048) |
| [ConFIG](https://torchjd.org/stable/docs/aggregation/config#torchjd.aggregation.ConFIG) | - | [ConFIG: Towards Conflict-free Training of Physics Informed Neural Networks](https://arxiv.org/pdf/2408.11104) |
| [Constant](https://torchjd.org/stable/docs/aggregation/constant#torchjd.aggregation.Constant) | [ConstantWeighting](https://torchjd.org/stable/docs/aggregation/constant#torchjd.aggregation.ConstantWeighting) | - |
| [DualProj](https://torchjd.org/stable/docs/aggregation/dualproj#torchjd.aggregation.DualProj) | [DualProjWeighting](https://torchjd.org/stable/docs/aggregation/dualproj#torchjd.aggregation.DualProjWeighting) | [Gradient Episodic Memory for Continual Learning](https://arxiv.org/pdf/1706.08840) |
| [GradDrop](https://torchjd.org/stable/docs/aggregation/graddrop#torchjd.aggregation.GradDrop) | - | [Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout](https://arxiv.org/pdf/2010.06808) |
| [IMTLG](https://torchjd.org/stable/docs/aggregation/imtl_g#torchjd.aggregation.IMTLG) | [IMTLGWeighting](https://torchjd.org/stable/docs/aggregation/imtl_g#torchjd.aggregation.IMTLGWeighting) | [Towards Impartial Multi-task Learning](https://discovery.ucl.ac.uk/id/eprint/10120667/) |
| [Krum](https://torchjd.org/stable/docs/aggregation/krum#torchjd.aggregation.Krum) | [KrumWeighting](https://torchjd.org/stable/docs/aggregation/krum#torchjd.aggregation.KrumWeighting) | [Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent](https://proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-Paper.pdf) |
| [Mean](https://torchjd.org/stable/docs/aggregation/mean#torchjd.aggregation.Mean) | [MeanWeighting](https://torchjd.org/stable/docs/aggregation/mean#torchjd.aggregation.MeanWeighting) | - |
| [MGDA](https://torchjd.org/stable/docs/aggregation/mgda#torchjd.aggregation.MGDA) | [MGDAWeighting](https://torchjd.org/stable/docs/aggregation/mgda#torchjd.aggregation.MGDAWeighting) | [Multiple-gradient descent algorithm (MGDA) for multiobjective optimization](https://www.sciencedirect.com/science/article/pii/S1631073X12000738) |
| [NashMTL](https://torchjd.org/stable/docs/aggregation/nash_mtl#torchjd.aggregation.NashMTL) | - | [Multi-Task Learning as a Bargaining Game](https://arxiv.org/pdf/2202.01017) |
| [PCGrad](https://torchjd.org/stable/docs/aggregation/pcgrad#torchjd.aggregation.PCGrad) | [PCGradWeighting](https://torchjd.org/stable/docs/aggregation/pcgrad#torchjd.aggregation.PCGradWeighting) | [Gradient Surgery for Multi-Task Learning](https://arxiv.org/pdf/2001.06782) |
| [Random](https://torchjd.org/stable/docs/aggregation/random#torchjd.aggregation.Random) | [RandomWeighting](https://torchjd.org/stable/docs/aggregation/random#torchjd.aggregation.RandomWeighting) | [Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning](https://arxiv.org/pdf/2111.10603) |
| [Sum](https://torchjd.org/stable/docs/aggregation/sum#torchjd.aggregation.Sum) | [SumWeighting](https://torchjd.org/stable/docs/aggregation/sum#torchjd.aggregation.SumWeighting) | - |
| [Trimmed Mean](https://torchjd.org/stable/docs/aggregation/trimmed_mean#torchjd.aggregation.TrimmedMean) | - | [Byzantine-Robust Distributed Learning: Towards Optimal Statistical Rates](https://proceedings.mlr.press/v80/yin18a/yin18a.pdf) |

## Contribution
Please read the [Contribution page](CONTRIBUTING.md).
Expand Down
Loading