Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion KD_Lib/KD/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def _train_student(

epoch_acc = correct / length_of_dataset

_, epoch_val_acc = self._evaluate_model(self.student_model, verbose=True)
epoch_val_acc = self.evaluate()

if epoch_val_acc > best_acc:
best_acc = epoch_val_acc
Expand Down
4 changes: 1 addition & 3 deletions KD_Lib/KD/vision/DML/dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,12 @@ def _evaluate_model(self, model, verbose=True):
def evaluate(self):
"""
Evaluate method for printing accuracies of the trained student networks

"""

for i, student in enumerate(self.student_cohort):
print("-" * 80)
model = deepcopy(student).to(self.device)
print(f"Evaluating student {i}")
_ = self._evaluate_model(model)
_, _ = self._evaluate_model(model)

def get_parameters(self):
"""
Expand Down
5 changes: 3 additions & 2 deletions KD_Lib/Pruning/lottery_tickets/lottery_tickets.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _train_pruned_model(self):
eps = 1e-6
self.model.train()
correct = 0

training_loss = 0.0
step = 0
for data, targets in self.train_loader:
self.optimizer.zero_grad()
Expand All @@ -192,6 +192,7 @@ def _train_pruned_model(self):
outputs = outputs[0]

train_loss = self.loss_fn(outputs, targets)
training_loss += train_loss.item()
train_loss.backward()

pred = outputs.argmax(dim=1, keepdim=True)
Expand All @@ -212,7 +213,7 @@ def _train_pruned_model(self):
step += 1

train_acc = 100.0 * correct / len(self.train_loader.dataset)
return train_loss.item(), train_acc
return training_loss / len(self.train_loader), train_acc

def _save_model(self, prune_it, best_weights):
file_name = f"{os.getcwd()}/pruned_model_{prune_it}.pth.tar"
Expand Down
115 changes: 61 additions & 54 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,67 +13,74 @@
LONG_DESCRIPTION = f.read()

# Define the keywords
KEYWORDS = ["Knowledge Distillation", "Pruning", "Quantization", "pytorch", "machine learning", "deep learning"]
KEYWORDS = [
"Knowledge Distillation",
"Pruning",
"Quantization",
"pytorch",
"machine learning",
"deep learning",
]
REQUIRE_PATH = "requirements.txt"
PROJECT = os.path.abspath(os.path.dirname(__file__))
setup_requirements = ['pytest-runner']
setup_requirements = ["pytest-runner"]

test_requirements = ['pytest', 'pytest-cov']
test_requirements = ["pytest", "pytest-cov"]

requirements = [
'pip==19.3.1',
'transformers==4.6.1',
'sacremoses',
'tokenizers==0.10.1',
'huggingface-hub==0.0.8',
'torchtext==0.9.1',
'bumpversion==0.5.3',
'wheel==0.32.1',
'watchdog==0.9.0',
'flake8==3.5.0',
'tox==3.5.2',
'coverage==4.5.1',
'Sphinx==1.8.1',
'twine==1.12.1',
'pytest==3.8.2',
'pytest-runner==4.2',
'pytest-cov==2.6.1',
'matplotlib==3.2.1',
'torch==1.8.1',
'torchvision==0.9.1',
'tensorboard==2.2.1',
'contextlib2==0.6.0.post1',
'pandas==1.0.1',
'tqdm==4.42.1',
'numpy==1.18.1',
'sphinx-rtd-theme==0.5.0',
"pip==19.3.1",
"transformers==4.6.1",
"sacremoses",
"tokenizers==0.10.1",
"huggingface-hub==0.0.8",
"torchtext==0.9.1",
"bumpversion==0.5.3",
"wheel==0.32.1",
"watchdog==0.9.0",
"flake8==3.5.0",
"tox==3.5.2",
"coverage==4.5.1",
"Sphinx==1.8.1",
"twine==1.12.1",
"pytest==3.8.2",
"pytest-runner==4.2",
"pytest-cov==2.6.1",
"matplotlib==3.2.1",
"torch==1.8.1",
"torchvision==0.9.1",
"tensorboard==2.2.1",
"contextlib2==0.6.0.post1",
"pandas==1.0.1",
"tqdm==4.42.1",
"numpy==1.18.1",
"sphinx-rtd-theme==0.5.0",
]


if __name__ == "__main__":
setup(
author="Het Shah",
author_email='[email protected]',
classifiers=[
'Development Status :: 2 - Pre-Alpha',
'Intended Audience :: Developers',
'License :: OSI Approved :: MIT License',
'Natural Language :: English',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
],
description="A Pytorch Library to help extend all Knowledge Distillation works",
install_requires=requirements,
license="MIT license",
long_description=LONG_DESCRIPTION,
include_package_data=True,
keywords=KEYWORDS,
name='KD_Lib',
packages=find_packages(where=PROJECT),
setup_requires=setup_requirements,
test_suite="tests",
tests_require=test_requirements,
url="https://github.com/SforAiDL/KD_Lib",
version='0.0.29',
zip_safe=False,
)
author="Het Shah",
author_email="[email protected]",
classifiers=[
"Development Status :: 2 - Pre-Alpha",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
],
description="A Pytorch Library to help extend all Knowledge Distillation works",
install_requires=requirements,
license="MIT license",
long_description=LONG_DESCRIPTION,
include_package_data=True,
keywords=KEYWORDS,
name="KD_Lib",
packages=find_packages(where=PROJECT),
setup_requires=setup_requirements,
test_suite="tests",
tests_require=test_requirements,
url="https://github.com/SforAiDL/KD_Lib",
version="0.0.29",
zip_safe=False,
)
1 change: 0 additions & 1 deletion tests/test_KD_Lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ def test_BaseClass():
distiller.train_teacher(epochs=1, plot_losses=True, save_model=True)
distiller.train_student(epochs=1, plot_losses=True, save_model=True)
distiller.evaluate(teacher=False)

distiller.get_parameters()


Expand Down