-
Notifications
You must be signed in to change notification settings - Fork 67
feat(tidy3d): FXC-4520-add-autograd-support-for-sphere #3082
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 3 comments
ad19310 to
98ebcc0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
tests/test_components/autograd/numerical/test_autograd_sphere_trianglemesh_numerical.py, line 530-531 (link)logic: Modifying global constant causes test pollution. Use local variable instead:
4 files reviewed, 2 comments
98ebcc0 to
8cf0d39
Compare
yaugenst-flex
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM, everything makes sense! Spotted a couple things to iron out.
| ) -> tuple[np.ndarray, np.ndarray]: | ||
| """Return physical and unit triangles for the surface discretization.""" | ||
|
|
||
| unit_tris = self._unit_sphere_triangles( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self._unit_sphere_triangles() checks if target_edge_length is not None and subdivisions is not None, and here both subdivisions and max_edge_length are always non-None so I think this will always raise?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mh I do not see a case where we directly set both non-None - but sure does not make sense to make target_edge_length always non-none and then reject calls with non-none subdivisions.
18801f0 to
6bf04d6
Compare
Diff CoverageDiff: origin/develop...HEAD, staged and unstaged changes
Summary
tidy3d/components/geometry/primitives.pyLines 128-139 128 TriangleMesh
129 Triangle mesh approximation of the sphere surface.
130 """
131
! 132 triangles, _ = self._triangulated_surface(
133 max_edge_length=max_edge_length, subdivisions=subdivisions
134 )
! 135 return TriangleMesh.from_triangles(triangles)
136
137 def inside(
138 self, x: np.ndarray[float], y: np.ndarray[float], z: np.ndarray[float]
139 ) -> np.ndarray[bool]:Lines 268-276 268 grid_cfg = config.adjoint
269
270 min_wvl_mat = grid_cfg.min_wvl_fraction * wvl0_min
271 if wvl_mat < min_wvl_mat:
! 272 log.warning(
273 f"The minimum wavelength inside the sphere material is {wvl_mat:.3e} μm, which would "
274 f"create a large number of discretization points for computing the gradient. "
275 f"To prevent performance degradation, the discretization wavelength has "
276 f"been clipped to {min_wvl_mat:.3e} μm.",Lines 335-343 335 """Compute adjoint derivatives using smooth sphere surface samples."""
336 valid_paths = {("radius",), *{("center", i) for i in range(3)}}
337 for path in derivative_info.paths:
338 if path not in valid_paths:
! 339 raise ValueError(
340 f"No derivative defined w.r.t. 'Sphere' field '{path}'. "
341 "Supported fields are 'radius' and 'center'."
342 )Lines 341-349 341 "Supported fields are 'radius' and 'center'."
342 )
343
344 if not derivative_info.paths:
! 345 return {}
346
347 grid_cfg = config.adjoint
348 wvl_mat = self._discretization_wavelength(derivative_info)
349 target_edge = max(wvl_mat / grid_cfg.points_per_wavelength, np.finfo(float).eps)Lines 359-367 359 sim_extents = sim_max - sim_min
360 collapsed_indices = np.flatnonzero(np.isclose(sim_extents, 0.0, atol=tol))
361 if collapsed_indices.size:
362 if collapsed_indices.size > 1:
! 363 return dict.fromkeys(derivative_info.paths, 0.0)
364 axis_idx = int(collapsed_indices[0])
365 plane_value = float(sim_min[axis_idx])
366 return self._compute_derivatives_collapsed_axis(
367 derivative_info=derivative_info,Lines 368-417 368 axis_idx=axis_idx,
369 plane_value=plane_value,
370 )
371
! 372 trimesh_obj = TriangleMesh._triangles_to_trimesh(triangles)
! 373 vertices = np.asarray(trimesh_obj.vertices, dtype=grid_cfg.gradient_dtype_float)
! 374 center = np.asarray(self.center, dtype=grid_cfg.gradient_dtype_float)
! 375 verts_centered = vertices - center
! 376 norms = np.linalg.norm(verts_centered, axis=1, keepdims=True)
! 377 norms = np.where(norms == 0, 1, norms)
! 378 normals = verts_centered / norms
379
! 380 if vertices.size == 0:
! 381 return dict.fromkeys(derivative_info.paths, 0.0)
382
383 # get vertex weights
! 384 faces = np.asarray(trimesh_obj.faces, dtype=int)
! 385 face_areas = np.asarray(trimesh_obj.area_faces, dtype=grid_cfg.gradient_dtype_float)
! 386 weights = np.zeros(len(vertices), dtype=grid_cfg.gradient_dtype_float)
! 387 np.add.at(weights, faces[:, 0], face_areas / 3.0)
! 388 np.add.at(weights, faces[:, 1], face_areas / 3.0)
! 389 np.add.at(weights, faces[:, 2], face_areas / 3.0)
390
! 391 perp1, perp2 = self._tangent_basis_from_normals(normals)
392
! 393 valid_axes = np.abs(sim_max - sim_min) > tol
! 394 inside_mask = np.all(
395 vertices[:, valid_axes] >= (sim_min - tol)[valid_axes], axis=1
396 ) & np.all(vertices[:, valid_axes] <= (sim_max + tol)[valid_axes], axis=1)
397
! 398 if not np.any(inside_mask):
! 399 return dict.fromkeys(derivative_info.paths, 0.0)
400
! 401 points = vertices[inside_mask]
! 402 normals_sel = normals[inside_mask]
! 403 perp1_sel = perp1[inside_mask]
! 404 perp2_sel = perp2[inside_mask]
! 405 weights_sel = weights[inside_mask]
406
! 407 interpolators = derivative_info.interpolators
! 408 if interpolators is None:
! 409 interpolators = derivative_info.create_interpolators(
410 dtype=grid_cfg.gradient_dtype_float
411 )
412
! 413 g = derivative_info.evaluate_gradient_at_points(
414 points,
415 normals_sel,
416 perp1_sel,
417 perp2_sel,Lines 417-437 417 perp2_sel,
418 interpolators,
419 )
420
! 421 weighted = (weights_sel * g).real
! 422 grad_center = np.sum(weighted[:, None] * normals_sel, axis=0)
! 423 grad_radius = np.sum(weighted)
424
! 425 vjps: AutogradFieldMap = {}
! 426 for path in derivative_info.paths:
! 427 if path == ("radius",):
! 428 vjps[path] = float(grad_radius)
429 else:
! 430 _, idx = path
! 431 vjps[path] = float(grad_center[idx])
432
! 433 return vjps
434
435 def _compute_derivatives_collapsed_axis(
436 self,
437 derivative_info: DerivativeInfo,Lines 444-456 444 center = np.asarray(self.center, dtype=float)
445 delta = plane_value - center[axis_idx]
446 radius_sq = radius**2 - delta**2
447 if radius_sq <= tol**2:
! 448 return dict.fromkeys(derivative_info.paths, 0.0)
449
450 radius_plane = float(np.sqrt(max(radius_sq, 0.0)))
451 if radius_plane <= tol:
! 452 return dict.fromkeys(derivative_info.paths, 0.0)
453
454 cyl_paths: set[tuple[str, int | None]] = set()
455 need_radius = False
456 for path in derivative_info.paths:Lines 460-468 460 elif path[0] == "center" and path[1] != axis_idx:
461 cyl_paths.add(("center", path[1]))
462
463 if not cyl_paths:
! 464 return dict.fromkeys(derivative_info.paths, 0.0)
465
466 cyl_center = center.copy()
467 cyl_center[axis_idx] = plane_value
468 cylinder = Cylinder(Lines 486-494 486 )
487 intersect_min = tuple(max(bounds[0][i], sim_min_arr[i]) for i in range(3))
488 intersect_max = tuple(min(bounds[1][i], sim_max_arr[i]) for i in range(3))
489 if any(lo > hi for lo, hi in zip(intersect_min, intersect_max)):
! 490 return dict.fromkeys(derivative_info.paths, 0.0)
491
492 derivative_info_cyl = derivative_info.updated_copy(
493 paths=list(cyl_paths),
494 bounds=bounds,Lines 512-523 512 def _edge_length_on_unit_sphere(self, max_edge_length: Optional[float]) -> Optional[float]:
513 """Convert ``max_edge_length`` in μm to unit-sphere coordinates."""
514
515 if max_edge_length is None:
! 516 return _DEFAULT_EDGE_FRACTION
517 radius = float(self.radius)
518 if radius <= 0.0:
! 519 return None
520 return max_edge_length / radius
521
522 def _triangulated_surface(
523 self,Lines 580-608 580 @staticmethod
581 def _tangent_basis_from_normals(normals: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
582 """Construct orthonormal tangential bases for each normal vector (vectorized)."""
583
! 584 dtype = normals.dtype
! 585 tol = np.finfo(dtype).eps
586
587 # Normalize normals (in case they are not perfectly unit length).
! 588 n_norm = np.linalg.norm(normals, axis=1)
! 589 n = normals / np.maximum(n_norm, tol)[:, None]
590
591 # Pick a reference axis least aligned with each normal: argmin(|nx|,|ny|,|nz|).
! 592 ref_idx = np.argmin(np.abs(n), axis=1)
! 593 ref = np.zeros_like(n)
! 594 ref[np.arange(n.shape[0]), ref_idx] = 1.0
595
! 596 basis1 = np.cross(n, ref)
! 597 b1_norm = np.linalg.norm(basis1, axis=1)
! 598 basis1 = basis1 / np.maximum(b1_norm, tol)[:, None]
599
! 600 basis2 = np.cross(n, basis1)
! 601 b2_norm = np.linalg.norm(basis2, axis=1)
! 602 basis2 = basis2 / np.maximum(b2_norm, tol)[:, None]
603
! 604 return basis1, basis2
605
606 def _icosphere_data(self, subdivisions: int) -> tuple[np.ndarray, float]:
607 cache = self._icosphere_cache
608 if subdivisions in cache: |
| return [] | ||
| return [shapely.Point(x0, y0).buffer(0.5 * intersect_dist, quad_segs=quad_segs)] | ||
|
|
||
| def _discretization_wavelength(self, derivative_info: DerivativeInfo) -> float: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
possibly since we have one of these in cylinder as well, there is a way to combine them together into a general function for choosing the discretization wavelength and issuing the warning?
| center = np.asarray(self.center, dtype=grid_cfg.gradient_dtype_float) | ||
| verts_centered = vertices - center | ||
| norms = np.linalg.norm(verts_centered, axis=1, keepdims=True) | ||
| norms = np.where(norms == 0, 1, norms) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this catching the case where you have a vertex point right at the sphere center?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I guess not - good point
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this case make sense at all? The area and thus the weights would be zero then. Take an epsilon radius in this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we also might just raise an exception here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah good point! the area would be negligibly small I think. I would suspect this case would only happen for a small or zero radius with respect to the mesh. I could see something like this happening in an optimization if the gradient is saying to reduce the sphere size. do you think a one-time warning would make sense here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a warning and zero grad return
|
|
||
| vjps_cyl = cylinder._compute_derivatives(derivative_info_cyl) | ||
| result = dict.fromkeys(derivative_info.paths, 0.0) | ||
| vjp_radius = float(vjps_cyl.get(("radius",), 0.0)) if need_radius else 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious why we need the need_radius here? will "radius" just not be in the derivative_info.paths if it isn't needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is because the cylinder radius can effectively change with the sphere center (on the sim-orthogonal axis) which causes a different intersection circle/cylinder with the simulation plane
| "radius": { | ||
| "minimum": 0, | ||
| "type": "number" | ||
| "anyOf": [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I ran into a similar thing here in my custom vjp PR. Talking with @yaugenst-flex , I think they are going to open up schema breaking changes again Jan 5th. So the PR might need to wait to get merged until then
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's fine :)
groberts-flex
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looking good @marcorudolphflex!
6bf04d6 to
15afd43
Compare
Greptile Overview
Greptile Summary
This PR adds autograd support for the
Spheregeometry class, enabling automatic differentiation for inverse design optimization. The implementation follows the same pattern asCylinderand other geometries by overriding theradiusfield withTracedSize1Dand implementing the_compute_derivativesmethod.Key changes:
radiusfield withTracedSize1Dto enable autograd tracking_compute_derivativesmethod that discretizes the sphere surface using icosphere triangulation and computes gradients with respect toradiusandcenterparametersto_triangle_mesh,_triangulated_surface,_discretization_wavelength, and_tangent_basis_from_normalsIssues found:
test_triangle_sphere_fd_step_sweep_ref(line 530-531)Confidence Score: 3/5
tidy3d/components/geometry/primitives.pylines 372-373 (division by zero) andtests/test_components/autograd/numerical/test_autograd_sphere_trianglemesh_numerical.pyline 530-531 (global modification)Important Files Changed
File Analysis
Sequence Diagram
sequenceDiagram participant User participant Sphere participant _compute_derivatives participant _triangulated_surface participant TriangleMesh participant DerivativeInfo User->>Sphere: Create Sphere(center, radius) User->>Sphere: Compute gradients Sphere->>_compute_derivatives: derivative_info _compute_derivatives->>_compute_derivatives: Validate derivative paths _compute_derivatives->>_triangulated_surface: Discretize sphere surface _triangulated_surface->>_triangulated_surface: Get unit sphere triangles _triangulated_surface-->>_compute_derivatives: Physical triangles _compute_derivatives->>TriangleMesh: Convert to trimesh object TriangleMesh-->>_compute_derivatives: vertices, faces, areas _compute_derivatives->>_compute_derivatives: Compute normals from vertices _compute_derivatives->>_compute_derivatives: Filter vertices inside bounds _compute_derivatives->>_compute_derivatives: Compute tangent basis _compute_derivatives->>DerivativeInfo: Evaluate gradient at points DerivativeInfo-->>_compute_derivatives: Gradient values _compute_derivatives->>_compute_derivatives: Weight by face areas _compute_derivatives->>_compute_derivatives: Integrate to get vjps _compute_derivatives-->>Sphere: Gradients for radius/center Sphere-->>User: Autograd gradients