|
14 | 14 |
|
15 | 15 | from typing import Any, Dict, List, Optional, Tuple
|
16 | 16 |
|
17 |
| -import numpy as np |
18 | 17 | import torch
|
19 | 18 | import torch.nn as nn
|
20 | 19 | import torch.nn.functional as F # noqa: N812
|
21 | 20 |
|
22 |
| -from lerobot.policies.octo.base import AttentionRule, PrefixGroup, TimestepGroup, TokenMetadata |
| 21 | +from lerobot.policies.octo.base import ( |
| 22 | + AttentionRule, |
| 23 | + PrefixGroup, |
| 24 | + TimestepGroup, |
| 25 | + find_match, |
| 26 | + RULE_MAP, |
| 27 | +) |
23 | 28 | from lerobot.policies.octo.tokenizers import ImageTokenizer, LanguageTokenizer, SmallStem16
|
24 | 29 |
|
25 | 30 |
|
@@ -311,62 +316,83 @@ def _generate_attention_mask(
|
311 | 316 | if self.enforce_causal:
|
312 | 317 | self._verify_causality(prefix_groups, timestep_groups)
|
313 | 318 |
|
314 |
| - def _get_position(i, tokens_per_elem): |
315 |
| - return np.searchsorted(np.cumsum(tokens_per_elem), i, side="right") |
316 |
| - |
| 319 | + device = timestep_groups[0].tokens.device |
317 | 320 | horizon = timestep_groups[0].tokens.shape[1]
|
318 |
| - tokens_per_prefix_group = [group.tokens.shape[1] for group in prefix_groups] |
319 |
| - tokens_per_timestep_group = [group.tokens.shape[2] for group in timestep_groups] |
320 | 321 |
|
321 |
| - tokens_for_prefix = sum(tokens_per_prefix_group) |
322 |
| - tokens_per_time_step = sum(tokens_per_timestep_group) |
| 322 | + all_groups = prefix_groups + timestep_groups |
| 323 | + num_groups = len(all_groups) |
| 324 | + |
| 325 | + prefix_token_counts = [g.tokens.shape[1] for g in prefix_groups] |
| 326 | + timestep_token_counts = [g.tokens.shape[2] for g in timestep_groups] |
| 327 | + |
| 328 | + tokens_for_prefix = sum(prefix_token_counts) |
| 329 | + tokens_per_time_step = sum(timestep_token_counts) |
323 | 330 | total_tokens = tokens_for_prefix + tokens_per_time_step * horizon
|
324 | 331 |
|
325 |
| - # Create attention mask using numpy for compatibility with JAX implementation |
326 |
| - attention_mask = np.zeros((total_tokens, total_tokens), dtype=int) |
| 332 | + group_ids = torch.zeros(total_tokens, dtype=torch.long, device=device) |
| 333 | + timesteps = torch.full((total_tokens,), -1, dtype=torch.long, device=device) |
| 334 | + |
| 335 | + current_pos = 0 |
| 336 | + for i, count in enumerate(prefix_token_counts): |
| 337 | + group_ids[current_pos : current_pos + count] = i |
| 338 | + current_pos += count |
| 339 | + |
| 340 | + timestep_group_ids_per_step = [] |
| 341 | + for i, count in enumerate(timestep_token_counts): |
| 342 | + timestep_group_ids_per_step.append( |
| 343 | + torch.full( |
| 344 | + (count,), |
| 345 | + len(prefix_groups) + i, |
| 346 | + dtype=torch.long, |
| 347 | + device=device, |
| 348 | + ) |
| 349 | + ) |
| 350 | + timestep_group_ids_per_step = torch.cat(timestep_group_ids_per_step) |
| 351 | + |
| 352 | + start_pos = tokens_for_prefix |
| 353 | + for t in range(horizon): |
| 354 | + end_pos = start_pos + tokens_per_time_step |
| 355 | + group_ids[start_pos:end_pos] = timestep_group_ids_per_step |
| 356 | + timesteps[start_pos:end_pos] = t |
| 357 | + start_pos = end_pos |
327 | 358 |
|
328 |
| - def get_token_metadata(i): |
329 |
| - if i < tokens_for_prefix: |
330 |
| - position = _get_position(i, tokens_per_prefix_group) |
331 |
| - return TokenMetadata.create(prefix_groups[position], timestep=-1) |
| 359 | + rules_table = torch.zeros( |
| 360 | + num_groups, num_groups, dtype=torch.long, device=device |
| 361 | + ) |
| 362 | + for i, group_i in enumerate(all_groups): |
| 363 | + for j, group_j in enumerate(all_groups): |
| 364 | + rule = find_match( |
| 365 | + group_i.attention_rules, group_j.name, AttentionRule.NEVER |
| 366 | + ) |
| 367 | + rules_table[i, j] = RULE_MAP[rule] |
332 | 368 |
|
333 |
| - i -= tokens_for_prefix |
334 |
| - timestep, i = divmod(i, tokens_per_time_step) |
335 |
| - position = _get_position(i, tokens_per_timestep_group) |
336 |
| - return TokenMetadata.create(timestep_groups[position], timestep) |
| 369 | + attending_rules = rules_table[group_ids[:, None], group_ids[None, :]] |
337 | 370 |
|
338 |
| - # Apply attention rules |
339 |
| - for i in range(total_tokens): # Token attending |
340 |
| - for j in range(total_tokens): # Token being attended to |
341 |
| - metadata_i = get_token_metadata(i) |
342 |
| - metadata_j = get_token_metadata(j) |
343 |
| - mask = int(metadata_i.should_attend_to(metadata_j)) |
344 |
| - attention_mask[i, j] = mask |
| 371 | + timesteps_i = timesteps[:, None] |
| 372 | + timesteps_j = timesteps[None, :] |
345 | 373 |
|
346 |
| - # Convert to torch tensor and move to correct device |
347 |
| - device = timestep_groups[0].tokens.device |
348 |
| - attention_mask = torch.from_numpy(attention_mask).bool().to(device) |
| 374 | + mask = torch.zeros( |
| 375 | + total_tokens, total_tokens, dtype=torch.bool, device=device |
| 376 | + ) |
| 377 | + mask |= (attending_rules == RULE_MAP[AttentionRule.CAUSAL]) & ( |
| 378 | + timesteps_j <= timesteps_i |
| 379 | + ) |
| 380 | + mask |= (attending_rules == RULE_MAP[AttentionRule.CURRENT]) & ( |
| 381 | + timesteps_j == timesteps_i |
| 382 | + ) |
| 383 | + mask |= ( |
| 384 | + attending_rules == RULE_MAP[AttentionRule.STRICT_PAST] |
| 385 | + ) & (timesteps_j < timesteps_i) |
| 386 | + mask |= attending_rules == RULE_MAP[AttentionRule.ALL] |
349 | 387 |
|
350 | 388 | # Combine with padding mask
|
351 | 389 | pad_attention_mask = self._generate_pad_attention_mask(prefix_groups, timestep_groups)
|
352 |
| - |
353 |
| - # The attention mask from rules is (total_tokens, total_tokens) |
354 |
| - # The padding mask is (batch, total_tokens, total_tokens) |
355 |
| - # We need to combine them properly |
356 | 390 | batch_size = pad_attention_mask.shape[0]
|
357 |
| - attention_mask = attention_mask.unsqueeze(0).expand( |
358 |
| - batch_size, -1, -1 |
359 |
| - ) # (batch, total_tokens, total_tokens) |
360 |
| - # attention_mask = attention_mask.unsqueeze(1) # (batch, 1, total_tokens, total_tokens) |
361 | 391 |
|
362 |
| - # Combine with padding mask using logical AND |
363 |
| - attention_mask = attention_mask & pad_attention_mask |
| 392 | + attention_mask = mask.unsqueeze(0) & pad_attention_mask |
364 | 393 |
|
365 | 394 | num_attention_heads = self.transformer_kwargs["num_attention_heads"]
|
366 |
| - |
367 |
| - attention_mask = attention_mask.unsqueeze(1).expand( |
368 |
| - batch_size, self.transformer_kwargs["num_attention_heads"], total_tokens, total_tokens |
369 |
| - ) |
| 395 | + attention_mask = attention_mask.unsqueeze(1).expand(batch_size, num_attention_heads, total_tokens, total_tokens) |
370 | 396 | attention_mask = attention_mask.reshape(batch_size * num_attention_heads, total_tokens, total_tokens)
|
371 | 397 |
|
372 | 398 | return attention_mask
|
|
0 commit comments