Skip to content

Conversation

@marcorudolphflex
Copy link
Contributor

@marcorudolphflex marcorudolphflex commented Dec 22, 2025

Fixed interpolation and boundary handling in gradient computation of custom medium.
Note that the new testfile test_autograd_medium_numerical.py will be extended to anisotropic medium for this PR #3080 which is why some test branches are already prepared.

Greptile Summary

This PR fixes critical bugs in gradient computation for CustomMedium by replacing the previous xarray-based interpolation approach with a custom transpose-accumulation implementation. The old implementation had issues with interpolation handling and boundary conditions that led to incorrect gradients.

Key changes:

  • Added cached properties (_permittivity_sorted, _conductivity_sorted, _eps_components_sorted) to ensure spatial coordinates are sorted before gradient computation, preventing errors during interpolation
  • Replaced _derivative_field_cmp with _derivative_field_cmp_custom that implements a custom transpose-based accumulation strategy instead of relying on xarray's .interp() method
  • Implemented proper boundary handling via searchsorted to clip field data to intersection bounds before processing
  • Added explicit volume element scaling based on field grid coordinates rather than parameter grid coordinates
  • Implemented custom nearest and linear interpolation using npo.add.at for proper gradient accumulation (using regular numpy npo instead of autograd numpy np for operations that need in-place accumulation)
  • Added comprehensive numerical validation tests comparing adjoint gradients to finite-difference gradients

Technical improvements:

  • The new implementation correctly handles the transpose (adjoint) operation needed for gradient backpropagation by accumulating field values onto parameter grid points using weighted contributions
  • Volume elements are now computed from field coordinates and applied before interpolation, fixing the previous normalization issues
  • The approach works for both "nearest" and "linear" interpolation methods with proper validation

Confidence Score: 4/5

  • This PR is safe to merge with minor suggestions for improvement
  • The implementation is mathematically sound and includes comprehensive numerical validation tests. The refactor addresses real bugs in gradient computation. Score reduced from 5 to 4 due to one style suggestion regarding floating-point comparison that could improve robustness
  • No files require special attention - the implementation is solid with good test coverage

Important Files Changed

Filename Overview
CHANGELOG.md Added changelog entry documenting the interpolation bug fix for CustomMedium gradients
tests/test_components/autograd/numerical/test_autograd_medium_numerical.py New comprehensive test file for numerical gradient validation using finite differences; includes parametrized tests for isotropic and custom media with flux/intensity objectives
tidy3d/components/medium.py Major refactor of gradient computation: replaced xarray interpolation with custom implementation using transpose accumulation for proper boundary handling; added sorted data caching

Sequence Diagram

sequenceDiagram
    participant Client as Autograd System
    participant CM as CustomMedium
    participant CDC as _compute_derivatives
    participant DFCC as _derivative_field_cmp_custom
    participant TIA as _transpose_interp_axis
    
    Client->>CM: Request gradients for permittivity/conductivity/eps_dataset
    CM->>CM: Access cached _permittivity_sorted, _conductivity_sorted, _eps_components_sorted
    CM->>CDC: _compute_derivatives(derivative_info)
    
    loop For each field_path
        CDC->>CDC: Get sorted spatial_data from cache
        
        alt permittivity or conductivity
            loop For each dimension (x, y, z)
                CDC->>DFCC: _derivative_field_cmp_custom(E_der_map, spatial_data, dim, freqs, bounds, component)
                DFCC->>DFCC: Extract field coordinates and values
                DFCC->>DFCC: Apply bounds filtering (searchsorted)
                DFCC->>DFCC: Compute volume element scaling (_axis_sizes)
                DFCC->>DFCC: Multiply values by scale
                
                loop For each axis (x, y, z)
                    DFCC->>TIA: _transpose_interp_axis(arr, field_axis, param_axis)
                    
                    alt method == "nearest"
                        TIA->>TIA: Compute midpoints between param coords
                        TIA->>TIA: Use searchsorted to find indices
                        TIA->>TIA: Accumulate using npo.add.at
                    else method == "linear"
                        TIA->>TIA: Find upper/lower indices via searchsorted
                        TIA->>TIA: Compute interpolation weights
                        TIA->>TIA: Accumulate weighted contributions using npo.add.at
                    end
                    
                    TIA-->>DFCC: Return interpolated array
                end
                
                DFCC->>DFCC: Apply component extraction (real/imag/sigma/complex)
                DFCC->>DFCC: Sum over frequencies
                DFCC-->>CDC: Return gradient array
            end
            CDC->>CDC: Sum contributions from all dimensions
        else eps_dataset component
            CDC->>DFCC: _derivative_field_cmp_custom for specific component
            DFCC-->>CDC: Return gradient array
        end
        
        CDC->>CDC: Store in vjps dictionary
    end
    
    CDC-->>Client: Return vjps (gradient dictionary)
Loading

@marcorudolphflex
Copy link
Contributor Author

@greptile

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@marcorudolphflex marcorudolphflex marked this pull request as ready for review December 22, 2025 09:14
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@marcorudolphflex marcorudolphflex force-pushed the FXC-4641-fix-gradients-in-custom-medium branch from 37b5e06 to 39c9050 Compare December 22, 2025 13:56
@github-actions
Copy link
Contributor

github-actions bot commented Dec 22, 2025

Diff Coverage

Diff: origin/develop...HEAD, staged and unstaged changes

  • tidy3d/components/medium.py (74.1%): Missing lines 1921,1928,1935,2351,2367,2384,2418,2428-2433,2515,2520,2546-2548,2551,2554-2556,2559-2560,2563,2566-2567,2569-2570,2592

Summary

  • Total: 116 lines
  • Missing: 30 lines
  • Coverage: 74%

tidy3d/components/medium.py

Lines 1917-1925

  1917     @cached_property
  1918     def _permittivity_sorted(self) -> SpatialDataArray | None:
  1919         """Cached copy of permittivity sorted along spatial axes."""
  1920         if self.permittivity is None:
! 1921             return None
  1922         return self.permittivity._spatially_sorted
  1923 
  1924     @cached_property
  1925     def _conductivity_sorted(self) -> SpatialDataArray | None:

Lines 1924-1932

  1924     @cached_property
  1925     def _conductivity_sorted(self) -> SpatialDataArray | None:
  1926         """Cached copy of conductivity sorted along spatial axes."""
  1927         if self.conductivity is None:
! 1928             return None
  1929         return self.conductivity._spatially_sorted
  1930 
  1931     @cached_property
  1932     def _eps_components_sorted(self) -> dict[str, ScalarFieldDataArray]:

Lines 1931-1939

  1931     @cached_property
  1932     def _eps_components_sorted(self) -> dict[str, ScalarFieldDataArray]:
  1933         """Cached copies of dataset components sorted along spatial axes."""
  1934         if self.eps_dataset is None:
! 1935             return {}
  1936         return {
  1937             key: comp._spatially_sorted for key, comp in self.eps_dataset.field_components.items()
  1938         }

Lines 2347-2355

  2347         for field_path in derivative_info.paths:
  2348             if field_path[0] == "permittivity":
  2349                 spatial_data = self._permittivity_sorted
  2350                 if spatial_data is None:
! 2351                     continue
  2352                 vjp_array = 0.0
  2353                 for dim in "xyz":
  2354                     vjp_array += self._derivative_field_cmp_custom(
  2355                         E_der_map=derivative_info.E_der_map,

Lines 2363-2371

  2363 
  2364             elif field_path[0] == "conductivity":
  2365                 spatial_data = self._conductivity_sorted
  2366                 if spatial_data is None:
! 2367                     continue
  2368                 vjp_array = 0.0
  2369                 for dim in "xyz":
  2370                     vjp_array += self._derivative_field_cmp_custom(
  2371                         E_der_map=derivative_info.E_der_map,

Lines 2380-2388

  2380             elif field_path[0] == "eps_dataset":
  2381                 key = field_path[1]
  2382                 spatial_data = self._eps_components_sorted.get(key)
  2383                 if spatial_data is None:
! 2384                     continue
  2385                 dim = key[-1]
  2386                 vjps[field_path] = self._derivative_field_cmp_custom(
  2387                     E_der_map=derivative_info.E_der_map,
  2388                     spatial_data=spatial_data,

Lines 2414-2422

  2414         dtype_out = complex if component == "complex" else float
  2415 
  2416         E_der_dim = E_der_map.get(f"E{dim}")
  2417         if E_der_dim is None or np.all(E_der_dim.values == 0):
! 2418             return np.zeros(eps_shape, dtype=dtype_out)
  2419 
  2420         field_coords = {axis: np.asarray(E_der_dim.coords[axis]) for axis in "xyz"}
  2421         values = E_der_dim.values

Lines 2424-2437

  2424             n = axis.size
  2425             i0 = int(np.searchsorted(axis, vmin, side="left"))
  2426             i1 = int(np.searchsorted(axis, vmax, side="right"))
  2427             if i1 <= i0 and n:
! 2428                 old = (i0, i1)
! 2429                 if i1 < n:
! 2430                     i1 = i0 + 1  # expand right
! 2431                 elif i0 > 0:
! 2432                     i0 = i1 - 1  # expand left
! 2433                 log.warning(
  2434                     f"Empty bounds crop on '{name}' while computing CustomMedium parameter gradients "
  2435                     f"(adjoint field grid -> medium grid): bounds=[{vmin!r}, {vmax!r}], "
  2436                     f"grid=[{axis[0]!r}, {axis[-1]!r}] -> indices {old}; using ({i0}, {i1}).",
  2437                     log_once=True,

Lines 2511-2524

  2511                 return field_values.sum(axis=0, keepdims=True)
  2512 
  2513             # Ensure parameter coordinates are sorted for searchsorted-based binning.
  2514             if np.any(param_coords_1d[1:] < param_coords_1d[:-1]):
! 2515                 raise ValueError("Spatial coordinates must be sorted before computing derivatives.")
  2516             param_coords_sorted = param_coords_1d
  2517 
  2518             n_param = param_coords_sorted.size
  2519             if method not in ALLOWED_INTERP_METHODS:
! 2520                 raise ValueError(
  2521                     f"Unsupported interpolation method: {method!r}. "
  2522                     f"Choose one of: {', '.join(ALLOWED_INTERP_METHODS)}."
  2523                 )

Lines 2542-2574

  2542                 return param_values
  2543 
  2544             # linear
  2545             # Find bracketing parameter indices for each field coordinate.
! 2546             param_index_upper = np.searchsorted(param_coords_sorted, field_coords_1d, side="right")
! 2547             param_index_upper = np.clip(param_index_upper, 1, n_param - 1)
! 2548             param_index_lower = param_index_upper - 1
  2549 
  2550             # Compute interpolation fraction within the bracketing segment.
! 2551             segment_width = (
  2552                 param_coords_sorted[param_index_upper] - param_coords_sorted[param_index_lower]
  2553             )
! 2554             segment_width = np.where(segment_width == 0, 1.0, segment_width)
! 2555             frac_upper = (field_coords_1d - param_coords_sorted[param_index_lower]) / segment_width
! 2556             frac_upper = np.clip(frac_upper, 0.0, 1.0)
  2557 
  2558             # Weights per field sample (broadcast across the flattened trailing dimensions).
! 2559             w_lower = (1.0 - frac_upper)[:, None]
! 2560             w_upper = frac_upper[:, None]
  2561 
  2562             # Accumulate contributions into both bracketing parameter indices.
! 2563             param_values_2d = npo.zeros(
  2564                 (n_param, field_values_2d.shape[1]), dtype=field_values.dtype
  2565             )
! 2566             npo.add.at(param_values_2d, param_index_lower, field_values_2d * w_lower)
! 2567             npo.add.at(param_values_2d, param_index_upper, field_values_2d * w_upper)
  2568 
! 2569             param_values = param_values_2d.reshape((n_param,) + field_values.shape[1:])
! 2570             return param_values
  2571 
  2572         def _interp_axis(
  2573             arr: NDArray, axis: int, field_axis: NDArray, param_axis: NDArray
  2574         ) -> NDArray:

Lines 2588-2596

  2588         freqs_da = np.asarray(E_der_dim.coords["f"])
  2589         if component == "sigma":
  2590             values = values.imag * (-1.0 / (2.0 * np.pi * freqs_da * EPSILON_0))
  2591         elif component == "imag":
! 2592             values = values.imag
  2593         elif component == "real":
  2594             values = values.real
  2595 
  2596         return values.sum(axis=-1).reshape(eps_shape)

Copy link
Contributor

@groberts-flex groberts-flex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @marcorudolphflex this is great! overall, looks good and thanks for adding some good tests! I left a few comments on it, excited to have this change!

@marcorudolphflex
Copy link
Contributor Author

thanks @marcorudolphflex this is great! overall, looks good and thanks for adding some good tests! I left a few comments on it, excited to have this change!

Thanks for the careful review! Hope it is now more clear overall.

@marcorudolphflex marcorudolphflex force-pushed the FXC-4641-fix-gradients-in-custom-medium branch from 39c9050 to e4bf5b3 Compare December 23, 2025 09:36
@marcorudolphflex marcorudolphflex force-pushed the FXC-4641-fix-gradients-in-custom-medium branch from e4bf5b3 to 5dca3f2 Compare December 23, 2025 09:59
Copy link
Contributor

@groberts-flex groberts-flex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the changes @marcorudolphflex, looks good to me!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants