Skip to content

Reference-model is slow on long sequences, especially with TP>1 #353

@RaymondLi0

Description

@RaymondLi0

Here are some latency numbers, run with https://github.com/ServiceNow/Fast-LLM/tree/874cb2a875a439cff10d18a67b293ed59831ce4e, by measuring the time taken to run the reference-model's forward

start_time = time.perf_counter()
# TODO: Do things work with >1?
Assert.eq(len(reference_batch), len(preprocessed_meta), 1)
for i, (reference_tokens, reference_kwargs) in enumerate(reference_batch):
reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration)
reference_logits[i][f"{name}_logits"] = reference_kwargs["logits"]
elapsed_time = (time.perf_counter() - start_time) * 1000
.

With TP=2, mbs=1, the time taken to run the reference model is much larger in comparison to TP=1.
Another puzzling point is that the reference-model-forward-time is larger in TP=2,mbs=1 than in TP=2,mbs=2.
Could be an issue in how this ref-model-forward time is measured here?

<style> </style>
Seq-length TP MBS BS Sequential micro-batches Teacher-size Studen-size Ref-model forward (ms) Step-time (ms)
2048 1 1 16 1 4.6B 4.6B 24 206
4096 1 1 16 1 4.6B 4.6B 46 333
8192 1 1 16 1 4.6B 4.6B 95 656
2048 2 1 8 1 4.6B 4.6B 80 230
4096 2 1 8 1 4.6B 4.6B 150 367
8192 2 1 8 1 4.6B 4.6B 301 655
2048 2 2 16 1 4.6B 4.6B 33 231
4096 2 2 16 1 4.6B 4.6B 59 376
8192 2 2 16 1 4.6B 4.6B 121 739

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions