From 97509dd96312dc64b8dc8b9eb609119cb7bdd6fe Mon Sep 17 00:00:00 2001 From: marcorudolphflex Date: Thu, 18 Dec 2025 17:30:18 +0100 Subject: [PATCH] feat(tidy3d): FXC-4607-autograd-for-clip-operation --- CHANGELOG.md | 2 +- .../test_autograd_clip_operation_numerical.py | 496 ++++++++++++ .../test_components/autograd/test_autograd.py | 8 +- .../autograd/test_autograd_clip_operation.py | 452 +++++++++++ .../autograd/test_mesh_derivatives.py | 208 +++++ tidy3d/components/geometry/base.py | 312 +++++++- tidy3d/components/geometry/mesh.py | 726 +++++++++++++----- tidy3d/components/geometry/polyslab.py | 336 ++++++++ tidy3d/components/geometry/primitives.py | 23 +- 9 files changed, 2339 insertions(+), 224 deletions(-) create mode 100644 tests/test_components/autograd/numerical/test_autograd_clip_operation_numerical.py create mode 100644 tests/test_components/autograd/test_autograd_clip_operation.py create mode 100644 tests/test_components/autograd/test_mesh_derivatives.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 9f84be6c38..852d22ae08 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added - +- Addded autograd support for `ClipOperation` geometries like unions or intersections of geometries. ### Changed ### Fixed diff --git a/tests/test_components/autograd/numerical/test_autograd_clip_operation_numerical.py b/tests/test_components/autograd/numerical/test_autograd_clip_operation_numerical.py new file mode 100644 index 0000000000..ac925a67e6 --- /dev/null +++ b/tests/test_components/autograd/numerical/test_autograd_clip_operation_numerical.py @@ -0,0 +1,496 @@ +"""Numerical finite-difference validation for ClipOperation gradients with real simulations.""" + +from __future__ import annotations + +import uuid +from collections.abc import Callable +from pathlib import Path + +import autograd.numpy as anp +import numpy as np +import pytest +from autograd import value_and_grad +from matplotlib import pyplot as plt + +import tidy3d as td +import tidy3d.web as web +from tests.test_components.autograd.numerical.test_autograd_box_polyslab_numerical import ( + angled_overlap_deg, +) +from tidy3d import config + +config.local_cache.enabled = True +# config.adjoint.default_wavelength_fraction = 0.03 +# pytestmark = pytest.mark.numerical + +WL_UM = 0.8 +FREQ0 = td.C_0 / WL_UM +SRC_OFFSET = -2.2 +MONITOR_OFFSET = 2.2 +SIM_SIZE = (4.0, 4.0, 6.0) +RUN_TIME = 2e-11 +PERMITTIVITY = 1.4**2 +GRID_STEPS_PER_WVL = 30 +FINITE_DIFF_STEP = 0.05 +BASE_OFFSET = (0.05, -0.035, 0.02) +BASE_CENTER_A = (-0.25, 0.12, 0.0) +BASE_CENTER_B = (0.15, -0.08, 0.0) +ANGLE_OVERLAP_FD_ADJ_THRESH_DEG = 12.0 +BASE_CENTER_A_VEC = anp.array(BASE_CENTER_A, dtype=float) +draft_mode = False +GEOMETRY_GROUP_SHIFT = anp.array((0.12, -0.06, 0.03), dtype=float) + +SIZE_MAP = { + "box": {"a": (1.0, 0.9, 0.8), "b": (0.9, 0.75, 0.7)}, + "polyslab": {"a": (0.95, 0.85, 0.8), "b": (0.8, 0.7, 0.6)}, + "mesh": {"a": (0.85, 0.95, 0.95), "b": (0.7, 0.95, 0.8)}, +} + +STRUCTURE_MEDIUM = td.Medium(permittivity=PERMITTIVITY) + + +def _make_base_simulation() -> tuple[td.Simulation, Callable[[td.SimulationData], float]]: + """Shared ClipOperation simulation (plane wave excitation and Ex monitor).""" + + # source = td.PlaneWave( + # center=(0.0, 0.0, SRC_OFFSET), + # size=(SIM_SIZE[0], SIM_SIZE[1], 0.0), + # source_time=td.GaussianPulse(freq0=FREQ0, fwidth=0.2 * FREQ0), + # direction="+", + # ) + source = td.PointDipole( + center=(0.0, 0.0, SRC_OFFSET), + source_time=td.GaussianPulse(freq0=FREQ0, fwidth=0.2 * FREQ0), + polarization="Ex", + ) + field_monitor = td.FieldMonitor( + center=(0.0, 0.0, MONITOR_OFFSET), + size=(0, 0, 0.0), + freqs=[FREQ0], + name="field", + ) + base_sim = td.Simulation( + center=(0.0, 0.0, 0.0), + size=SIM_SIZE, + sources=[source], + monitors=[field_monitor], + structures=[], + boundary_spec=td.BoundarySpec( + x=td.Boundary.pml(), y=td.Boundary.pml(), z=td.Boundary.pml() + ), + run_time=RUN_TIME, + grid_spec=td.GridSpec.auto(min_steps_per_wvl=GRID_STEPS_PER_WVL), + ) + + def fom(sim_data: td.SimulationData) -> float: + # return sim_data["field"].flux.values + dataset = sim_data["field"] + ex_vals = dataset.Ex.values + ey_vals = dataset.Ey.values + ez_vals = dataset.Ez.values + intensity = np.abs(ex_vals) ** 2 + np.abs(ey_vals) ** 2 + np.abs(ez_vals) ** 2 + return anp.real(anp.mean(intensity)) + + return base_sim, fom + + +def _triangle_prism( + center: tuple[float, float, float], size: tuple[float, float, float] +) -> td.PolySlab: + half_x, half_y, half_z = 0.5 * size[0], 0.5 * size[1], 0.5 * size[2] + cx, cy, cz = center + vertices = ( + (cx - half_x, cy - half_y), + (cx + half_x, cy - half_y), + (cx, cy + half_y), + ) + slab_bounds = (cz - half_z, cz + half_z) + return td.PolySlab(vertices=vertices, axis=2, slab_bounds=slab_bounds) + + +def _tetra_mesh( + center: tuple[float, float, float], size: tuple[float, float, float] +) -> td.TriangleMesh: + center_arr = _normalize_array(center) + half = 0.5 * _normalize_array(size) + cx, cy, cz = center_arr + vertices = anp.array( + [ + (cx - half[0], cy - half[1], cz - half[2]), + (cx + half[0], cy - half[1], cz - half[2]), + (cx, cy + half[1], cz - half[2]), + (cx, cy, cz + half[2]), + ], + dtype=float, + ) + faces = anp.array( + [ + (0, 2, 1), + (0, 1, 3), + (1, 2, 3), + (2, 0, 3), + ], + dtype=int, + ) + # return td.TriangleMesh.from_vertices_faces(vertices, faces) + mesh = td.TriangleMesh.from_triangles(vertices[faces]) + # mesh_fix = mesh.fix_winding() + # tr1 = mesh.mesh_dataset.surface_mesh.values + # tr2 = mesh_fix.mesh_dataset.surface_mesh.values + # print(tr1 == tr2) + # exit() + return mesh + + +def _make_geometry( + geometry_type: str, center: tuple[float, float, float], size: tuple[float, float, float] +) -> td.Geometry: + if geometry_type == "box": + return td.Box(center=center, size=size) + if geometry_type == "polyslab": + return _triangle_prism(center, size) + if geometry_type == "mesh": + return _tetra_mesh(center, size) + raise ValueError(f"Unsupported geometry_type '{geometry_type}'.") + + +def _normalize_array(offset_vec): + """Return the offset as an autograd array of length 3.""" + + offset = anp.array(offset_vec, dtype=float) + if offset.shape != (3,): + raise ValueError("ClipOperation offset vector must have length 3.") + return offset + + +def _clip_structure(offset_vec, geometry_type: str, operation: str) -> td.Structure: + offset = _normalize_array(offset_vec) + size_spec = SIZE_MAP[geometry_type] + center_arr = BASE_CENTER_A_VEC + offset + center_a = tuple(center_arr) + geometry_a = _make_geometry(geometry_type, center=center_a, size=size_spec["a"]) + geometry_b = _make_geometry(geometry_type, center=BASE_CENTER_B, size=size_spec["b"]) + clip_geom = td.ClipOperation( + operation=operation, + geometry_a=geometry_a, + geometry_b=geometry_b, + ) + return td.Structure(geometry=clip_geom, medium=STRUCTURE_MEDIUM) + + +def _clip_structure_geometry_group(offset_vec, geometry_type: str, operation: str) -> td.Structure: + offset = _normalize_array(offset_vec) + size_spec = SIZE_MAP[geometry_type] + center_arr = BASE_CENTER_A_VEC + offset + center_a = tuple(center_arr) + primary = _make_geometry(geometry_type, center=center_a, size=size_spec["a"]) + center_b = tuple(center_arr + GEOMETRY_GROUP_SHIFT) + secondary = _make_geometry(geometry_type, center=center_b, size=size_spec["a"]) + geometry_group = primary + secondary + geometry_b = _make_geometry(geometry_type, center=BASE_CENTER_B, size=size_spec["b"]) + clip_geom = td.ClipOperation( + operation=operation, + geometry_a=geometry_group, + geometry_b=geometry_b, + ) + return td.Structure(geometry=clip_geom, medium=STRUCTURE_MEDIUM) + + +def _run_clip_simulation( + base_sim: td.Simulation, + fom: Callable[[td.SimulationData], float], + offset_vec, + geometry_type: str, + operation: str, + result_dir: Path, + *, + local_gradient: bool, + tag: str, + structure_builder: Callable = _clip_structure, +) -> float: + return _run_clip_simulation_batch( + base_sim, + fom, + offset_list=[offset_vec], + geometry_type=geometry_type, + operation=operation, + result_dir=result_dir, + local_gradient=local_gradient, + tag=tag, + structure_builder=structure_builder, + )[0] + + +def _run_clip_simulation_batch( + base_sim: td.Simulation, + fom: Callable[[td.SimulationData], float], + offset_list, + geometry_type: str, + operation: str, + result_dir: Path, + *, + local_gradient: bool, + tag: str, + structure_builder: Callable = _clip_structure, +) -> list: + """Run a batch of simulations (in parallel if possible) for different offsets.""" + + offsets = [_normalize_array(off) for off in offset_list] + simulations: dict[str, td.Simulation] = {} + key_order: list[str] = [] + + for idx, offset in enumerate(offsets): + structure = structure_builder(offset, geometry_type, operation) + grid_spec = td.GridSpec.auto( + min_steps_per_wvl=GRID_STEPS_PER_WVL, override_structures=[structure] + ) + sim = base_sim.updated_copy(structures=[structure], grid_spec=grid_spec, validate=True) + task_name = f"clip_{geometry_type}_{operation}_{tag}_{idx}_{uuid.uuid4().hex[:6]}" + simulations[task_name] = sim + key_order.append(task_name) + + result_dir.mkdir(parents=True, exist_ok=True) + + if len(simulations) == 1: + key = key_order[0] + sim = simulations[key] + result_path = result_dir / f"{key}.hdf5" + if draft_mode: + web.upload(sim, task_name="draft") + exit() + sim_data = web.run( + sim, + task_name=key, + path=str(result_path), + local_gradient=local_gradient, + verbose=False, + ) + return [fom(sim_data)] + + sim_data_map = web.run_async( + simulations, + path_dir=str(result_dir), + local_gradient=local_gradient, + verbose=False, + ) + + return [fom(sim_data_map[key]) for key in key_order] + + +def _make_clip_objective( + base_sim: td.Simulation, + fom, + geometry_type: str, + operation: str, + case_dir: Path, + *, + local_gradient: bool, + structure_builder: Callable = _clip_structure, +): + run_dir = case_dir / ("adjoint" if local_gradient else "finite_difference") + + def objective(params, *, batched: bool = False): + tag = "adj" if local_gradient else "fd" + if batched: + offsets = [_normalize_array(p) for p in params] + return _run_clip_simulation_batch( + base_sim, + fom, + offset_list=offsets, + geometry_type=geometry_type, + operation=operation, + result_dir=run_dir, + local_gradient=local_gradient, + tag=tag, + structure_builder=structure_builder, + ) + offset = _normalize_array(params) + return _run_clip_simulation( + base_sim, + fom, + offset_vec=offset, + geometry_type=geometry_type, + operation=operation, + result_dir=run_dir, + local_gradient=local_gradient, + tag=tag, + structure_builder=structure_builder, + ) + + return objective + + +def _finite_difference_gradient(objective, params: anp.ndarray, step: float) -> np.ndarray: + base = _normalize_array(params) + offsets = [] + axis_indices = [] + + for idx in range(3): + delta = anp.array( + [step if dim == idx else 0.0 for dim in range(3)], + dtype=float, + ) + offsets.append(base + delta) + offsets.append(base - delta) + axis_indices.append(idx) + + results = objective(offsets, batched=True) + grads = np.zeros(3, dtype=float) + for pair_idx, axis in enumerate(axis_indices): + obj_up = float(np.asarray(results[2 * pair_idx], dtype=float)) + obj_down = float(np.asarray(results[2 * pair_idx + 1], dtype=float)) + grads[axis] = (obj_up - obj_down) / (2.0 * step) + + return grads + + +@pytest.mark.numerical +@pytest.mark.parametrize("geometry_type", ["b", "p", "m"]) +@pytest.mark.parametrize("operation", ["u", "i", "d", "s"]) +def test_clip_operation_vs_fd(geometry_type: str, operation: str, numerical_case_dir: Path): + """Compare ClipOperation adjoint gradients against finite differences (3D offsets).""" + operation = {"u": "union", "i": "intersection", "d": "difference", "s": "symmetric_difference"}[ + operation + ] + geometry_type = {"b": "box", "p": "polyslab", "m": "mesh"}[geometry_type] + base_sim, fom = _make_base_simulation() + case_dir = numerical_case_dir / f"{geometry_type}_{operation}" + case_dir.mkdir(parents=True, exist_ok=True) + adjoint_objective = _make_clip_objective( + base_sim, + fom, + geometry_type, + operation, + case_dir, + local_gradient=True, + ) + params0 = anp.array(BASE_OFFSET, dtype=float) + _, grad = value_and_grad(adjoint_objective)(params0) + # return + fd_objective = _make_clip_objective( + base_sim, + fom, + geometry_type, + operation, + case_dir, + local_gradient=False, + ) + + autograd_grad = np.asarray(grad, dtype=float).ravel() + fd_grad = _finite_difference_gradient(fd_objective, params0, FINITE_DIFF_STEP) + + angle = angled_overlap_deg(autograd_grad, fd_grad) + assert angle < ANGLE_OVERLAP_FD_ADJ_THRESH_DEG, ( + f"FD–adjoint angle overlap too large ({angle:.2f}°) for {geometry_type}/{operation}" + ) + + +@pytest.mark.numerical +@pytest.mark.parametrize("operation", ["u", "i", "d", "s"]) +def test_clip_operation_geometry_group_vs_fd(operation: str, numerical_case_dir: Path): + operation = {"u": "union", "i": "intersection", "d": "difference", "s": "symmetric_difference"}[ + operation + ] + geometry_type = "box" + base_sim, fom = _make_base_simulation() + case_dir = numerical_case_dir / f"group_{operation}" + case_dir.mkdir(parents=True, exist_ok=True) + + adjoint_objective = _make_clip_objective( + base_sim, + fom, + geometry_type, + operation, + case_dir, + local_gradient=True, + structure_builder=_clip_structure_geometry_group, + ) + params0 = anp.array(BASE_OFFSET, dtype=float) + _, grad = value_and_grad(adjoint_objective)(params0) + + fd_objective = _make_clip_objective( + base_sim, + fom, + geometry_type, + operation, + case_dir, + local_gradient=False, + structure_builder=_clip_structure_geometry_group, + ) + + autograd_grad = np.asarray(grad, dtype=float).ravel() + fd_grad = _finite_difference_gradient(fd_objective, params0, FINITE_DIFF_STEP) + angle = angled_overlap_deg(autograd_grad, fd_grad) + assert angle < 13, ( + f"FD–adjoint angle overlap too large ({angle:.2f}°) for geometry group / {operation}" + ) + + +@pytest.mark.skip +@pytest.mark.parametrize("geometry_type", ["b", "p", "m"]) +@pytest.mark.parametrize("operation", ["u", "i", "d", "s"]) +def test_clip_operation_fd_sweep(geometry_type: str, operation: str, numerical_case_dir: Path): + """Sweep FD step sizes and save comparison plots for diagnostic use.""" + + operation = {"u": "union", "i": "intersection", "d": "difference", "s": "symmetric_difference"}[ + operation + ] + geometry_type = {"b": "box", "p": "polyslab", "m": "mesh"}[geometry_type] + + base_sim, fom = _make_base_simulation() + case_dir = numerical_case_dir / f"{geometry_type}_{operation}_fd_sweep" + case_dir.mkdir(parents=True, exist_ok=True) + adjoint_objective = _make_clip_objective( + base_sim, + fom, + geometry_type, + operation, + case_dir, + local_gradient=True, + ) + fd_objective = _make_clip_objective( + base_sim, + fom, + geometry_type, + operation, + case_dir, + local_gradient=False, + ) + + params0 = anp.array(BASE_OFFSET, dtype=float) + _, grad = value_and_grad(adjoint_objective)(params0) + autograd_grad = np.asarray(grad, dtype=float).ravel() + + steps = np.logspace(-4, -1, num=7) + fd_grads = np.array( + [_finite_difference_gradient(fd_objective, params0, float(step)) for step in steps] + ) + + fig, ax = plt.subplots(figsize=(6, 4)) + labels = ("dx", "dy", "dz") + for idx, label in enumerate(labels): + ax.plot(steps, fd_grads[:, idx], marker="o", label=f"{label} FD") + ax.axhline( + autograd_grad[idx], + color=ax.get_lines()[-1].get_color(), + linestyle="--", + alpha=0.7, + label=f"{label} adjoint", + ) + ax.set_xscale("log") + ax.set_xlabel("Finite-difference step (µm)") + ax.set_ylabel("Gradient value") + ax.set_title(f"FD sweep for {geometry_type} / {operation}") + ax.grid(True, which="both", linestyle=":") + ax.legend(loc="best", fontsize="small") + + fig_path = case_dir / f"fd_sweep_{geometry_type}_{operation}.png" + fig.savefig(fig_path, dpi=200) + plt.close(fig) + + np.savez( + case_dir / f"fd_sweep_{geometry_type}_{operation}.npz", + steps=steps, + gradients=fd_grads, + autograd_grad=autograd_grad, + ) diff --git a/tests/test_components/autograd/test_autograd.py b/tests/test_components/autograd/test_autograd.py index b56a6345cb..7a8a4eb60b 100644 --- a/tests/test_components/autograd/test_autograd.py +++ b/tests/test_components/autograd/test_autograd.py @@ -2952,8 +2952,8 @@ def objective(center, size): grad = ag.grad(objective, argnum=1)(base_sim.center, base_sim.size) -def test_error_clip(use_emulated_run): - """Make sure proper error raised if differentiating a ``ClipOperation``.""" +def test_clip_operation_autograd(use_emulated_run): + """Ensure ``ClipOperation`` geometries support differentiation.""" def objective(x): box1 = td.Box(center=(0, 0, 0), size=(x, x, x)) @@ -2969,8 +2969,8 @@ def objective(x): data = run(sim, task_name="clip_error") return anp.sum(data["field"].intensity.item()) - with pytest.raises(ValueError): - g = ag.grad(objective)(1.0) + grad_val = ag.grad(objective)(1.0) + assert anp.all(grad_val != 0.0) def test_custom_medium_conductivity_only_gradient(rng, use_emulated_run, tmp_path): diff --git a/tests/test_components/autograd/test_autograd_clip_operation.py b/tests/test_components/autograd/test_autograd_clip_operation.py new file mode 100644 index 0000000000..3afac79541 --- /dev/null +++ b/tests/test_components/autograd/test_autograd_clip_operation.py @@ -0,0 +1,452 @@ +"""Tests for ``ClipOperation`` autograd support.""" + +from __future__ import annotations + +import copy +from collections.abc import Sequence +from typing import Callable + +import numpy as np +import numpy.testing as npt +import pytest + +import tidy3d as td +from tidy3d import TriangleMesh + +DEFAULT_SIM_BOUNDS = ((-10.0, -10.0, -10.0), (10.0, 10.0, 10.0)) + + +def _default_gradient(points: np.ndarray) -> np.ndarray: + """Deterministic gradient profile used in analytical objective evaluations.""" + if points.size == 0: + return np.zeros(0, dtype=float) + return points[:, 0] + 0.5 * points[:, 1] - 0.25 * points[:, 2] + + +class SimpleDerivativeInfo: + """Lightweight derivative info stub for geometry autograd testing.""" + + def __init__( + self, + paths: Sequence[tuple], + bounds: tuple[tuple[float, float, float], tuple[float, float, float]] | None = None, + *, + gradient_func: Callable[[np.ndarray], np.ndarray] | None = None, + spacing: float = 0.01, + simulation_bounds: tuple[tuple[float, float, float], tuple[float, float, float]] + | None = None, + ) -> None: + self.paths = list(paths) + self.frequencies = [200e12] + self.eps_in = 1.0 + self.eps_out = 1.0 + self.interpolators = {} + self._spacing = spacing + self.gradient_func = gradient_func or _default_gradient + self.simulation_bounds = simulation_bounds or DEFAULT_SIM_BOUNDS + self.bounds = bounds if bounds is not None else self.simulation_bounds + self.bounds_intersect = self.bounds + + def updated_copy(self, **kwargs): + clone = copy.copy(self) + for key, value in kwargs.items(): + setattr(clone, key, value) + return clone + + def adaptive_vjp_spacing(self) -> float: + return float(self._spacing) + + def create_interpolators(self, dtype=None): + return {} + + def evaluate_gradient_at_points( + self, spatial_coords, normals, perps1, perps2, interpolators=None + ): + points = np.asarray(spatial_coords, dtype=float) + if points.size == 0: + return np.zeros(0, dtype=float) + return self.gradient_func(points) + + +def _tetrahedron_mesh(center: Sequence[float], size: Sequence[float]) -> TriangleMesh: + """Return a watertight tetrahedron mesh centered at ``center`` with ``size`` extents.""" + + center = np.asarray(center, dtype=float) + half = 0.5 * np.asarray(size, dtype=float) + cx, cy, cz = center + hx, hy, hz = half + vertices = np.array( + [ + (cx + hx, cy + hy, cz + hz), + (cx + hx, cy - hy, cz - hz), + (cx - hx, cy + hy, cz - hz), + (cx - hx, cy - hy, cz + hz), + ], + dtype=float, + ) + triangles = np.array( + [ + (vertices[0], vertices[1], vertices[2]), + (vertices[0], vertices[3], vertices[1]), + (vertices[0], vertices[2], vertices[3]), + (vertices[1], vertices[3], vertices[2]), + ], + dtype=float, + ) + return TriangleMesh.from_triangles(triangles) + + +def build_geometry( + geometry_type: str, center: Sequence[float], size: Sequence[float] +) -> td.Geometry: + """Return a geometry instance of the requested type.""" + + center = tuple(center) + size = tuple(size) + if geometry_type == "box": + return td.Box(center=center, size=size) + + if geometry_type == "polyslab": + half_x = size[0] / 2.0 + half_y = size[1] / 2.0 + half_z = size[2] / 2.0 + vertices = np.array( + [ + (center[0] - half_x, center[1] - half_y), + (center[0] + half_x, center[1] - half_y), + (center[0], center[1] + half_y), + ], + dtype=float, + ) + slab_bounds = (center[2] - half_z, center[2] + half_z) + return td.PolySlab(vertices=vertices, axis=2, slab_bounds=slab_bounds) + + if geometry_type == "mesh": + return _tetrahedron_mesh(center=center, size=size) + + raise ValueError(f"Unsupported geometry type '{geometry_type}'.") + + +SAMPLE_POINTS = np.array( + [ + (0.0, 0.0, 0.0), + (1.5, 0.0, 0.0), + (-0.5, 0.0, 0.0), + ], + dtype=float, +) +SAMPLE_NORMALS = np.array( + [ + (1.0, 0.0, 0.0), + (0.0, 1.0, 0.0), + (0.0, 0.0, 1.0), + ], + dtype=float, +) +SAMPLE_PERPS1 = np.array( + [ + (0.0, 1.0, 0.0), + (0.0, 0.0, 1.0), + (1.0, 0.0, 0.0), + ], + dtype=float, +) +SAMPLE_PERPS2 = np.array( + [ + (0.0, 0.0, 1.0), + (1.0, 0.0, 0.0), + (0.0, 1.0, 0.0), + ], + dtype=float, +) +SAMPLE_WEIGHTS = np.ones(3, dtype=float) +SAMPLE_FACES = np.zeros(3, dtype=int) +SAMPLE_BARY = np.full((3, 3), 1.0 / 3.0, dtype=float) +INSIDE_MASK = np.array([True, False, True], dtype=bool) + + +def _sample_dict() -> dict[str, np.ndarray]: + """Return a fresh copy of the mocked sampling dictionary.""" + + return { + "points": SAMPLE_POINTS.copy(), + "normals": SAMPLE_NORMALS.copy(), + "perps1": SAMPLE_PERPS1.copy(), + "perps2": SAMPLE_PERPS2.copy(), + "weights": SAMPLE_WEIGHTS.copy(), + "faces": SAMPLE_FACES.copy(), + "barycentric": SAMPLE_BARY.copy(), + } + + +def _make_mesh() -> td.TriangleMesh: + """Create a simple triangle mesh for testing.""" + vertices = np.array( + [ + (0.0, 0.0, 0.0), + (1.0, 0.0, 0.0), + (0.0, 1.0, 0.0), + ], + dtype=float, + ) + faces = np.array([(0, 1, 2)], dtype=int) + return td.TriangleMesh.from_vertices_faces(vertices, faces) + + +def _patch_sample_collection(monkeypatch: pytest.MonkeyPatch) -> None: + """Replace surface sampling with deterministic data.""" + + def fake_collect(self, *args, **kwargs): + return _sample_dict() + + monkeypatch.setattr(td.TriangleMesh, "_collect_surface_samples", fake_collect, raising=True) + + +class RecordingDerivativeInfo: + """DerivativeInfo stub that records sampling points.""" + + def __init__(self) -> None: + self.paths = [("mesh_dataset", "surface_mesh")] + self.simulation_bounds = ((-5.0, -5.0, -5.0), (5.0, 5.0, 5.0)) + self.bounds_intersect = self.simulation_bounds + self.interpolators: dict | None = {} + self.last_points: np.ndarray | None = None + self.last_normals: np.ndarray | None = None + + def adaptive_vjp_spacing(self) -> float: + return 0.5 + + def create_interpolators(self, dtype=None): + return {} + + def evaluate_gradient_at_points( + self, + spatial_coords, + normals, + perps1, + perps2, + interpolators=None, + ): + self.last_points = np.array(spatial_coords) + self.last_normals = np.array(normals) + return np.ones(spatial_coords.shape[0], dtype=float) + + +class MinimalDerivativeInfo: + """Lightweight derivative info container for ClipOperation routing tests.""" + + def __init__(self, paths) -> None: + self.paths = [tuple(path) for path in paths] + self.interpolators = None + self.bounds = ((-2.0, -2.0, -2.0), (2.0, 2.0, 2.0)) + self.simulation_bounds = self.bounds + self.bounds_intersect = self.bounds + self.frequencies = [200e12] + self.E_der_map = {} + self.D_der_map = {} + self.E_fwd = {} + self.E_adj = {} + self.D_fwd = {} + self.D_adj = {} + self.eps_data = {} + self.eps_in = 1.0 + self.eps_out = 1.0 + self.eps_approx = False + + def create_interpolators(self, dtype=None): + return {} + + def updated_copy(self, **kwargs): + kwargs.pop("deep", None) + kwargs.pop("validate", None) + new = MinimalDerivativeInfo(self.paths) + new.__dict__.update(self.__dict__) + if "paths" in kwargs: + new.paths = list(kwargs.pop("paths")) + for key, value in kwargs.items(): + setattr(new, key, value) + return new + + def adaptive_vjp_spacing(self) -> float: + return 0.5 + + def evaluate_gradient_at_points( + self, + spatial_coords, + normals, + perps1, + perps2, + interpolators=None, + ): + return np.zeros(len(spatial_coords), dtype=float) + + +@pytest.mark.parametrize( + ("operation", "expected_use", "expected_flip"), + [ + ("intersection", INSIDE_MASK, np.array([False, False, False], dtype=bool)), + ("union", ~INSIDE_MASK, np.array([False, False, False], dtype=bool)), + ("difference", ~INSIDE_MASK, np.array([False, False, False], dtype=bool)), + ("symmetric_difference", np.array([True, True, True], dtype=bool), INSIDE_MASK), + ], +) +def test_triangle_mesh_clip_filters_geometry_a(operation, expected_use, expected_flip, monkeypatch): + """TriangleMesh sampling honors ClipOperation rules for geometry_a.""" + + mesh = _make_mesh() + _patch_sample_collection(monkeypatch) + info = RecordingDerivativeInfo() + other = td.Box(center=(0.0, 0.0, 0.0), size=(2.0, 2.0, 2.0)) + clip = td.ClipOperation(operation=operation, geometry_a=mesh, geometry_b=other) + + mesh._compute_derivatives(info, clip_operation=(clip, "geometry_a")) + + expected_points = SAMPLE_POINTS[expected_use] + expected_normals = SAMPLE_NORMALS.copy() + expected_normals[expected_flip] *= -1.0 + npt.assert_allclose(info.last_points, expected_points) + npt.assert_allclose(info.last_normals, expected_normals[expected_use]) + + +@pytest.mark.parametrize( + ("operation", "expected_use", "expected_flip"), + [ + ("intersection", INSIDE_MASK, np.array([False, False, False], dtype=bool)), + ("union", ~INSIDE_MASK, np.array([False, False, False], dtype=bool)), + ("difference", INSIDE_MASK, INSIDE_MASK), + ("symmetric_difference", np.array([True, True, True], dtype=bool), INSIDE_MASK), + ], +) +def test_triangle_mesh_clip_filters_geometry_b(operation, expected_use, expected_flip, monkeypatch): + """TriangleMesh sampling honors ClipOperation rules for geometry_b.""" + + mesh = _make_mesh() + _patch_sample_collection(monkeypatch) + info = RecordingDerivativeInfo() + other = td.Box(center=(0.0, 0.0, 0.0), size=(2.0, 2.0, 2.0)) + clip = td.ClipOperation(operation=operation, geometry_a=other, geometry_b=mesh) + + mesh._compute_derivatives(info, clip_operation=(clip, "geometry_b")) + + expected_points = SAMPLE_POINTS[expected_use] + expected_normals = SAMPLE_NORMALS.copy() + expected_normals[expected_flip] *= -1.0 + npt.assert_allclose(info.last_points, expected_points) + npt.assert_allclose(info.last_normals, expected_normals[expected_use]) + + +def test_clip_operation_box_passes_clip_context(monkeypatch): + """``ClipOperation`` forwards context when differentiating Box geometries.""" + + contexts = [] + + def fake_mesh_derivatives(self, derivative_info, clip_operation=None): + contexts.append(clip_operation) + triangles = np.asarray(self.triangles, dtype=float) + return {("mesh_dataset", "surface_mesh"): np.zeros_like(triangles)} + + monkeypatch.setattr( + td.TriangleMesh, "_compute_derivatives", fake_mesh_derivatives, raising=True + ) + + box_a = td.Box(center=(0.0, 0.0, 0.0), size=(1.0, 1.0, 1.0)) + box_b = td.Box(center=(1.0, 0.0, 0.0), size=(1.0, 1.0, 1.0)) + clip = td.ClipOperation(operation="union", geometry_a=box_a, geometry_b=box_b) + info = MinimalDerivativeInfo(paths=[("geometry_a", "center", 0), ("geometry_b", "size", 1)]) + + result = clip._compute_derivatives(info) + + assert contexts == [(clip, "geometry_a"), (clip, "geometry_b")] + assert ("geometry_a", "center", 0) in result + assert ("geometry_b", "size", 1) in result + + +def test_clip_operation_polyslab_passes_clip_context(monkeypatch): + """``ClipOperation`` forwards context when differentiating PolySlab geometries.""" + + contexts = [] + + def fake_mesh_derivatives(self, derivative_info, clip_operation=None): + contexts.append(clip_operation) + triangles = np.asarray(self.triangles, dtype=float) + return {("mesh_dataset", "surface_mesh"): np.zeros_like(triangles)} + + monkeypatch.setattr( + td.TriangleMesh, "_compute_derivatives", fake_mesh_derivatives, raising=True + ) + + vertices = np.array(((0.0, 0.0), (1.0, 0.0), (0.0, 1.0)), dtype=float) + slab = td.PolySlab(vertices=vertices, slab_bounds=(-0.5, 0.5), axis=2) + box = td.Box(center=(0.0, 0.0, 0.0), size=(3.0, 3.0, 3.0)) + clip = td.ClipOperation(operation="union", geometry_a=slab, geometry_b=box) + info = MinimalDerivativeInfo(paths=[("geometry_a", "vertices")]) + + result = clip._compute_derivatives(info) + + assert contexts == [(clip, "geometry_a")] + assert ("geometry_a", "vertices") in result + + +FIELD_PATHS = { + "box": ("center", 0), + # "polyslab": ("vertices",), + "polyslab": ("slab_bounds", 0), + "mesh": ("mesh_dataset", "surface_mesh"), +} + + +@pytest.mark.parametrize("geometry_type", ["box", "polyslab", "mesh"]) +@pytest.mark.parametrize( + "operation, expected_scales", + [ + ("union", (1.0, 0.0)), + ("intersection", (0.0, 1.0)), + ("difference", (1.0, -1.0)), + ("symmetric_difference", (1.0, -1.0)), + ], +) +def test_clip_operation_known_gradient_relations(geometry_type, operation, expected_scales): + """Compare ClipOperation gradients against analytical expectations for nested boxes.""" + + center_a = (0.0, 0.0, 0.0) + center_b = (0.2, 0.0, 0.0) + size_a = (2.0, 1.4, 1.0) + size_b = (1.0, 0.8, 0.6) + + geometry_a = build_geometry(geometry_type, center=center_a, size=size_a) + geometry_b = build_geometry(geometry_type, center=center_b, size=size_b) + + field_path = FIELD_PATHS[geometry_type] + geo_di_a = SimpleDerivativeInfo(paths=[field_path], bounds=geometry_a.bounds) + geo_di_b = SimpleDerivativeInfo(paths=[field_path], bounds=geometry_b.bounds) + baseline_a = geometry_a._compute_derivatives(geo_di_a).get(field_path, 0.0) + baseline_b = geometry_b._compute_derivatives(geo_di_b).get(field_path, 0.0) + + clip = td.ClipOperation(operation=operation, geometry_a=geometry_a, geometry_b=geometry_b) + clip_path_a = ("geometry_a", *field_path) + clip_path_b = ("geometry_b", *field_path) + clip_di = SimpleDerivativeInfo(paths=[clip_path_a, clip_path_b], bounds=clip.bounds) + gradients = clip._compute_derivatives(clip_di) + + expected_scale_a, expected_scale_b = expected_scales + grad_a = gradients.get(clip_path_a, 0.0) + grad_b = gradients.get(clip_path_b, 0.0) + + rtol = 6e-2 + atol = 3e-2 + # if geometry_type == "mesh": + # rtol = 6e-2 + # atol = 8e-2 + + np.testing.assert_allclose( + np.asarray(grad_a, dtype=float), + expected_scale_a * np.asarray(baseline_a, dtype=float), + rtol=rtol, + atol=atol, + ) + np.testing.assert_allclose( + np.asarray(grad_b, dtype=float), + expected_scale_b * np.asarray(baseline_b, dtype=float), + rtol=rtol, + atol=atol, + ) diff --git a/tests/test_components/autograd/test_mesh_derivatives.py b/tests/test_components/autograd/test_mesh_derivatives.py new file mode 100644 index 0000000000..abd0e11a2f --- /dev/null +++ b/tests/test_components/autograd/test_mesh_derivatives.py @@ -0,0 +1,208 @@ +"""Regression tests comparing mesh-based derivatives to legacy implementations.""" + +from __future__ import annotations + +import copy +import time + +import numpy as np +import numpy.testing as npt + +import tidy3d as td + + +class DummyDerivativeInfo: + """Minimal derivative info stub used for geometry unit tests.""" + + def __init__(self, grad_func, paths): + self.paths = paths + self._grad_func = grad_func + self.frequencies = [200e12] + self.eps_in = 12.0 + self._spacing = 0.05 + self.simulation_bounds = ((-2.0, -2.0, -2.0), (2.0, 2.0, 2.0)) + self.bounds = self.simulation_bounds + self.bounds_intersect = self.simulation_bounds + self.interpolators = None + + def adaptive_vjp_spacing(self) -> float: + return self._spacing + + def create_interpolators(self, dtype=None): + return {} + + def updated_copy(self, **kwargs): + kwargs.pop("deep", None) + kwargs.pop("validate", None) + new_info = copy.copy(self) + for key, value in kwargs.items(): + setattr(new_info, key, value) + return new_info + + @property + def wavelength_min(self) -> float: + return td.C_0 / max(self.frequencies) + + def evaluate_gradient_at_points( + self, + spatial_coords=None, + normals=None, + perps1=None, + perps2=None, + interpolators=None, + ): + coords = spatial_coords if spatial_coords is not None else np.zeros((0, 3)) + return self._grad_func(coords) + + +def constant_grad(points: np.ndarray) -> np.ndarray: + return np.ones(points.shape[0], dtype=float) + + +def _assert_mesh_legacy_match(geometry: td.Geometry, derivative_info: DummyDerivativeInfo) -> None: + derivative_info.bounds = geometry.bounds + derivative_info.bounds_intersect = geometry.bounds + mesh_vjps = geometry._compute_derivatives(derivative_info) + legacy_vjps = geometry._compute_derivatives_via_mesh(derivative_info) + + assert set(mesh_vjps) == set(legacy_vjps) + for key in mesh_vjps: + mesh_val = mesh_vjps[key] + legacy_val = legacy_vjps[key] + npt.assert_allclose(mesh_val, legacy_val, rtol=1e-4, atol=1e-6) + + +def test_box_mesh_derivatives_match_legacy_gradients(): + box = td.Box(center=(0.2, -0.1, 0.05), size=(1.2, 0.9, 0.8)) + + derivative_info = DummyDerivativeInfo( + constant_grad, + paths=[("center",), ("size",)], + ) + + _assert_mesh_legacy_match(box, derivative_info) + + +def test_cylinder_mesh_derivatives_match_legacy_gradients(): + cylinder = td.Cylinder( + center=(-0.3, 0.15, 0.0), + radius=0.45, + length=1.1, + axis=2, + sidewall_angle=0.08, + ) + + derivative_info = DummyDerivativeInfo( + constant_grad, + paths=[("center", 0), ("center", 1), ("radius",), ("length",), ("sidewall_angle",)], + ) + + _assert_mesh_legacy_match(cylinder, derivative_info) + + +# def test_box_mesh_derivative_timing(): +# box = td.Box(center=(0.1, -0.2, 0.3), size=(1.2, 0.8, 1.6)) +# +# derivative_info = DummyDerivativeInfo( +# constant_grad, +# paths=[("center",), ("size",)], +# ) +# derivative_info.bounds = box.bounds +# derivative_info.bounds_intersect = box.bounds +# +# # timings = _runtime_comparison_by_path(box, derivative_info) +# +# for path, (mesh_time, legacy_time, ratio) in timings.items(): +# print( +# f"Box | path={path}: " +# f"mesh {mesh_time * 1e3:.2f} ms, " +# f"legacy {legacy_time * 1e3:.2f} ms, " +# f"ratio={ratio:.2f}x", +# flush=True, +# ) + + +def _runtime_comparison( + polyslab: td.PolySlab, derivative_info: DummyDerivativeInfo +) -> tuple[float, float]: + start = time.perf_counter() + polyslab._compute_derivatives(derivative_info) + mesh_time = time.perf_counter() - start + + start = time.perf_counter() + polyslab._compute_derivatives_legacy(derivative_info) + legacy_time = time.perf_counter() - start + return mesh_time, legacy_time + + +def _runtime_comparison_by_path(polyslab, derivative_info): + timings = {} + + for path in derivative_info.paths: + if path[0] == "slab_bounds": + continue + single_info = copy.copy(derivative_info) # ✅ shallow clone + single_info.paths = [path] # ✅ override only what we need + + # --- mesh backend --- + t0 = time.perf_counter() + polyslab._compute_derivatives(single_info) + mesh_time = time.perf_counter() - t0 + + # --- legacy backend --- + t0 = time.perf_counter() + polyslab._compute_derivatives_legacy(single_info) + legacy_time = time.perf_counter() - t0 + + ratio = mesh_time / legacy_time if legacy_time != 0 else float("inf") + + timings[path] = (mesh_time, legacy_time, ratio) + + return timings + + +def test_polyslab_mesh_derivatives_match_legacy(): + base_vertices = np.array( + [ + (0.0, 0.0), + (1.0, 0.0), + (1.0, 1.0), + (0.0, 1.0), + ], + dtype=float, + ) + vertices_list = [base_vertices] + for scale in np.logspace(1, 6, num=6, base=2, dtype=int): + radial = np.linspace(0, 2 * np.pi, 4 * scale, endpoint=False) + vertices_list.append(np.column_stack((np.cos(radial), np.sin(radial)))) + + slab_bounds = (-0.6, 0.7) + for verts in vertices_list: + polyslab = td.PolySlab(vertices=verts, slab_bounds=slab_bounds, axis=2, sidewall_angle=0.1) + + derivative_info = DummyDerivativeInfo( + constant_grad, + paths=[("vertices",), ("sidewall_angle",), ("slab_bounds", 0), ("slab_bounds", 1)], + ) + derivative_info.bounds = polyslab.bounds + derivative_info.bounds_intersect = polyslab.bounds + + mesh_vjps = polyslab._compute_derivatives(derivative_info) + legacy_vjps = polyslab._compute_derivatives_via_mesh(derivative_info) + + assert set(mesh_vjps) == set(legacy_vjps) + for key in mesh_vjps: + mesh_val = mesh_vjps[key] + legacy_val = legacy_vjps[key] + npt.assert_allclose(mesh_val, legacy_val, rtol=1e-4, atol=1e-6) + + # timings = _runtime_comparison_by_path(polyslab, derivative_info) + # + # for path, (mesh_time, legacy_time, ratio) in timings.items(): + # print( + # f"Polyslab vertices={len(verts)} | path={path}: " + # f"mesh {mesh_time * 1e3:.2f} ms, " + # f"legacy {legacy_time * 1e3:.2f} ms, " + # f"ratio={ratio:.2f}x", + # flush=True, + # ) diff --git a/tidy3d/components/geometry/base.py b/tidy3d/components/geometry/base.py index 9a2fc61401..0bd2711850 100644 --- a/tidy3d/components/geometry/base.py +++ b/tidy3d/components/geometry/base.py @@ -7,7 +7,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterable, Sequence from os import PathLike -from typing import TYPE_CHECKING, Any, Callable, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union import autograd.numpy as np import pydantic.v1 as pydantic @@ -61,6 +61,7 @@ polygon_patch, set_default_labels_and_title, ) +from tidy3d.config import config from tidy3d.constants import LARGE_NUMBER, MICROMETER, RADIAN, fp_eps, inf from tidy3d.exceptions import ( SetupError, @@ -77,6 +78,11 @@ from matplotlib.backend_bases import Event from matplotlib.patches import FancyArrowPatch + from tidy3d import TriangleMesh + +ClipGeometryKey = Literal["geometry_a", "geometry_b"] +ClipOperationContext = tuple["ClipOperation", ClipGeometryKey] + POLY_GRID_SIZE = 1e-12 POLY_TOLERANCE_RATIO = 1e-12 POLY_DISTANCE_TOLERANCE = 8e-12 @@ -1483,10 +1489,23 @@ def to_gds_file( fname.parent.mkdir(parents=True, exist_ok=True) library.write_gds(fname) - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + def _compute_derivatives( + self, + derivative_info: DerivativeInfo, + ) -> AutogradFieldMap: """Compute the adjoint derivatives for this object.""" raise NotImplementedError(f"Can't compute derivative for 'Geometry': '{type(self)}'.") + def _compute_derivatives_via_mesh( + self, + derivative_info: DerivativeInfo, + clip_operation: Optional[ClipOperationContext] = None, + ) -> AutogradFieldMap: + """Compute the adjoint derivatives for this object.""" + raise NotImplementedError( + f"Can't compute derivative for clipped 'Geometry': '{type(self)}'." + ) + def _as_union(self) -> list[Geometry]: """Return a list of geometries that, united, make up the given geometry.""" if isinstance(self, GeometryGroup): @@ -1909,6 +1928,50 @@ def _intersect_dist(self, position: float, z0: float) -> float: """Primitive classes""" +def _default_box_faces() -> NDArray[np.int_]: + """Return the canonical triangle indices for a cube with CCW winding.""" + + # faces ordered as: bottom, top, x-, x+, y-, y+ + return np.asarray( + [ + (0, 1, 2), + (0, 2, 3), + (4, 6, 5), + (4, 7, 6), + (0, 4, 5), + (0, 5, 1), + (2, 6, 7), + (2, 7, 3), + (1, 5, 6), + (1, 6, 2), + (0, 3, 7), + (0, 7, 4), + ], + dtype=int, + ) + + +_BOX_VERTEX_SIGNS = np.asarray( + [ + (-1.0, -1.0, -1.0), + (1.0, -1.0, -1.0), + (1.0, 1.0, -1.0), + (-1.0, 1.0, -1.0), + (-1.0, -1.0, 1.0), + (1.0, -1.0, 1.0), + (1.0, 1.0, 1.0), + (-1.0, 1.0, 1.0), + ], + dtype=float, +) + +_BOX_FACE_VERTEX_INDICES = { + 0: ((0, 3, 7, 4), (1, 2, 6, 5)), + 1: ((0, 4, 5, 1), (3, 7, 6, 2)), + 2: ((0, 1, 2, 3), (4, 5, 6, 7)), +} + + class Box(SimplePlaneIntersection, Centered): """Rectangular prism. Also base class for :class:`.Simulation`, :class:`Monitor`, and :class:`Source`. @@ -1925,6 +1988,9 @@ class Box(SimplePlaneIntersection, Centered): units=MICROMETER, ) + _triangle_faces: NDArray[np.int_] = pydantic.PrivateAttr(default_factory=_default_box_faces) + _triangle_mesh_cache: Optional[TriangleMesh] = pydantic.PrivateAttr(default=None) + @classmethod def from_bounds(cls, rmin: Coordinate, rmax: Coordinate, **kwargs: Any) -> Self: """Constructs a :class:`Box` from minimum and maximum coordinate bounds @@ -2539,6 +2605,85 @@ def _surface_area(self, bounds: Bound) -> float: """ Autograd code """ + def _box_vertices_array(self, dtype: np.dtype | None = None) -> NDArray: + dtype = dtype or config.adjoint.gradient_dtype_float + center = np.asarray(self.center, dtype=dtype) + half_size = 0.5 * np.asarray(self.size, dtype=dtype) + return center + half_size * _BOX_VERTEX_SIGNS.astype(dtype) + + def to_triangle_mesh(self) -> TriangleMesh: + """Return (and lazily construct) the triangle mesh representation.""" + + if self._triangle_mesh_cache is None: + from .mesh import TriangleMesh + + vertices = self._box_vertices_array() + self._triangle_mesh_cache = TriangleMesh.from_vertices_faces( + vertices, self._triangle_faces + ) + return self._triangle_mesh_cache + + def _accumulate_vertex_gradients(self, triangle_grads: NDArray) -> NDArray: + """Aggregate per-triangle gradients into unique vertex gradients.""" + + num_vertices = _BOX_VERTEX_SIGNS.shape[0] + vertex_grads = np.zeros((num_vertices, 3), dtype=triangle_grads.dtype) + for face_index, face in enumerate(self._triangle_faces): + for local_idx, vertex_idx in enumerate(face): + vertex_grads[vertex_idx] += triangle_grads[face_index, local_idx] + return vertex_grads + + def _compute_derivatives_via_mesh( + self, + derivative_info: DerivativeInfo, + clip_operation: Optional[ClipOperationContext] = None, + ) -> AutogradFieldMap: + """Compute the adjoint derivatives using the ``TriangleMesh`` surface sampling.""" + + mesh = self.to_triangle_mesh() + original_paths = derivative_info.paths + derivative_info.paths = [("mesh_dataset", "surface_mesh")] + try: + mesh_vjps = mesh._compute_derivatives(derivative_info, clip_operation=clip_operation) + finally: + derivative_info.paths = original_paths + gradient_key = ("mesh_dataset", "surface_mesh") + if gradient_key not in mesh_vjps: + return {} + + triangle_grads = mesh_vjps[gradient_key] + vertex_grads = np.asarray(self._accumulate_vertex_gradients(triangle_grads), dtype=float) + vjps_faces = np.zeros((2, 3), dtype=vertex_grads.dtype) + + for axis in range(3): + for min_max_index, indices in enumerate(_BOX_FACE_VERTEX_INDICES[axis]): + direction = 1.0 if min_max_index == 0 else -1.0 + vjps_faces[min_max_index, axis] = direction * float( + np.sum(vertex_grads[np.asarray(indices), axis]) + ) + + vjps_center_size = self._derivatives_center_size(vjps_faces) + noise_floor = 1e-7 + cleaned_center_size = {} + for key, val in vjps_center_size.items(): + arr = np.asarray(val, dtype=float) + arr = np.where(np.abs(arr) < noise_floor, 0.0, arr) + cleaned_center_size[key] = tuple(arr.tolist()) + vjps_center_size = cleaned_center_size + + derivative_map: AutogradFieldMap = {} + for field_path in derivative_info.paths: + field_name, *index = field_path + + if field_name in vjps_center_size: + if index and len(index) == 1: + idx = int(index[0]) + derivative_map[field_path] = vjps_center_size[field_name][idx] + else: + derivative_map[field_path] = vjps_center_size[field_name] + + return derivative_map + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: """Compute the adjoint derivatives for this object.""" @@ -3086,17 +3231,6 @@ class ClipOperation(Geometry): description="Second operand for the set operation. It can also be any geometry type.", ) - @pydantic.validator("geometry_a", "geometry_b", always=True) - def _geometries_untraced(cls, val: GeometryType) -> GeometryType: - """Make sure that ``ClipOperation`` geometries do not contain tracers.""" - traced = val._strip_traced_fields() - if traced: - raise ValidationError( - f"{val.type} contains traced fields {list(traced.keys())}. Note that " - "'ClipOperation' does not currently support automatic differentiation." - ) - return val - @staticmethod def to_polygon_list(base_geometry: Shapely, cleanup: bool = False) -> list[Shapely]: """Return a list of valid polygons from a shapely geometry, discarding points, lines, and @@ -3162,6 +3296,97 @@ def _bit_operation(self) -> Callable[[Any, Any], Any]: ) return result + def _geometry_from_key(self, which: ClipGeometryKey) -> Geometry: + """Return the geometry referenced by ``which``.""" + if which == "geometry_a": + return self.geometry_a + if which == "geometry_b": + return self.geometry_b + raise ValueError(f"Unsupported geometry key '{which}'.") + + def _other_geometry(self, which: ClipGeometryKey) -> Geometry: + """Return the opposing geometry for ``which``.""" + return self.geometry_b if which == "geometry_a" else self.geometry_a + + @staticmethod + def _points_to_array(points: ArrayLike) -> tuple[np.ndarray, bool]: + """Convert sample points to a 2D array and track whether input was a single point.""" + arr = np.asarray(points, dtype=float) + single_point = arr.ndim == 1 + if arr.size == 0: + arr = arr.reshape((0, 3)) + else: + arr = arr.reshape((-1, 3)) + return arr, single_point + + @staticmethod + def _clip_use_mask( + operation: ClipOperationType, which: ClipGeometryKey, inside_mask: np.ndarray + ) -> np.ndarray: + """Return the inclusion mask for the requested clip operation.""" + + mask = np.asarray(inside_mask, dtype=bool) + if operation == "intersection": + return mask.copy() + if operation == "union": + return ~mask + if operation == "difference": + return (~mask) if which == "geometry_a" else mask.copy() + if operation == "symmetric_difference": + return np.ones_like(mask, dtype=bool) + raise ValueError(f"Unsupported clip operation '{operation}'.") + + @staticmethod + def _clip_flip_mask( + operation: ClipOperationType, which: ClipGeometryKey, inside_mask: np.ndarray + ) -> np.ndarray: + """Return the normal flip mask for the requested clip operation.""" + + mask = np.asarray(inside_mask, dtype=bool) + if operation == "difference": + if which == "geometry_b": + return mask.copy() + return np.zeros_like(mask, dtype=bool) + if operation == "symmetric_difference": + return mask.copy() + return np.zeros_like(mask, dtype=bool) + + def sample_points_should_use( + self, which: ClipGeometryKey, points: ArrayLike + ) -> Union[bool, NDArray[np.bool_]]: + """Return a mask indicating which samples contribute to the gradient.""" + points_arr, single_point = self._points_to_array(points) + if points_arr.size == 0: + result = np.zeros(0, dtype=bool) + return bool(result.size) if single_point else result + + other = self._other_geometry(which) + inside_other = np.asarray( + other.inside(points_arr[:, 0], points_arr[:, 1], points_arr[:, 2]), dtype=bool + ) + + use_mask = self._clip_use_mask(self.operation, which, inside_other) + + return bool(use_mask[0]) if single_point else use_mask + + def sample_normals_should_flip( + self, which: ClipGeometryKey, points: ArrayLike + ) -> Union[bool, NDArray[np.bool_]]: + """Return a mask indicating which sample normals require flipping.""" + points_arr, single_point = self._points_to_array(points) + if points_arr.size == 0: + result = np.zeros(0, dtype=bool) + return bool(result.size) if single_point else result + + other = self._other_geometry(which) + inside_other = np.asarray( + other.inside(points_arr[:, 0], points_arr[:, 1], points_arr[:, 2]), dtype=bool + ) + + flip_mask = self._clip_flip_mask(self.operation, which, inside_other) + + return bool(flip_mask[0]) if single_point else flip_mask + def intersections_tilted_plane( self, normal: Coordinate, @@ -3363,6 +3588,46 @@ def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> ClipOp new_geom_b = self.geometry_b._update_from_bounds(bounds=bounds, axis=axis) return self.updated_copy(geometry_a=new_geom_a, geometry_b=new_geom_b) + def _compute_derivatives( + self, + derivative_info: DerivativeInfo, + clip_operation: Optional[ClipOperationContext] = None, + ) -> AutogradFieldMap: + """Compute adjoint gradients for both operands in the clip operation.""" + + grad_vjps: AutogradFieldMap = {} + interpolators = derivative_info.interpolators or derivative_info.create_interpolators() + + for field_path in derivative_info.paths: + if not field_path: + continue + which, *geo_path = field_path + if which not in ("geometry_a", "geometry_b"): + raise ValueError( + "ClipOperation derivatives are only defined for 'geometry_a' or 'geometry_b'." + ) + if not geo_path: + raise ValueError("ClipOperation derivative path must specify a geometry field.") + geometry = self._geometry_from_key(which) + geo_info = derivative_info.updated_copy( + paths=[tuple(geo_path)], + bounds=geometry.bounds, + bounds_intersect=self.bounds_intersection( + geometry.bounds, derivative_info.simulation_bounds + ), + eps_approx=True, + deep=False, + interpolators=interpolators, + ) + vjps_geo = geometry._compute_derivatives_via_mesh( + geo_info, clip_operation=(self, which) + ) + if len(vjps_geo) != 1: + raise AssertionError("Expected a single gradient value for each geometry field.") + grad_vjps[field_path] = vjps_geo.popitem()[1] + + return grad_vjps + class GeometryGroup(Geometry): """A collection of Geometry objects that can be called as a single geometry object.""" @@ -3579,7 +3844,18 @@ def _update_from_bounds(self, bounds: tuple[float, float], axis: Axis) -> Geomet ] return self.updated_copy(geometries=new_geometries) - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + def _compute_derivatives_via_mesh( + self, + derivative_info: DerivativeInfo, + clip_operation: Optional[ClipOperationContext] = None, + ) -> AutogradFieldMap: + return self._compute_derivatives(derivative_info, clip_operation=clip_operation) + + def _compute_derivatives( + self, + derivative_info: DerivativeInfo, + clip_operation: Optional[ClipOperationContext] = None, + ) -> AutogradFieldMap: """Compute the adjoint derivatives for this object.""" grad_vjps = {} @@ -3601,8 +3877,12 @@ def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradField deep=False, interpolators=interpolators, ) - - vjp_dict_geo = geo._compute_derivatives(geo_info) + if clip_operation is not None: + vjp_dict_geo = geo._compute_derivatives_via_mesh( + geo_info, clip_operation=clip_operation + ) + else: + vjp_dict_geo = geo._compute_derivatives(geo_info) if len(vjp_dict_geo) != 1: raise AssertionError("Got multiple gradients for single geometry field.") diff --git a/tidy3d/components/geometry/mesh.py b/tidy3d/components/geometry/mesh.py index 006f49dc0a..5c20eb7d19 100644 --- a/tidy3d/components/geometry/mesh.py +++ b/tidy3d/components/geometry/mesh.py @@ -746,7 +746,18 @@ def plot( return base.Geometry.plot(self, x=x, y=y, z=z, ax=ax, **patch_kwargs) - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + def _compute_derivatives_via_mesh( + self, + derivative_info: DerivativeInfo, + clip_operation: Optional[base.ClipOperationContext] = None, + ) -> AutogradFieldMap: + return self._compute_derivatives(derivative_info, clip_operation=clip_operation) + + def _compute_derivatives( + self, + derivative_info: DerivativeInfo, + clip_operation: Optional[base.ClipOperationContext] = None, + ) -> AutogradFieldMap: """Compute adjoint derivatives for a ``TriangleMesh`` geometry.""" vjps: AutogradFieldMap = {} @@ -777,12 +788,16 @@ def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradField # gather surface samples within the simulation bounds dx = derivative_info.adaptive_vjp_spacing() + if clip_operation is not None: + dx = max(dx * 0.2, np.finfo(float).eps) samples = self._collect_surface_samples( triangles=triangles, spacing=dx, sim_min=sim_min, sim_max=sim_max, ) + if clip_operation is not None: + samples = self._apply_clip_filters(samples, clip_operation) if samples["points"].shape[0] == 0: zeros = np.zeros_like(triangles) @@ -802,6 +817,21 @@ def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradField samples["perps2"], interpolators, ) + # g = np.asarray(g, dtype=config.adjoint.gradient_dtype_float) + # n_points = samples["points"].shape[0] + # if g.ndim == 0: + # g = np.full(n_points, g, dtype=config.adjoint.gradient_dtype_float) + # elif g.ndim > 1: + # last_dim = g.shape[-1] + # if last_dim != n_points: + # g = np.reshape(g, (-1,)) + # if g.size != n_points: + # raise ValueError( + # "Gradient evaluation result shape does not match sample count." + # ) + # g = g + # else: + # g = g.reshape(-1, n_points).sum(axis=0) # accumulate per-vertex contributions using barycentric weights weights = (samples["weights"] * g).real @@ -819,6 +849,110 @@ def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradField vjps[("mesh_dataset", "surface_mesh")] = triangle_grads return vjps + def _apply_clip_filters( + self, + samples: dict[str, np.ndarray], + clip_operation: base.ClipOperationContext, + ) -> dict[str, np.ndarray]: + """Filter and adjust samples according to clip operation rules.""" + + clip_obj, which = clip_operation + other_geometry = clip_obj._other_geometry(which) + if not isinstance(other_geometry, TriangleMesh): + return self._apply_basic_clip_filters(samples, clip_operation) + + clip_geometry = self._prepare_clip_geometry(other_geometry) + + points = np.asarray(samples["points"], dtype=config.adjoint.gradient_dtype_float) + normals = np.asarray(samples["normals"], dtype=config.adjoint.gradient_dtype_float) + total_points = points.shape[0] + if total_points == 0: + return samples + + shift = max(float(config.adjoint.edge_clip_tolerance), 1e-9) + probe_points = points - normals * shift + inside_mask = np.asarray( + clip_geometry.inside( + probe_points[:, 0], + probe_points[:, 1], + probe_points[:, 2], + ), + dtype=bool, + ).reshape(-1) + + use_mask = clip_obj._clip_use_mask(clip_obj.operation, which, inside_mask) + if use_mask.size != total_points: + raise ValueError("ClipOperation sample mask has incorrect shape.") + if not np.any(use_mask): + return {key: np.asarray(value[:0]).copy() for key, value in samples.items()} + + filtered = {key: np.asarray(value[use_mask]).copy() for key, value in samples.items()} + + flip_mask = clip_obj._clip_flip_mask(clip_obj.operation, which, inside_mask) + if flip_mask.size != total_points: + raise ValueError("ClipOperation normal flip mask has incorrect shape.") + flip_mask = flip_mask[use_mask] + if np.any(flip_mask): + flip_signs = np.where(flip_mask[:, None], -1.0, 1.0) + filtered["normals"] = filtered["normals"] * flip_signs + + return filtered + + def _apply_basic_clip_filters( + self, + samples: dict[str, np.ndarray], + clip_operation: base.ClipOperationContext, + ) -> dict[str, np.ndarray]: + """Fallback clip filtering that mirrors the historical behavior.""" + + clip_obj, which = clip_operation + points = samples["points"] + total_points = points.shape[0] + if total_points == 0: + return {key: np.asarray(value[:0]).copy() for key, value in samples.items()} + + raw_use = clip_obj.sample_points_should_use(which, points) + use_mask = np.asarray(raw_use, dtype=bool).reshape(-1) + if use_mask.size != total_points: + raise ValueError("ClipOperation sample mask has incorrect shape.") + if not np.any(use_mask): + return {key: np.asarray(value[:0]).copy() for key, value in samples.items()} + + filtered = {key: np.asarray(value[use_mask]).copy() for key, value in samples.items()} + + raw_flip = clip_obj.sample_normals_should_flip(which, points) + flip_mask = np.asarray(raw_flip, dtype=bool).reshape(-1) + if flip_mask.size != total_points: + raise ValueError("ClipOperation normal flip mask has incorrect shape.") + flip_mask = flip_mask[use_mask] + if np.any(flip_mask): + flip_signs = np.where(flip_mask[:, None], -1.0, 1.0) + filtered["normals"] = filtered["normals"] * flip_signs + + return filtered + + @staticmethod + def _prepare_clip_geometry(other: base.Geometry) -> base.Geometry: + """Return a TriangleMesh suitable for geometric clipping operations.""" + + if not isinstance(other, TriangleMesh): + return other + + try: + tri_mesh = other.trimesh + except Exception: + return other + + if tri_mesh.is_volume: + return other + + try: + hull = tri_mesh.convex_hull + except Exception: + return other + + return TriangleMesh.from_trimesh(hull) + def _collect_surface_samples( self, triangles: NDArray, @@ -844,6 +978,37 @@ def _collect_surface_samples( spacing = max(float(spacing), np.finfo(float).eps) triangles_arr = np.asarray(triangles, dtype=dtype) + if triangles_arr.size == 0: + return self._empty_sample_result(dtype) + + edges01 = triangles_arr[:, 1, :] - triangles_arr[:, 0, :] + edges02 = triangles_arr[:, 2, :] - triangles_arr[:, 0, :] + edges12 = triangles_arr[:, 2, :] - triangles_arr[:, 1, :] + cross = np.cross(edges01, edges02) + norm = np.linalg.norm(cross, axis=1) + areas = 0.5 * norm + + normals = np.zeros((triangles_arr.shape[0], 3), dtype=dtype) + nonzero_norm = norm > 0.0 + normals[nonzero_norm] = cross[nonzero_norm] / norm[nonzero_norm][:, None] + + edge_tol = np.finfo(dtype).eps + edge_choice = np.where( + (np.linalg.norm(edges01, axis=1) > edge_tol)[:, None], + edges01, + edges12, + ) + edge_choice_norm = np.linalg.norm(edge_choice, axis=1) + has_edge = edge_choice_norm > edge_tol + perps1 = np.zeros_like(normals) + perps1[has_edge] = edge_choice[has_edge] / edge_choice_norm[has_edge][:, None] + perps2_tmp = np.cross(normals, perps1) + perps2_norm = np.linalg.norm(perps2_tmp, axis=1) + has_basis = perps2_norm > edge_tol + perps2 = np.zeros_like(normals) + perps2[has_basis] = perps2_tmp[has_basis] / perps2_norm[has_basis][:, None] + + usable_mask = (areas > AREA_SIZE_THRESHOLD) & has_edge & has_basis sim_extents = sim_max - sim_min valid_axes = np.abs(sim_extents) > tol @@ -854,229 +1019,306 @@ def _collect_surface_samples( collapsed_axis = int(collapsed_indices[0]) plane_value = float(sim_min[collapsed_axis]) - warned = False - warning_msg = "Some triangles from the mesh lie outside the simulation bounds - this may lead to inaccurate gradients." - for face_index, tri in enumerate(triangles_arr): - area, normal = self._triangle_area_and_normal(tri) - if area <= AREA_SIZE_THRESHOLD: - continue - - perps = self._triangle_tangent_basis(tri, normal) - if perps is None: - continue - perp1, perp2 = perps - - if collapsed_axis is not None and plane_value is not None: - samples, outside_bounds = self._collect_surface_samples_2d( - triangle=tri, - face_index=face_index, - normal=normal, - perp1=perp1, - perp2=perp2, - spacing=spacing, - collapsed_axis=collapsed_axis, - plane_value=plane_value, - sim_min=sim_min, - sim_max=sim_max, - valid_axes=valid_axes, - tol=tol, - dtype=dtype, - ) - else: - samples, outside_bounds = self._collect_surface_samples_3d( - triangle=tri, - face_index=face_index, - normal=normal, - perp1=perp1, - perp2=perp2, - area=area, - spacing=spacing, - sim_min=sim_min, - sim_max=sim_max, - valid_axes=valid_axes, - tol=tol, - dtype=dtype, - ) - - if outside_bounds and not warned: - log.warning(warning_msg) - warned = True + if collapsed_axis is not None and plane_value is not None: + return self._collect_plane_samples( + triangles_arr=triangles_arr, + normals=normals, + perps1=perps1, + perps2=perps2, + usable_mask=usable_mask, + spacing=spacing, + sim_min=sim_min, + sim_max=sim_max, + valid_axes=valid_axes, + collapsed_axis=collapsed_axis, + plane_value=plane_value, + tol=tol, + dtype=dtype, + ) - if samples is None: - continue + tri_min = np.min(triangles_arr, axis=1) + tri_max = np.max(triangles_arr, axis=1) + overlaps = np.all(tri_max >= (sim_min - tol), axis=1) & np.all( + tri_min <= (sim_max + tol), axis=1 + ) + usable_mask &= overlaps - points_list.append(samples["points"]) - normals_list.append(samples["normals"]) - perps1_list.append(samples["perps1"]) - perps2_list.append(samples["perps2"]) - weights_list.append(samples["weights"]) - faces_list.append(samples["faces"]) - bary_list.append(samples["barycentric"]) + fully_inside = ( + usable_mask + & np.all(tri_min >= (sim_min - tol), axis=1) + & np.all(tri_max <= (sim_max + tol), axis=1) + ) + needs_clip = usable_mask & ~fully_inside - if not points_list: - return { - "points": np.zeros((0, 3), dtype=dtype), - "normals": np.zeros((0, 3), dtype=dtype), - "perps1": np.zeros((0, 3), dtype=dtype), - "perps2": np.zeros((0, 3), dtype=dtype), - "weights": np.zeros((0,), dtype=dtype), - "faces": np.zeros((0,), dtype=int), - "barycentric": np.zeros((0, 3), dtype=dtype), - } + warned = False - return { - "points": np.concatenate(points_list, axis=0), - "normals": np.concatenate(normals_list, axis=0), - "perps1": np.concatenate(perps1_list, axis=0), - "perps2": np.concatenate(perps2_list, axis=0), - "weights": np.concatenate(weights_list, axis=0), - "faces": np.concatenate(faces_list, axis=0), - "barycentric": np.concatenate(bary_list, axis=0), - } + if np.any(fully_inside): + face_indices = np.flatnonzero(fully_inside) + tri_inside = triangles_arr[face_indices] + normals_inside = normals[face_indices] + perp1_inside = perps1[face_indices] + perp2_inside = perps2[face_indices] + areas_inside = areas[face_indices] + edge_lengths = np.linalg.norm( + np.stack( + ( + edges01[face_indices], + edges02[face_indices], + edges12[face_indices], + ), + axis=1, + ), + axis=2, + ) + subdivisions = self._vectorized_subdivisions(areas_inside, spacing, edge_lengths) + unique_subdiv, inverse = np.unique(subdivisions, return_inverse=True) + + for group_idx, group_subdiv in enumerate(unique_subdiv): + group_faces = np.flatnonzero(inverse == group_idx) + if group_faces.size == 0: + continue + barycentric = self._get_barycentric_samples(int(group_subdiv), dtype) + num_samples = barycentric.shape[0] + tris_group = tri_inside[group_faces] + normals_group = normals_inside[group_faces] + perp1_group = perp1_inside[group_faces] + perp2_group = perp2_inside[group_faces] + areas_group = areas_inside[group_faces] + face_ids_group = face_indices[group_faces] + + sample_points = np.einsum("sb,fbc->fsc", barycentric, tris_group).reshape(-1, 3) + points_list.append(sample_points) + + normals_list.append(np.repeat(normals_group, num_samples, axis=0)) + perps1_list.append(np.repeat(perp1_group, num_samples, axis=0)) + perps2_list.append(np.repeat(perp2_group, num_samples, axis=0)) + + weights = np.repeat((areas_group / num_samples).astype(dtype), num_samples) + weights_list.append(weights) + faces_list.append(np.repeat(face_ids_group, num_samples)) + bary_tile = np.broadcast_to( + barycentric, (group_faces.size, num_samples, 3) + ).reshape(-1, 3) + bary_list.append(bary_tile) + + if np.any(needs_clip): + for face_index in np.flatnonzero(needs_clip): + tri = triangles_arr[face_index] + normal = normals[face_index] + perp1 = perps1[face_index] + perp2 = perps2[face_index] + clipped, was_clipped = self._clip_triangle_to_bounds(tri, sim_min, sim_max, tol) + if was_clipped and not warned: + log.warning( + "Some triangles from the mesh lie outside the simulation bounds - this may lead to inaccurate gradients." + ) + warned = True + if not clipped: + continue + + for tri_clip in clipped: + area_clip, _ = self._triangle_area_and_normal(tri_clip) + if area_clip <= AREA_SIZE_THRESHOLD: + continue + + edge_lengths = ( + np.linalg.norm(tri_clip[1] - tri_clip[0]), + np.linalg.norm(tri_clip[2] - tri_clip[1]), + np.linalg.norm(tri_clip[0] - tri_clip[2]), + ) + subdivisions = self._subdivision_count(area_clip, spacing, edge_lengths) + barycentric_clip = self._get_barycentric_samples(subdivisions, dtype) + num_samples = barycentric_clip.shape[0] + base_weight = area_clip / num_samples + + bary_basis = np.stack( + [ + self._barycentric_coordinates(tri, vertex[None, :], tol)[0] + for vertex in tri_clip + ], + axis=0, + ) + bary_orig = barycentric_clip @ bary_basis + sample_points = bary_orig @ tri + + normal_tile = np.repeat(normal[None, :], num_samples, axis=0) + perp1_tile = np.repeat(perp1[None, :], num_samples, axis=0) + perp2_tile = np.repeat(perp2[None, :], num_samples, axis=0) + weights_tile = np.full(num_samples, base_weight, dtype=dtype) + faces_tile = np.full(num_samples, face_index, dtype=int) + points_list.append(sample_points) + normals_list.append(normal_tile) + perps1_list.append(perp1_tile) + perps2_list.append(perp2_tile) + weights_list.append(weights_tile) + faces_list.append(faces_tile) + bary_list.append(bary_orig) + + return self._finalize_sample_lists( + dtype, + points_list, + normals_list, + perps1_list, + perps2_list, + weights_list, + faces_list, + bary_list, + ) - def _collect_surface_samples_2d( + def _collect_plane_samples( self, - triangle: NDArray, - face_index: int, - normal: np.ndarray, - perp1: np.ndarray, - perp2: np.ndarray, + triangles_arr: np.ndarray, + normals: np.ndarray, + perps1: np.ndarray, + perps2: np.ndarray, + usable_mask: np.ndarray, spacing: float, - collapsed_axis: int, - plane_value: float, sim_min: np.ndarray, sim_max: np.ndarray, valid_axes: np.ndarray, + collapsed_axis: int, + plane_value: float, tol: float, dtype: np.dtype, - ) -> tuple[Optional[dict[str, np.ndarray]], bool]: - """Collect samples when the simulation bounds collapse onto a 2D plane.""" - - segments = self._triangle_plane_segments( - triangle=triangle, axis=collapsed_axis, plane_value=plane_value, tol=tol - ) + ) -> dict[str, np.ndarray]: + """Sample intersection of triangles with a collapsed-axis plane.""" - points: list[np.ndarray] = [] - normals: list[np.ndarray] = [] + points_list: list[np.ndarray] = [] + normals_list: list[np.ndarray] = [] perps1_list: list[np.ndarray] = [] perps2_list: list[np.ndarray] = [] - weights: list[np.ndarray] = [] - faces: list[np.ndarray] = [] - barycentric: list[np.ndarray] = [] - outside_bounds = False - - for start, end in segments: - vec = end - start - length = float(np.linalg.norm(vec)) - if length <= tol: - continue + weights_list: list[np.ndarray] = [] + faces_list: list[np.ndarray] = [] + bary_list: list[np.ndarray] = [] - subdivisions = max(1, int(np.ceil(length / spacing))) - t_vals = (np.arange(subdivisions, dtype=dtype) + 0.5) / subdivisions - sample_points = start[None, :] + t_vals[:, None] * vec[None, :] - bary = self._barycentric_coordinates(triangle, sample_points, tol) - - inside_mask = np.ones(sample_points.shape[0], dtype=bool) - if np.any(valid_axes): - min_bound = (sim_min - tol)[valid_axes] - max_bound = (sim_max + tol)[valid_axes] - coords = sample_points[:, valid_axes] - inside_mask = np.all(coords >= min_bound, axis=1) & np.all( - coords <= max_bound, axis=1 - ) + warned = False + face_indices = np.flatnonzero(usable_mask) + for face_index in face_indices: + tri = triangles_arr[face_index] + normal = normals[face_index] + perp1 = perps1[face_index] + perp2 = perps2[face_index] + + segments = self._triangle_plane_segments( + triangle=tri, axis=collapsed_axis, plane_value=plane_value, tol=tol + ) + for start, end in segments: + vec = end - start + length = float(np.linalg.norm(vec)) + if length <= tol: + continue + + subdivisions = max(1, int(np.ceil(length / spacing))) + t_vals = (np.arange(subdivisions, dtype=dtype) + 0.5) / subdivisions + sample_points = start[None, :] + t_vals[:, None] * vec[None, :] + barycentric = self._barycentric_coordinates(tri, sample_points, tol) + + inside_mask = np.ones(sample_points.shape[0], dtype=bool) + if np.any(valid_axes): + min_bound = (sim_min - tol)[valid_axes] + max_bound = (sim_max + tol)[valid_axes] + coords = sample_points[:, valid_axes] + inside_mask = np.all(coords >= min_bound, axis=1) & np.all( + coords <= max_bound, axis=1 + ) + + if not np.all(inside_mask) and not warned: + log.warning( + "Some triangles from the mesh lie outside the simulation bounds - this may lead to inaccurate gradients." + ) + warned = True + + if not np.any(inside_mask): + continue + + sample_points = sample_points[inside_mask] + bary_inside = barycentric[inside_mask] + n_inside = sample_points.shape[0] + + normal_tile = np.repeat(normal[None, :], n_inside, axis=0) + perp1_tile = np.repeat(perp1[None, :], n_inside, axis=0) + perp2_tile = np.repeat(perp2[None, :], n_inside, axis=0) + weights_tile = np.full(n_inside, length / subdivisions, dtype=dtype) + faces_tile = np.full(n_inside, face_index, dtype=int) + + points_list.append(sample_points) + normals_list.append(normal_tile) + perps1_list.append(perp1_tile) + perps2_list.append(perp2_tile) + weights_list.append(weights_tile) + faces_list.append(faces_tile) + bary_list.append(bary_inside) + + return self._finalize_sample_lists( + dtype, + points_list, + normals_list, + perps1_list, + perps2_list, + weights_list, + faces_list, + bary_list, + ) - outside_bounds = outside_bounds or (not np.all(inside_mask)) - if not np.any(inside_mask): - continue + @staticmethod + def _vectorized_subdivisions( + areas: np.ndarray, spacing: float, edge_lengths: np.ndarray + ) -> np.ndarray: + """Compute subdivision counts for many triangles at once.""" - sample_points = sample_points[inside_mask] - bary_inside = bary[inside_mask] - n_inside = sample_points.shape[0] - - normal_tile = np.repeat(normal[None, :], n_inside, axis=0) - perp1_tile = np.repeat(perp1[None, :], n_inside, axis=0) - perp2_tile = np.repeat(perp2[None, :], n_inside, axis=0) - weights_tile = np.full(n_inside, length / subdivisions, dtype=dtype) - faces_tile = np.full(n_inside, face_index, dtype=int) - - points.append(sample_points) - normals.append(normal_tile) - perps1_list.append(perp1_tile) - perps2_list.append(perp2_tile) - weights.append(weights_tile) - faces.append(faces_tile) - barycentric.append(bary_inside) - - if not points: - return None, outside_bounds - - samples = { - "points": np.concatenate(points, axis=0), - "normals": np.concatenate(normals, axis=0), - "perps1": np.concatenate(perps1_list, axis=0), - "perps2": np.concatenate(perps2_list, axis=0), - "weights": np.concatenate(weights, axis=0), - "faces": np.concatenate(faces, axis=0), - "barycentric": np.concatenate(barycentric, axis=0), + spacing = max(float(spacing), np.finfo(float).eps) + target = np.sqrt(np.maximum(areas, 0.0)) + area_based = np.ceil(np.sqrt(2.0) * target / spacing) + + max_edge = np.max(edge_lengths, axis=1) + edge_based = np.ceil(max_edge / spacing) + + subdivisions = np.maximum(area_based, edge_based) + subdivisions = np.maximum(subdivisions, 1.0) + return subdivisions.astype(int) + + @staticmethod + def _empty_sample_result(dtype: np.dtype) -> dict[str, np.ndarray]: + """Return the default empty sampling dictionary.""" + + zeros_vec = np.zeros((0, 3), dtype=dtype) + zeros_scalar = np.zeros((0,), dtype=dtype) + zeros_faces = np.zeros((0,), dtype=int) + return { + "points": zeros_vec, + "normals": zeros_vec.copy(), + "perps1": zeros_vec.copy(), + "perps2": zeros_vec.copy(), + "weights": zeros_scalar, + "faces": zeros_faces, + "barycentric": zeros_vec.copy(), } - return samples, outside_bounds - def _collect_surface_samples_3d( - self, - triangle: NDArray, - face_index: int, - normal: np.ndarray, - perp1: np.ndarray, - perp2: np.ndarray, - area: float, - spacing: float, - sim_min: np.ndarray, - sim_max: np.ndarray, - valid_axes: np.ndarray, - tol: float, + @staticmethod + def _finalize_sample_lists( dtype: np.dtype, - ) -> tuple[Optional[dict[str, np.ndarray]], bool]: - """Collect samples when the simulation bounds represent a full 3D region.""" + points_list: list[np.ndarray], + normals_list: list[np.ndarray], + perps1_list: list[np.ndarray], + perps2_list: list[np.ndarray], + weights_list: list[np.ndarray], + faces_list: list[np.ndarray], + bary_list: list[np.ndarray], + ) -> dict[str, np.ndarray]: + """Concatenate accumulated sampling data or return an empty structure.""" - edge_lengths = ( - np.linalg.norm(triangle[1] - triangle[0]), - np.linalg.norm(triangle[2] - triangle[1]), - np.linalg.norm(triangle[0] - triangle[2]), - ) - subdivisions = self._subdivision_count(area, spacing, edge_lengths) - barycentric = self._get_barycentric_samples(subdivisions, dtype) - num_samples = barycentric.shape[0] - base_weight = area / num_samples - - sample_points = barycentric @ triangle - - inside_mask = np.all( - sample_points[:, valid_axes] >= (sim_min - tol)[valid_axes], axis=1 - ) & np.all(sample_points[:, valid_axes] <= (sim_max + tol)[valid_axes], axis=1) - outside_bounds = not np.all(inside_mask) - if not np.any(inside_mask): - return None, outside_bounds - - sample_points = sample_points[inside_mask] - bary_inside = barycentric[inside_mask] - n_samples_inside = sample_points.shape[0] - - normal_tile = np.repeat(normal[None, :], n_samples_inside, axis=0) - perp1_tile = np.repeat(perp1[None, :], n_samples_inside, axis=0) - perp2_tile = np.repeat(perp2[None, :], n_samples_inside, axis=0) - weights_tile = np.full(n_samples_inside, base_weight, dtype=dtype) - faces_tile = np.full(n_samples_inside, face_index, dtype=int) - - samples = { - "points": sample_points, - "normals": normal_tile, - "perps1": perp1_tile, - "perps2": perp2_tile, - "weights": weights_tile, - "faces": faces_tile, - "barycentric": bary_inside, + if not points_list: + return TriangleMesh._empty_sample_result(dtype) + + return { + "points": np.concatenate(points_list, axis=0), + "normals": np.concatenate(normals_list, axis=0), + "perps1": np.concatenate(perps1_list, axis=0), + "perps2": np.concatenate(perps2_list, axis=0), + "weights": np.concatenate(weights_list, axis=0), + "faces": np.concatenate(faces_list, axis=0), + "barycentric": np.concatenate(bary_list, axis=0), } - return samples, outside_bounds @staticmethod def _triangle_area_and_normal(triangle: NDArray) -> tuple[float, np.ndarray]: @@ -1149,6 +1391,88 @@ def add_point(pt: np.ndarray) -> None: return [] + @staticmethod + def _clip_polygon_with_plane( + polygon: list[np.ndarray], axis: int, bound: float, keep_below: bool, tol: float + ) -> list[np.ndarray]: + """Clip a polygon with an axis-aligned plane.""" + + if not polygon: + return [] + + result: list[np.ndarray] = [] + prev = polygon[-1] + prev_val = prev[axis] + prev_inside = (prev_val <= bound + tol) if keep_below else (prev_val >= bound - tol) + + for current in polygon: + curr_val = current[axis] + curr_inside = (curr_val <= bound + tol) if keep_below else (curr_val >= bound - tol) + + if curr_inside: + if not prev_inside: + result.append( + TriangleMesh._segment_plane_intersection(prev, current, axis, bound, tol) + ) + result.append(current) + elif prev_inside: + result.append( + TriangleMesh._segment_plane_intersection(prev, current, axis, bound, tol) + ) + + prev = current + prev_inside = curr_inside + + return result + + @staticmethod + def _segment_plane_intersection( + p0: np.ndarray, p1: np.ndarray, axis: int, bound: float, tol: float + ) -> np.ndarray: + """Return intersection point between segment (p0,p1) and axis-aligned plane.""" + + v0 = float(p0[axis]) - bound + v1 = float(p1[axis]) - bound + denom = v1 - v0 + if abs(denom) <= tol: + return p0.copy() + t = -v0 / denom + t = float(np.clip(t, 0.0, 1.0)) + return p0 + t * (p1 - p0) + + @classmethod + def _clip_triangle_to_bounds( + cls, triangle: NDArray, sim_min: NDArray, sim_max: NDArray, tol: float + ) -> tuple[list[NDArray], bool]: + """Clip triangle against axis-aligned bounds, return list of sub-triangles and flag.""" + + vertices = np.asarray(triangle) + inside = np.all(vertices >= (sim_min - tol), axis=1) & np.all( + vertices <= (sim_max + tol), axis=1 + ) + if np.all(inside): + return [triangle], False + + polygon = [triangle[0].copy(), triangle[1].copy(), triangle[2].copy()] + clipped_flag = True + for axis in range(3): + polygon = cls._clip_polygon_with_plane(polygon, axis, sim_min[axis], False, tol) + if not polygon: + return [], True + polygon = cls._clip_polygon_with_plane(polygon, axis, sim_max[axis], True, tol) + if not polygon: + return [], True + + if len(polygon) < 3: + return [], True + + triangles: list[NDArray] = [] + anchor = polygon[0] + for idx in range(1, len(polygon) - 1): + tri_clip = np.array([anchor, polygon[idx], polygon[idx + 1]], dtype=triangle.dtype) + triangles.append(tri_clip) + return triangles, clipped_flag + @staticmethod def _barycentric_coordinates(triangle: NDArray, points: np.ndarray, tol: float) -> np.ndarray: """Compute barycentric coordinates of ``points`` with respect to ``triangle``.""" diff --git a/tidy3d/components/geometry/polyslab.py b/tidy3d/components/geometry/polyslab.py index a8e633663d..90c8a67b08 100644 --- a/tidy3d/components/geometry/polyslab.py +++ b/tidy3d/components/geometry/polyslab.py @@ -10,6 +10,7 @@ import autograd.numpy as np import pydantic.v1 as pydantic import shapely +from autograd.core import make_vjp from autograd.tracer import getval, isbox from numpy._typing import NDArray from numpy.polynomial.legendre import leggauss as _leggauss @@ -99,6 +100,10 @@ class PolySlab(base.Planar): units=MICROMETER, ) + _mesh_faces: Optional[tuple[NDArray[np.int_], dict[str, slice]]] = pydantic.PrivateAttr( + default=None + ) + @staticmethod def make_shapely_polygon(vertices: ArrayLike) -> shapely.Polygon: """Make a shapely polygon from some vertices, first ensures they are untraced.""" @@ -1449,6 +1454,74 @@ def _surface_area(self, bounds: Bound) -> float: """ Autograd code """ + def _compute_derivatives_via_mesh( + self, + derivative_info: DerivativeInfo, + clip_operation: Optional[base.ClipOperationContext] = None, + ) -> AutogradFieldMap: + """Compute adjoint derivatives via mesh-based sampling.""" + + if not self._mesh_derivatives_supported(): + return self._zero_derivative_map(derivative_info) + + dtype = config.adjoint.gradient_dtype_float + vertices_arr = np.asarray(self.vertices, dtype=dtype) + slab_bounds_arr = np.asarray(self.slab_bounds, dtype=dtype) + sidewall_angle_val = np.array(self.sidewall_angle, dtype=dtype) + + if vertices_arr.shape[0] < 3: + return self._zero_derivative_map(derivative_info) + + mesh_vertices, base_polygon, top_polygon = self._mesh_vertex_positions( + vertices=vertices_arr, + slab_bounds=slab_bounds_arr, + sidewall_angle=sidewall_angle_val, + ) + + faces, partitions = self._ensure_mesh_faces(base_polygon, top_polygon) + + if mesh_vertices.size == 0 or faces.size == 0: + return self._zero_derivative_map(derivative_info) + + from .mesh import TriangleMesh + + mesh = TriangleMesh.from_vertices_faces(mesh_vertices, faces) + + original_paths = derivative_info.paths + derivative_info.paths = [("mesh_dataset", "surface_mesh")] + try: + mesh_vjps = mesh._compute_derivatives(derivative_info, clip_operation=clip_operation) + finally: + derivative_info.paths = original_paths + gradient_key = ("mesh_dataset", "surface_mesh") + if gradient_key not in mesh_vjps: + return self._zero_derivative_map(derivative_info) + + triangle_grads = mesh_vjps[gradient_key] + num_vertices = mesh_vertices.shape[0] + base_slice = partitions["base"] + top_slice = partitions["top"] + side_slice = partitions["side"] + + vertex_grads_side = self._accumulate_vertex_gradients( + triangle_grads[side_slice], faces[side_slice], num_vertices=num_vertices + ) + vertex_grads_base = self._accumulate_vertex_gradients( + triangle_grads[base_slice], faces[base_slice], num_vertices=num_vertices + ) + vertex_grads_top = self._accumulate_vertex_gradients( + triangle_grads[top_slice], faces[top_slice], num_vertices=num_vertices + ) + vertex_grads_caps = vertex_grads_base + vertex_grads_top + return self._map_mesh_vjps_to_fields( + vertex_grads_side, + vertex_grads_caps, + derivative_info, + vertices_arr, + slab_bounds_arr, + sidewall_angle_val, + ) + def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: """ Return VJPs while handling several edge-cases: @@ -1518,6 +1591,269 @@ def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradField return vjps + def _mesh_derivatives_supported(self) -> bool: + """Return ``True`` if we can evaluate mesh-based derivatives.""" + + return True + + def _mesh_vertex_positions( + self, + vertices: NDArray, + slab_bounds: NDArray, + sidewall_angle: NDArray, + *, + return_numpy: bool = True, + ) -> tuple[NDArray, NDArray, NDArray]: + """Return stacked vertex coordinates for the PolySlab mesh.""" + + dtype = config.adjoint.gradient_dtype_float + + def empty_result() -> tuple[NDArray, NDArray, NDArray]: + verts3d = np.zeros((0, 3), dtype=dtype) + polys = np.zeros((0, 2), dtype=dtype) + return verts3d, polys, polys + + reference_polygon = PolySlab._proper_vertices(vertices) + if reference_polygon.shape[0] < 3: + return empty_result() + + bounds_vals = np.array([getval(slab_bounds[0]), getval(slab_bounds[1])], dtype=float) + length_val = bounds_vals[1] - bounds_vals[0] + if length_val <= fp_eps: + return empty_result() + + zmin = np.maximum(slab_bounds[0], -LARGE_NUMBER) + zmax = np.minimum(slab_bounds[1], LARGE_NUMBER) + finite_length = zmax - zmin + half_length = finite_length / 2.0 + + tan_val = np.tan(sidewall_angle) + offset = np.where(np.isclose(tan_val, 0.0), 0.0, -half_length * tan_val) + + if self.reference_plane == "bottom": + middle_polygon = PolySlab._shift_vertices(reference_polygon, offset)[0] + elif self.reference_plane == "top": + middle_polygon = PolySlab._shift_vertices(reference_polygon, -offset)[0] + else: + middle_polygon = reference_polygon + + if self.reference_plane == "bottom": + base_polygon = reference_polygon + else: + base_polygon = PolySlab._shift_vertices(middle_polygon, -offset)[0] + + if self.reference_plane == "top": + top_polygon = reference_polygon + else: + top_polygon = PolySlab._shift_vertices(middle_polygon, offset)[0] + + planar = np.vstack((base_polygon, top_polygon)) + axis_vals = np.concatenate( + ( + np.full(base_polygon.shape[0], zmin), + np.full(top_polygon.shape[0], zmax), + ) + ) + coords = np.vstack(self.unpop_axis(axis_vals, (planar[:, 0], planar[:, 1]), self.axis)) + vertices3d = coords.T + if return_numpy: + return ( + np.asarray(vertices3d, dtype=dtype), + np.asarray(base_polygon, dtype=dtype), + np.asarray(top_polygon, dtype=dtype), + ) + return vertices3d, base_polygon, top_polygon + + def _ensure_mesh_faces( + self, base_polygon: NDArray, top_polygon: NDArray + ) -> tuple[NDArray[np.int_], dict[str, slice]]: + """Construct (and cache) the triangle indices for the PolySlab mesh.""" + + if self._mesh_faces is not None: + return self._mesh_faces + + def empty_faces() -> tuple[NDArray[np.int_], dict[str, slice]]: + faces = np.zeros((0, 3), dtype=int) + empty = slice(0, 0) + partitions = {"base": empty, "top": empty, "side": empty} + self._mesh_faces = (faces, partitions) + return self._mesh_faces + + n_base = int(base_polygon.shape[0]) + n_top = int(top_polygon.shape[0]) + if n_base < 3 or n_top < 3 or n_base != n_top: + return empty_faces() + + try: + base_triangles = triangulation.triangulate(base_polygon) + if math.isclose(self.sidewall_angle, 0): + top_triangles = base_triangles + else: + top_triangles = triangulation.triangulate(top_polygon) + except Exception as exc: + log.debug("Failed to triangulate 'PolySlab' mesh faces: %s", exc) + return empty_faces() + + base_faces = [[a, b, c] for c, b, a in base_triangles] + top_shift = n_base + top_faces = [[top_shift + a, top_shift + b, top_shift + c] for a, b, c in top_triangles] + side_faces = [(i, (i + 1) % n_base, n_base + i) for i in range(n_base)] + [ + ((i + 1) % n_base, n_base + ((i + 1) % n_base), n_base + i) for i in range(n_base) + ] + + faces = np.asarray(base_faces + top_faces + side_faces, dtype=int) + partitions = { + "base": slice(0, len(base_faces)), + "top": slice(len(base_faces), len(base_faces) + len(top_faces)), + "side": slice(len(base_faces) + len(top_faces), faces.shape[0]), + } + + self._mesh_faces = (faces, partitions) + return self._mesh_faces + + @staticmethod + def _accumulate_vertex_gradients( + triangle_grads: NDArray, faces: NDArray, *, num_vertices: Optional[int] = None + ) -> NDArray: + """Aggregate per-triangle gradients into per-vertex values.""" + + if triangle_grads.size == 0 or faces.size == 0: + length = int(num_vertices or 0) + return np.zeros((length, 3), dtype=triangle_grads.dtype) + + if num_vertices is None: + num_vertices = int(faces.max() + 1) + + vertex_grads = np.zeros((num_vertices, 3), dtype=triangle_grads.dtype) + for face_index, face in enumerate(faces): + for local_idx, vertex_idx in enumerate(face): + vertex_grads[vertex_idx] += triangle_grads[face_index, local_idx] + return vertex_grads + + def _map_mesh_vjps_to_fields( + self, + vertex_grads_side: NDArray, + vertex_grads_caps: NDArray, + derivative_info: DerivativeInfo, + vertices: NDArray, + slab_bounds: NDArray, + sidewall_angle: NDArray, + ) -> AutogradFieldMap: + """Convert mesh vertex gradients into PolySlab field derivatives.""" + + grad_vertices, _, grad_angle_side = self._mesh_parameter_gradients( + vertex_grads_side, vertices, slab_bounds, sidewall_angle + ) + _, grad_bounds, grad_angle_caps = self._mesh_parameter_gradients( + vertex_grads_caps, vertices, slab_bounds, sidewall_angle + ) + # grad_angle = grad_angle_side + grad_angle_caps + grad_vertices *= self._planar_orientation_sign() + grad_bounds *= self._planar_orientation_sign() + if self._is_2d_slice(derivative_info): + slab_thickness = float(getval(slab_bounds[1]) - getval(slab_bounds[0])) + if not np.isfinite(slab_thickness) or slab_thickness <= fp_eps: + thickness = 1.0 + else: + thickness = slab_thickness + grad_vertices /= thickness + + sim_min, sim_max = map(np.asarray, derivative_info.simulation_bounds) + intersect_min, intersect_max = map(np.asarray, derivative_info.bounds_intersect) + is_2d = np.isclose(intersect_max[self.axis] - intersect_min[self.axis], 0.0) + if is_2d: + grad_bounds = np.zeros_like(grad_bounds) + interpolators = derivative_info.interpolators or derivative_info.create_interpolators( + dtype=config.adjoint.gradient_dtype_float + ) + grad_angle_exact = self._compute_derivative_sidewall_angle( + derivative_info, + sim_min, + sim_max, + is_2d=is_2d, + interpolators=interpolators, + ) + + results: AutogradFieldMap = {} + for path in derivative_info.paths: + if path == ("vertices",): + results[path] = grad_vertices + elif path == ("sidewall_angle",): + results[path] = float(grad_angle_exact) + elif path[0] == "slab_bounds": + idx = int(path[1]) + results[path] = float(grad_bounds[idx]) + else: + raise ValueError(f"No derivative defined w.r.t. 'PolySlab' field '{path}'.") + + return results + + def _planar_orientation_sign(self) -> float: + """Return +1 or -1 based on (plane_axes, axis) permutation parity.""" + + plane_axes = [idx for idx in range(3) if idx != self.axis] + perm = (*plane_axes, self.axis) + even_perms = {(0, 1, 2), (1, 2, 0), (2, 0, 1)} + return 1.0 if perm in even_perms else -1.0 + + def _is_2d_slice(self, derivative_info: DerivativeInfo) -> bool: + """Return True if the intersection bounds collapse along the extrusion axis.""" + + intersect_min, intersect_max = derivative_info.bounds_intersect + axis_extent = intersect_max[self.axis] - intersect_min[self.axis] + return np.isclose(axis_extent, 0.0) + + def _zero_derivative_map(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + """Return a zero-valued derivative map for requested fields.""" + + result: AutogradFieldMap = {} + for path in derivative_info.paths: + if path == ("vertices",): + result[path] = np.zeros_like(self.vertices) + elif path == ("sidewall_angle",): + result[path] = 0.0 + elif path[0] == "slab_bounds": + result[path] = 0.0 + else: + raise ValueError(f"No derivative defined w.r.t. 'PolySlab' field '{path}'.") + return result + + def _mesh_parameter_gradients( + self, + vertex_grads: NDArray, + vertices: NDArray, + slab_bounds: NDArray, + sidewall_angle: NDArray, + ) -> tuple[NDArray, NDArray, float]: + """Return gradients w.r.t. (vertices, slab_bounds, sidewall_angle) parameters.""" + + dtype = config.adjoint.gradient_dtype_float + flattened = np.asarray(vertex_grads, dtype=dtype).reshape(-1) + + vertex_flat = np.asarray(vertices, dtype=dtype).reshape(-1) + bounds_flat = np.asarray(slab_bounds, dtype=dtype) + angle_flat = np.array([sidewall_angle], dtype=dtype) + param = np.concatenate((vertex_flat, bounds_flat, angle_flat)) + + num_vertex_params = vertex_flat.size + bounds_offset = num_vertex_params + angle_offset = num_vertex_params + 2 + + def param_builder(packed: Any) -> Any: + verts = packed[:num_vertex_params].reshape(vertices.shape) + bounds = packed[bounds_offset : bounds_offset + 2] + angle = packed[angle_offset] + coords, _, _ = self._mesh_vertex_positions(verts, bounds, angle, return_numpy=False) + return coords.reshape(-1) + + vjp_fn, _ = make_vjp(param_builder, param) + grad_param = np.asarray(vjp_fn(flattened)) + + grad_vertices = grad_param[:num_vertex_params].reshape(vertices.shape) + grad_bounds = grad_param[bounds_offset : bounds_offset + 2] + grad_angle = grad_param[angle_offset] + return grad_vertices, grad_bounds, float(grad_angle) + # ---- Shared helpers for VJP surface integrations ---- def _z_slices( self, sim_min: NDArray, sim_max: NDArray, is_2d: bool, dx: float diff --git a/tidy3d/components/geometry/primitives.py b/tidy3d/components/geometry/primitives.py index f6916b535d..681847f511 100644 --- a/tidy3d/components/geometry/primitives.py +++ b/tidy3d/components/geometry/primitives.py @@ -16,6 +16,7 @@ from tidy3d.components.autograd.derivative_utils import DerivativeInfo from tidy3d.components.base import cached_property, skip_if_fields_missing from tidy3d.components.geometry import base +from tidy3d.components.geometry.base import ClipOperationContext from tidy3d.components.geometry.mesh import TriangleMesh from tidy3d.components.geometry.polyslab import PolySlab from tidy3d.components.types import Axis, Bound, Coordinate, MatrixReal4x4, Shapely @@ -474,7 +475,20 @@ def _discretization_wavelength(self, derivative_info: DerivativeInfo) -> float: return wvl_mat - def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradFieldMap: + def _compute_derivatives_via_mesh( + self, + derivative_info: DerivativeInfo, + clip_operation: base.Optional[ClipOperationContext] = None, + ) -> AutogradFieldMap: + return self._compute_derivatives( + derivative_info=derivative_info, clip_operation=clip_operation + ) + + def _compute_derivatives( + self, + derivative_info: DerivativeInfo, + clip_operation: Optional[base.ClipOperationContext] = None, + ) -> AutogradFieldMap: """Compute the adjoint derivatives for this object.""" # compute circumference discretization @@ -516,7 +530,12 @@ def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradField update_kwargs["interpolators"] = derivative_info.interpolators derivative_info_polyslab = derivative_info.updated_copy(**update_kwargs) - vjps_polyslab = polyslab._compute_derivatives(derivative_info_polyslab) + if clip_operation is not None: + vjps_polyslab = polyslab._compute_derivatives_via_mesh( + derivative_info_polyslab, clip_operation=clip_operation + ) + else: + vjps_polyslab = polyslab._compute_derivatives(derivative_info_polyslab) vjps = {} for path in derivative_info.paths: