-
Notifications
You must be signed in to change notification settings - Fork 67
fix(tidy3d): FXC-4641-fix-gradients-in-custom-medium #3113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 1 comment
37b5e06 to
39c9050
Compare
Diff CoverageDiff: origin/develop...HEAD, staged and unstaged changes
Summary
tidy3d/components/medium.pyLines 1917-1925 1917 @cached_property
1918 def _permittivity_sorted(self) -> SpatialDataArray | None:
1919 """Cached copy of permittivity sorted along spatial axes."""
1920 if self.permittivity is None:
! 1921 return None
1922 return self.permittivity._spatially_sorted
1923
1924 @cached_property
1925 def _conductivity_sorted(self) -> SpatialDataArray | None:Lines 1924-1932 1924 @cached_property
1925 def _conductivity_sorted(self) -> SpatialDataArray | None:
1926 """Cached copy of conductivity sorted along spatial axes."""
1927 if self.conductivity is None:
! 1928 return None
1929 return self.conductivity._spatially_sorted
1930
1931 @cached_property
1932 def _eps_components_sorted(self) -> dict[str, ScalarFieldDataArray]:Lines 1931-1939 1931 @cached_property
1932 def _eps_components_sorted(self) -> dict[str, ScalarFieldDataArray]:
1933 """Cached copies of dataset components sorted along spatial axes."""
1934 if self.eps_dataset is None:
! 1935 return {}
1936 return {
1937 key: comp._spatially_sorted for key, comp in self.eps_dataset.field_components.items()
1938 }Lines 2347-2355 2347 for field_path in derivative_info.paths:
2348 if field_path[0] == "permittivity":
2349 spatial_data = self._permittivity_sorted
2350 if spatial_data is None:
! 2351 continue
2352 vjp_array = 0.0
2353 for dim in "xyz":
2354 vjp_array += self._derivative_field_cmp_custom(
2355 E_der_map=derivative_info.E_der_map,Lines 2363-2371 2363
2364 elif field_path[0] == "conductivity":
2365 spatial_data = self._conductivity_sorted
2366 if spatial_data is None:
! 2367 continue
2368 vjp_array = 0.0
2369 for dim in "xyz":
2370 vjp_array += self._derivative_field_cmp_custom(
2371 E_der_map=derivative_info.E_der_map,Lines 2380-2388 2380 elif field_path[0] == "eps_dataset":
2381 key = field_path[1]
2382 spatial_data = self._eps_components_sorted.get(key)
2383 if spatial_data is None:
! 2384 continue
2385 dim = key[-1]
2386 vjps[field_path] = self._derivative_field_cmp_custom(
2387 E_der_map=derivative_info.E_der_map,
2388 spatial_data=spatial_data,Lines 2414-2422 2414 dtype_out = complex if component == "complex" else float
2415
2416 E_der_dim = E_der_map.get(f"E{dim}")
2417 if E_der_dim is None or np.all(E_der_dim.values == 0):
! 2418 return np.zeros(eps_shape, dtype=dtype_out)
2419
2420 field_coords = {axis: np.asarray(E_der_dim.coords[axis]) for axis in "xyz"}
2421 values = E_der_dim.valuesLines 2424-2437 2424 n = axis.size
2425 i0 = int(np.searchsorted(axis, vmin, side="left"))
2426 i1 = int(np.searchsorted(axis, vmax, side="right"))
2427 if i1 <= i0 and n:
! 2428 old = (i0, i1)
! 2429 if i1 < n:
! 2430 i1 = i0 + 1 # expand right
! 2431 elif i0 > 0:
! 2432 i0 = i1 - 1 # expand left
! 2433 log.warning(
2434 f"Empty bounds crop on '{name}' while computing CustomMedium parameter gradients "
2435 f"(adjoint field grid -> medium grid): bounds=[{vmin!r}, {vmax!r}], "
2436 f"grid=[{axis[0]!r}, {axis[-1]!r}] -> indices {old}; using ({i0}, {i1}).",
2437 log_once=True,Lines 2511-2524 2511 return field_values.sum(axis=0, keepdims=True)
2512
2513 # Ensure parameter coordinates are sorted for searchsorted-based binning.
2514 if np.any(param_coords_1d[1:] < param_coords_1d[:-1]):
! 2515 raise ValueError("Spatial coordinates must be sorted before computing derivatives.")
2516 param_coords_sorted = param_coords_1d
2517
2518 n_param = param_coords_sorted.size
2519 if method not in ALLOWED_INTERP_METHODS:
! 2520 raise ValueError(
2521 f"Unsupported interpolation method: {method!r}. "
2522 f"Choose one of: {', '.join(ALLOWED_INTERP_METHODS)}."
2523 )Lines 2542-2574 2542 return param_values
2543
2544 # linear
2545 # Find bracketing parameter indices for each field coordinate.
! 2546 param_index_upper = np.searchsorted(param_coords_sorted, field_coords_1d, side="right")
! 2547 param_index_upper = np.clip(param_index_upper, 1, n_param - 1)
! 2548 param_index_lower = param_index_upper - 1
2549
2550 # Compute interpolation fraction within the bracketing segment.
! 2551 segment_width = (
2552 param_coords_sorted[param_index_upper] - param_coords_sorted[param_index_lower]
2553 )
! 2554 segment_width = np.where(segment_width == 0, 1.0, segment_width)
! 2555 frac_upper = (field_coords_1d - param_coords_sorted[param_index_lower]) / segment_width
! 2556 frac_upper = np.clip(frac_upper, 0.0, 1.0)
2557
2558 # Weights per field sample (broadcast across the flattened trailing dimensions).
! 2559 w_lower = (1.0 - frac_upper)[:, None]
! 2560 w_upper = frac_upper[:, None]
2561
2562 # Accumulate contributions into both bracketing parameter indices.
! 2563 param_values_2d = npo.zeros(
2564 (n_param, field_values_2d.shape[1]), dtype=field_values.dtype
2565 )
! 2566 npo.add.at(param_values_2d, param_index_lower, field_values_2d * w_lower)
! 2567 npo.add.at(param_values_2d, param_index_upper, field_values_2d * w_upper)
2568
! 2569 param_values = param_values_2d.reshape((n_param,) + field_values.shape[1:])
! 2570 return param_values
2571
2572 def _interp_axis(
2573 arr: NDArray, axis: int, field_axis: NDArray, param_axis: NDArray
2574 ) -> NDArray:Lines 2588-2596 2588 freqs_da = np.asarray(E_der_dim.coords["f"])
2589 if component == "sigma":
2590 values = values.imag * (-1.0 / (2.0 * np.pi * freqs_da * EPSILON_0))
2591 elif component == "imag":
! 2592 values = values.imag
2593 elif component == "real":
2594 values = values.real
2595
2596 return values.sum(axis=-1).reshape(eps_shape) |
groberts-flex
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @marcorudolphflex this is great! overall, looks good and thanks for adding some good tests! I left a few comments on it, excited to have this change!
Thanks for the careful review! Hope it is now more clear overall. |
39c9050 to
e4bf5b3
Compare
e4bf5b3 to
5dca3f2
Compare
groberts-flex
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the changes @marcorudolphflex, looks good to me!
Fixed interpolation and boundary handling in gradient computation of custom medium.
Note that the new testfile test_autograd_medium_numerical.py will be extended to anisotropic medium for this PR #3080 which is why some test branches are already prepared.
Greptile Summary
This PR fixes critical bugs in gradient computation for
CustomMediumby replacing the previous xarray-based interpolation approach with a custom transpose-accumulation implementation. The old implementation had issues with interpolation handling and boundary conditions that led to incorrect gradients.Key changes:
_permittivity_sorted,_conductivity_sorted,_eps_components_sorted) to ensure spatial coordinates are sorted before gradient computation, preventing errors during interpolation_derivative_field_cmpwith_derivative_field_cmp_customthat implements a custom transpose-based accumulation strategy instead of relying on xarray's.interp()methodsearchsortedto clip field data to intersection bounds before processingnpo.add.atfor proper gradient accumulation (using regular numpynpoinstead of autograd numpynpfor operations that need in-place accumulation)Technical improvements:
Confidence Score: 4/5
Important Files Changed
Sequence Diagram
sequenceDiagram participant Client as Autograd System participant CM as CustomMedium participant CDC as _compute_derivatives participant DFCC as _derivative_field_cmp_custom participant TIA as _transpose_interp_axis Client->>CM: Request gradients for permittivity/conductivity/eps_dataset CM->>CM: Access cached _permittivity_sorted, _conductivity_sorted, _eps_components_sorted CM->>CDC: _compute_derivatives(derivative_info) loop For each field_path CDC->>CDC: Get sorted spatial_data from cache alt permittivity or conductivity loop For each dimension (x, y, z) CDC->>DFCC: _derivative_field_cmp_custom(E_der_map, spatial_data, dim, freqs, bounds, component) DFCC->>DFCC: Extract field coordinates and values DFCC->>DFCC: Apply bounds filtering (searchsorted) DFCC->>DFCC: Compute volume element scaling (_axis_sizes) DFCC->>DFCC: Multiply values by scale loop For each axis (x, y, z) DFCC->>TIA: _transpose_interp_axis(arr, field_axis, param_axis) alt method == "nearest" TIA->>TIA: Compute midpoints between param coords TIA->>TIA: Use searchsorted to find indices TIA->>TIA: Accumulate using npo.add.at else method == "linear" TIA->>TIA: Find upper/lower indices via searchsorted TIA->>TIA: Compute interpolation weights TIA->>TIA: Accumulate weighted contributions using npo.add.at end TIA-->>DFCC: Return interpolated array end DFCC->>DFCC: Apply component extraction (real/imag/sigma/complex) DFCC->>DFCC: Sum over frequencies DFCC-->>CDC: Return gradient array end CDC->>CDC: Sum contributions from all dimensions else eps_dataset component CDC->>DFCC: _derivative_field_cmp_custom for specific component DFCC-->>CDC: Return gradient array end CDC->>CDC: Store in vjps dictionary end CDC-->>Client: Return vjps (gradient dictionary)