diff --git a/KD_Lib/KD/common/base_class.py b/KD_Lib/KD/common/base_class.py index fd83687c..064f1dbc 100755 --- a/KD_Lib/KD/common/base_class.py +++ b/KD_Lib/KD/common/base_class.py @@ -206,7 +206,7 @@ def _train_student( epoch_acc = correct / length_of_dataset - _, epoch_val_acc = self._evaluate_model(self.student_model, verbose=True) + epoch_val_acc = self.evaluate() if epoch_val_acc > best_acc: best_acc = epoch_val_acc diff --git a/KD_Lib/KD/vision/DML/dml.py b/KD_Lib/KD/vision/DML/dml.py index 8b37347d..32e0b91a 100755 --- a/KD_Lib/KD/vision/DML/dml.py +++ b/KD_Lib/KD/vision/DML/dml.py @@ -178,14 +178,12 @@ def _evaluate_model(self, model, verbose=True): def evaluate(self): """ Evaluate method for printing accuracies of the trained student networks - """ - for i, student in enumerate(self.student_cohort): print("-" * 80) model = deepcopy(student).to(self.device) print(f"Evaluating student {i}") - _ = self._evaluate_model(model) + _, _ = self._evaluate_model(model) def get_parameters(self): """ diff --git a/KD_Lib/Pruning/lottery_tickets/lottery_tickets.py b/KD_Lib/Pruning/lottery_tickets/lottery_tickets.py index 99b9932f..6c46cb97 100755 --- a/KD_Lib/Pruning/lottery_tickets/lottery_tickets.py +++ b/KD_Lib/Pruning/lottery_tickets/lottery_tickets.py @@ -181,7 +181,7 @@ def _train_pruned_model(self): eps = 1e-6 self.model.train() correct = 0 - + training_loss = 0.0 step = 0 for data, targets in self.train_loader: self.optimizer.zero_grad() @@ -192,6 +192,7 @@ def _train_pruned_model(self): outputs = outputs[0] train_loss = self.loss_fn(outputs, targets) + training_loss += train_loss.item() train_loss.backward() pred = outputs.argmax(dim=1, keepdim=True) @@ -212,7 +213,7 @@ def _train_pruned_model(self): step += 1 train_acc = 100.0 * correct / len(self.train_loader.dataset) - return train_loss.item(), train_acc + return training_loss / len(self.train_loader), train_acc def _save_model(self, prune_it, best_weights): file_name = f"{os.getcwd()}/pruned_model_{prune_it}.pth.tar" diff --git a/setup.py b/setup.py index dd5add73..d840bb84 100755 --- a/setup.py +++ b/setup.py @@ -13,67 +13,74 @@ LONG_DESCRIPTION = f.read() # Define the keywords -KEYWORDS = ["Knowledge Distillation", "Pruning", "Quantization", "pytorch", "machine learning", "deep learning"] +KEYWORDS = [ + "Knowledge Distillation", + "Pruning", + "Quantization", + "pytorch", + "machine learning", + "deep learning", +] REQUIRE_PATH = "requirements.txt" PROJECT = os.path.abspath(os.path.dirname(__file__)) -setup_requirements = ['pytest-runner'] +setup_requirements = ["pytest-runner"] -test_requirements = ['pytest', 'pytest-cov'] +test_requirements = ["pytest", "pytest-cov"] requirements = [ -'pip==19.3.1', -'transformers==4.6.1', -'sacremoses', -'tokenizers==0.10.1', -'huggingface-hub==0.0.8', -'torchtext==0.9.1', -'bumpversion==0.5.3', -'wheel==0.32.1', -'watchdog==0.9.0', -'flake8==3.5.0', -'tox==3.5.2', -'coverage==4.5.1', -'Sphinx==1.8.1', -'twine==1.12.1', -'pytest==3.8.2', -'pytest-runner==4.2', -'pytest-cov==2.6.1', -'matplotlib==3.2.1', -'torch==1.8.1', -'torchvision==0.9.1', -'tensorboard==2.2.1', -'contextlib2==0.6.0.post1', -'pandas==1.0.1', -'tqdm==4.42.1', -'numpy==1.18.1', -'sphinx-rtd-theme==0.5.0', + "pip==19.3.1", + "transformers==4.6.1", + "sacremoses", + "tokenizers==0.10.1", + "huggingface-hub==0.0.8", + "torchtext==0.9.1", + "bumpversion==0.5.3", + "wheel==0.32.1", + "watchdog==0.9.0", + "flake8==3.5.0", + "tox==3.5.2", + "coverage==4.5.1", + "Sphinx==1.8.1", + "twine==1.12.1", + "pytest==3.8.2", + "pytest-runner==4.2", + "pytest-cov==2.6.1", + "matplotlib==3.2.1", + "torch==1.8.1", + "torchvision==0.9.1", + "tensorboard==2.2.1", + "contextlib2==0.6.0.post1", + "pandas==1.0.1", + "tqdm==4.42.1", + "numpy==1.18.1", + "sphinx-rtd-theme==0.5.0", ] if __name__ == "__main__": setup( - author="Het Shah", - author_email='divhet163@gmail.com', - classifiers=[ - 'Development Status :: 2 - Pre-Alpha', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: MIT License', - 'Natural Language :: English', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - ], - description="A Pytorch Library to help extend all Knowledge Distillation works", - install_requires=requirements, - license="MIT license", - long_description=LONG_DESCRIPTION, - include_package_data=True, - keywords=KEYWORDS, - name='KD_Lib', - packages=find_packages(where=PROJECT), - setup_requires=setup_requirements, - test_suite="tests", - tests_require=test_requirements, - url="https://github.com/SforAiDL/KD_Lib", - version='0.0.29', - zip_safe=False, -) + author="Het Shah", + author_email="divhet163@gmail.com", + classifiers=[ + "Development Status :: 2 - Pre-Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Natural Language :: English", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + ], + description="A Pytorch Library to help extend all Knowledge Distillation works", + install_requires=requirements, + license="MIT license", + long_description=LONG_DESCRIPTION, + include_package_data=True, + keywords=KEYWORDS, + name="KD_Lib", + packages=find_packages(where=PROJECT), + setup_requires=setup_requirements, + test_suite="tests", + tests_require=test_requirements, + url="https://github.com/SforAiDL/KD_Lib", + version="0.0.29", + zip_safe=False, + ) diff --git a/tests/test_KD_Lib.py b/tests/test_KD_Lib.py index 95df44c3..0aab0882 100644 --- a/tests/test_KD_Lib.py +++ b/tests/test_KD_Lib.py @@ -179,7 +179,6 @@ def test_BaseClass(): distiller.train_teacher(epochs=1, plot_losses=True, save_model=True) distiller.train_student(epochs=1, plot_losses=True, save_model=True) distiller.evaluate(teacher=False) - distiller.get_parameters()