@@ -519,6 +519,9 @@ def __init__(
519
519
self .prefix_groups = None
520
520
self .timestep_groups = None
521
521
522
+ # Caching
523
+ self .cached_prefix_groups = None
524
+
522
525
# Projections
523
526
self .obs_primary_projection = nn .Linear (512 , self .token_embedding_size )
524
527
self .obs_wrist_projection = nn .Linear (512 , self .token_embedding_size )
@@ -587,6 +590,53 @@ def forward(
587
590
)
588
591
)
589
592
593
+ # with torch.profiler.record_function("Octo/TaskTokenization"):
594
+ # if self.cached_prefix_groups is not None:
595
+ # # Expand cached groups to current batch size and device
596
+ # prefix_groups = [
597
+ # PrefixGroup(
598
+ # tokens=g.tokens.to(timestep_pad_mask.device).expand(batch_size, -1, -1),
599
+ # mask=g.mask.to(timestep_pad_mask.device).expand(batch_size, -1),
600
+ # name=g.name,
601
+ # attention_rules=g.attention_rules,
602
+ # )
603
+ # for g in self.cached_prefix_groups
604
+ # ]
605
+ # else:
606
+ # prefix_groups = []
607
+ # for name, tokenizer in self.task_tokenizers.items():
608
+ # if name in tasks:
609
+ # token_group = tokenizer(tasks[name], tasks)
610
+ # projected_tokens = self.task_language_projection(token_group.tokens)
611
+
612
+ # # Add positional embedding
613
+ # pos_embedding = self.task_language_pos_embedding[:, : projected_tokens.shape[1]]
614
+ # processed_tokens = projected_tokens + pos_embedding
615
+
616
+ # # ✅ store only one exemplar in cache (batch = 1)
617
+ # processed_tokens = processed_tokens[:1].detach().cpu()
618
+ # token_mask = token_group.mask[:1].detach().cpu()
619
+
620
+ # prefix_groups.append(
621
+ # PrefixGroup(
622
+ # tokens=processed_tokens.to(timestep_pad_mask.device).expand(batch_size, -1, -1),
623
+ # mask=token_mask.to(timestep_pad_mask.device).expand(batch_size, -1),
624
+ # name=f"task_{name}",
625
+ # attention_rules=task_attention_rules,
626
+ # )
627
+ # )
628
+
629
+ # # ✅ cache the single exemplar (not the expanded version)
630
+ # self.cached_prefix_groups = [
631
+ # PrefixGroup(
632
+ # tokens=g.tokens[:1].detach().cpu(),
633
+ # mask=g.mask[:1].detach().cpu(),
634
+ # name=g.name,
635
+ # attention_rules=g.attention_rules,
636
+ # )
637
+ # for g in prefix_groups
638
+ # ]
639
+
590
640
# Create timestep groups for observation tokens
591
641
timestep_groups = []
592
642
for name , tokenizer in self .observation_tokenizers .items ():
0 commit comments