@@ -1626,11 +1626,14 @@ def train(
16261626 lr_scheduler = lr_scheduler ,
16271627 )
16281628 # Evaluation
1629- if (
1630- neox_args .eval_interval
1631- and iteration % neox_args .eval_interval == 0
1632- and neox_args .do_valid
1633- ):
1629+ is_eval_internal = neox_args .eval_interval and iteration % neox_args .eval_interval == 0
1630+ is_validation_configured = bool (neox_args .do_valid ) or (isinstance (neox_args .eval_tasks , list ) and len (neox_args .eval_tasks ) > 0 )
1631+ # if (
1632+ # neox_args.eval_interval
1633+ # and iteration % neox_args.eval_interval == 0
1634+ # # and neox_args.do_valid
1635+ # ):
1636+ if is_eval_internal and is_validation_configured :
16341637 prefix = "iteration {}" .format (iteration )
16351638 evaluate_and_print_results (
16361639 neox_args = neox_args ,
@@ -1683,46 +1686,49 @@ def evaluate(
16831686 if neox_args .char_level_ppl :
16841687 data_iterator = CharCounter (data_iterator , neox_args .tokenizer )
16851688
1686- with torch .no_grad ():
1687- iteration = 0
1688- while iteration < neox_args .eval_iters :
1689- iteration += 1
1690- if verbose and iteration % neox_args .log_interval == 0 :
1691- print_rank_0 (
1692- "Evaluating iter {}/{}" .format (iteration , neox_args .eval_iters )
1693- )
1689+ eval_results = {}
1690+ if data_iterator is not None :
1691+ with torch .no_grad ():
1692+ iteration = 0
1693+ while iteration < neox_args .eval_iters :
1694+ iteration += 1
1695+ if verbose and iteration % neox_args .log_interval == 0 :
1696+ print_rank_0 (
1697+ "Evaluating iter {}/{}" .format (iteration , neox_args .eval_iters )
1698+ )
16941699
1695- # although we're not accumulating gradients here, we count one iter as train_batch_size_per_gpu * g.a.s
1696- # to be consistent with deepspeed's pipe parallel engine
1697- # since pipe parallel already takes gradient_accumulation_steps into account - default to 1 here if pipe parallel is true
1698- for _ in range (
1699- 1
1700- if neox_args .is_pipe_parallel
1701- else neox_args .gradient_accumulation_steps
1702- ):
1703- # Forward evaluation
1704- loss , metric_dict = forward_step_fn (
1705- model = model ,
1706- data_iterator = data_iterator ,
1707- neox_args = neox_args ,
1708- timers = timers ,
1709- reference_model = reference_model ,
1710- )
1711- losses .append (loss )
1712- for key in metric_dict .keys ():
1713- metric_dicts [key ].append (metric_dict [key ])
1714- # When contiguous memory optimizations are enabled, the buffers
1715- # allocated by the optimizations are deallocated during backward pass
1716- # in the absence of backward pass the buffers should be reset after each
1717- # forward pass
1718- if neox_args .deepspeed and neox_args .deepspeed_activation_checkpointing :
1719- deepspeed .checkpointing .reset ()
1720-
1721- # reduces losses across processes for logging & run eval harness tasks
1722- eval_results = {"lm_loss" : reduce_losses (losses ).mean ().item ()}
1723- for key in metric_dicts .keys ():
1724- eval_results [key ] = reduce_losses (metric_dicts [key ]).mean ().item ()
1725- eval_results ["lm_loss_ppl" ] = math .exp (eval_results ["lm_loss" ])
1700+ # although we're not accumulating gradients here, we count one iter as train_batch_size_per_gpu * g.a.s
1701+ # to be consistent with deepspeed's pipe parallel engine
1702+ # since pipe parallel already takes gradient_accumulation_steps into account - default to 1 here if pipe parallel is true
1703+ for _ in range (
1704+ 1
1705+ if neox_args .is_pipe_parallel
1706+ else neox_args .gradient_accumulation_steps
1707+ ):
1708+ # Forward evaluation
1709+ loss , metric_dict = forward_step_fn (
1710+ model = model ,
1711+ data_iterator = data_iterator ,
1712+ neox_args = neox_args ,
1713+ timers = timers ,
1714+ reference_model = reference_model ,
1715+ )
1716+ losses .append (loss )
1717+ for key in metric_dict .keys ():
1718+ metric_dicts [key ].append (metric_dict [key ])
1719+ # When contiguous memory optimizations are enabled, the buffers
1720+ # allocated by the optimizations are deallocated during backward pass
1721+ # in the absence of backward pass the buffers should be reset after each
1722+ # forward pass
1723+ if neox_args .deepspeed and neox_args .deepspeed_activation_checkpointing :
1724+ deepspeed .checkpointing .reset ()
1725+
1726+ # reduces losses across processes for logging & run eval harness tasks
1727+ eval_results = {"lm_loss" : reduce_losses (losses ).mean ().item ()}
1728+ for key in metric_dicts .keys ():
1729+ eval_results [key ] = reduce_losses (metric_dicts [key ]).mean ().item ()
1730+
1731+ eval_results ["lm_loss_ppl" ] = math .exp (eval_results ["lm_loss" ])
17261732
17271733 if neox_args .char_level_ppl :
17281734 # calculate character level perplexity, if specified
0 commit comments