From 4883afed22cbc8c2b815c6fd835628dfbe47d661 Mon Sep 17 00:00:00 2001 From: MarisaJH Date: Sat, 26 Apr 2025 19:23:57 +0000 Subject: [PATCH] fix bug in EvalHarnessAdapter when pipe parallel size is set to 1 --- eval_tasks/eval_adapter.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/eval_tasks/eval_adapter.py b/eval_tasks/eval_adapter.py index abbd5ca8d..f5a01107c 100644 --- a/eval_tasks/eval_adapter.py +++ b/eval_tasks/eval_adapter.py @@ -59,8 +59,8 @@ def __init__(self, model, forward_step_fn, neox_args, batch_size=None): self.is_main = neox_args.rank == 0 self.is_local_main = neox_args.local_rank == 0 self.is_model_parallel = neox_args.model_parallel_size > 1 - self.is_pipe_parallel = self.model.is_pipe_parallel - self.is_data_parallel = self.model.is_data_parallel + self.is_pipe_parallel = getattr(self.model, 'is_pipe_parallel', False) + self.is_data_parallel = getattr(self.model, 'is_data_parallel', False) self.is_last_stage = ( True if not self.is_pipe_parallel else model.is_last_stage() ) # only the last stage of the pipeline model will receive the logits @@ -369,7 +369,11 @@ def _model_call(self, inps): self.model.first_output_send = True self.model.pipe_recv_buf = None - _, logits = self._forward_step_fn(model=self.model, data_iterator=inps) + _, logits = self._forward_step_fn(model=self.model, data_iterator=inps) + + # since return_logits is true, forward will return 3 vals + else: + _, logits, _ = self._forward_step_fn(model=self.model, data_iterator=inps) # gather outputs from all dp ranks: logits = self._dp_gather(logits) @@ -396,9 +400,7 @@ def run_eval( ): was_training = self.model.training self.model.eval() - in_micro_batches = ( - self.model.micro_batches - ) # store input microbatches - we need to set to 1 during eval, but want to return to its original value after + in_micro_batches = getattr(self.model, 'micro_batches', 1) # store input microbatches - we need to set to 1 during eval, but want to return to its original value after self.model.micro_batches = 1 if eval_tasks is None: eval_tasks = [