-
Notifications
You must be signed in to change notification settings - Fork 37
Open
Labels
Description
Here are some latency numbers, run with https://github.com/ServiceNow/Fast-LLM/tree/874cb2a875a439cff10d18a67b293ed59831ce4e, by measuring the time taken to run the reference-model's forward
Fast-LLM/fast_llm/models/gpt/model.py
Lines 336 to 342 in 874cb2a
| start_time = time.perf_counter() | |
| # TODO: Do things work with >1? | |
| Assert.eq(len(reference_batch), len(preprocessed_meta), 1) | |
| for i, (reference_tokens, reference_kwargs) in enumerate(reference_batch): | |
| reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) | |
| reference_logits[i][f"{name}_logits"] = reference_kwargs["logits"] | |
| elapsed_time = (time.perf_counter() - start_time) * 1000 |
With TP=2, mbs=1, the time taken to run the reference model is much larger in comparison to TP=1.
Another puzzling point is that the reference-model-forward-time is larger in TP=2,mbs=1 than in TP=2,mbs=2.
Could be an issue in how this ref-model-forward time is measured here?
| Seq-length | TP | MBS | BS | Sequential micro-batches | Teacher-size | Studen-size | Ref-model forward (ms) | Step-time (ms) |
|---|---|---|---|---|---|---|---|---|
| 2048 | 1 | 1 | 16 | 1 | 4.6B | 4.6B | 24 | 206 |
| 4096 | 1 | 1 | 16 | 1 | 4.6B | 4.6B | 46 | 333 |
| 8192 | 1 | 1 | 16 | 1 | 4.6B | 4.6B | 95 | 656 |
| 2048 | 2 | 1 | 8 | 1 | 4.6B | 4.6B | 80 | 230 |
| 4096 | 2 | 1 | 8 | 1 | 4.6B | 4.6B | 150 | 367 |
| 8192 | 2 | 1 | 8 | 1 | 4.6B | 4.6B | 301 | 655 |
| 2048 | 2 | 2 | 16 | 1 | 4.6B | 4.6B | 33 | 231 |
| 4096 | 2 | 2 | 16 | 1 | 4.6B | 4.6B | 59 | 376 |
| 8192 | 2 | 2 | 16 | 1 | 4.6B | 4.6B | 121 | 739 |