Skip to content

Commit ceec6f1

Browse files
committed
Refactor: Vectorize Octo attention mask creation
- Replaced nested for-loop attention mask creation with fully vectorized PyTorch implementation. - Removed the `TokenMetadata` class from `base.py`. - Introduced a `RULE_MAP` to convert attention rules to integers for faster processing.
1 parent 9ecf089 commit ceec6f1

File tree

2 files changed

+79
-78
lines changed

2 files changed

+79
-78
lines changed

src/lerobot/policies/octo/base.py

Lines changed: 10 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
from dataclasses import dataclass
1616
from enum import Enum
1717
from fnmatch import fnmatch
18-
from typing import Any, Dict, List, Mapping, Union
19-
18+
from typing import Any, Dict, List, Mapping
2019
import torch
2120

2221

@@ -30,6 +29,15 @@ class AttentionRule(Enum):
3029
ALL = "all"
3130

3231

32+
RULE_MAP = {
33+
AttentionRule.NEVER: 0,
34+
AttentionRule.CAUSAL: 1,
35+
AttentionRule.CURRENT: 2,
36+
AttentionRule.STRICT_PAST: 3,
37+
AttentionRule.ALL: 4,
38+
}
39+
40+
3341
def find_match(pattern_dict: Dict[str, Any], name: str, default: Any) -> Any:
3442
"""Find the first matching pattern in the dictionary, or return the default value."""
3543
for pattern, value in pattern_dict.items():
@@ -38,39 +46,6 @@ def find_match(pattern_dict: Dict[str, Any], name: str, default: Any) -> Any:
3846
return default
3947

4048

41-
@dataclass
42-
class TokenMetadata:
43-
"""Attention mask logic supported by AttentionRule."""
44-
45-
name: str
46-
timestep: int # -1 for prefix tokens
47-
attention_rules: Mapping[str, AttentionRule]
48-
49-
@classmethod
50-
def create(cls, group: Union["PrefixGroup", "TimestepGroup"], timestep: int):
51-
return cls(
52-
timestep=timestep,
53-
name=group.name,
54-
attention_rules=group.attention_rules,
55-
)
56-
57-
def should_attend_to(self, other_metadata: "TokenMetadata") -> bool:
58-
attention_rule = find_match(self.attention_rules, other_metadata.name, AttentionRule.NEVER)
59-
60-
if attention_rule == AttentionRule.CAUSAL:
61-
return other_metadata.timestep <= self.timestep
62-
elif attention_rule == AttentionRule.CURRENT:
63-
return other_metadata.timestep == self.timestep
64-
elif attention_rule == AttentionRule.STRICT_PAST:
65-
return other_metadata.timestep < self.timestep
66-
elif attention_rule == AttentionRule.ALL:
67-
return True
68-
elif attention_rule == AttentionRule.NEVER:
69-
return False
70-
else:
71-
raise ValueError(f"Invalid attention rule: {attention_rule}")
72-
73-
7449
@dataclass
7550
class TokenGroup:
7651
"""A group of tokens that have semantic meaning together (e.g. the tokens for a single observation)."""

src/lerobot/policies/octo/transformer.py

Lines changed: 69 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,17 @@
1414

1515
from typing import Any, Dict, List, Optional, Tuple
1616

17-
import numpy as np
1817
import torch
1918
import torch.nn as nn
2019
import torch.nn.functional as F # noqa: N812
2120

22-
from lerobot.policies.octo.base import AttentionRule, PrefixGroup, TimestepGroup, TokenMetadata
21+
from lerobot.policies.octo.base import (
22+
AttentionRule,
23+
PrefixGroup,
24+
TimestepGroup,
25+
find_match,
26+
RULE_MAP,
27+
)
2328
from lerobot.policies.octo.tokenizers import ImageTokenizer, LanguageTokenizer, SmallStem16
2429

2530

@@ -311,62 +316,83 @@ def _generate_attention_mask(
311316
if self.enforce_causal:
312317
self._verify_causality(prefix_groups, timestep_groups)
313318

314-
def _get_position(i, tokens_per_elem):
315-
return np.searchsorted(np.cumsum(tokens_per_elem), i, side="right")
316-
319+
device = timestep_groups[0].tokens.device
317320
horizon = timestep_groups[0].tokens.shape[1]
318-
tokens_per_prefix_group = [group.tokens.shape[1] for group in prefix_groups]
319-
tokens_per_timestep_group = [group.tokens.shape[2] for group in timestep_groups]
320321

321-
tokens_for_prefix = sum(tokens_per_prefix_group)
322-
tokens_per_time_step = sum(tokens_per_timestep_group)
322+
all_groups = prefix_groups + timestep_groups
323+
num_groups = len(all_groups)
324+
325+
prefix_token_counts = [g.tokens.shape[1] for g in prefix_groups]
326+
timestep_token_counts = [g.tokens.shape[2] for g in timestep_groups]
327+
328+
tokens_for_prefix = sum(prefix_token_counts)
329+
tokens_per_time_step = sum(timestep_token_counts)
323330
total_tokens = tokens_for_prefix + tokens_per_time_step * horizon
324331

325-
# Create attention mask using numpy for compatibility with JAX implementation
326-
attention_mask = np.zeros((total_tokens, total_tokens), dtype=int)
332+
group_ids = torch.zeros(total_tokens, dtype=torch.long, device=device)
333+
timesteps = torch.full((total_tokens,), -1, dtype=torch.long, device=device)
334+
335+
current_pos = 0
336+
for i, count in enumerate(prefix_token_counts):
337+
group_ids[current_pos : current_pos + count] = i
338+
current_pos += count
339+
340+
timestep_group_ids_per_step = []
341+
for i, count in enumerate(timestep_token_counts):
342+
timestep_group_ids_per_step.append(
343+
torch.full(
344+
(count,),
345+
len(prefix_groups) + i,
346+
dtype=torch.long,
347+
device=device,
348+
)
349+
)
350+
timestep_group_ids_per_step = torch.cat(timestep_group_ids_per_step)
351+
352+
start_pos = tokens_for_prefix
353+
for t in range(horizon):
354+
end_pos = start_pos + tokens_per_time_step
355+
group_ids[start_pos:end_pos] = timestep_group_ids_per_step
356+
timesteps[start_pos:end_pos] = t
357+
start_pos = end_pos
327358

328-
def get_token_metadata(i):
329-
if i < tokens_for_prefix:
330-
position = _get_position(i, tokens_per_prefix_group)
331-
return TokenMetadata.create(prefix_groups[position], timestep=-1)
359+
rules_table = torch.zeros(
360+
num_groups, num_groups, dtype=torch.long, device=device
361+
)
362+
for i, group_i in enumerate(all_groups):
363+
for j, group_j in enumerate(all_groups):
364+
rule = find_match(
365+
group_i.attention_rules, group_j.name, AttentionRule.NEVER
366+
)
367+
rules_table[i, j] = RULE_MAP[rule]
332368

333-
i -= tokens_for_prefix
334-
timestep, i = divmod(i, tokens_per_time_step)
335-
position = _get_position(i, tokens_per_timestep_group)
336-
return TokenMetadata.create(timestep_groups[position], timestep)
369+
attending_rules = rules_table[group_ids[:, None], group_ids[None, :]]
337370

338-
# Apply attention rules
339-
for i in range(total_tokens): # Token attending
340-
for j in range(total_tokens): # Token being attended to
341-
metadata_i = get_token_metadata(i)
342-
metadata_j = get_token_metadata(j)
343-
mask = int(metadata_i.should_attend_to(metadata_j))
344-
attention_mask[i, j] = mask
371+
timesteps_i = timesteps[:, None]
372+
timesteps_j = timesteps[None, :]
345373

346-
# Convert to torch tensor and move to correct device
347-
device = timestep_groups[0].tokens.device
348-
attention_mask = torch.from_numpy(attention_mask).bool().to(device)
374+
mask = torch.zeros(
375+
total_tokens, total_tokens, dtype=torch.bool, device=device
376+
)
377+
mask |= (attending_rules == RULE_MAP[AttentionRule.CAUSAL]) & (
378+
timesteps_j <= timesteps_i
379+
)
380+
mask |= (attending_rules == RULE_MAP[AttentionRule.CURRENT]) & (
381+
timesteps_j == timesteps_i
382+
)
383+
mask |= (
384+
attending_rules == RULE_MAP[AttentionRule.STRICT_PAST]
385+
) & (timesteps_j < timesteps_i)
386+
mask |= attending_rules == RULE_MAP[AttentionRule.ALL]
349387

350388
# Combine with padding mask
351389
pad_attention_mask = self._generate_pad_attention_mask(prefix_groups, timestep_groups)
352-
353-
# The attention mask from rules is (total_tokens, total_tokens)
354-
# The padding mask is (batch, total_tokens, total_tokens)
355-
# We need to combine them properly
356390
batch_size = pad_attention_mask.shape[0]
357-
attention_mask = attention_mask.unsqueeze(0).expand(
358-
batch_size, -1, -1
359-
) # (batch, total_tokens, total_tokens)
360-
# attention_mask = attention_mask.unsqueeze(1) # (batch, 1, total_tokens, total_tokens)
361391

362-
# Combine with padding mask using logical AND
363-
attention_mask = attention_mask & pad_attention_mask
392+
attention_mask = mask.unsqueeze(0) & pad_attention_mask
364393

365394
num_attention_heads = self.transformer_kwargs["num_attention_heads"]
366-
367-
attention_mask = attention_mask.unsqueeze(1).expand(
368-
batch_size, self.transformer_kwargs["num_attention_heads"], total_tokens, total_tokens
369-
)
395+
attention_mask = attention_mask.unsqueeze(1).expand(batch_size, num_attention_heads, total_tokens, total_tokens)
370396
attention_mask = attention_mask.reshape(batch_size * num_attention_heads, total_tokens, total_tokens)
371397

372398
return attention_mask

0 commit comments

Comments
 (0)