Skip to content

Commit b174166

Browse files
Calculates tokens per second for actors. (#1034)
* Added actor tokens_per_second * Fixed bug where one timing was using time.time() and the other was using time.perf_counter(). * Updated code * Cleaned up PR.
1 parent 6d141b6 commit b174166

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

open_instruct/grpo_fast.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1828,6 +1828,9 @@ def data_preparation_thread(
18281828
**reward_metrics,
18291829
}
18301830

1831+
total_tokens = result.token_statistics.num_prompt_tokens + result.token_statistics.num_response_tokens
1832+
metrics["val/actor_tokens_per_second"] = total_tokens / result.token_statistics.generation_time
1833+
18311834
if args.save_traces:
18321835
traces = {
18331836
"scores": scores.tolist(),
@@ -2287,8 +2290,8 @@ def one_training_step(
22872290
"val/num_total_tokens": num_total_tokens,
22882291
"val/num_step_tokens": num_step_tokens,
22892292
"epoch": episode / args.num_samples_per_prompt_rollout / len(train_dataset),
2290-
"tokens_per_second_overall": num_total_tokens / total_training_time if total_training_time > 0 else 0,
2291-
"tokens_per_second_step": num_step_tokens / step_time if step_time > 0 else 0,
2293+
"learner_tokens_per_second_overall": num_total_tokens / total_training_time,
2294+
"learner_tokens_per_second_step": num_step_tokens / step_time,
22922295
"time/total": step_time,
22932296
"time/training": train_timer.duration,
22942297
"time/saving": save_time,
@@ -2374,6 +2377,12 @@ def maybe_evaluate(
23742377
}
23752378
if "time/generation" in eval_generate_metrics:
23762379
eval_metrics["eval/generation_time"] = eval_generate_metrics["time/generation"]
2380+
2381+
total_tokens = (
2382+
eval_result.token_statistics.num_prompt_tokens + eval_result.token_statistics.num_response_tokens
2383+
)
2384+
eval_metrics["eval/actor_tokens_per_second"] = total_tokens / eval_result.token_statistics.generation_time
2385+
23772386
print_rich_single_line_metrics(eval_metrics)
23782387

23792388
table = {}

open_instruct/vllm_utils3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,7 @@ def process_from_queue(self, timeout: float = 60.0):
497497
self._prefetch_future.result()
498498

499499
self._poll_tool_futures(self.tracking, self.llm_engine.tokenizer)
500-
current_time = time.time()
500+
current_time = time.perf_counter()
501501
if self.llm_engine.has_unfinished_requests():
502502
for output in [o for o in self.llm_engine.step() if o.finished]:
503503
# Fix the index field for all sub-requests
@@ -824,7 +824,7 @@ def _poll_tool_futures(self, tracking, tokenizer):
824824
tracking["pending_tool_futures"].pop(req_id, None)
825825

826826
complete_output = tracking["concat_outputs"][req_id].outputs[0]
827-
current_time = time.time()
827+
current_time = time.perf_counter()
828828
self._finalize_sub_request(req_id, last_output, complete_output, current_time)
829829
# Don't add to dict_keys_to_delete since we already removed it
830830
continue

0 commit comments

Comments
 (0)