-
Notifications
You must be signed in to change notification settings - Fork 67
Add user_vjp hook and custom run functions to allow overriding the internal vjp #3015
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
13 files reviewed, 6 comments
Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format
Diff CoverageDiff: origin/develop...HEAD, staged and unstaged changes
Summary
tidy3d/web/api/autograd/autograd.pyLines 126-134 126 if isinstance(vjp_config.structure, type) and issubclass(
127 vjp_config.structure, allowed_classes_geometry
128 ):
129 if vjp_config.structure in geometry_types_seen:
! 130 raise AdjointError(
131 f"custom_vjp assigned multiple times for geometry type {vjp_config.structure}"
132 )
133
134 geometry_types_seen.append(vjp_config.structure)Lines 143-163 143
144 elif isinstance(vjp_config.structure, type) and issubclass(
145 vjp_config.structure, allowed_classes_medium
146 ):
! 147 if vjp_config.structure in medium_types_seen:
! 148 raise AdjointError(
149 f"custom_vjp multiple times for medium type {vjp_config.structure}"
150 )
151
! 152 medium_types_seen.append(vjp_config.structure)
153
! 154 for structure_idx, structure in enumerate(simulation.structures):
! 155 if isinstance(structure.medium, vjp_config.structure) and (
156 structure_idx not in custom_vjp_indices
157 ):
! 158 updated_vjp_config = replace(vjp_config, structure=structure_idx)
! 159 expanded_custom_vjp.append(updated_vjp_config)
160
161 else:
162 expanded_custom_vjp.append(vjp_config)Lines 570-578 570 return expanded
571
572 expanded = {}
573 if not isinstance(fn_arg, type(orig_sim_arg)):
! 574 raise AdjointError(
575 f"{fn_arg_name} type ({type(fn_arg)}) should match simulations type ({type(simulations)})"
576 )
577
578 if isinstance(orig_sim_arg, dict):Lines 578-586 578 if isinstance(orig_sim_arg, dict):
579 check_keys = fn_arg.keys() == sim_dict.keys()
580
581 if not check_keys:
! 582 raise AdjointError(f"{fn_arg_name} keys do not match simulations keys")
583
584 for key, val in fn_arg.items():
585 if isinstance(val, base_type):
586 expanded[key] = (val,)Lines 588-596 588 expanded[key] = val
589
590 elif isinstance(orig_sim_arg, (list, tuple)):
591 if not (len(fn_arg) == len(orig_sim_arg)):
! 592 raise AdjointError(
593 f"{fn_arg_name} is not the same length as simulations ({len(expanded)} vs. {len(simulations)})"
594 )
595
596 for idx, key in enumerate(sim_dict.keys()):tidy3d/web/api/autograd/backward.pyLines 271-279 271 """Return the simulation permittivity for eps_box after replacing the geometry
272 for this structure with a new geometry. This is helpful for carrying out finite
273 difference permittivity computations.
274 """
! 275 update_sim = sim_orig.updated_copy(
276 structures=[
277 sim_orig.structures[idx].updated_copy(geometry=replacement_geometry)
278 if idx == structure_index
279 else sim_orig.structures[idx]Lines 281-294 281 ],
282 grid_spec=td.components.grid.grid_spec.GridSpec.from_grid(sim_orig.grid),
283 )
284
! 285 eps_by_f = [
286 update_sim.epsilon(box=eps_box, coord_key="centers", freq=f)
287 for f in adjoint_frequencies
288 ]
289
! 290 return xr.concat(eps_by_f, dim="f").assign_coords(f=adjoint_frequencies)
291
292 updated_epsilon_full = functools.partial(
293 updated_epsilon_full_impl,
294 adjoint_frequencies=adjoint_frequencies,Lines 364-372 364 select_adjoint_freqs: typing.Optional[FreqDataArray],
365 updated_epsilon_full: typing.Optional[typing.Callable],
366 ) -> ScalarFieldDataArray:
367 # Get permittivity function for a subset of frequencies
! 368 return updated_epsilon_full(replacement_geometry).sel(f=select_adjoint_freqs)
369
370 updated_epsilon = functools.partial(
371 updated_epsilon_wrapper,
372 select_adjoint_freqs=select_adjoint_freqs,Lines 403-417 403 vjp_chunk = structure._compute_derivatives(derivative_info_struct, vjp_fns=vjp_fns)
404
405 for path, value in vjp_chunk.items():
406 if path in vjp_value_map:
! 407 existing = vjp_value_map[path]
! 408 if isinstance(existing, (list, tuple)) and isinstance(value, (list, tuple)):
! 409 vjp_value_map[path] = type(existing)(
410 x + y for x, y in zip(existing, value)
411 )
412 else:
! 413 vjp_value_map[path] = existing + value
414 else:
415 vjp_value_map[path] = value
416
417 # store vjps in output map |
yaugenst-flex
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @groberts-flex this is very nice to have! As discussed, I think we should change the name to "custom" instead of "user" VJP. Left a couple of other comments but overall looks good!
116571c to
3a444ca
Compare
|
when you guys get a chance to take another look at this, it would be much appreciated! I rebased the changes I made last week, so should be ready to go if things look good to you all |
marcorudolphflex
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks much better now. Great to have that feature!
Just a few code styling comments/questions from my side.
yaugenst-flex
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great don't have too much to add!
ca5e293 to
fe881ee
Compare
fd07752 to
7b35188
Compare
…p arguments to provide hook into gradient computation for custom vjp calculation.
7b35188 to
63ff34c
Compare
@marcorudolphflex @yaugenst-flex
After the interface reconfiguration, I split the PR for these custom autograd hooks into two so that hopefully it's easier to review! This one is for the user_vjp which allows someone to override the internal vjp calculation for a structure geometry or medium. The other hook is done as well, but I'll save it for after this one is done with review!
Based on the other review, I updated the interface to ideally be a little more straightforward to use and less cumbersome. The specification of paths in the user_vjp is not required unless you want it to only apply to a specific path in the structure. It can also be specified as just a single user_vjp value in run_async_custom if you want the same one to apply to all of the simulations (instead of having to manually broadcast it). I think there are other helper functions that could be added in the future that might make things even easier like applying a certain user_vjp for all structures with a specific geometry type, but I'll leave those for a future upgrade.
Greptile Summary
user_vjpparameter to autograd run functions enabling custom gradient calculations for specific structuresDerivativeInfowithupdated_epsilonhelper for finite difference gradient computations in custom VJPsConfidence Score: 4/5
tidy3d/components/structure.pyandtidy3d/web/api/autograd/backward.pyfor the docstring formatting issuesImportant Files Changed
UserVJPConfigdataclass for custom gradient computation andSetupRunResultfor run preparationuser_vjpparameter throughout run functions with validation and broadcasting logic for single/batch simulationsupdated_epsilonhelper function for finite difference gradient computations_compute_derivativesto accept optionalvjp_fnsdict for custom gradient paths per geometry/medium fieldSequence Diagram
sequenceDiagram participant User participant run_custom participant _run_primitive participant setup_fwd participant _run_tidy3d participant _run_bwd participant postprocess_adj participant Structure participant UserVJP User->>run_custom: "call with simulation and user_vjp" run_custom->>_run_primitive: "pass user_vjp to primitive" _run_primitive->>setup_fwd: "setup forward simulation" setup_fwd-->>_run_primitive: "combined simulation" _run_primitive->>_run_tidy3d: "run forward simulation" _run_tidy3d-->>_run_primitive: "simulation data" Note over _run_bwd: Backward pass triggered _run_bwd->>postprocess_adj: "compute gradients with user_vjp" postprocess_adj->>postprocess_adj: "build user_vjp_lookup dict" postprocess_adj->>Structure: "_compute_derivatives with vjp_fns" alt user VJP exists for path Structure->>UserVJP: "call user-defined vjp function" UserVJP-->>Structure: "custom gradients" else default path Structure->>Structure: "call internal gradient method" Structure-->>Structure: "standard gradients" end Structure-->>postprocess_adj: "gradient values" postprocess_adj-->>_run_bwd: "VJP field map" _run_bwd-->>User: "gradients for optimization"