diff --git a/CHANGELOG.md b/CHANGELOG.md index 9f84be6c38..7195fa3c77 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed ### Fixed +- Fixed interpolation handling for permittivity and conductivity gradients in CustomMedium. ## [2.10.0] - 2025-12-18 diff --git a/tests/test_components/autograd/numerical/test_autograd_medium_numerical.py b/tests/test_components/autograd/numerical/test_autograd_medium_numerical.py new file mode 100644 index 0000000000..df866bcf77 --- /dev/null +++ b/tests/test_components/autograd/numerical/test_autograd_medium_numerical.py @@ -0,0 +1,354 @@ +from __future__ import annotations + +import sys + +import autograd.numpy as anp +import matplotlib.pyplot as plt +import numpy as np +import pytest +from autograd import value_and_grad + +import tidy3d as td +import tidy3d.web as web +from tidy3d.components.autograd import get_static + +td.config.local_cache.enabled = True + +SIM_SIZE_SCALE = (4, 3, 4) +BOX_SIZE_SCALE = (1, 1, 1) +GRID_STEPS_PER_WVL = 30 +RUN_TIME = 2e-12 +ANGLE_TOL = 10.0 +FD_STEP = 5e-2 + +TEST_CASES = [ + { + "name": "opt_flux_iso", + "wavelength": 1.0, + "permittivities": (2.2, 2.2, 2.2), + "objective_kind": "flux", + "monitor_size": (np.inf, np.inf, 0.0), + "polarization": 0.0, + "medium_type": "isotropic", + }, + { + "name": "mw_intensity_iso", + "wavelength": 1.6, + "permittivities": (1.8, 1.8, 1.8), + "objective_kind": "intensity", + "monitor_size": (0.4, 0.4, 0.0), + "polarization": np.pi / 5, + "medium_type": "isotropic", + }, + { + "name": "opt_flux_custom_iso", + "wavelength": 1.3, + "permittivities": (2.0, 2.0, 2.0), + "objective_kind": "flux", + "monitor_size": (np.inf, np.inf, 0.0), + "polarization": 0.0, + "medium_type": "custom", + }, + { + "name": "mw_int_custom_iso", + "wavelength": 1.1, + "permittivities": (1.6, 1.6, 1.6), + "objective_kind": "intensity", + "monitor_size": (0.3, 0.3, 0.0), + "polarization": np.pi / 3, + "medium_type": "custom", + }, +] + + +def _scale_monitor_dim(dim: float, wavelength: float) -> float: + if np.isinf(dim): + return np.inf + return dim * wavelength + + +def _box_geometry(case) -> td.Box: + size = tuple(scale * case["wavelength"] for scale in BOX_SIZE_SCALE) + return td.Box(size=size, center=(0.0, 0.0, 0.0)) + + +def _build_base_sim(case): + wavelength = case["wavelength"] + freq0 = td.C_0 / wavelength + sim_size = tuple(scale * wavelength for scale in SIM_SIZE_SCALE) + + plane_wave = td.PlaneWave( + center=(0.0, 0.0, -0.75 * sim_size[2] / 2), + size=(sim_size[0], sim_size[1], 0.0), + source_time=td.GaussianPulse(freq0=freq0, fwidth=freq0 / 10.0), + direction="+", + pol_angle=case.get("polarization", 0.0), + ) + + monitor_center = (0.0, 0.0, sim_size[2] / 2 * 0.75) + monitor_size = tuple(_scale_monitor_dim(dim, wavelength) for dim in case["monitor_size"]) + monitor_name = f"{case['name']}_monitor" + monitor = td.FieldMonitor( + center=monitor_center, + size=monitor_size, + freqs=[freq0], + name=monitor_name, + colocate=False, + ) + + sim = td.Simulation( + size=sim_size, + center=(0.0, 0.0, 0.0), + grid_spec=td.GridSpec.auto(min_steps_per_wvl=GRID_STEPS_PER_WVL, wavelength=wavelength), + boundary_spec=td.BoundarySpec.pml(x=True, y=True, z=True), + sources=[plane_wave], + monitors=[monitor], + structures=[], + run_time=RUN_TIME, + ) + return sim, monitor_name, freq0 + + +def _add_medium(case, base_sim: td.Simulation, box_geom: td.Box, eps_vals) -> td.Simulation: + medium_type = case["medium_type"] + + coords = None + factor = None + if medium_type in ("custom_anisotropic", "custom"): + coords = { + "x": np.linspace(-box_geom.size[0] / 2, box_geom.size[0] / 2, 4), + "y": np.linspace(-box_geom.size[1] / 2, box_geom.size[1] / 2, 5), + "z": np.linspace(-box_geom.size[2] / 2, box_geom.size[2] / 2, 3), + } + _cx, _cy, _cz = np.meshgrid(coords["x"], coords["y"], coords["z"], indexing="ij") + factor = 1 + 0.2 * (_cx + _cy + _cz) / 3.0 + + if medium_type == "custom_anisotropic": + + def _custom_medium(val): + values = factor * val + data = td.SpatialDataArray(values, coords=coords) + return td.CustomMedium(permittivity=data) + + medium = td.CustomAnisotropicMedium( + xx=_custom_medium(eps_vals[0]), + yy=_custom_medium(eps_vals[1]), + zz=_custom_medium(eps_vals[2]), + ) + elif medium_type == "custom": + + def _custom_isotropic(val): + values = factor * val + data = td.SpatialDataArray(values, coords=coords) + return td.CustomMedium(permittivity=data) + + medium = _custom_isotropic(eps_vals[0]) + elif medium_type == "isotropic": + # use first entry; others are identical by construction + medium = td.Medium(permittivity=eps_vals[0]) + elif medium_type == "anisotropic": + medium = td.AnisotropicMedium( + xx=td.Medium(permittivity=eps_vals[0]), + yy=td.Medium(permittivity=eps_vals[1]), + zz=td.Medium(permittivity=eps_vals[2]), + ) + else: + raise ValueError( + "Medium type has to be one of 'custom', 'isotropic', 'anisotropic' or 'custom_anisotropic'" + ) + + structure = td.Structure(geometry=box_geom, medium=medium) + return base_sim.updated_copy(structures=[structure]) + + +def _metric_value(case, dataset, freq0): + if case["objective_kind"] == "flux": + return dataset.flux.values + ex_vals = dataset.Ex.values + ey_vals = dataset.Ey.values + ez_vals = dataset.Ez.values + intensity = np.abs(ex_vals) ** 2 + np.abs(ey_vals) ** 2 + np.abs(ez_vals) ** 2 + return anp.real(anp.mean(intensity)) + + +def _angle_deg(vec_a: np.ndarray, vec_b: np.ndarray) -> float: + norm_a = np.linalg.norm(vec_a) + norm_b = np.linalg.norm(vec_b) + if norm_a == 0 or norm_b == 0: + return np.nan + cos_theta = np.clip(np.dot(vec_a, vec_b) / (norm_a * norm_b), -1.0, 1.0) + return float(np.degrees(np.arccos(cos_theta))) + + +def _run_simulation( + case, base_sim, box_geom, eps_vals, label, tmp_path, monitor_name, freq0, gradient +): + sim = _add_medium(case, base_sim, box_geom, eps_vals) + sim_data = web.run( + sim, + task_name=f"medium_grad_{case['name']}_{label}", + local_gradient=gradient, + verbose=False, + path=str(tmp_path / f"{case['name']}_{label}.hdf5"), + ) + return _metric_value(case, sim_data[monitor_name], freq0) + + +@pytest.mark.numerical +@pytest.mark.parametrize("case", TEST_CASES, ids=lambda c: c["name"]) +def test_medium_grads_match_fd(case, numerical_case_dir, tmp_path): + base_sim, monitor_name, freq0 = _build_base_sim(case) + box_geom = _box_geometry(case) + params0 = anp.array(case["permittivities"]) + + def objective(eps_vals): + return _run_simulation( + case, + base_sim, + box_geom, + eps_vals, + label="adjoint", + tmp_path=tmp_path, + monitor_name=monitor_name, + freq0=freq0, + gradient=True, + ) + + _, grad_adj = value_and_grad(objective)(params0) + grad_adj = get_static(grad_adj) + + fd_sims = {} + base_params = get_static(params0) + for axis in range(3): + delta = np.zeros_like(base_params) + delta[axis] = FD_STEP + fd_sims[f"fd_plus_{axis}"] = _add_medium(case, base_sim, box_geom, base_params + delta) + fd_sims[f"fd_minus_{axis}"] = _add_medium(case, base_sim, box_geom, base_params - delta) + + fd_results = web.run_async( + fd_sims, + path_dir=str(numerical_case_dir / f"fd_batch_{case['name']}"), + local_gradient=False, + verbose=False, + ) + + grad_fd = np.zeros_like(grad_adj) + for axis in range(3): + plus = _metric_value(case, fd_results[f"fd_plus_{axis}"][monitor_name], freq0) + minus = _metric_value(case, fd_results[f"fd_minus_{axis}"][monitor_name], freq0) + grad_fd[axis] = (plus - minus) / (2.0 * FD_STEP) + + angle_deg = _angle_deg(grad_adj, grad_fd) + + print( + f"[medium-grad-test:{case['name']}] adjoint={grad_adj}, " + f"finite-difference={grad_fd}, angle_deg={angle_deg:.3f}", + file=sys.stderr, + ) + + angle_tol = case.get("angle_tol_deg", ANGLE_TOL) + assert angle_deg <= angle_tol or np.isnan(angle_deg), ( + f"Gradient angle deviation {angle_deg:.3f} deg exceeds tolerance ({angle_tol}). " + f"adj={grad_adj}, fd={grad_fd}" + ) + + +@pytest.mark.skip +@pytest.mark.parametrize("case", TEST_CASES, ids=lambda c: c["name"]) +def test_medium_fd_step_sweep(case, numerical_case_dir, tmp_path): + base_sim, monitor_name, freq0 = _build_base_sim(case) + box_geom = _box_geometry(case) + params0 = anp.array(case["permittivities"]) + + def objective(eps_vals): + return _run_simulation( + case, + base_sim, + box_geom, + eps_vals, + label="adjoint_sweep", + tmp_path=tmp_path, + monitor_name=monitor_name, + freq0=freq0, + gradient=True, + ) + + _, grad_adj = value_and_grad(objective)(params0) + grad_adj = get_static(grad_adj) + base_params = get_static(params0) + + sweep_steps = np.logspace(-4, -1, num=9) + step_labels = [f"{step:.3e}" for step in sweep_steps] + + sweep_runs: dict[str, td.Simulation] = {} + for step_label, step in zip(step_labels, sweep_steps): + for axis in range(base_params.size): + delta = np.zeros_like(base_params) + delta[axis] = step + key_base = f"{case['name']}_axis{axis}_{step_label}" + sweep_runs[f"{key_base}_plus"] = _add_medium( + case, + base_sim, + box_geom, + base_params + delta, + ) + sweep_runs[f"{key_base}_minus"] = _add_medium( + case, + base_sim, + box_geom, + base_params - delta, + ) + + sweep_results = web.run_async( + sweep_runs, + path_dir=str(numerical_case_dir / f"fd_sweep_{case['name']}"), + local_gradient=False, + verbose=False, + ) + + fd_sweep_matrix = np.zeros((len(sweep_steps), base_params.size), dtype=float) + for step_idx, (step_label, step) in enumerate(zip(step_labels, sweep_steps)): + for axis in range(base_params.size): + plus_key = f"{case['name']}_axis{axis}_{step_label}_plus" + minus_key = f"{case['name']}_axis{axis}_{step_label}_minus" + plus_val = _metric_value(case, sweep_results[plus_key][monitor_name], freq0) + minus_val = _metric_value(case, sweep_results[minus_key][monitor_name], freq0) + fd_sweep_matrix[step_idx, axis] = (plus_val - minus_val) / (2.0 * step) + + labels = ["xx", "yy", "zz"] + fig, ax = plt.subplots(figsize=(6, 4)) + for axis, label in enumerate(labels[: base_params.size]): + ax.plot(sweep_steps, fd_sweep_matrix[:, axis], marker="o", label=f"{label} (FD)") + color = ax.get_lines()[-1].get_color() + ax.axhline( + grad_adj[axis], + color=color, + linestyle="--", + alpha=0.7, + label=f"{label} (autograd)", + ) + + ax.set_xscale("log") + ax.set_xlabel("Finite difference step") + ax.set_ylabel("Gradient value") + ax.set_title(f"FD gradients vs. step size ({case['name']})") + ax.grid(True, which="both", ls=":") + ax.legend() + + fig_path = numerical_case_dir / f"medium_fd_step_sweep_{case['name']}.png" + fig.savefig(fig_path, dpi=200) + plt.close(fig) + + # FD gradient extrema per parameter (across all step sizes) + fd_min_per_param = fd_sweep_matrix.min(axis=0) + fd_max_per_param = fd_sweep_matrix.max(axis=0) + + print( + ( + f"[medium-fd-sweep:{case['name']}] " + f"grad_adj={np.array2string(grad_adj, precision=6, separator=', ')} " + f"fd_grad_per_param[min,max]=" + f"{[(f'({mn:.3e},{mx:.3e})') for mn, mx in zip(fd_min_per_param, fd_max_per_param)]}" + ), + file=sys.stderr, + ) diff --git a/tidy3d/components/medium.py b/tidy3d/components/medium.py index 0808ec9568..39a3856d2f 100644 --- a/tidy3d/components/medium.py +++ b/tidy3d/components/medium.py @@ -5,7 +5,7 @@ import functools from abc import ABC, abstractmethod from math import isclose -from typing import Any, Callable, Literal, Optional, Union +from typing import Any, Callable, Literal, Optional, Union, get_args import autograd.numpy as np @@ -14,6 +14,7 @@ import pydantic.v1 as pd import xarray as xr from autograd.differential_operators import tensor_jacobian_product +from numpy.typing import NDArray from tidy3d.components.autograd.utils import pack_complex_vec from tidy3d.components.material.tcad.heat import ThermalSpecType @@ -108,6 +109,8 @@ LOSSY_METAL_DEFAULT_MAX_POLES = 5 LOSSY_METAL_DEFAULT_TOLERANCE_RMS = 1e-3 +ALLOWED_INTERP_METHODS = get_args(InterpMethod) + def ensure_freq_in_range(eps_model: Callable[[float], complex]) -> Callable[[float], complex]: """Decorate ``eps_model`` to log warning if frequency supplied is out of bounds.""" @@ -1911,6 +1914,29 @@ def is_spatially_uniform(self) -> bool: """Whether the medium is spatially uniform.""" return self._medium.is_spatially_uniform + @cached_property + def _permittivity_sorted(self) -> SpatialDataArray | None: + """Cached copy of permittivity sorted along spatial axes.""" + if self.permittivity is None: + return None + return self.permittivity._spatially_sorted + + @cached_property + def _conductivity_sorted(self) -> SpatialDataArray | None: + """Cached copy of conductivity sorted along spatial axes.""" + if self.conductivity is None: + return None + return self.conductivity._spatially_sorted + + @cached_property + def _eps_components_sorted(self) -> dict[str, ScalarFieldDataArray]: + """Cached copies of dataset components sorted along spatial axes.""" + if self.eps_dataset is None: + return {} + return { + key: comp._spatially_sorted for key, comp in self.eps_dataset.field_components.items() + } + @cached_property def freqs(self) -> np.ndarray: """float array of frequencies. @@ -2320,37 +2346,49 @@ def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradField for field_path in derivative_info.paths: if field_path[0] == "permittivity": + spatial_data = self._permittivity_sorted + if spatial_data is None: + continue vjp_array = 0.0 for dim in "xyz": - vjp_array += self._derivative_field_cmp( + vjp_array += self._derivative_field_cmp_custom( E_der_map=derivative_info.E_der_map, - spatial_data=self.permittivity, + spatial_data=spatial_data, dim=dim, freqs=derivative_info.frequencies, + bounds=derivative_info.bounds_intersect, component="real", ) vjps[field_path] = vjp_array elif field_path[0] == "conductivity": + spatial_data = self._conductivity_sorted + if spatial_data is None: + continue vjp_array = 0.0 for dim in "xyz": - vjp_array += self._derivative_field_cmp( + vjp_array += self._derivative_field_cmp_custom( E_der_map=derivative_info.E_der_map, - spatial_data=self.conductivity, + spatial_data=spatial_data, dim=dim, freqs=derivative_info.frequencies, + bounds=derivative_info.bounds_intersect, component="sigma", ) vjps[field_path] = vjp_array elif field_path[0] == "eps_dataset": key = field_path[1] + spatial_data = self._eps_components_sorted.get(key) + if spatial_data is None: + continue dim = key[-1] - vjps[field_path] = self._derivative_field_cmp( + vjps[field_path] = self._derivative_field_cmp_custom( E_der_map=derivative_info.E_der_map, - spatial_data=self.eps_dataset.field_components[key], + spatial_data=spatial_data, dim=dim, freqs=derivative_info.frequencies, + bounds=derivative_info.bounds_intersect, component="complex", ) else: @@ -2360,96 +2398,202 @@ def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradField return vjps - def _derivative_field_cmp( + def _derivative_field_cmp_custom( self, E_der_map: ElectromagneticFieldDataset, - spatial_data: CustomSpatialDataTypeAnnotated, + spatial_data: SpatialDataArray, dim: str, - freqs: np.ndarray, + freqs: NDArray, + bounds: Optional[Bound] = None, component: str = "real", - ) -> np.ndarray: + interp_method: Optional[InterpMethod] = None, + ) -> NDArray: """Compute the derivative with respect to a material property component.""" - coords_interp = {key: spatial_data.coords[key] for key in "xyz"} - coords_interp = {key: val for key, val in coords_interp.items() if len(val) > 1} - - eps_coordinate_shape = [ - len(spatial_data.coords[dim]) for dim in spatial_data.dims if dim in "xyz" - ] - - E_der_dim_interp = E_der_map[f"E{dim}"] + param_coords = {axis: np.asarray(spatial_data.coords[axis]) for axis in "xyz"} + eps_shape = [len(param_coords[axis]) for axis in "xyz"] + dtype_out = complex if component == "complex" else float + + E_der_dim = E_der_map.get(f"E{dim}") + if E_der_dim is None or np.all(E_der_dim.values == 0): + return np.zeros(eps_shape, dtype=dtype_out) + + field_coords = {axis: np.asarray(E_der_dim.coords[axis]) for axis in "xyz"} + values = E_der_dim.values + + def _bounds_slice(axis: NDArray, vmin: float, vmax: float, *, name: str) -> slice: + n = axis.size + i0 = int(np.searchsorted(axis, vmin, side="left")) + i1 = int(np.searchsorted(axis, vmax, side="right")) + if i1 <= i0 and n: + old = (i0, i1) + if i1 < n: + i1 = i0 + 1 # expand right + elif i0 > 0: + i0 = i1 - 1 # expand left + log.warning( + f"Empty bounds crop on '{name}' while computing CustomMedium parameter gradients " + f"(adjoint field grid -> medium grid): bounds=[{vmin!r}, {vmax!r}], " + f"grid=[{axis[0]!r}, {axis[-1]!r}] -> indices {old}; using ({i0}, {i1}).", + log_once=True, + ) + return slice(i0, i1) - for dim_ in "xyz": - if dim_ not in coords_interp: - bound_max = np.max(E_der_dim_interp.coords[dim_]) - bound_min = np.min(E_der_dim_interp.coords[dim_]) - dimension_size = bound_max - bound_min + # usage + if bounds is not None: + (xmin, ymin, zmin), (xmax, ymax, zmax) = bounds - if dimension_size > 0.0: - E_der_dim_interp = E_der_dim_interp.integrate(dim_) + sx = _bounds_slice(field_coords["x"], xmin, xmax, name="x") + sy = _bounds_slice(field_coords["y"], ymin, ymax, name="y") + sz = _bounds_slice(field_coords["z"], zmin, zmax, name="z") - # compute sizes along each of the interpolation dimensions - sizes_list = [] - for _, coords in coords_interp.items(): - num_coords = len(coords) - coords = np.array(coords) + field_coords = {k: field_coords[k][s] for k, s in (("x", sx), ("y", sy), ("z", sz))} + values = values[sx, sy, sz, :] - # compute distances between midpoints for all internal coords + def _axis_sizes(coords: NDArray) -> NDArray: + if coords.size <= 1: + return np.array([1.0]) mid_points = (coords[1:] + coords[:-1]) / 2.0 dists = np.diff(mid_points) - sizes = np.zeros(num_coords) + sizes = np.zeros(coords.size) sizes[1:-1] = dists - - # estimate the sizes on the edges using 2 x the midpoint distance sizes[0] = 2 * abs(mid_points[0] - coords[0]) sizes[-1] = 2 * abs(coords[-1] - mid_points[-1]) + return sizes - sizes_list.append(sizes) + size_x = _axis_sizes(field_coords["x"]) + size_y = _axis_sizes(field_coords["y"]) + size_z = _axis_sizes(field_coords["z"]) + scale = ( + size_x[:, None, None, None] * size_y[None, :, None, None] * size_z[None, None, :, None] + ) + np.multiply(values, scale, out=values) - # turn this into a volume element, should be re-sizeable to the gradient shape - if sizes_list: - d_vol = functools.reduce(np.outer, sizes_list) - else: - # if sizes_list is empty, then reduce() fails - d_vol = np.array(1.0) + method = interp_method if interp_method is not None else self.interp_method - E_der_dim_interp_complex = E_der_dim_interp.interp( - **coords_interp, assume_sorted=True - ).fillna(0.0) + def _transpose_interp_axis( + field_values: NDArray, field_coords_1d: NDArray, param_coords_1d: NDArray + ) -> NDArray: + """ + Transpose (adjoint) of 1D interpolation along one axis. + + Parameters + ---------- + field_values : np.ndarray + Array of values sampled on the field grid along this axis. + Shape: (n_field, ...rest...). + Notes: + - The first axis corresponds to `field_coords_1d`. + - The remaining axes (...rest...) are treated as batch dimensions and are + carried through unchanged. + + field_coords_1d : np.ndarray + 1D coordinates of the field grid along this axis. + Shape: (n_field,). + + param_coords_1d : np.ndarray + 1D coordinates of the parameter grid along this axis. + Shape: (n_param,). Must be sorted ascending for the searchsorted-based logic. + + Returns + ------- + param_values : np.ndarray + Field contributions accumulated onto the parameter grid along this axis. + Shape: (n_param, ...rest...). + + Implementation note + ------------------- + For efficient accumulation, we flatten the trailing dimensions (...rest...) into a single + dimension so we can run a vectorized `np.add.at` on a 2D buffer of shape (n_param, n_rest), + then reshape back to (n_param, ...rest...). + """ + # Single-point parameter grid: every field sample maps to the only parameter entry, + if param_coords_1d.size == 1: + return field_values.sum(axis=0, keepdims=True) + + # Ensure parameter coordinates are sorted for searchsorted-based binning. + if np.any(param_coords_1d[1:] < param_coords_1d[:-1]): + raise ValueError("Spatial coordinates must be sorted before computing derivatives.") + param_coords_sorted = param_coords_1d + + n_param = param_coords_sorted.size + if method not in ALLOWED_INTERP_METHODS: + raise ValueError( + f"Unsupported interpolation method: {method!r}. " + f"Choose one of: {', '.join(ALLOWED_INTERP_METHODS)}." + ) - if component == "sigma": - # compute conductivity gradient from imaginary-permittivity gradient - # apply per-frequency scaling before summing over frequencies - # d eps_imag / d sigma = 1 / (2 * pi * f * EPSILON_0) - E_der_dim_interp = E_der_dim_interp_complex.imag - freqs_da = E_der_dim_interp_complex.coords["f"] - scale = -1.0 / (2.0 * np.pi * freqs_da * EPSILON_0) - E_der_dim_interp *= scale - elif component == "complex": - # for complex permittivity in eps_dataset, return the full complex derivative - E_der_dim_interp = E_der_dim_interp_complex - elif component == "imag": - # pure imaginary component (no conductivity conversion) - E_der_dim_interp = E_der_dim_interp_complex.imag - else: - E_der_dim_interp = E_der_dim_interp_complex.real + # Flatten trailing dimensions into a single "rest" dimension for vectorized accumulation. + n_field = field_values.shape[0] + field_values_2d = field_values.reshape(n_field, -1) - E_der_dim_interp = E_der_dim_interp.sum("f") + if method == "nearest": + # Midpoints define bin edges between adjacent parameter coordinates. + param_midpoints = (param_coords_sorted[1:] + param_coords_sorted[:-1]) / 2.0 + # Map each field coordinate to a nearest parameter-bin index. + param_index_nearest = np.searchsorted(param_midpoints, field_coords_1d) - try: - E_der_dim_interp = E_der_dim_interp * d_vol.reshape(E_der_dim_interp.shape) - except ValueError: - log.warning( - "Skipping volume element normalization of 'CustomMedium' gradients. " - f"Could not reshape the volume elements of shape {d_vol.shape} " - f"to the shape of the fields {E_der_dim_interp.shape}. " - "If you encounter this warning, gradient direction will be accurate but the norm " - "will be inaccurate. Please raise an issue on the tidy3d front end with this " - "message and some information about your simulation setup and we will investigate. " + # Accumulate all field samples into their assigned parameter bins. + param_values_2d = npo.zeros( + (n_param, field_values_2d.shape[1]), dtype=field_values.dtype + ) + npo.add.at(param_values_2d, param_index_nearest, field_values_2d) + + param_values = param_values_2d.reshape((n_param,) + field_values.shape[1:]) + return param_values + + # linear + # Find bracketing parameter indices for each field coordinate. + param_index_upper = np.searchsorted(param_coords_sorted, field_coords_1d, side="right") + param_index_upper = np.clip(param_index_upper, 1, n_param - 1) + param_index_lower = param_index_upper - 1 + + # Compute interpolation fraction within the bracketing segment. + segment_width = ( + param_coords_sorted[param_index_upper] - param_coords_sorted[param_index_lower] ) - vjp_array = E_der_dim_interp.values - vjp_array = vjp_array.reshape(eps_coordinate_shape) + segment_width = np.where(segment_width == 0, 1.0, segment_width) + frac_upper = (field_coords_1d - param_coords_sorted[param_index_lower]) / segment_width + frac_upper = np.clip(frac_upper, 0.0, 1.0) - return vjp_array + # Weights per field sample (broadcast across the flattened trailing dimensions). + w_lower = (1.0 - frac_upper)[:, None] + w_upper = frac_upper[:, None] + + # Accumulate contributions into both bracketing parameter indices. + param_values_2d = npo.zeros( + (n_param, field_values_2d.shape[1]), dtype=field_values.dtype + ) + npo.add.at(param_values_2d, param_index_lower, field_values_2d * w_lower) + npo.add.at(param_values_2d, param_index_upper, field_values_2d * w_upper) + + param_values = param_values_2d.reshape((n_param,) + field_values.shape[1:]) + return param_values + + def _interp_axis( + arr: NDArray, axis: int, field_axis: NDArray, param_axis: NDArray + ) -> NDArray: + """Accumulate values from the field grid onto the parameter grid along one axis. + + Moves ``axis`` to the front, applies ``_transpose_interp_axis`` (adjoint of 1D interpolation) + to map from ``field_axis`` (n_field) to ``param_axis`` (n_param), then moves the axis back. + """ + moved = np.moveaxis(arr, axis, 0) + moved = _transpose_interp_axis(moved, field_axis, param_axis) + return np.moveaxis(moved, 0, axis) + + values = _interp_axis(values, 0, field_coords["x"], param_coords["x"]) + values = _interp_axis(values, 1, field_coords["y"], param_coords["y"]) + values = _interp_axis(values, 2, field_coords["z"], param_coords["z"]) + + freqs_da = np.asarray(E_der_dim.coords["f"]) + if component == "sigma": + values = values.imag * (-1.0 / (2.0 * np.pi * freqs_da * EPSILON_0)) + elif component == "imag": + values = values.imag + elif component == "real": + values = values.real + + return values.sum(axis=-1).reshape(eps_shape) """ Dispersive Media """