diff --git a/README.md b/README.md index 5c03b421..7bb11dcd 100644 --- a/README.md +++ b/README.md @@ -55,8 +55,23 @@ Some aggregators may have additional dependencies. Please refer to the [installation documentation](https://torchjd.org/stable/installation) for them. ## Usage -The main way to use TorchJD is to replace the usual call to `loss.backward()` by a call to -`torchjd.backward` or `torchjd.mtl_backward`, depending on the use-case. +There are two main ways to use TorchJD. The first one is to replace the usual call to +`loss.backward()` by a call to +[`torchjd.autojac.backward`](https://torchjd.org/stable/docs/autojac/backward/) or +[`torchjd.autojac.mtl_backward`](https://torchjd.org/stable/docs/autojac/mtl_backward/), depending +on the use-case. This will compute the Jacobian of the vector of losses with respect to the model +parameters, and aggregate it with the specified +[`Aggregator`](https://torchjd.org/stable/docs/aggregation/index.html#torchjd.aggregation.Aggregator). +Whenever you want to optimize the vector of per-sample losses, you should rather use the +[`torchjd.autogram.Engine`](https://torchjd.org/stable/docs/autogram/engine.html). Instead of +computing the full Jacobian at once, it computes the Gramian of this Jacobian, layer by layer, in a +memory-efficient way. A vector of weights (one per element of the batch) can then be extracted from +this Gramian, using a +[`Weighting`](https://torchjd.org/stable/docs/aggregation/index.html#torchjd.aggregation.Weighting), +and used to combine the losses of the batch. Assuming each element of the batch is +processed independently from the others, this approach is equivalent to +[`torchjd.autojac.backward`](https://torchjd.org/stable/docs/autojac/backward/) while being +generally much faster due to the lower memory usage. The following example shows how to use TorchJD to train a multi-task model with Jacobian descent, using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/). @@ -66,7 +81,7 @@ using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/). from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD -+ from torchjd import mtl_backward ++ from torchjd.autojac import mtl_backward + from torchjd.aggregation import UPGrad shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) @@ -104,49 +119,120 @@ using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/). > In this example, the Jacobian is only with respect to the shared parameters. The task-specific > parameters are simply updated via the gradient of their task’s loss with respect to them. -More usage examples can be found [here](https://torchjd.org/stable/examples/). +The following example shows how to use TorchJD to minimize the vector of per-instance losses with +Jacobian descent using [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/). -## Supported Aggregators -TorchJD provides many existing aggregators from the literature, listed in the following table. +```diff + import torch + from torch.nn import Linear, MSELoss, ReLU, Sequential + from torch.optim import SGD - -| Aggregator | Publication | -|-----------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/) (recommended) | [Jacobian Descent For Multi-Objective Optimization](https://arxiv.org/pdf/2406.16232) | -| [AlignedMTL](https://torchjd.org/stable/docs/aggregation/aligned_mtl/) | [Independent Component Alignment for Multi-Task Learning](https://arxiv.org/pdf/2305.19000) | -| [CAGrad](https://torchjd.org/stable/docs/aggregation/cagrad/) | [Conflict-Averse Gradient Descent for Multi-task Learning](https://arxiv.org/pdf/2110.14048) | -| [ConFIG](https://torchjd.org/stable/docs/aggregation/config/) | [ConFIG: Towards Conflict-free Training of Physics Informed Neural Networks](https://arxiv.org/pdf/2408.11104) | -| [Constant](https://torchjd.org/stable/docs/aggregation/constant/) | - | -| [DualProj](https://torchjd.org/stable/docs/aggregation/dualproj/) | [Gradient Episodic Memory for Continual Learning](https://arxiv.org/pdf/1706.08840) | -| [GradDrop](https://torchjd.org/stable/docs/aggregation/graddrop/) | [Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout](https://arxiv.org/pdf/2010.06808) | -| [IMTL-G](https://torchjd.org/stable/docs/aggregation/imtl_g/) | [Towards Impartial Multi-task Learning](https://discovery.ucl.ac.uk/id/eprint/10120667/) | -| [Krum](https://torchjd.org/stable/docs/aggregation/krum/) | [Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent](https://proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-Paper.pdf) | -| [Mean](https://torchjd.org/stable/docs/aggregation/mean/) | - | -| [MGDA](https://torchjd.org/stable/docs/aggregation/mgda/) | [Multiple-gradient descent algorithm (MGDA) for multiobjective optimization](https://www.sciencedirect.com/science/article/pii/S1631073X12000738) | -| [Nash-MTL](https://torchjd.org/stable/docs/aggregation/nash_mtl/) | [Multi-Task Learning as a Bargaining Game](https://arxiv.org/pdf/2202.01017) | -| [PCGrad](https://torchjd.org/stable/docs/aggregation/pcgrad/) | [Gradient Surgery for Multi-Task Learning](https://arxiv.org/pdf/2001.06782) | -| [Random](https://torchjd.org/stable/docs/aggregation/random/) | [Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning](https://arxiv.org/pdf/2111.10603) | -| [Sum](https://torchjd.org/stable/docs/aggregation/sum/) | - | -| [Trimmed Mean](https://torchjd.org/stable/docs/aggregation/trimmed_mean/) | [Byzantine-Robust Distributed Learning: Towards Optimal Statistical Rates](https://proceedings.mlr.press/v80/yin18a/yin18a.pdf) | - -The following example shows how to instantiate -[UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad/) and aggregate a simple matrix `J` with -it. -```python -from torch import tensor -from torchjd.aggregation import UPGrad ++ from torchjd.autogram import Engine ++ from torchjd.aggregation import UPGradWeighting + + model = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU(), Linear(3, 1), ReLU()) -A = UPGrad() -J = tensor([[-4., 1., 1.], [6., 1., 1.]]) +- loss_fn = MSELoss() ++ loss_fn = MSELoss(reduction="none") + optimizer = SGD(model.parameters(), lr=0.1) -A(J) -# Output: tensor([0.2929, 1.9004, 1.9004]) ++ weighting = UPGradWeighting() ++ engine = Engine(model, batch_dim=0) + + inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 + targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task + + for input, target in zip(inputs, targets): + output = model(input).squeeze(dim=1) # shape [16] +- loss = loss_fn(output, target) # shape [1] ++ losses = loss_fn(output, target) # shape [16] + + optimizer.zero_grad() +- loss.backward() ++ gramian = engine.compute_gramian(losses) # shape: [16, 16] ++ weights = weighting(gramian) # shape: [16] ++ losses.backward(weights) + optimizer.step() ``` -> [!TIP] -> When using TorchJD, you generally don't have to use aggregators directly. You simply instantiate -> one and pass it to the backward function (`torchjd.backward` or `torchjd.mtl_backward`), which -> will in turn apply it to the Jacobian matrix that it will compute. +Lastly, you can even combine the two approaches by considering multiple tasks and each element of +the batch independently. We call that Instance-Wise Multitask Learning (IWMTL). + +```python +import torch +from torch.nn import Linear, MSELoss, ReLU, Sequential +from torch.optim import SGD + +from torchjd.aggregation import Flattening, UPGradWeighting +from torchjd.autogram import Engine + +shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) +task1_module = Linear(3, 1) +task2_module = Linear(3, 1) +params = [ + *shared_module.parameters(), + *task1_module.parameters(), + *task2_module.parameters(), +] + +optimizer = SGD(params, lr=0.1) +mse = MSELoss(reduction="none") +weighting = Flattening(UPGradWeighting()) +engine = Engine(shared_module, batch_dim=0) + +inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 +task1_targets = torch.randn(8, 16) # 8 batches of 16 targets for the first task +task2_targets = torch.randn(8, 16) # 8 batches of 16 targets for the second task + +for input, target1, target2 in zip(inputs, task1_targets, task2_targets): + features = shared_module(input) # shape: [16, 3] + out1 = task1_module(features).squeeze(1) # shape: [16] + out2 = task2_module(features).squeeze(1) # shape: [16] + + # Compute the matrix of losses: one loss per element of the batch and per task + losses = torch.stack([mse(out1, target1), mse(out2, target2)], dim=1) # shape: [16, 2] + + # Compute the gramian (inner products between pairs of gradients of the losses) + gramian = engine.compute_gramian(losses) # shape: [16, 2, 2, 16] + + # Obtain the weights that lead to no conflict between reweighted gradients + weights = weighting(gramian) # shape: [16, 2] + + optimizer.zero_grad() + # Do the standard backward pass, but weighted using the obtained weights + losses.backward(weights) + optimizer.step() +``` + +> [!NOTE] +> Here, because the losses are a matrix instead of a simple vector, we compute a *generalized +> Gramian* and we extract weights from it using a +> [GeneralizedWeighting](https://torchjd.org/docs/aggregation/index.html#torchjd.aggregation.GeneralizedWeighting). + +More usage examples can be found [here](https://torchjd.org/stable/examples/). + +## Supported Aggregators and Weightings +TorchJD provides many existing aggregators from the literature, listed in the following table. + + +| Aggregator | Weighting | Publication | +|------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [UPGrad](https://torchjd.org/stable/docs/aggregation/upgrad.html#torchjd.aggregation.UPGrad) (recommended) | [UPGradWeighting](https://torchjd.org/stable/docs/aggregation/upgrad#torchjd.aggregation.UPGradWeighting) | [Jacobian Descent For Multi-Objective Optimization](https://arxiv.org/pdf/2406.16232) | +| [AlignedMTL](https://torchjd.org/stable/docs/aggregation/aligned_mtl#torchjd.aggregation.AlignedMTL) | [AlignedMTLWeighting](https://torchjd.org/stable/docs/aggregation/aligned_mtl#torchjd.aggregation.AlignedMTLWeighting) | [Independent Component Alignment for Multi-Task Learning](https://arxiv.org/pdf/2305.19000) | +| [CAGrad](https://torchjd.org/stable/docs/aggregation/cagrad#torchjd.aggregation.CAGrad) | [CAGradWeighting](https://torchjd.org/stable/docs/aggregation/cagrad#torchjd.aggregation.CAGradWeighting) | [Conflict-Averse Gradient Descent for Multi-task Learning](https://arxiv.org/pdf/2110.14048) | +| [ConFIG](https://torchjd.org/stable/docs/aggregation/config#torchjd.aggregation.ConFIG) | - | [ConFIG: Towards Conflict-free Training of Physics Informed Neural Networks](https://arxiv.org/pdf/2408.11104) | +| [Constant](https://torchjd.org/stable/docs/aggregation/constant#torchjd.aggregation.Constant) | [ConstantWeighting](https://torchjd.org/stable/docs/aggregation/constant#torchjd.aggregation.ConstantWeighting) | - | +| [DualProj](https://torchjd.org/stable/docs/aggregation/dualproj#torchjd.aggregation.DualProj) | [DualProjWeighting](https://torchjd.org/stable/docs/aggregation/dualproj#torchjd.aggregation.DualProjWeighting) | [Gradient Episodic Memory for Continual Learning](https://arxiv.org/pdf/1706.08840) | +| [GradDrop](https://torchjd.org/stable/docs/aggregation/graddrop#torchjd.aggregation.GradDrop) | - | [Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout](https://arxiv.org/pdf/2010.06808) | +| [IMTLG](https://torchjd.org/stable/docs/aggregation/imtl_g#torchjd.aggregation.IMTLG) | [IMTLGWeighting](https://torchjd.org/stable/docs/aggregation/imtl_g#torchjd.aggregation.IMTLGWeighting) | [Towards Impartial Multi-task Learning](https://discovery.ucl.ac.uk/id/eprint/10120667/) | +| [Krum](https://torchjd.org/stable/docs/aggregation/krum#torchjd.aggregation.Krum) | [KrumWeighting](https://torchjd.org/stable/docs/aggregation/krum#torchjd.aggregation.KrumWeighting) | [Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent](https://proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-Paper.pdf) | +| [Mean](https://torchjd.org/stable/docs/aggregation/mean#torchjd.aggregation.Mean) | [MeanWeighting](https://torchjd.org/stable/docs/aggregation/mean#torchjd.aggregation.MeanWeighting) | - | +| [MGDA](https://torchjd.org/stable/docs/aggregation/mgda#torchjd.aggregation.MGDA) | [MGDAWeighting](https://torchjd.org/stable/docs/aggregation/mgda#torchjd.aggregation.MGDAWeighting) | [Multiple-gradient descent algorithm (MGDA) for multiobjective optimization](https://www.sciencedirect.com/science/article/pii/S1631073X12000738) | +| [NashMTL](https://torchjd.org/stable/docs/aggregation/nash_mtl#torchjd.aggregation.NashMTL) | - | [Multi-Task Learning as a Bargaining Game](https://arxiv.org/pdf/2202.01017) | +| [PCGrad](https://torchjd.org/stable/docs/aggregation/pcgrad#torchjd.aggregation.PCGrad) | [PCGradWeighting](https://torchjd.org/stable/docs/aggregation/pcgrad#torchjd.aggregation.PCGradWeighting) | [Gradient Surgery for Multi-Task Learning](https://arxiv.org/pdf/2001.06782) | +| [Random](https://torchjd.org/stable/docs/aggregation/random#torchjd.aggregation.Random) | [RandomWeighting](https://torchjd.org/stable/docs/aggregation/random#torchjd.aggregation.RandomWeighting) | [Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning](https://arxiv.org/pdf/2111.10603) | +| [Sum](https://torchjd.org/stable/docs/aggregation/sum#torchjd.aggregation.Sum) | [SumWeighting](https://torchjd.org/stable/docs/aggregation/sum#torchjd.aggregation.SumWeighting) | - | +| [Trimmed Mean](https://torchjd.org/stable/docs/aggregation/trimmed_mean#torchjd.aggregation.TrimmedMean) | - | [Byzantine-Robust Distributed Learning: Towards Optimal Statistical Rates](https://proceedings.mlr.press/v80/yin18a/yin18a.pdf) | ## Contribution Please read the [Contribution page](CONTRIBUTING.md).