From 06e731ee51a98efed2621b241a505de93f5181eb Mon Sep 17 00:00:00 2001
From: "allcontributors[bot]"
<46447321+allcontributors[bot]@users.noreply.github.com>
Date: Thu, 10 Jun 2021 12:09:26 +0000
Subject: [PATCH 1/6] docs: update CONTRIBUTORS.md [skip ci]
---
CONTRIBUTORS.md | 13 +++++++------
1 file changed, 7 insertions(+), 6 deletions(-)
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index 1656769..ff6f2e2 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -8,12 +8,13 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
From 4798882f72bf70eaaa10038cd9f2fc56cc55b914 Mon Sep 17 00:00:00 2001
From: "allcontributors[bot]"
<46447321+allcontributors[bot]@users.noreply.github.com>
Date: Thu, 10 Jun 2021 12:09:27 +0000
Subject: [PATCH 2/6] docs: update .all-contributorsrc [skip ci]
---
.all-contributorsrc | 11 ++++++++++-
1 file changed, 10 insertions(+), 1 deletion(-)
diff --git a/.all-contributorsrc b/.all-contributorsrc
index 3e6ca7f..39d3440 100644
--- a/.all-contributorsrc
+++ b/.all-contributorsrc
@@ -1,6 +1,6 @@
{
"projectName": "Ensemble-Pytorch",
- "projectOwner": "xuyxu",
+ "projectOwner": "TorchEnsemble-Community",
"repoType": "github",
"repoHost": "https://github.com",
"files": [
@@ -69,6 +69,15 @@
"contributions": [
"code"
]
+ },
+ {
+ "login": "e-eight",
+ "name": "Soham Pal",
+ "avatar_url": "https://avatars.githubusercontent.com/u/3883241?v=4",
+ "profile": "https://soham.dev",
+ "contributions": [
+ "code"
+ ]
}
],
"contributorsPerLine": 7,
From 673eb0933e48de0fa2315b459d9d58c7dd22535a Mon Sep 17 00:00:00 2001
From: e-eight
Date: Mon, 14 Jun 2021 10:55:19 -0500
Subject: [PATCH 3/6] Added LBFGS optimizer for Fusion
---
torchensemble/_constants.py | 2 +-
torchensemble/fusion.py | 56 ++++++++++++++++-------
torchensemble/tests/test_set_optimizer.py | 15 ++++--
torchensemble/utils/set_module.py | 1 +
4 files changed, 51 insertions(+), 23 deletions(-)
diff --git a/torchensemble/_constants.py b/torchensemble/_constants.py
index be276e6..66796db 100644
--- a/torchensemble/_constants.py
+++ b/torchensemble/_constants.py
@@ -64,7 +64,7 @@
optimizer_name : string
The name of the optimizer, should be one of {``Adadelta``, ``Adagrad``,
``Adam``, ``AdamW``, ``Adamax``, ``ASGD``, ``RMSprop``, ``Rprop``,
- ``SGD``}.
+ ``SGD``, ``LBFGS``}.
**kwargs : keyword arguments
Keyword arguments on setting the optimizer, should be in the form:
``lr=1e-3, weight_decay=5e-4, ...``. These keyword arguments
diff --git a/torchensemble/fusion.py b/torchensemble/fusion.py
index 3cee94d..6138f49 100644
--- a/torchensemble/fusion.py
+++ b/torchensemble/fusion.py
@@ -10,12 +10,10 @@
import torch.nn as nn
import torch.nn.functional as F
-from ._base import BaseClassifier, BaseRegressor
-from ._base import torchensemble_model_doc
+from ._base import BaseClassifier, BaseRegressor, torchensemble_model_doc
from .utils import io
-from .utils import set_module
from .utils import operator as op
-
+from .utils import set_module
__all__ = ["FusionClassifier", "FusionRegressor"]
@@ -99,11 +97,20 @@ def fit(
data, target = io.split_data_target(elem, self.device)
batch_size = data[0].size(0)
- optimizer.zero_grad()
+ def closure():
+ if torch.is_grad_enabled():
+ optimizer.zero_grad()
+ output = self._forward(*data)
+ loss = criterion(output, target)
+ if loss.requires_grad():
+ loss.backward()
+ return loss
+
+ optimizer.step(closure)
+
+ # Calculate loss for logging
output = self._forward(*data)
- loss = criterion(output, target)
- loss.backward()
- optimizer.step()
+ loss = closure()
# Print training status
if batch_idx % log_interval == 0:
@@ -117,12 +124,16 @@ def fit(
)
self.logger.info(
msg.format(
- epoch, batch_idx, loss, correct, batch_size
+ epoch,
+ batch_idx,
+ loss.item(),
+ correct,
+ batch_size,
)
)
if self.tb_logger:
self.tb_logger.add_scalar(
- "fusion/Train_Loss", loss, total_iters
+ "fusion/Train_Loss", loss.item(), total_iters
)
total_iters += 1
@@ -237,20 +248,31 @@ def fit(
data, target = io.split_data_target(elem, self.device)
- optimizer.zero_grad()
- output = self.forward(*data)
- loss = criterion(output, target)
- loss.backward()
- optimizer.step()
+ def closure():
+ if torch.is_grad_enabled():
+ optimizer.zero_grad()
+ output = self._forward(*data)
+ loss = criterion(output, target)
+ if loss.requires_grad():
+ loss.backward()
+ return loss
+
+ optimizer.step(closure)
+
+ # Calculate loss for logging
+ output = self._forward(*data)
+ loss = closure()
# Print training status
if batch_idx % log_interval == 0:
with torch.no_grad():
msg = "Epoch: {:03d} | Batch: {:03d} | Loss: {:.5f}"
- self.logger.info(msg.format(epoch, batch_idx, loss))
+ self.logger.info(
+ msg.format(epoch, batch_idx, loss.item())
+ )
if self.tb_logger:
self.tb_logger.add_scalar(
- "fusion/Train_Loss", loss, total_iters
+ "fusion/Train_Loss", loss.item(), total_iters
)
total_iters += 1
diff --git a/torchensemble/tests/test_set_optimizer.py b/torchensemble/tests/test_set_optimizer.py
index 393c08f..1991ede 100644
--- a/torchensemble/tests/test_set_optimizer.py
+++ b/torchensemble/tests/test_set_optimizer.py
@@ -1,7 +1,6 @@
import pytest
-import torchensemble
import torch.nn as nn
-
+import torchensemble
optimizer_list = [
"Adadelta",
@@ -13,6 +12,7 @@
"RMSprop",
"Rprop",
"SGD",
+ "LBFGS",
]
@@ -33,9 +33,14 @@ def forward(self, X):
@pytest.mark.parametrize("optimizer_name", optimizer_list)
def test_set_optimizer_normal(optimizer_name):
model = MLP()
- torchensemble.utils.set_module.set_optimizer(
- model, optimizer_name, lr=1e-3
- )
+ if optimizer_name != "LBFGS":
+ torchensemble.utils.set_module.set_optimizer(
+ model, optimizer_name, lr=1e-3
+ )
+ else:
+ torchensemble.utils.set_module.set_optimizer(
+ model, optimizer_name, history_size=7, max_iter=10
+ )
def test_set_optimizer_Unknown():
diff --git a/torchensemble/utils/set_module.py b/torchensemble/utils/set_module.py
index 750ffe8..3db6dc6 100644
--- a/torchensemble/utils/set_module.py
+++ b/torchensemble/utils/set_module.py
@@ -18,6 +18,7 @@ def set_optimizer(model, optimizer_name, **kwargs):
"RMSprop",
"Rprop",
"SGD",
+ "LBFGS"
]
if optimizer_name not in torch_optim_optimizers:
msg = "Unrecognized optimizer: {}, should be one of {}."
From 084a76f07923160f75dd4ccf58cddfd240e87642 Mon Sep 17 00:00:00 2001
From: e-eight
Date: Sun, 20 Jun 2021 23:56:08 -0500
Subject: [PATCH 4/6] Fixed code-quality errors.
---
torchensemble/fusion.py | 8 ++++----
torchensemble/utils/set_module.py | 2 +-
2 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/torchensemble/fusion.py b/torchensemble/fusion.py
index 6138f49..0c8973e 100644
--- a/torchensemble/fusion.py
+++ b/torchensemble/fusion.py
@@ -102,7 +102,7 @@ def closure():
optimizer.zero_grad()
output = self._forward(*data)
loss = criterion(output, target)
- if loss.requires_grad():
+ if loss.requires_grad:
loss.backward()
return loss
@@ -251,16 +251,16 @@ def fit(
def closure():
if torch.is_grad_enabled():
optimizer.zero_grad()
- output = self._forward(*data)
+ output = self.forward(*data)
loss = criterion(output, target)
- if loss.requires_grad():
+ if loss.requires_grad:
loss.backward()
return loss
optimizer.step(closure)
# Calculate loss for logging
- output = self._forward(*data)
+ output = self.forward(*data)
loss = closure()
# Print training status
diff --git a/torchensemble/utils/set_module.py b/torchensemble/utils/set_module.py
index 3db6dc6..dcbb312 100644
--- a/torchensemble/utils/set_module.py
+++ b/torchensemble/utils/set_module.py
@@ -18,7 +18,7 @@ def set_optimizer(model, optimizer_name, **kwargs):
"RMSprop",
"Rprop",
"SGD",
- "LBFGS"
+ "LBFGS",
]
if optimizer_name not in torch_optim_optimizers:
msg = "Unrecognized optimizer: {}, should be one of {}."
From 44ffcee1c7f7142df07e0c64f42de08a8e216cc4 Mon Sep 17 00:00:00 2001
From: e-eight
Date: Wed, 14 Jul 2021 15:09:00 -0500
Subject: [PATCH 5/6] Added LBFGS optimizer to bagging and voting
---
torchensemble/bagging.py | 36 ++++++++++++++++++++++++------------
torchensemble/voting.py | 31 +++++++++++++++++++------------
2 files changed, 43 insertions(+), 24 deletions(-)
diff --git a/torchensemble/bagging.py b/torchensemble/bagging.py
index ab44199..e10a600 100644
--- a/torchensemble/bagging.py
+++ b/torchensemble/bagging.py
@@ -6,19 +6,17 @@
"""
+import warnings
+
import torch
import torch.nn as nn
import torch.nn.functional as F
-
-import warnings
from joblib import Parallel, delayed
-from ._base import BaseClassifier, BaseRegressor
-from ._base import torchensemble_model_doc
+from ._base import BaseClassifier, BaseRegressor, torchensemble_model_doc
from .utils import io
-from .utils import set_module
from .utils import operator as op
-
+from .utils import set_module
__all__ = ["BaggingClassifier", "BaggingRegressor"]
@@ -59,11 +57,20 @@ def _parallel_fit_per_epoch(
sampling_data = [tensor[sampling_mask] for tensor in data]
sampling_target = target[sampling_mask]
- optimizer.zero_grad()
+ def closure():
+ if torch.is_grad_enabled():
+ optimizer.zero_grad()
+ sampling_output = estimator(*sampling_data)
+ loss = criterion(sampling_output, sampling_target)
+ if loss.requires_grad:
+ loss.backward()
+ return loss
+
+ optimizer.step(closure)
+
+ # Calculate loss for logging
sampling_output = estimator(*sampling_data)
- loss = criterion(sampling_output, sampling_target)
- loss.backward()
- optimizer.step()
+ loss = closure()
# Print training status
if batch_idx % log_interval == 0:
@@ -79,7 +86,12 @@ def _parallel_fit_per_epoch(
)
print(
msg.format(
- idx, epoch, batch_idx, loss, correct, subsample_size
+ idx,
+ epoch,
+ batch_idx,
+ loss.item(),
+ correct,
+ subsample_size,
)
)
else:
@@ -87,7 +99,7 @@ def _parallel_fit_per_epoch(
"Estimator: {:03d} | Epoch: {:03d} | Batch: {:03d}"
" | Loss: {:.5f}"
)
- print(msg.format(idx, epoch, batch_idx, loss))
+ print(msg.format(idx, epoch, batch_idx, loss.item()))
return estimator, optimizer
diff --git a/torchensemble/voting.py b/torchensemble/voting.py
index 5edd07a..6ede86f 100644
--- a/torchensemble/voting.py
+++ b/torchensemble/voting.py
@@ -5,19 +5,17 @@
"""
+import warnings
+
import torch
import torch.nn as nn
import torch.nn.functional as F
-
-import warnings
from joblib import Parallel, delayed
-from ._base import BaseClassifier, BaseRegressor
-from ._base import torchensemble_model_doc
+from ._base import BaseClassifier, BaseRegressor, torchensemble_model_doc
from .utils import io
-from .utils import set_module
from .utils import operator as op
-
+from .utils import set_module
__all__ = ["VotingClassifier", "VotingRegressor"]
@@ -49,11 +47,20 @@ def _parallel_fit_per_epoch(
data, target = io.split_data_target(elem, device)
batch_size = data[0].size(0)
- optimizer.zero_grad()
+ def closure():
+ if torch.is_grad_enabled():
+ optimizer.zero_grad()
+ output = estimator(*data)
+ loss = criterion(output, target)
+ if loss.requires_grad:
+ loss.backward()
+ return loss
+
+ optimizer.step(closure)
+
+ # Calculate loss for logging
output = estimator(*data)
- loss = criterion(output, target)
- loss.backward()
- optimizer.step()
+ loss = closure()
# Print training status
if batch_idx % log_interval == 0:
@@ -69,7 +76,7 @@ def _parallel_fit_per_epoch(
)
print(
msg.format(
- idx, epoch, batch_idx, loss, correct, batch_size
+ idx, epoch, batch_idx, loss.item(), correct, batch_size
)
)
# Regression
@@ -78,7 +85,7 @@ def _parallel_fit_per_epoch(
"Estimator: {:03d} | Epoch: {:03d} | Batch: {:03d}"
" | Loss: {:.5f}"
)
- print(msg.format(idx, epoch, batch_idx, loss))
+ print(msg.format(idx, epoch, batch_idx, loss.item()))
return estimator, optimizer
From 39da90c83566cfa9a2025303f9905c60e4835da3 Mon Sep 17 00:00:00 2001
From: xuyxu
Date: Sun, 12 Sep 2021 10:49:13 +0800
Subject: [PATCH 6/6] Update fusion.py
---
torchensemble/fusion.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/torchensemble/fusion.py b/torchensemble/fusion.py
index 0869aad..a573968 100644
--- a/torchensemble/fusion.py
+++ b/torchensemble/fusion.py
@@ -111,7 +111,7 @@ def closure():
if torch.is_grad_enabled():
optimizer.zero_grad()
output = self._forward(*data)
- loss = criterion(output, target)
+ loss = self._criterion(output, target)
if loss.requires_grad:
loss.backward()
return loss
@@ -272,7 +272,7 @@ def closure():
if torch.is_grad_enabled():
optimizer.zero_grad()
output = self.forward(*data)
- loss = criterion(output, target)
+ loss = self._criterion(output, target)
if loss.requires_grad:
loss.backward()
return loss