Skip to content

Commit b97abd1

Browse files
Merge branch 'master' into fix-tests
2 parents 5d4bd7d + 1a1b0ec commit b97abd1

File tree

4 files changed

+47
-34
lines changed

4 files changed

+47
-34
lines changed

.github/workflows/python-package-test.yml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,6 @@ jobs:
3333
pip install build
3434
- name: Build package
3535
run: python -m build
36-
- name: Black
37-
run: |
38-
# stop the build if there are Python syntax errors or undefined names
39-
black --check KD_Lib
40-
black --check tests
4136
- name: Test with pytest
4237
run: |
4338
pytest

.travis.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ install:
1313
- pip install -U tox-travis codecov black
1414
- python setup.py install
1515

16-
1716
jobs:
1817
include:
1918
# Deploy Documentation

KD_Lib/KD/common/base_class.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -53,27 +53,25 @@ def __init__(
5353
if self.log:
5454
self.writer = SummaryWriter(logdir)
5555

56-
try:
57-
torch.Tensor(0).to(device)
58-
self.device = device
59-
except:
60-
print(
61-
"Either an invalid device or CUDA is not available. Defaulting to CPU."
62-
)
56+
if device == "cpu":
6357
self.device = torch.device("cpu")
58+
elif device == "cuda":
59+
if torch.cuda.is_available():
60+
self.device = torch.device("cuda")
61+
else:
62+
print(
63+
"Either an invalid device or CUDA is not available. Defaulting to CPU."
64+
)
65+
self.device = torch.device("cpu")
6466

65-
try:
67+
if teacher_model:
6668
self.teacher_model = teacher_model.to(self.device)
67-
except:
69+
else:
6870
print("Warning!!! Teacher is NONE.")
71+
6972
self.student_model = student_model.to(self.device)
70-
try:
71-
self.loss_fn = loss_fn.to(self.device)
72-
self.ce_fn = nn.CrossEntropyLoss().to(self.device)
73-
except:
74-
self.loss_fn = loss_fn
75-
self.ce_fn = nn.CrossEntropyLoss()
76-
print("Warning: Loss Function can't be moved to device.")
73+
self.loss_fn = loss_fn.to(self.device)
74+
self.ce_fn = nn.CrossEntropyLoss().to(self.device)
7775

7876
def train_teacher(
7977
self,
@@ -142,7 +140,7 @@ def train_teacher(
142140
)
143141

144142
loss_arr.append(epoch_loss)
145-
print(f"Epoch: {ep+1}, Loss: {epoch_loss}, Accuracy: {epoch_acc}")
143+
print("Epoch: {}, Loss: {}, Accuracy: {}".format(ep+1, epoch_loss, epoch_acc))
146144

147145
self.post_epoch_call(ep)
148146

@@ -224,7 +222,7 @@ def _train_student(
224222
)
225223

226224
loss_arr.append(epoch_loss)
227-
print(f"Epoch: {ep+1}, Loss: {epoch_loss}, Accuracy: {epoch_acc}")
225+
print("Epoch: {}, Loss: {}, Accuracy: {}".format(ep+1, epoch_loss, epoch_acc))
228226

229227
self.student_model.load_state_dict(self.best_student_model_weights)
230228
if save_model:
@@ -290,7 +288,7 @@ def _evaluate_model(self, model, verbose=True):
290288

291289
if verbose:
292290
print("-" * 80)
293-
print(f"Validation Accuracy: {accuracy}")
291+
print("Validation Accuracy: {}".format(accuracy))
294292
return outputs, accuracy
295293

296294
def evaluate(self, teacher=False):
@@ -315,8 +313,8 @@ def get_parameters(self):
315313
student_params = sum(p.numel() for p in self.student_model.parameters())
316314

317315
print("-" * 80)
318-
print(f"Total parameters for the teacher network are: {teacher_params}")
319-
print(f"Total parameters for the student network are: {student_params}")
316+
print("Total parameters for the teacher network are: {}".format(teacher_params))
317+
print("Total parameters for the student network are: {}".format(student_params))
320318

321319
def post_epoch_call(self, epoch):
322320
"""

tests/test_KD_Lib.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
ProbShift,
2424
LabelSmoothReg,
2525
DML,
26+
BaseClass
2627
)
2728

2829
from KD_Lib.models import (
@@ -94,6 +95,10 @@ def test_resnet():
9495
ResNet50(params)
9596
ResNet101(params)
9697
ResNet152(params)
98+
ResNet34(params, att=True)
99+
ResNet34(params, mean=True)
100+
ResNet101(params, att=True)
101+
ResNet101(params, mean=True)
97102

98103

99104
def test_attention_model():
@@ -159,6 +164,22 @@ def test_LSTMNet():
159164
# Strategy TESTS
160165
#
161166

167+
def test_BaseClass()
168+
teac = Shallow(hidden_size=400)
169+
stud = Shallow(hidden_size=100)
170+
171+
t_optimizer = optim.SGD(teac.parameters(), 0.01)
172+
s_optimizer = optim.SGD(stud.parameters(), 0.01)
173+
174+
distiller = BaseClass(
175+
teac, stud, train_loader, test_loader, t_optimizer, s_optimizer, log=True
176+
)
177+
178+
distiller.train_teacher(epochs=1, plot_losses=True, save_model=True)
179+
distiller.train_student(epochs=1, plot_losses=True, save_model=True)
180+
distiller.evaluate(teacher=False)
181+
distiller.get_parameters()
182+
162183

163184
def test_VanillaKD():
164185
teac = Shallow(hidden_size=400)
@@ -168,11 +189,11 @@ def test_VanillaKD():
168189
s_optimizer = optim.SGD(stud.parameters(), 0.01)
169190

170191
distiller = VanillaKD(
171-
teac, stud, train_loader, test_loader, t_optimizer, s_optimizer
192+
teac, stud, train_loader, test_loader, t_optimizer, s_optimizer, log=True
172193
)
173194

174-
distiller.train_teacher(epochs=1, plot_losses=False, save_model=False)
175-
distiller.train_student(epochs=1, plot_losses=False, save_model=False)
195+
distiller.train_teacher(epochs=1, plot_losses=True, save_model=True)
196+
distiller.train_student(epochs=1, plot_losses=True, save_model=True)
176197
distiller.evaluate(teacher=False)
177198
distiller.get_parameters()
178199

@@ -289,8 +310,8 @@ def test_SelfTraining():
289310

290311

291312
def test_mean_teacher():
292-
teacher_params = [4, 4, 8, 4, 4]
293-
student_params = [4, 4, 4, 4, 4]
313+
teacher_params = [16, 16, 32, 16, 16]
314+
student_params = [16, 16, 16, 16, 16]
294315
teacher_model = ResNet50(teacher_params, 1, 10, mean=True)
295316
student_model = ResNet18(student_params, 1, 10, mean=True)
296317

@@ -488,7 +509,7 @@ def test_lottery_tickets():
488509
teacher_params = [4, 4, 8, 4, 4]
489510
teacher_model = ResNet50(teacher_params, 1, 10, True)
490511
pruner = Lottery_Tickets_Pruner(teacher_model, train_loader, test_loader)
491-
pruner.prune(num_iterations=1, train_iterations=1, valid_freq=1, print_freq=1)
512+
pruner.prune(num_iterations=2, train_iterations=2, valid_freq=1, print_freq=1)
492513

493514

494515
#
@@ -539,6 +560,6 @@ def test_qat_quantization():
539560
model.fc.out_features = 10
540561
optimizer = torch.optim.Adam(model.parameters())
541562
quantizer = QAT_Quantizer(model, cifar_trainloader, cifar_testloader, optimizer)
542-
quantized_model = quantizer.quantize(1, 1, -1, -1)
563+
quantized_model = quantizer.quantize(1, 1, 1, 1)
543564
quantizer.get_model_sizes()
544565
quantizer.get_performance_statistics()

0 commit comments

Comments
 (0)