Skip to content

Conversation

@groberts-flex
Copy link
Contributor

@groberts-flex groberts-flex commented Nov 19, 2025

@marcorudolphflex @yaugenst-flex

After the interface reconfiguration, I split the PR for these custom autograd hooks into two so that hopefully it's easier to review! This one is for the user_vjp which allows someone to override the internal vjp calculation for a structure geometry or medium. The other hook is done as well, but I'll save it for after this one is done with review!

Based on the other review, I updated the interface to ideally be a little more straightforward to use and less cumbersome. The specification of paths in the user_vjp is not required unless you want it to only apply to a specific path in the structure. It can also be specified as just a single user_vjp value in run_async_custom if you want the same one to apply to all of the simulations (instead of having to manually broadcast it). I think there are other helper functions that could be added in the future that might make things even easier like applying a certain user_vjp for all structures with a specific geometry type, but I'll leave those for a future upgrade.

Greptile Summary

  • Adds user_vjp parameter to autograd run functions enabling custom gradient calculations for specific structures
  • Implements VJP lookup mechanism in backward pass to route computation through user-defined functions when specified
  • Extends DerivativeInfo with updated_epsilon helper for finite difference gradient computations in custom VJPs

Confidence Score: 4/5

  • This PR is safe to merge with minor documentation improvements recommended
  • Score reflects solid implementation with comprehensive test coverage and proper error handling. The core logic correctly routes custom VJP functions through the gradient computation pipeline. Minor documentation issues with inline docstrings and incomplete comments don't impact functionality but should be addressed for maintainability.
  • Pay attention to tidy3d/components/structure.py and tidy3d/web/api/autograd/backward.py for the docstring formatting issues

Important Files Changed

Filename Overview
tidy3d/web/api/autograd/types.py New file introducing UserVJPConfig dataclass for custom gradient computation and SetupRunResult for run preparation
tidy3d/web/api/autograd/autograd.py Adds user_vjp parameter throughout run functions with validation and broadcasting logic for single/batch simulations
tidy3d/web/api/autograd/backward.py Implements user VJP lookup mechanism and updated_epsilon helper function for finite difference gradient computations
tidy3d/components/structure.py Extends _compute_derivatives to accept optional vjp_fns dict for custom gradient paths per geometry/medium field

Sequence Diagram

sequenceDiagram
    participant User
    participant run_custom
    participant _run_primitive
    participant setup_fwd
    participant _run_tidy3d
    participant _run_bwd
    participant postprocess_adj
    participant Structure
    participant UserVJP

    User->>run_custom: "call with simulation and user_vjp"
    run_custom->>_run_primitive: "pass user_vjp to primitive"
    _run_primitive->>setup_fwd: "setup forward simulation"
    setup_fwd-->>_run_primitive: "combined simulation"
    _run_primitive->>_run_tidy3d: "run forward simulation"
    _run_tidy3d-->>_run_primitive: "simulation data"
    
    Note over _run_bwd: Backward pass triggered
    _run_bwd->>postprocess_adj: "compute gradients with user_vjp"
    postprocess_adj->>postprocess_adj: "build user_vjp_lookup dict"
    postprocess_adj->>Structure: "_compute_derivatives with vjp_fns"
    
    alt user VJP exists for path
        Structure->>UserVJP: "call user-defined vjp function"
        UserVJP-->>Structure: "custom gradients"
    else default path
        Structure->>Structure: "call internal gradient method"
        Structure-->>Structure: "standard gradients"
    end
    
    Structure-->>postprocess_adj: "gradient values"
    postprocess_adj-->>_run_bwd: "VJP field map"
    _run_bwd-->>User: "gradients for optimization"
Loading

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

13 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format

@github-actions
Copy link
Contributor

github-actions bot commented Nov 19, 2025

Diff Coverage

Diff: origin/develop...HEAD, staged and unstaged changes

  • tidy3d/components/autograd/derivative_utils.py (100%)
  • tidy3d/components/geometry/primitives.py (100%)
  • tidy3d/components/structure.py (100%)
  • tidy3d/plugins/smatrix/run.py (100%)
  • tidy3d/web/api/autograd/autograd.py (91.4%): Missing lines 130,147-148,152,154-155,158-159,574,582,592
  • tidy3d/web/api/autograd/backward.py (81.8%): Missing lines 275,285,290,368,407-409,413
  • tidy3d/web/api/autograd/types.py (100%)

Summary

  • Total: 231 lines
  • Missing: 19 lines
  • Coverage: 91%

tidy3d/web/api/autograd/autograd.py

Lines 126-134

  126         if isinstance(vjp_config.structure, type) and issubclass(
  127             vjp_config.structure, allowed_classes_geometry
  128         ):
  129             if vjp_config.structure in geometry_types_seen:
! 130                 raise AdjointError(
  131                     f"custom_vjp assigned multiple times for geometry type {vjp_config.structure}"
  132                 )
  133 
  134             geometry_types_seen.append(vjp_config.structure)

Lines 143-163

  143 
  144         elif isinstance(vjp_config.structure, type) and issubclass(
  145             vjp_config.structure, allowed_classes_medium
  146         ):
! 147             if vjp_config.structure in medium_types_seen:
! 148                 raise AdjointError(
  149                     f"custom_vjp multiple times for medium type {vjp_config.structure}"
  150                 )
  151 
! 152             medium_types_seen.append(vjp_config.structure)
  153 
! 154             for structure_idx, structure in enumerate(simulation.structures):
! 155                 if isinstance(structure.medium, vjp_config.structure) and (
  156                     structure_idx not in custom_vjp_indices
  157                 ):
! 158                     updated_vjp_config = replace(vjp_config, structure=structure_idx)
! 159                     expanded_custom_vjp.append(updated_vjp_config)
  160 
  161         else:
  162             expanded_custom_vjp.append(vjp_config)

Lines 570-578

  570             return expanded
  571 
  572         expanded = {}
  573         if not isinstance(fn_arg, type(orig_sim_arg)):
! 574             raise AdjointError(
  575                 f"{fn_arg_name} type ({type(fn_arg)}) should match simulations type ({type(simulations)})"
  576             )
  577 
  578         if isinstance(orig_sim_arg, dict):

Lines 578-586

  578         if isinstance(orig_sim_arg, dict):
  579             check_keys = fn_arg.keys() == sim_dict.keys()
  580 
  581             if not check_keys:
! 582                 raise AdjointError(f"{fn_arg_name} keys do not match simulations keys")
  583 
  584             for key, val in fn_arg.items():
  585                 if isinstance(val, base_type):
  586                     expanded[key] = (val,)

Lines 588-596

  588                     expanded[key] = val
  589 
  590         elif isinstance(orig_sim_arg, (list, tuple)):
  591             if not (len(fn_arg) == len(orig_sim_arg)):
! 592                 raise AdjointError(
  593                     f"{fn_arg_name} is not the same length as simulations ({len(expanded)} vs. {len(simulations)})"
  594                 )
  595 
  596             for idx, key in enumerate(sim_dict.keys()):

tidy3d/web/api/autograd/backward.py

Lines 271-279

  271             """Return the simulation permittivity for eps_box after replacing the geometry
  272             for this structure with a new geometry. This is helpful for carrying out finite
  273             difference permittivity computations.
  274             """
! 275             update_sim = sim_orig.updated_copy(
  276                 structures=[
  277                     sim_orig.structures[idx].updated_copy(geometry=replacement_geometry)
  278                     if idx == structure_index
  279                     else sim_orig.structures[idx]

Lines 281-294

  281                 ],
  282                 grid_spec=td.components.grid.grid_spec.GridSpec.from_grid(sim_orig.grid),
  283             )
  284 
! 285             eps_by_f = [
  286                 update_sim.epsilon(box=eps_box, coord_key="centers", freq=f)
  287                 for f in adjoint_frequencies
  288             ]
  289 
! 290             return xr.concat(eps_by_f, dim="f").assign_coords(f=adjoint_frequencies)
  291 
  292         updated_epsilon_full = functools.partial(
  293             updated_epsilon_full_impl,
  294             adjoint_frequencies=adjoint_frequencies,

Lines 364-372

  364                 select_adjoint_freqs: typing.Optional[FreqDataArray],
  365                 updated_epsilon_full: typing.Optional[typing.Callable],
  366             ) -> ScalarFieldDataArray:
  367                 # Get permittivity function for a subset of frequencies
! 368                 return updated_epsilon_full(replacement_geometry).sel(f=select_adjoint_freqs)
  369 
  370             updated_epsilon = functools.partial(
  371                 updated_epsilon_wrapper,
  372                 select_adjoint_freqs=select_adjoint_freqs,

Lines 403-417

  403                 vjp_chunk = structure._compute_derivatives(derivative_info_struct, vjp_fns=vjp_fns)
  404 
  405                 for path, value in vjp_chunk.items():
  406                     if path in vjp_value_map:
! 407                         existing = vjp_value_map[path]
! 408                         if isinstance(existing, (list, tuple)) and isinstance(value, (list, tuple)):
! 409                             vjp_value_map[path] = type(existing)(
  410                                 x + y for x, y in zip(existing, value)
  411                             )
  412                         else:
! 413                             vjp_value_map[path] = existing + value
  414                     else:
  415                         vjp_value_map[path] = value
  416 
  417         # store vjps in output map

Copy link
Collaborator

@yaugenst-flex yaugenst-flex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @groberts-flex this is very nice to have! As discussed, I think we should change the name to "custom" instead of "user" VJP. Left a couple of other comments but overall looks good!

@groberts-flex groberts-flex force-pushed the groberts-flex/custom_user_vjp_hook_FXC-3730 branch 3 times, most recently from 116571c to 3a444ca Compare December 2, 2025 19:49
@groberts-flex
Copy link
Contributor Author

when you guys get a chance to take another look at this, it would be much appreciated! I rebased the changes I made last week, so should be ready to go if things look good to you all

Copy link
Contributor

@marcorudolphflex marcorudolphflex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks much better now. Great to have that feature!
Just a few code styling comments/questions from my side.

Copy link
Collaborator

@yaugenst-flex yaugenst-flex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great don't have too much to add!

@groberts-flex groberts-flex force-pushed the groberts-flex/custom_user_vjp_hook_FXC-3730 branch 2 times, most recently from ca5e293 to fe881ee Compare December 9, 2025 21:54
@groberts-flex groberts-flex force-pushed the groberts-flex/custom_user_vjp_hook_FXC-3730 branch 3 times, most recently from fd07752 to 7b35188 Compare December 17, 2025 19:14
…p arguments to provide hook into gradient computation for custom vjp calculation.
@groberts-flex groberts-flex force-pushed the groberts-flex/custom_user_vjp_hook_FXC-3730 branch from 7b35188 to 63ff34c Compare December 17, 2025 19:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants