Skip to content

Commit 9ecf089

Browse files
committed
Implement caching for prefix groups in OctoWithoutHead class
1 parent 808537c commit 9ecf089

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

src/lerobot/policies/octo/transformer.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,9 @@ def __init__(
519519
self.prefix_groups = None
520520
self.timestep_groups = None
521521

522+
# Caching
523+
self.cached_prefix_groups = None
524+
522525
# Projections
523526
self.obs_primary_projection = nn.Linear(512, self.token_embedding_size)
524527
self.obs_wrist_projection = nn.Linear(512, self.token_embedding_size)
@@ -587,6 +590,53 @@ def forward(
587590
)
588591
)
589592

593+
# with torch.profiler.record_function("Octo/TaskTokenization"):
594+
# if self.cached_prefix_groups is not None:
595+
# # Expand cached groups to current batch size and device
596+
# prefix_groups = [
597+
# PrefixGroup(
598+
# tokens=g.tokens.to(timestep_pad_mask.device).expand(batch_size, -1, -1),
599+
# mask=g.mask.to(timestep_pad_mask.device).expand(batch_size, -1),
600+
# name=g.name,
601+
# attention_rules=g.attention_rules,
602+
# )
603+
# for g in self.cached_prefix_groups
604+
# ]
605+
# else:
606+
# prefix_groups = []
607+
# for name, tokenizer in self.task_tokenizers.items():
608+
# if name in tasks:
609+
# token_group = tokenizer(tasks[name], tasks)
610+
# projected_tokens = self.task_language_projection(token_group.tokens)
611+
612+
# # Add positional embedding
613+
# pos_embedding = self.task_language_pos_embedding[:, : projected_tokens.shape[1]]
614+
# processed_tokens = projected_tokens + pos_embedding
615+
616+
# # ✅ store only one exemplar in cache (batch = 1)
617+
# processed_tokens = processed_tokens[:1].detach().cpu()
618+
# token_mask = token_group.mask[:1].detach().cpu()
619+
620+
# prefix_groups.append(
621+
# PrefixGroup(
622+
# tokens=processed_tokens.to(timestep_pad_mask.device).expand(batch_size, -1, -1),
623+
# mask=token_mask.to(timestep_pad_mask.device).expand(batch_size, -1),
624+
# name=f"task_{name}",
625+
# attention_rules=task_attention_rules,
626+
# )
627+
# )
628+
629+
# # ✅ cache the single exemplar (not the expanded version)
630+
# self.cached_prefix_groups = [
631+
# PrefixGroup(
632+
# tokens=g.tokens[:1].detach().cpu(),
633+
# mask=g.mask[:1].detach().cpu(),
634+
# name=g.name,
635+
# attention_rules=g.attention_rules,
636+
# )
637+
# for g in prefix_groups
638+
# ]
639+
590640
# Create timestep groups for observation tokens
591641
timestep_groups = []
592642
for name, tokenizer in self.observation_tokenizers.items():

0 commit comments

Comments
 (0)