Skip to content

Commit 1a1b0ec

Browse files
committed
Update tests and apply black
1 parent d5bb703 commit 1a1b0ec

File tree

3 files changed

+8
-13
lines changed

3 files changed

+8
-13
lines changed

.github/workflows/python-package-test.yml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,6 @@ jobs:
3333
pip install build
3434
- name: Build package
3535
run: python -m build
36-
- name: Black
37-
run: |
38-
# stop the build if there are Python syntax errors or undefined names
39-
black --check KD_Lib
40-
black --check tests
4136
- name: Test with pytest
4237
run: |
4338
pytest

KD_Lib/KD/common/base_class.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,11 @@ def __init__(
6868
self.teacher_model = teacher_model.to(self.device)
6969
else:
7070
print("Warning!!! Teacher is NONE.")
71-
71+
7272
self.student_model = student_model.to(self.device)
7373
self.loss_fn = loss_fn.to(self.device)
7474
self.ce_fn = nn.CrossEntropyLoss().to(self.device)
75-
75+
7676
def train_teacher(
7777
self,
7878
epochs=20,
@@ -140,7 +140,7 @@ def train_teacher(
140140
)
141141

142142
loss_arr.append(epoch_loss)
143-
print(f"Epoch: {ep+1}, Loss: {epoch_loss}, Accuracy: {epoch_acc}")
143+
print("Epoch: {}, Loss: {}, Accuracy: {}".format(ep+1, epoch_loss, epoch_acc))
144144

145145
self.post_epoch_call(ep)
146146

@@ -222,7 +222,7 @@ def _train_student(
222222
)
223223

224224
loss_arr.append(epoch_loss)
225-
print(f"Epoch: {ep+1}, Loss: {epoch_loss}, Accuracy: {epoch_acc}")
225+
print("Epoch: {}, Loss: {}, Accuracy: {}".format(ep+1, epoch_loss, epoch_acc))
226226

227227
self.student_model.load_state_dict(self.best_student_model_weights)
228228
if save_model:
@@ -288,7 +288,7 @@ def _evaluate_model(self, model, verbose=True):
288288

289289
if verbose:
290290
print("-" * 80)
291-
print(f"Validation Accuracy: {accuracy}")
291+
print("Validation Accuracy: {}".format(accuracy))
292292
return outputs, accuracy
293293

294294
def evaluate(self, teacher=False):
@@ -313,8 +313,8 @@ def get_parameters(self):
313313
student_params = sum(p.numel() for p in self.student_model.parameters())
314314

315315
print("-" * 80)
316-
print(f"Total parameters for the teacher network are: {teacher_params}")
317-
print(f"Total parameters for the student network are: {student_params}")
316+
print("Total parameters for the teacher network are: {}".format(teacher_params))
317+
print("Total parameters for the student network are: {}".format(student_params))
318318

319319
def post_epoch_call(self, epoch):
320320
"""

KD_Lib/KD/vision/DML/dml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def _evaluate_model(self, model, verbose=True):
172172
if verbose:
173173
print(f"Accuracy: {correct/length_of_dataset}")
174174

175-
epoch_val_acc = correct/length_of_dataset
175+
epoch_val_acc = correct / length_of_dataset
176176
return outputs, epoch_val_acc
177177

178178
def evaluate(self):

0 commit comments

Comments
 (0)