diff --git a/.vscode/launch.json b/.vscode/launch.json index 69912d02..f7c5e375 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -17,13 +17,13 @@ "request": "launch", "program": "pretrain.py", "args": [ - // "data_path=data/sudoku-extreme-1k-aug-1000", - // "epochs=20000", - // "eval_interval=2000", - // "lr=1e-4", - // "puzzle_emb_lr=1e-4", - // "weight_decay=1.0", - // "puzzle_emb_weight_decay=1.0" + "data_path=data/sudoku-extreme-1k-aug-1000", + "epochs=20000", + "eval_interval=2000", + "lr=1e-4", + "puzzle_emb_lr=1e-4", + "weight_decay=1.0", + "puzzle_emb_weight_decay=1.0" ], "env": { "OMP_NUM_THREADS": "1", diff --git a/config/arch/hrm_v1.yaml b/config/arch/hrm_v1.yaml index a5646b89..b3fb9a2d 100644 --- a/config/arch/hrm_v1.yaml +++ b/config/arch/hrm_v1.yaml @@ -9,8 +9,8 @@ halt_max_steps: 16 H_cycles: 2 L_cycles: 2 -H_layers: 4 -L_layers: 4 +H_layers: 2 +L_layers: 2 hidden_size: 512 num_heads: 8 # min(2, hidden_size // 64) @@ -19,3 +19,8 @@ expansion: 4 puzzle_emb_ndim: ${.hidden_size} pos_encodings: rope + +# Image encoding +img_size: 96 +img_channels: 3 +patch_size: 32 diff --git a/config/cfg_pretrain.yaml b/config/cfg_pretrain.yaml index 51c55a07..ed2f73b9 100644 --- a/config/cfg_pretrain.yaml +++ b/config/cfg_pretrain.yaml @@ -9,6 +9,11 @@ hydra: # Data path data_path: data/arc-aug-1000 +dataset_name: sudoku + +# Image rendering +render_res: 144 +output_size: 96 # Hyperparams - Training global_batch_size: 768 diff --git a/dataset/build_sudoku_dataset.py b/dataset/build_sudoku_dataset.py index 7924438b..3b10b774 100644 --- a/dataset/build_sudoku_dataset.py +++ b/dataset/build_sudoku_dataset.py @@ -9,12 +9,95 @@ from tqdm import tqdm from huggingface_hub import hf_hub_download -from common import PuzzleDatasetMetadata - +from .common import PuzzleDatasetMetadata +import numpy as np +import cv2 +from concurrent.futures import ThreadPoolExecutor cli = ArgParser() + +class SudokuImageRenderer: + """ + Render Sudoku puzzles (flattened 81 ints: 0=blank, 1..9=digits) to RGB images. + + Args: + output_size: final side length of output image (e.g., 32 -> 32x32). + render_res: internal canvas side length for crisp glyphs (>= output_size). + line_thick: base grid line thickness at render_res (thicker 3x3 lines auto-handled). + font_scale: OpenCV font scale at render_res. + font_thick: OpenCV font thickness at render_res. + """ + def __init__(self, + output_size: int = 32, + render_res: int = 144, + line_thick: int = 2, + font_scale: float = 0.5, + font_thick: int = 1): + assert render_res % 9 == 0, "render_res should be divisible by 9" + self.output_size = int(output_size) + self.render_res = int(render_res) + self.cell = self.render_res // 9 + self.line_thick = int(line_thick) + self.font_scale = float(font_scale) + self.font_thick = int(font_thick) + self.font = cv2.FONT_HERSHEY_SIMPLEX + + def render_single(self, grid_flat: np.ndarray) -> np.ndarray: + """Returns (H, W, 3) uint8 image with H=W=output_size.""" + img = np.full((self.render_res, self.render_res, 3), 255, np.uint8) + + # Grid lines (thicker for 3x3 boundaries) + for i in range(10): + t = self.line_thick + (1 if i % 3 == 0 else 0) + y = i * self.cell + x = i * self.cell + cv2.line(img, (0, y), (self.render_res, y), (0, 0, 0), t) + cv2.line(img, (x, 0), (x, self.render_res), (0, 0, 0), t) + + # Digits + for r in range(9): + for c in range(9): + v = int(grid_flat[r * 9 + c]) + if v == 0 or v == 1: # The numbers are mapped from 1-->10 in this file, so ignore values < 2 + continue + text = str(v) + (tw, th), _ = cv2.getTextSize(text, self.font, self.font_scale, self.font_thick) + tx = c * self.cell + (self.cell - tw) // 2 + ty = r * self.cell + (self.cell + th) // 2 + cv2.putText(img, text, (tx, ty), self.font, self.font_scale, (0, 0, 0), + self.font_thick, cv2.LINE_AA) + + if self.output_size != self.render_res: + img = cv2.resize(img, (self.output_size, self.output_size), interpolation=cv2.INTER_AREA) + return img + + def render_batch(self, inputs_np: np.ndarray, num_workers: int = 32) -> np.ndarray: + """ + Threaded batch renderer with tqdm. + inputs_np: (N, 81) -> returns (N, 3, output_size, output_size) uint8 + """ + N = int(inputs_np.shape[0]) + + # sequential fallback (still shows progress) + if num_workers <= 1: + out = np.empty((N, self.output_size, self.output_size, 3), dtype=np.uint8) + for i in tqdm(range(N), desc="Creating images (1 thread)", unit="img"): + out[i] = self.render_single(inputs_np[i]) + return np.transpose(out, (0, 3, 1, 2)) + + # multithreaded path (order-preserving) + with ThreadPoolExecutor(max_workers=num_workers) as ex: + # chunksize is optional; ignored by ThreadPoolExecutor in some Python versions + imgs_iter = ex.map(self.render_single, inputs_np, chunksize=64) + imgs = list(imgs_iter) + + out = np.stack(imgs, axis=0) # (N, H, W, 3) + return np.transpose(out, (0, 3, 1, 2)) + + + class DataProcessConfig(BaseModel): source_repo: str = "sapientinc/sudoku-extreme" output_dir: str = "data/sudoku-extreme-full" @@ -127,6 +210,10 @@ def _seq_to_numpy(seq): "puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32), } + # # Render images --> Painfully slow and huge!! + # renderer = SudokuImageRenderer(output_size=224, render_res=288) + # results["images"] = renderer.render_batch(results["inputs"]) + # Metadata metadata = PuzzleDatasetMetadata( seq_len=81, diff --git a/dataset/visualize_images.py b/dataset/visualize_images.py new file mode 100644 index 00000000..9ab52cf1 --- /dev/null +++ b/dataset/visualize_images.py @@ -0,0 +1,19 @@ +import numpy as np +from PIL import Image +import matplotlib.pyplot as plt + +# Load the .npy file +images = np.load('/home/pbhat1/projects/NeurAI/HRM/data/sudoku-extreme-1k-aug-1000/train/all__images.npy') + +# Save the first image +first_image = images[0] # Get the first image + +# Ensure the image is in uint8 format if necessary +first_image = (first_image * 255).astype(np.uint8) + +# Rearrange dimensions from (C, H, W) to (H, W, C) +first_image = first_image.transpose(1, 2, 0) + +# Option 1: Save using PIL +pil_image = Image.fromarray(first_image) +pil_image.save('first_image_pil.png') diff --git a/models/hrm/hrm_act_v1.py b/models/hrm/hrm_act_v1.py index 78500fc1..3d6102d1 100644 --- a/models/hrm/hrm_act_v1.py +++ b/models/hrm/hrm_act_v1.py @@ -16,6 +16,8 @@ class HierarchicalReasoningModel_ACTV1InnerCarry: z_H: torch.Tensor z_L: torch.Tensor + z_H_v: torch.Tensor + z_L_v: torch.Tensor @dataclass @@ -35,6 +37,12 @@ class HierarchicalReasoningModel_ACTV1Config(BaseModel): num_puzzle_identifiers: int vocab_size: int + # --- vision input --- + img_size: int + img_channels: int + patch_size: int + vision_pos_encodings: str = "learned" + H_cycles: int L_cycles: int @@ -77,9 +85,15 @@ def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None: def forward(self, cos_sin: CosSin, hidden_states: torch.Tensor) -> torch.Tensor: # Post Norm # Self Attention - hidden_states = rms_norm(hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), variance_epsilon=self.norm_eps) + hidden_states = rms_norm( + hidden_states + self.self_attn(cos_sin=cos_sin, hidden_states=hidden_states), + variance_epsilon=self.norm_eps + ) # Fully Connected - hidden_states = rms_norm(hidden_states + self.mlp(hidden_states), variance_epsilon=self.norm_eps) + hidden_states = rms_norm( + hidden_states + self.mlp(hidden_states), + variance_epsilon=self.norm_eps + ) return hidden_states @@ -133,15 +147,43 @@ def __init__(self, config: HierarchicalReasoningModel_ACTV1Config) -> None: self.H_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.H_layers)]) self.L_level = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)]) - # Initial states - # self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True) - # self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True) + self.H_level_v = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.H_layers)]) + self.L_level_v = HierarchicalReasoningModel_ACTV1ReasoningModule(layers=[HierarchicalReasoningModel_ACTV1Block(self.config) for _i in range(self.config.L_layers)]) + + # ----------------------------- + # Vision: patchify + pos + injector + # ----------------------------- + assert (self.config.img_size % self.config.patch_size) == 0, "img_size must be divisible by patch_size" + self.vH = self.config.img_size // self.config.patch_size + self.vW = self.config.img_size // self.config.patch_size + self.v_num_patches = self.vH * self.vW + # Conv2d as patch embedding -> hidden_size channels + self.vision_patch_embed = nn.Conv2d( + in_channels=self.config.img_channels, + out_channels=self.config.hidden_size, + kernel_size=self.config.patch_size, + stride=self.config.patch_size, + bias=True, + ) + if self.config.vision_pos_encodings == "learned": + self.vision_pos = nn.Parameter( + trunc_normal_init_(torch.empty(1, self.v_num_patches, self.config.hidden_size), std=0.02) + ) + # Project a pooled visual token into per-token injection, then broadcast to seq_len+puzzle_emb_len + self.vision_to_seq = CastedLinear(self.config.hidden_size, self.config.hidden_size, bias=True) + # Initial states h_init_tensor = trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1) self.register_buffer('H_init', h_init_tensor) l_init_tensor = trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1) self.register_buffer('L_init', l_init_tensor) + + h_v_init_tensor = trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1) + self.register_buffer('H_init_v', h_v_init_tensor) + + l_v_init_tensor = trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1) + self.register_buffer('L_init_v', l_v_init_tensor) # Q head special init # Init Q to (almost) zero for faster learning during bootstrapping @@ -171,16 +213,45 @@ def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tenso # Scale return self.embed_scale * embedding + def _image_embeddings(self, images: Optional[torch.Tensor]) -> torch.Tensor: + """ + images: [B, C, H, W] uint8|float -> returns [B, (seq_len+puzzle_emb_len), hidden_size] + Implementation: patchify -> [B, P, H] add pos -> mean-pool -> project -> broadcast. + """ + if images is None: + # zero injection if no images provided + B = 1 + return torch.zeros((B, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size), + dtype=self.forward_dtype, device=self.H_init.device) + x = images + if x.dtype not in (torch.float16, torch.bfloat16, torch.float32): + x = x.float() + # normalize to [0,1] if looks like 0..255 + if x.max() > 1.5: + x = x / 255.0 + x = self.vision_patch_embed(x) # [B, H, vH, vW] (H == hidden_size) + x = x.flatten(2).transpose(1, 2) # [B, P, hidden] + if hasattr(self, "vision_pos"): + x = x + self.vision_pos.to(x.dtype) # learned 2D pos + v_global = x.mean(dim=1) # [B, hidden] + inj = self.vision_to_seq(v_global.to(self.forward_dtype)) # [B, hidden] + inj = inj.unsqueeze(1).expand(-1, self.config.seq_len + self.puzzle_emb_len, -1) # [B, T_txt, H] + return self.embed_scale * inj + def empty_carry(self, batch_size: int): return HierarchicalReasoningModel_ACTV1InnerCarry( - z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype), - z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype), + z_H = torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype), + z_L = torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype), + z_H_v = torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype), + z_L_v = torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype), ) def reset_carry(self, reset_flag: torch.Tensor, carry: HierarchicalReasoningModel_ACTV1InnerCarry): return HierarchicalReasoningModel_ACTV1InnerCarry( - z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H), - z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L), + z_H = torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H), + z_L = torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L), + z_H_v = torch.where(reset_flag.view(-1, 1, 1), self.H_init_v, carry.z_H_v), + z_L_v = torch.where(reset_flag.view(-1, 1, 1), self.L_init_v, carry.z_L_v), ) def forward(self, carry: HierarchicalReasoningModel_ACTV1InnerCarry, batch: Dict[str, torch.Tensor]) -> Tuple[HierarchicalReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: @@ -190,31 +261,74 @@ def forward(self, carry: HierarchicalReasoningModel_ACTV1InnerCarry, batch: Dict # Input encoding input_embeddings = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"]) + + # Vision injection (broadcasted to text length for compatibility with current *_v sequence shapes) + image_embeddings = self._image_embeddings(batch.get("images", None)) # Forward iterations with torch.no_grad(): z_H, z_L = carry.z_H, carry.z_L + z_H_v, z_L_v = carry.z_H_v, carry.z_L_v for _H_step in range(self.config.H_cycles): for _L_step in range(self.config.L_cycles): if not ((_H_step == self.config.H_cycles - 1) and (_L_step == self.config.L_cycles - 1)): - z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info) + use_cuda = z_L.is_cuda and torch.cuda.is_available() + if use_cuda: + s1 = torch.cuda.Stream(device=z_L.device) # text stream + s2 = torch.cuda.Stream(device=z_L.device) # vision stream + + # Launch the two independent branches in parallel streams + with torch.cuda.stream(s1): + z_L_next = self.L_level( + z_L + z_L_v, + z_H + z_H_v + input_embeddings, + **seq_info + ) + with torch.cuda.stream(s2): + z_L_v_next = self.L_level_v( + z_L_v, + z_H_v + image_embeddings, + **seq_info + ) + + # Synchronize before consuming the results + # e1 = torch.cuda.Event(enable_timing=False) + # e2 = torch.cuda.Event(enable_timing=False) + # e1.record(s1); e2.record(s2) + # torch.cuda.current_stream().wait_event(e1) + # torch.cuda.current_stream().wait_event(e2) + + torch.cuda.synchronize() + z_L, z_L_v = z_L_next, z_L_v_next + else: + # CPU / non-CUDA fallback (sequential) + z_L = self.L_level( z_L + z_L_v, z_H + z_H_v + input_embeddings, **seq_info) + z_L_v = self.L_level_v(z_L_v, z_H_v + image_embeddings, **seq_info) + if not (_H_step == self.config.H_cycles - 1): - z_H = self.H_level(z_H, z_L, **seq_info) + z_H = self.H_level(z_H + z_H_v, z_L + z_L_v, **seq_info) + z_H_v = self.H_level_v(z_H_v, z_L_v, **seq_info) - assert not z_H.requires_grad and not z_L.requires_grad + assert (not z_H.requires_grad) and (not z_L.requires_grad) and (not z_L_v.requires_grad) and (not z_H_v.requires_grad) # 1-step grad - z_L = self.L_level(z_L, z_H + input_embeddings, **seq_info) - z_H = self.H_level(z_H, z_L, **seq_info) + z_L = self.L_level(z_L + z_L_v, z_H + z_H_v + input_embeddings, **seq_info) + z_H = self.H_level(z_H + z_H_v, z_L + z_L_v, **seq_info) + + z_L_v = self.L_level_v(z_L_v, z_H_v + image_embeddings, **seq_info) + z_H_v = self.H_level_v(z_H_v, z_L_v, **seq_info) # LM Outputs - new_carry = HierarchicalReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), z_L=z_L.detach()) # New carry no grad - output = self.lm_head(z_H)[:, self.puzzle_emb_len:] + new_carry = HierarchicalReasoningModel_ACTV1InnerCarry(z_H=z_H.detach(), + z_L=z_L.detach(), + z_L_v=z_L_v.detach(), + z_H_v=z_H_v.detach()) + output = self.lm_head(z_H + z_H_v)[:, self.puzzle_emb_len:] # Q head - q_logits = self.q_head(z_H[:, 0]).to(torch.float32) + q_logits = self.q_head(z_H[:, 0]+ z_H_v[:, 0]).to(torch.float32) return new_carry, output, (q_logits[..., 0], q_logits[..., 1]) diff --git a/pretrain.py b/pretrain.py index 245cb5c7..32d7e9d9 100644 --- a/pretrain.py +++ b/pretrain.py @@ -21,6 +21,9 @@ from puzzle_dataset import PuzzleDataset, PuzzleDatasetConfig, PuzzleDatasetMetadata from utils.functions import load_model_class, get_model_source_path from models.sparse_embedding import CastedSparseEmbeddingSignSGD_Distributed +from dataset.build_sudoku_dataset import SudokuImageRenderer +from concurrent.futures import ThreadPoolExecutor + class LossConfig(pydantic.BaseModel): @@ -39,8 +42,14 @@ class ArchConfig(pydantic.BaseModel): class PretrainConfig(pydantic.BaseModel): # Config arch: ArchConfig + # Data data_path: str + dataset_name: str + + # Image rendering + render_res: int + output_size: int # Hyperparams global_batch_size: int @@ -210,7 +219,7 @@ def train_batch(config: PretrainConfig, train_state: TrainState, batch: Any, glo train_state.step += 1 if train_state.step > train_state.total_steps: # At most train_total_steps return - + # To device batch = {k: v.cuda() for k, v in batch.items()} @@ -263,7 +272,8 @@ def train_batch(config: PretrainConfig, train_state: TrainState, batch: Any, glo return reduced_metrics -def evaluate(config: PretrainConfig, train_state: TrainState, eval_loader: torch.utils.data.DataLoader, eval_metadata: PuzzleDatasetMetadata, rank: int, world_size: int): +def evaluate(config: PretrainConfig, train_state: TrainState, eval_loader: torch.utils.data.DataLoader, + eval_metadata: PuzzleDatasetMetadata, rank: int, world_size: int, renderer = None): with torch.inference_mode(): set_ids = {k: idx for idx, k in enumerate(eval_metadata.sets)} @@ -275,8 +285,14 @@ def evaluate(config: PretrainConfig, train_state: TrainState, eval_loader: torch carry = None for set_name, batch, global_batch_size in eval_loader: + # Create images + assert renderer is not None + batch["images"] = renderer.render_batch(batch["inputs"]) + batch["images"] = torch.from_numpy(batch["images"]) + # To device batch = {k: v.cuda() for k, v in batch.items()} + with torch.device("cuda"): carry = train_state.model.initial_carry(batch) # type: ignore @@ -376,7 +392,31 @@ def load_synced_config(hydra_config: DictConfig, rank: int, world_size: int) -> return objects[0] # type: ignore +def create_renderer(dataset_name: str, output_size, render_res, **kwargs): + renderers = { + "sudoku": SudokuImageRenderer(output_size=output_size, render_res=render_res), + # "maze": CrosswordImageRenderer, + # "arc": KakuroImageRenderer, + } + + try: + return renderers[dataset_name.lower()] + except KeyError: + raise ValueError( + f"Unknown dataset: {dataset_name}. " + f"Available options: {list(renderers.keys())}" + ) + +def precompute_images(batch: dict, renderer = None): + # Create images + assert renderer is not None + images = renderer.render_batch(batch["inputs"]) + images = torch.from_numpy(images) + return {'images': images} + + + @hydra.main(config_path="config", config_name="cfg_pretrain", version_base=None) def launch(hydra_config: DictConfig): RANK = 0 @@ -407,6 +447,9 @@ def launch(hydra_config: DictConfig): train_loader, train_metadata = create_dataloader(config, "train", test_set_mode=False, epochs_per_iter=train_epochs_per_iter, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) eval_loader, eval_metadata = create_dataloader(config, "test", test_set_mode=True, epochs_per_iter=1, global_batch_size=config.global_batch_size, rank=RANK, world_size=WORLD_SIZE) + # Image renderer + renderer = create_renderer(config.dataset_name, config.output_size, config.render_res) + # Train state train_state = init_train_state(config, train_metadata, world_size=WORLD_SIZE) @@ -423,18 +466,38 @@ def launch(hydra_config: DictConfig): for _iter_id in range(total_iters): print (f"[Rank {RANK}, World Size {WORLD_SIZE}]: Epoch {_iter_id * train_epochs_per_iter}") + # Image generation for next batch in parallel + executor = ThreadPoolExecutor(max_workers=1) + it = iter(train_loader) + first = next(it, None) + future = executor.submit(precompute_images, first[1], renderer) if first else None + + ############ Train Iter train_state.model.train() - for set_name, batch, global_batch_size in train_loader: + # for set_name, batch, global_batch_size in train_loader: + for nxt in it: + set_name, batch, global_batch_size = first + # Image generation: wait for precompute of CURRENT (submitted last iteration) + images = future.result() if future is not None else None + if images is not None: + batch.update(images) + + # Image generation: kick off precompute for NEXT + future = executor.submit(precompute_images, nxt[1], renderer) + metrics = train_batch(config, train_state, batch, global_batch_size, rank=RANK, world_size=WORLD_SIZE) if RANK == 0 and metrics is not None: wandb.log(metrics, step=train_state.step) progress_bar.update(train_state.step - progress_bar.n) # type: ignore + + # Image generation: slide + first = nxt ############ Evaluation train_state.model.eval() - metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE) + metrics = evaluate(config, train_state, eval_loader, eval_metadata, rank=RANK, world_size=WORLD_SIZE, renderer=renderer) if RANK == 0 and metrics is not None: wandb.log(metrics, step=train_state.step) diff --git a/puzzle_dataset.py b/puzzle_dataset.py index 2782403c..1c6a2101 100644 --- a/puzzle_dataset.py +++ b/puzzle_dataset.py @@ -76,7 +76,6 @@ def _lazy_load_dataset(self): field_mmap_modes = { "inputs": "r", "labels": "r", - # Keep indices in memory "puzzle_identifiers": None, "puzzle_indices": None, @@ -100,7 +99,7 @@ def _collate_batch(self, batch): if self.metadata.ignore_label_id is not None: batch["labels"][batch["labels"] == self.metadata.ignore_label_id] = IGNORE_LABEL_ID - # Pad + # Pad -- if batch["puzzle_identifiers"].size < self.local_batch_size: pad_size = self.local_batch_size - batch["puzzle_identifiers"].size @@ -111,7 +110,7 @@ def _collate_batch(self, batch): "puzzle_identifiers": self.metadata.blank_identifier_id } batch = {k: np.pad(v, ((0, pad_size), ) + ((0, 0), ) * (v.ndim - 1), constant_values=pad_values[k]) for k, v in batch.items()} - + # To tensor return {k: torch.from_numpy(v) for k, v in batch.items()} @@ -196,4 +195,4 @@ def __iter__(self): if self.config.test_set_mode: yield from self._iter_test() else: - yield from self._iter_train() + yield from self._iter_train() \ No newline at end of file diff --git a/run.sh b/run.sh new file mode 100644 index 00000000..be9178ca --- /dev/null +++ b/run.sh @@ -0,0 +1,71 @@ +#!/bin/bash +# Define other parameters +data_path='./data/sudoku-extreme-1k-aug-1000' +dataset_name='sudoku' +epochs=20000 +eval_interval=2000 +lr=0.0001 +puzzle_emb_lr=0.0001 +weight_decay=1 +puzzle_emb_weight_decay=1 +start_seed=50 +num_runs=1 +render_res=288 +output_size=224 +global_batch_size=384 + + + +# Loop over all combinations +for seed in $(seq $start_seed $((start_seed + num_runs - 1))); do + # Create a unique experiment ID + exp_id="ix-s-${seed}-" + echo "Submitting job for combination: $exp_id" + + # Create a temporary script file + tmp_script=$(mktemp /home/pbhat1/projects/NeurAI/HRM/scripts/slurm_script.XXXXXX) + cat < "$tmp_script" +#!/bin/bash +#SBATCH --reservation=jhs_tue2022 +#SBATCH --partition=gpu_mig +#SBATCH --time=10:00:00 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=16 +#SBATCH -o /home/pbhat1/projects/NeurAI/HRM/slurm/${dataset}/${seed}/slurm-%j.out +#SBATCH -e /home/pbhat1/projectsNeurAI/HRM/slurm/${dataset}/${seed}/slurm-%j.err + +# Load necessary modules (adjust based on your environment) +source ~/miniconda3/bin/activate +conda activate hrm + +export MASTER_PORT=\$((10000 + \$(echo -n \$SLURM_JOBID | tail -c 4))) +export WORLD_SIZE=\$((\$SLURM_NNODES * \$SLURM_NTASKS_PER_NODE)) +echo "WORLD_SIZE=\$WORLD_SIZE" + +master_addr=\$(scontrol show hostnames "\$SLURM_JOB_NODELIST" | head -n 1) +export MASTER_ADDR=\$master_addr +echo "MASTER_ADDR=\$MASTER_ADDR" + +# Run the Python script with the current parameters +srun OMP_NUM_THREADS=96 torchrun --nproc-per-node 1 python /home/pbhat1/projects/NeurAI/HRM/pretrain.py \ + --data_path $data_path \ + --epochs $epochs \ + --eval_interval $eval_interval \ + --lr $lr \ + --puzzle_emb_lr $puzzle_emb_lr \ + --weight_decay $weight_decay \ + --puzzle_emb_weight_decay $puzzle_emb_weight_decay \ + --dataset_name $dataset_name \ + --render_res $render_res \ + --output_size $output_size \ + --global_batch_size $global_batch_size \ + +EOF + + # Submit the temporary script + sbatch "$tmp_script" + sleep 30 + rm "$tmp_script" +done \ No newline at end of file