diff --git a/CHANGELOG.rst b/CHANGELOG.rst index c2bcc7d..47b8dd2 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -18,6 +18,7 @@ Changelog Ver 0.1.* --------- +* |Feature| |API| Add :meth:`vectorize` for faster ensemble model inference using :mod:`functorch` (requiring :mod:`torch` version >= 1.13.0) | `@xuyxu `__ * |Feature| |API| Add ``voting_strategy`` parameter for :class:`VotingClassifer`, :class:`NeuralForestClassifier`, and :class:`SnapshotEnsembleClassifier` | `@LukasGardberg `__ * |Fix| Fix the sampling issue in :class:`BaggingClassifier` and :class:`BaggingRegressor` | `@SunHaozhe `__ * |Feature| |API| Add :class:`NeuralForestClassifier` and :class:`NeuralForestRegressor` | `@xuyxu `__ diff --git a/docs/advanced.rst b/docs/advanced.rst new file mode 100644 index 0000000..ceb7988 --- /dev/null +++ b/docs/advanced.rst @@ -0,0 +1,15 @@ +Advanced Usage +============== + +The following sections outline advanced usage in :mod:`torchensemble`. + +Faster inference using functorch +-------------------------------- + +:mod:`functorch` has been integrated into Pytorch since the release of version 1.13, which is JAX-like composable function transforms for PyTorch. To enable faster inference of ensembles in :mod:`torchensemble`, you could use :meth:`vectorize` method of the ensemble to convert it into a stateless version (fmodel), and stacked parameters and buffers. + +The stateless model, parameters, along with buffers could be used to reduce the inference time using :meth:`vmap` in :mod:`functorch`. More details are available at `functorch documentation `__. The code snippet below demonstrates how to pass :meth:`ensemble.vectorize` results into :meth:`functorch.vmap`. + +.. code:: python + + from torchensemble import VotingClassifier # voting is a classic ensemble strategy \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 54c9a69..cdee677 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -72,6 +72,7 @@ Content Guidance Experiment API Reference + Advanced Usage .. toctree:: :maxdepth: 1 diff --git a/torchensemble/_base.py b/torchensemble/_base.py index c94b3cb..c7b8a81 100644 --- a/torchensemble/_base.py +++ b/torchensemble/_base.py @@ -29,9 +29,10 @@ def get_doc(item): __doc = { "model": const.__model_doc, "seq_model": const.__seq_model_doc, - "tree_ensmeble_model": const.__tree_ensemble_doc, + "tree_ensemble_model": const.__tree_ensemble_doc, "fit": const.__fit_doc, "predict": const.__predict_doc, + "vectorize": const.__vectorize_doc, "set_optimizer": const.__set_optimizer_doc, "set_scheduler": const.__set_scheduler_doc, "set_criterion": const.__set_criterion_doc, @@ -199,6 +200,21 @@ def predict(self, *x): pred = pred.cpu() return pred + def vectorize(self): + """Docstrings decorated by downstream ensembles.""" + try: + from functorch import combine_state_for_ensemble + except Exception: + msg = ( + "Failed to import functorch utils, please make sure the" + " Pytorch version >= 1.13.0." + ) + raise RuntimeError(msg) + + self.eval() + fmodel, params, buffers = combine_state_for_ensemble(self.estimators_) + return fmodel, params, buffers + class BaseTreeEnsemble(BaseModule): def __init__( diff --git a/torchensemble/_constants.py b/torchensemble/_constants.py index b7c362a..907d5bb 100644 --- a/torchensemble/_constants.py +++ b/torchensemble/_constants.py @@ -173,6 +173,21 @@ """ +__vectorize_doc = """ + Return the vectorization result of the ensemble using functorch. Details + available at `functorch model ensembling `_. + + Returns + ------- + fmodel : FunctionalModuleWithBuffers + Functional version of one of the models in the ensemble. + params : tuple + Tuple of stacked model parameters in the ensemble. + buffers : tuple + Tuple of buffers, empty if not exists. +""" # noqa: E501 + + __classification_forward_doc = """ Parameters ---------- diff --git a/torchensemble/tests/test_all_models.py b/torchensemble/tests/test_all_models.py index 00d1b78..6ea0ac7 100644 --- a/torchensemble/tests/test_all_models.py +++ b/torchensemble/tests/test_all_models.py @@ -3,6 +3,8 @@ import numpy as np import torch.nn as nn from numpy.testing import assert_array_equal + +from functorch import vmap from torch.utils.data import TensorDataset, DataLoader import torchensemble @@ -302,3 +304,79 @@ def test_predict(): with pytest.raises(ValueError) as excinfo: model.predict([X_test]) # list assert "The type of input X should be one of" in str(excinfo.value) + + +@pytest.mark.parametrize("clf", all_clf) +def test_clf_vectorize_same_output(clf): + """ + This unit test checks the inference with/without vectorize for all + classifiers. + """ + epochs = 2 + n_estimators = 2 + + model = clf(estimator=MLP_clf, n_estimators=n_estimators, cuda=False) + + # Optimizer + model.set_optimizer("Adam", lr=1e-3, weight_decay=5e-4) + + # Prepare data + train = TensorDataset(X_train, y_train_clf) + train_loader = DataLoader(train, batch_size=2, shuffle=False) + test = TensorDataset(X_test, y_test_clf) + test_loader = DataLoader(test, batch_size=2, shuffle=False) + + # Train + model.fit(train_loader, epochs=epochs, test_loader=test_loader) + + fmodel, params, buffers = model.vectorize() + + with torch.no_grad(): + for idx, (data, target) in enumerate(test_loader): + vmap_output = vmap(fmodel, in_dims=(0, 0, None))( + params, buffers, data + ) + pytorch_output = [ + estimator(data) for estimator in model.estimators_ + ] + assert torch.allclose( + vmap_output, torch.stack(pytorch_output), atol=1e-3, rtol=1e-5 + ) + + +@pytest.mark.parametrize("reg", all_reg) +def test_reg_vectorize_same_output(reg): + """ + This unit test checks the inference with/without vectorize for all + classifiers. + """ + epochs = 2 + n_estimators = 2 + + model = reg(estimator=MLP_reg, n_estimators=n_estimators, cuda=False) + + # Optimizer + model.set_optimizer("Adam", lr=1e-3, weight_decay=5e-4) + + # Prepare data + train = TensorDataset(X_train, y_train_reg) + train_loader = DataLoader(train, batch_size=2, shuffle=False) + test = TensorDataset(X_test, y_test_reg) + test_loader = DataLoader(test, batch_size=2, shuffle=False) + + # Train + model.fit(train_loader, epochs=epochs, test_loader=test_loader) + + fmodel, params, buffers = model.vectorize() + + with torch.no_grad(): + for idx, (data, target) in enumerate(test_loader): + vmap_output = vmap(fmodel, in_dims=(0, 0, None))( + params, buffers, data + ) + pytorch_output = [ + estimator(data) for estimator in model.estimators_ + ] + assert torch.allclose( + vmap_output, torch.stack(pytorch_output), atol=1e-3, rtol=1e-5 + ) diff --git a/torchensemble/voting.py b/torchensemble/voting.py index 54ed023..c108e30 100644 --- a/torchensemble/voting.py +++ b/torchensemble/voting.py @@ -307,9 +307,13 @@ def evaluate(self, test_loader, return_loss=False): def predict(self, *x): return super().predict(*x) + @torchensemble_model_doc(item="vectorize") + def vectorize(self): + return super().vectorize() + @torchensemble_model_doc( - """Implementation on the NeuralForestClassifier.""", "tree_ensmeble_model" + """Implementation on the NeuralForestClassifier.""", "tree_ensemble_model" ) class NeuralForestClassifier(BaseTreeEnsemble, VotingClassifier): def __init__(self, voting_strategy="soft", **kwargs): @@ -374,6 +378,10 @@ def fit( save_dir=save_dir, ) + @torchensemble_model_doc(item="vectorize") + def vectorize(self): + return super().vectorize() + @torchensemble_model_doc("""Implementation on the VotingRegressor.""", "model") class VotingRegressor(BaseRegressor): @@ -559,9 +567,13 @@ def evaluate(self, test_loader): def predict(self, *x): return super().predict(*x) + @torchensemble_model_doc(item="vectorize") + def vectorize(self): + return super().vectorize() + @torchensemble_model_doc( - """Implementation on the NeuralForestRegressor.""", "tree_ensmeble_model" + """Implementation on the NeuralForestRegressor.""", "tree_ensemble_model" ) class NeuralForestRegressor(BaseTreeEnsemble, VotingRegressor): @torchensemble_model_doc( @@ -620,3 +632,7 @@ def fit( save_model=save_model, save_dir=save_dir, ) + + @torchensemble_model_doc(item="vectorize") + def vectorize(self): + return super().vectorize()