Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
"request": "launch",
"program": "pretrain.py",
"args": [
// "data_path=data/sudoku-extreme-1k-aug-1000",
// "epochs=20000",
// "eval_interval=2000",
// "lr=1e-4",
// "puzzle_emb_lr=1e-4",
// "weight_decay=1.0",
// "puzzle_emb_weight_decay=1.0"
"data_path=data/sudoku-extreme-1k-aug-1000",
"epochs=20000",
"eval_interval=2000",
"lr=1e-4",
"puzzle_emb_lr=1e-4",
"weight_decay=1.0",
"puzzle_emb_weight_decay=1.0"
],
"env": {
"OMP_NUM_THREADS": "1",
Expand Down
9 changes: 7 additions & 2 deletions config/arch/hrm_v1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ halt_max_steps: 16
H_cycles: 2
L_cycles: 2

H_layers: 4
L_layers: 4
H_layers: 2
L_layers: 2

hidden_size: 512
num_heads: 8 # min(2, hidden_size // 64)
Expand All @@ -19,3 +19,8 @@ expansion: 4
puzzle_emb_ndim: ${.hidden_size}

pos_encodings: rope

# Image encoding
img_size: 96
img_channels: 3
patch_size: 32
5 changes: 5 additions & 0 deletions config/cfg_pretrain.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ hydra:

# Data path
data_path: data/arc-aug-1000
dataset_name: sudoku

# Image rendering
render_res: 144
output_size: 96

# Hyperparams - Training
global_batch_size: 768
Expand Down
91 changes: 89 additions & 2 deletions dataset/build_sudoku_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,95 @@
from tqdm import tqdm
from huggingface_hub import hf_hub_download

from common import PuzzleDatasetMetadata

from .common import PuzzleDatasetMetadata
import numpy as np
import cv2
from concurrent.futures import ThreadPoolExecutor

cli = ArgParser()



class SudokuImageRenderer:
"""
Render Sudoku puzzles (flattened 81 ints: 0=blank, 1..9=digits) to RGB images.

Args:
output_size: final side length of output image (e.g., 32 -> 32x32).
render_res: internal canvas side length for crisp glyphs (>= output_size).
line_thick: base grid line thickness at render_res (thicker 3x3 lines auto-handled).
font_scale: OpenCV font scale at render_res.
font_thick: OpenCV font thickness at render_res.
"""
def __init__(self,
output_size: int = 32,
render_res: int = 144,
line_thick: int = 2,
font_scale: float = 0.5,
font_thick: int = 1):
assert render_res % 9 == 0, "render_res should be divisible by 9"
self.output_size = int(output_size)
self.render_res = int(render_res)
self.cell = self.render_res // 9
self.line_thick = int(line_thick)
self.font_scale = float(font_scale)
self.font_thick = int(font_thick)
self.font = cv2.FONT_HERSHEY_SIMPLEX

def render_single(self, grid_flat: np.ndarray) -> np.ndarray:
"""Returns (H, W, 3) uint8 image with H=W=output_size."""
img = np.full((self.render_res, self.render_res, 3), 255, np.uint8)

# Grid lines (thicker for 3x3 boundaries)
for i in range(10):
t = self.line_thick + (1 if i % 3 == 0 else 0)
y = i * self.cell
x = i * self.cell
cv2.line(img, (0, y), (self.render_res, y), (0, 0, 0), t)
cv2.line(img, (x, 0), (x, self.render_res), (0, 0, 0), t)

# Digits
for r in range(9):
for c in range(9):
v = int(grid_flat[r * 9 + c])
if v == 0 or v == 1: # The numbers are mapped from 1-->10 in this file, so ignore values < 2
continue
text = str(v)
(tw, th), _ = cv2.getTextSize(text, self.font, self.font_scale, self.font_thick)
tx = c * self.cell + (self.cell - tw) // 2
ty = r * self.cell + (self.cell + th) // 2
cv2.putText(img, text, (tx, ty), self.font, self.font_scale, (0, 0, 0),
self.font_thick, cv2.LINE_AA)

if self.output_size != self.render_res:
img = cv2.resize(img, (self.output_size, self.output_size), interpolation=cv2.INTER_AREA)
return img

def render_batch(self, inputs_np: np.ndarray, num_workers: int = 32) -> np.ndarray:
"""
Threaded batch renderer with tqdm.
inputs_np: (N, 81) -> returns (N, 3, output_size, output_size) uint8
"""
N = int(inputs_np.shape[0])

# sequential fallback (still shows progress)
if num_workers <= 1:
out = np.empty((N, self.output_size, self.output_size, 3), dtype=np.uint8)
for i in tqdm(range(N), desc="Creating images (1 thread)", unit="img"):
out[i] = self.render_single(inputs_np[i])
return np.transpose(out, (0, 3, 1, 2))

# multithreaded path (order-preserving)
with ThreadPoolExecutor(max_workers=num_workers) as ex:
# chunksize is optional; ignored by ThreadPoolExecutor in some Python versions
imgs_iter = ex.map(self.render_single, inputs_np, chunksize=64)
imgs = list(imgs_iter)

out = np.stack(imgs, axis=0) # (N, H, W, 3)
return np.transpose(out, (0, 3, 1, 2))



class DataProcessConfig(BaseModel):
source_repo: str = "sapientinc/sudoku-extreme"
output_dir: str = "data/sudoku-extreme-full"
Expand Down Expand Up @@ -127,6 +210,10 @@ def _seq_to_numpy(seq):
"puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32),
}

# # Render images --> Painfully slow and huge!!
# renderer = SudokuImageRenderer(output_size=224, render_res=288)
# results["images"] = renderer.render_batch(results["inputs"])

# Metadata
metadata = PuzzleDatasetMetadata(
seq_len=81,
Expand Down
19 changes: 19 additions & 0 deletions dataset/visualize_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

# Load the .npy file
images = np.load('/home/pbhat1/projects/NeurAI/HRM/data/sudoku-extreme-1k-aug-1000/train/all__images.npy')

# Save the first image
first_image = images[0] # Get the first image

# Ensure the image is in uint8 format if necessary
first_image = (first_image * 255).astype(np.uint8)

# Rearrange dimensions from (C, H, W) to (H, W, C)
first_image = first_image.transpose(1, 2, 0)

# Option 1: Save using PIL
pil_image = Image.fromarray(first_image)
pil_image.save('first_image_pil.png')
Loading