Skip to content

Conversation

@marcorudolphflex
Copy link
Contributor

@marcorudolphflex marcorudolphflex commented Dec 12, 2025

Greptile Overview

Greptile Summary

This PR adds autograd support for the Sphere geometry class, enabling automatic differentiation for inverse design optimization. The implementation follows the same pattern as Cylinder and other geometries by overriding the radius field with TracedSize1D and implementing the _compute_derivatives method.

Key changes:

  • Overrides radius field with TracedSize1D to enable autograd tracking
  • Implements _compute_derivatives method that discretizes the sphere surface using icosphere triangulation and computes gradients with respect to radius and center parameters
  • Adds helper methods to_triangle_mesh, _triangulated_surface, _discretization_wavelength, and _tangent_basis_from_normals
  • Includes comprehensive numerical validation tests comparing autograd gradients with finite-difference approximations
  • Adds basic integration test to verify non-zero gradients

Issues found:

  • Division by zero risk in normal computation (line 373 in primitives.py) - missing safety check that polyslab.py uses
  • Test pollution from global variable modification in test_triangle_sphere_fd_step_sweep_ref (line 530-531)

Confidence Score: 3/5

  • This PR requires fixes for two logic issues before merging - division by zero risk and test pollution
  • Score reflects solid implementation following established patterns, but two critical logic issues need addressing: (1) unprotected division by zero when computing normals could cause NaN/inf values, and (2) global variable modification in tests causes test pollution. Both are straightforward to fix.
  • Pay close attention to tidy3d/components/geometry/primitives.py lines 372-373 (division by zero) and tests/test_components/autograd/numerical/test_autograd_sphere_trianglemesh_numerical.py line 530-531 (global modification)

Important Files Changed

File Analysis

Filename Score Overview
tidy3d/components/geometry/primitives.py 4/5 Core implementation of Sphere autograd support with gradient computation for radius and center parameters
tests/test_components/autograd/numerical/test_autograd_sphere_trianglemesh_numerical.py 4/5 Comprehensive numerical validation tests comparing finite-difference with autograd gradients for Sphere

Sequence Diagram

sequenceDiagram
    participant User
    participant Sphere
    participant _compute_derivatives
    participant _triangulated_surface
    participant TriangleMesh
    participant DerivativeInfo
    
    User->>Sphere: Create Sphere(center, radius)
    User->>Sphere: Compute gradients
    Sphere->>_compute_derivatives: derivative_info
    _compute_derivatives->>_compute_derivatives: Validate derivative paths
    _compute_derivatives->>_triangulated_surface: Discretize sphere surface
    _triangulated_surface->>_triangulated_surface: Get unit sphere triangles
    _triangulated_surface-->>_compute_derivatives: Physical triangles
    _compute_derivatives->>TriangleMesh: Convert to trimesh object
    TriangleMesh-->>_compute_derivatives: vertices, faces, areas
    _compute_derivatives->>_compute_derivatives: Compute normals from vertices
    _compute_derivatives->>_compute_derivatives: Filter vertices inside bounds
    _compute_derivatives->>_compute_derivatives: Compute tangent basis
    _compute_derivatives->>DerivativeInfo: Evaluate gradient at points
    DerivativeInfo-->>_compute_derivatives: Gradient values
    _compute_derivatives->>_compute_derivatives: Weight by face areas
    _compute_derivatives->>_compute_derivatives: Integrate to get vjps
    _compute_derivatives-->>Sphere: Gradients for radius/center
    Sphere-->>User: Autograd gradients
Loading

@marcorudolphflex
Copy link
Contributor Author

@greptile

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

@marcorudolphflex marcorudolphflex force-pushed the FXC-4520-add-autograd-support-for-sphere branch from ad19310 to 98ebcc0 Compare December 12, 2025 14:06
@marcorudolphflex marcorudolphflex marked this pull request as ready for review December 12, 2025 14:07
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. tests/test_components/autograd/numerical/test_autograd_sphere_trianglemesh_numerical.py, line 530-531 (link)

    logic: Modifying global constant causes test pollution. Use local variable instead:

4 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@marcorudolphflex marcorudolphflex force-pushed the FXC-4520-add-autograd-support-for-sphere branch from 98ebcc0 to 8cf0d39 Compare December 12, 2025 14:15
Copy link
Collaborator

@yaugenst-flex yaugenst-flex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, everything makes sense! Spotted a couple things to iron out.

) -> tuple[np.ndarray, np.ndarray]:
"""Return physical and unit triangles for the surface discretization."""

unit_tris = self._unit_sphere_triangles(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._unit_sphere_triangles() checks if target_edge_length is not None and subdivisions is not None, and here both subdivisions and max_edge_length are always non-None so I think this will always raise?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mh I do not see a case where we directly set both non-None - but sure does not make sense to make target_edge_length always non-none and then reject calls with non-none subdivisions.

@marcorudolphflex marcorudolphflex force-pushed the FXC-4520-add-autograd-support-for-sphere branch 2 times, most recently from 18801f0 to 6bf04d6 Compare December 16, 2025 11:15
@github-actions
Copy link
Contributor

Diff Coverage

Diff: origin/develop...HEAD, staged and unstaged changes

  • tidy3d/components/geometry/primitives.py (58.6%): Missing lines 132,135,272,339,345,363,372-378,380-381,384-389,391,393-394,398-399,401-405,407-409,413,421-423,425-428,430-431,433,448,452,464,490,516,519,584-585,588-589,592-594,596-598,600-602,604

Summary

  • Total: 157 lines
  • Missing: 65 lines
  • Coverage: 58%

tidy3d/components/geometry/primitives.py

Lines 128-139

  128         TriangleMesh
  129             Triangle mesh approximation of the sphere surface.
  130         """
  131 
! 132         triangles, _ = self._triangulated_surface(
  133             max_edge_length=max_edge_length, subdivisions=subdivisions
  134         )
! 135         return TriangleMesh.from_triangles(triangles)
  136 
  137     def inside(
  138         self, x: np.ndarray[float], y: np.ndarray[float], z: np.ndarray[float]
  139     ) -> np.ndarray[bool]:

Lines 268-276

  268         grid_cfg = config.adjoint
  269 
  270         min_wvl_mat = grid_cfg.min_wvl_fraction * wvl0_min
  271         if wvl_mat < min_wvl_mat:
! 272             log.warning(
  273                 f"The minimum wavelength inside the sphere material is {wvl_mat:.3e} μm, which would "
  274                 f"create a large number of discretization points for computing the gradient. "
  275                 f"To prevent performance degradation, the discretization wavelength has "
  276                 f"been clipped to {min_wvl_mat:.3e} μm.",

Lines 335-343

  335         """Compute adjoint derivatives using smooth sphere surface samples."""
  336         valid_paths = {("radius",), *{("center", i) for i in range(3)}}
  337         for path in derivative_info.paths:
  338             if path not in valid_paths:
! 339                 raise ValueError(
  340                     f"No derivative defined w.r.t. 'Sphere' field '{path}'. "
  341                     "Supported fields are 'radius' and 'center'."
  342                 )

Lines 341-349

  341                     "Supported fields are 'radius' and 'center'."
  342                 )
  343 
  344         if not derivative_info.paths:
! 345             return {}
  346 
  347         grid_cfg = config.adjoint
  348         wvl_mat = self._discretization_wavelength(derivative_info)
  349         target_edge = max(wvl_mat / grid_cfg.points_per_wavelength, np.finfo(float).eps)

Lines 359-367

  359         sim_extents = sim_max - sim_min
  360         collapsed_indices = np.flatnonzero(np.isclose(sim_extents, 0.0, atol=tol))
  361         if collapsed_indices.size:
  362             if collapsed_indices.size > 1:
! 363                 return dict.fromkeys(derivative_info.paths, 0.0)
  364             axis_idx = int(collapsed_indices[0])
  365             plane_value = float(sim_min[axis_idx])
  366             return self._compute_derivatives_collapsed_axis(
  367                 derivative_info=derivative_info,

Lines 368-417

  368                 axis_idx=axis_idx,
  369                 plane_value=plane_value,
  370             )
  371 
! 372         trimesh_obj = TriangleMesh._triangles_to_trimesh(triangles)
! 373         vertices = np.asarray(trimesh_obj.vertices, dtype=grid_cfg.gradient_dtype_float)
! 374         center = np.asarray(self.center, dtype=grid_cfg.gradient_dtype_float)
! 375         verts_centered = vertices - center
! 376         norms = np.linalg.norm(verts_centered, axis=1, keepdims=True)
! 377         norms = np.where(norms == 0, 1, norms)
! 378         normals = verts_centered / norms
  379 
! 380         if vertices.size == 0:
! 381             return dict.fromkeys(derivative_info.paths, 0.0)
  382 
  383         # get vertex weights
! 384         faces = np.asarray(trimesh_obj.faces, dtype=int)
! 385         face_areas = np.asarray(trimesh_obj.area_faces, dtype=grid_cfg.gradient_dtype_float)
! 386         weights = np.zeros(len(vertices), dtype=grid_cfg.gradient_dtype_float)
! 387         np.add.at(weights, faces[:, 0], face_areas / 3.0)
! 388         np.add.at(weights, faces[:, 1], face_areas / 3.0)
! 389         np.add.at(weights, faces[:, 2], face_areas / 3.0)
  390 
! 391         perp1, perp2 = self._tangent_basis_from_normals(normals)
  392 
! 393         valid_axes = np.abs(sim_max - sim_min) > tol
! 394         inside_mask = np.all(
  395             vertices[:, valid_axes] >= (sim_min - tol)[valid_axes], axis=1
  396         ) & np.all(vertices[:, valid_axes] <= (sim_max + tol)[valid_axes], axis=1)
  397 
! 398         if not np.any(inside_mask):
! 399             return dict.fromkeys(derivative_info.paths, 0.0)
  400 
! 401         points = vertices[inside_mask]
! 402         normals_sel = normals[inside_mask]
! 403         perp1_sel = perp1[inside_mask]
! 404         perp2_sel = perp2[inside_mask]
! 405         weights_sel = weights[inside_mask]
  406 
! 407         interpolators = derivative_info.interpolators
! 408         if interpolators is None:
! 409             interpolators = derivative_info.create_interpolators(
  410                 dtype=grid_cfg.gradient_dtype_float
  411             )
  412 
! 413         g = derivative_info.evaluate_gradient_at_points(
  414             points,
  415             normals_sel,
  416             perp1_sel,
  417             perp2_sel,

Lines 417-437

  417             perp2_sel,
  418             interpolators,
  419         )
  420 
! 421         weighted = (weights_sel * g).real
! 422         grad_center = np.sum(weighted[:, None] * normals_sel, axis=0)
! 423         grad_radius = np.sum(weighted)
  424 
! 425         vjps: AutogradFieldMap = {}
! 426         for path in derivative_info.paths:
! 427             if path == ("radius",):
! 428                 vjps[path] = float(grad_radius)
  429             else:
! 430                 _, idx = path
! 431                 vjps[path] = float(grad_center[idx])
  432 
! 433         return vjps
  434 
  435     def _compute_derivatives_collapsed_axis(
  436         self,
  437         derivative_info: DerivativeInfo,

Lines 444-456

  444         center = np.asarray(self.center, dtype=float)
  445         delta = plane_value - center[axis_idx]
  446         radius_sq = radius**2 - delta**2
  447         if radius_sq <= tol**2:
! 448             return dict.fromkeys(derivative_info.paths, 0.0)
  449 
  450         radius_plane = float(np.sqrt(max(radius_sq, 0.0)))
  451         if radius_plane <= tol:
! 452             return dict.fromkeys(derivative_info.paths, 0.0)
  453 
  454         cyl_paths: set[tuple[str, int | None]] = set()
  455         need_radius = False
  456         for path in derivative_info.paths:

Lines 460-468

  460             elif path[0] == "center" and path[1] != axis_idx:
  461                 cyl_paths.add(("center", path[1]))
  462 
  463         if not cyl_paths:
! 464             return dict.fromkeys(derivative_info.paths, 0.0)
  465 
  466         cyl_center = center.copy()
  467         cyl_center[axis_idx] = plane_value
  468         cylinder = Cylinder(

Lines 486-494

  486         )
  487         intersect_min = tuple(max(bounds[0][i], sim_min_arr[i]) for i in range(3))
  488         intersect_max = tuple(min(bounds[1][i], sim_max_arr[i]) for i in range(3))
  489         if any(lo > hi for lo, hi in zip(intersect_min, intersect_max)):
! 490             return dict.fromkeys(derivative_info.paths, 0.0)
  491 
  492         derivative_info_cyl = derivative_info.updated_copy(
  493             paths=list(cyl_paths),
  494             bounds=bounds,

Lines 512-523

  512     def _edge_length_on_unit_sphere(self, max_edge_length: Optional[float]) -> Optional[float]:
  513         """Convert ``max_edge_length`` in μm to unit-sphere coordinates."""
  514 
  515         if max_edge_length is None:
! 516             return _DEFAULT_EDGE_FRACTION
  517         radius = float(self.radius)
  518         if radius <= 0.0:
! 519             return None
  520         return max_edge_length / radius
  521 
  522     def _triangulated_surface(
  523         self,

Lines 580-608

  580     @staticmethod
  581     def _tangent_basis_from_normals(normals: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
  582         """Construct orthonormal tangential bases for each normal vector (vectorized)."""
  583 
! 584         dtype = normals.dtype
! 585         tol = np.finfo(dtype).eps
  586 
  587         # Normalize normals (in case they are not perfectly unit length).
! 588         n_norm = np.linalg.norm(normals, axis=1)
! 589         n = normals / np.maximum(n_norm, tol)[:, None]
  590 
  591         # Pick a reference axis least aligned with each normal: argmin(|nx|,|ny|,|nz|).
! 592         ref_idx = np.argmin(np.abs(n), axis=1)
! 593         ref = np.zeros_like(n)
! 594         ref[np.arange(n.shape[0]), ref_idx] = 1.0
  595 
! 596         basis1 = np.cross(n, ref)
! 597         b1_norm = np.linalg.norm(basis1, axis=1)
! 598         basis1 = basis1 / np.maximum(b1_norm, tol)[:, None]
  599 
! 600         basis2 = np.cross(n, basis1)
! 601         b2_norm = np.linalg.norm(basis2, axis=1)
! 602         basis2 = basis2 / np.maximum(b2_norm, tol)[:, None]
  603 
! 604         return basis1, basis2
  605 
  606     def _icosphere_data(self, subdivisions: int) -> tuple[np.ndarray, float]:
  607         cache = self._icosphere_cache
  608         if subdivisions in cache:

return []
return [shapely.Point(x0, y0).buffer(0.5 * intersect_dist, quad_segs=quad_segs)]

def _discretization_wavelength(self, derivative_info: DerivativeInfo) -> float:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

possibly since we have one of these in cylinder as well, there is a way to combine them together into a general function for choosing the discretization wavelength and issuing the warning?

center = np.asarray(self.center, dtype=grid_cfg.gradient_dtype_float)
verts_centered = vertices - center
norms = np.linalg.norm(verts_centered, axis=1, keepdims=True)
norms = np.where(norms == 0, 1, norms)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this catching the case where you have a vertex point right at the sphere center?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I guess not - good point

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this case make sense at all? The area and thus the weights would be zero then. Take an epsilon radius in this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also might just raise an exception here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah good point! the area would be negligibly small I think. I would suspect this case would only happen for a small or zero radius with respect to the mesh. I could see something like this happening in an optimization if the gradient is saying to reduce the sphere size. do you think a one-time warning would make sense here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a warning and zero grad return


vjps_cyl = cylinder._compute_derivatives(derivative_info_cyl)
result = dict.fromkeys(derivative_info.paths, 0.0)
vjp_radius = float(vjps_cyl.get(("radius",), 0.0)) if need_radius else 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious why we need the need_radius here? will "radius" just not be in the derivative_info.paths if it isn't needed?

Copy link
Contributor Author

@marcorudolphflex marcorudolphflex Dec 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is because the cylinder radius can effectively change with the sphere center (on the sim-orthogonal axis) which causes a different intersection circle/cylinder with the simulation plane

"radius": {
"minimum": 0,
"type": "number"
"anyOf": [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran into a similar thing here in my custom vjp PR. Talking with @yaugenst-flex , I think they are going to open up schema breaking changes again Jan 5th. So the PR might need to wait to get merged until then

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's fine :)

Copy link
Contributor

@groberts-flex groberts-flex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looking good @marcorudolphflex!

@marcorudolphflex marcorudolphflex force-pushed the FXC-4520-add-autograd-support-for-sphere branch from 6bf04d6 to 15afd43 Compare December 18, 2025 09:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants