@@ -68,11 +68,11 @@ def __init__(
6868 self .teacher_model = teacher_model .to (self .device )
6969 else :
7070 print ("Warning!!! Teacher is NONE." )
71-
71+
7272 self .student_model = student_model .to (self .device )
7373 self .loss_fn = loss_fn .to (self .device )
7474 self .ce_fn = nn .CrossEntropyLoss ().to (self .device )
75-
75+
7676 def train_teacher (
7777 self ,
7878 epochs = 20 ,
@@ -140,7 +140,7 @@ def train_teacher(
140140 )
141141
142142 loss_arr .append (epoch_loss )
143- print (f "Epoch: { ep + 1 } , Loss: { epoch_loss } , Accuracy: { epoch_acc } " )
143+ print ("Epoch: {}, Loss: {}, Accuracy: {}" . format ( ep + 1 , epoch_loss , epoch_acc ) )
144144
145145 self .post_epoch_call (ep )
146146
@@ -222,7 +222,7 @@ def _train_student(
222222 )
223223
224224 loss_arr .append (epoch_loss )
225- print (f "Epoch: { ep + 1 } , Loss: { epoch_loss } , Accuracy: { epoch_acc } " )
225+ print ("Epoch: {}, Loss: {}, Accuracy: {}" . format ( ep + 1 , epoch_loss , epoch_acc ) )
226226
227227 self .student_model .load_state_dict (self .best_student_model_weights )
228228 if save_model :
@@ -288,7 +288,7 @@ def _evaluate_model(self, model, verbose=True):
288288
289289 if verbose :
290290 print ("-" * 80 )
291- print (f "Validation Accuracy: { accuracy } " )
291+ print ("Validation Accuracy: {}" . format ( accuracy ) )
292292 return outputs , accuracy
293293
294294 def evaluate (self , teacher = False ):
@@ -313,8 +313,8 @@ def get_parameters(self):
313313 student_params = sum (p .numel () for p in self .student_model .parameters ())
314314
315315 print ("-" * 80 )
316- print (f "Total parameters for the teacher network are: { teacher_params } " )
317- print (f "Total parameters for the student network are: { student_params } " )
316+ print ("Total parameters for the teacher network are: {}" . format ( teacher_params ) )
317+ print ("Total parameters for the student network are: {}" . format ( student_params ) )
318318
319319 def post_epoch_call (self , epoch ):
320320 """
0 commit comments