Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions tests/test_components/test_mode_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,12 +620,10 @@ def test_mode_solver_data_interp_single_frequency():
field_data = getattr(data_interp, field_name)
assert field_data is not None
assert field_data.coords["f"].size == 1
assert float(field_data.coords["f"]) == 1.5e14
assert float(field_data.coords["f"].item()) == 1.5e14

# Check n_group_raw and dispersion_raw if present
if data_interp.n_group_raw is not None:
print(data_interp.n_group_raw.shape)
print((1, original_num_modes))
assert data_interp.n_group_raw.shape == (1, original_num_modes)
if data_interp.dispersion_raw is not None:
assert data_interp.dispersion_raw.shape == (1, original_num_modes)
Expand Down
7 changes: 6 additions & 1 deletion tests/test_plugins/autograd/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@
from tidy3d.plugins.autograd.functions import _normalize_axes
from tidy3d.plugins.autograd.types import PaddingType

try:
from numpy import trapezoid as np_trapezoid
except ImportError: # NumPy < 2.0
from numpy import trapz as np_trapezoid

_mode_to_scipy = {
"constant": "constant",
"edge": "nearest",
Expand Down Expand Up @@ -552,7 +557,7 @@ def test_trapz_val(self, rng, shape, axis, use_x):
"""Test trapz values against NumPy for different array dimensions and integration axes."""
y, x, dx = self.generate_y_x_dx(rng, shape, use_x)
result_custom = trapz(y, x=x, dx=dx, axis=axis)
result_numpy = np.trapz(y, x=x, dx=dx, axis=axis)
result_numpy = np_trapezoid(y, x=x, dx=dx, axis=axis)
npt.assert_allclose(result_custom, result_numpy)

def test_trapz_grad(self, rng, shape, axis, use_x):
Expand Down
33 changes: 32 additions & 1 deletion tidy3d/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
import math
import os
import tempfile
from collections.abc import Callable
from functools import wraps
from math import ceil
from os import PathLike
from pathlib import Path
from typing import Any, Callable, Literal, Optional, Union
from typing import Any, Literal, Optional, Union, get_args, get_origin

import h5py
import numpy as np
Expand Down Expand Up @@ -188,6 +189,36 @@ def __init__(self, **kwargs: Any) -> None:
self._post_init_validators()
log.end_capture(self)

@classmethod
def _field_allows_scalar(cls, field: ModelField) -> bool:
annotation = field.outer_type_

def allows_scalar(a: Any) -> bool:
origin = get_origin(a)
if origin is Union:
args = (arg for arg in get_args(a) if arg is not type(None))
return any(allows_scalar(arg) for arg in args)
if origin is not None:
return False
return isinstance(a, type) and issubclass(a, (float, int, np.generic))

return allows_scalar(annotation)

@pydantic.root_validator(pre=True)
def _coerce_numpy_scalars(cls, values: dict[str, Any]) -> dict[str, Any]:
if not isinstance(values, dict):
return values

for name, field in cls.__fields__.items():
if name not in values or not cls._field_allows_scalar(field):
continue

value = values[name]
if isinstance(value, np.generic) or (isinstance(value, np.ndarray) and value.size == 1):
values[name] = value.item()

return values

def _post_init_validators(self) -> None:
"""Call validators taking ``self`` that get run after init, implement in subclasses."""

Expand Down
11 changes: 5 additions & 6 deletions tidy3d/components/geometry/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import functools
import pathlib
from abc import ABC, abstractmethod
from collections.abc import Iterable, Sequence
from collections.abc import Callable, Iterable, Sequence
from os import PathLike
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union

import autograd.numpy as np
import pydantic.v1 as pydantic
Expand All @@ -20,6 +20,7 @@
except ImportError:
pass


from tidy3d.compat import _shapely_is_older_than
from tidy3d.components.autograd import (
AutogradFieldMap,
Expand Down Expand Up @@ -3544,13 +3545,11 @@ def inside_meshgrid(

def _volume(self, bounds: Bound) -> float:
"""Returns object's volume within given bounds."""
individual_volumes = (geometry.volume(bounds) for geometry in self.geometries)
return np.sum(individual_volumes)
return sum(geometry.volume(bounds) for geometry in self.geometries)

def _surface_area(self, bounds: Bound) -> float:
"""Returns object's surface area within given bounds."""
individual_areas = (geometry.surface_area(bounds) for geometry in self.geometries)
return np.sum(individual_areas)
return sum(geometry.surface_area(bounds) for geometry in self.geometries)

@cached_property
def _normal_2dmaterial(self) -> Axis:
Expand Down
2 changes: 1 addition & 1 deletion tidy3d/components/geometry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,7 @@ def _shift_value_signed(

# get the index of the grid cell where the obj lies
obj_position = obj.center[normal_axis]
obj_pos_gt_grid_bounds = np.argwhere(obj_position > grid_boundaries)
obj_pos_gt_grid_bounds = np.argwhere(obj_position > grid_boundaries)[:, 0]

# no obj index can be determined
if len(obj_pos_gt_grid_bounds) == 0 or obj_position > grid_boundaries[-1]:
Expand Down
5 changes: 3 additions & 2 deletions tidy3d/components/medium.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

import functools
from abc import ABC, abstractmethod
from collections.abc import Callable
from math import isclose
from typing import Any, Callable, Literal, Optional, Union
from typing import Any, Literal, Optional, Union

import autograd.numpy as np

Expand Down Expand Up @@ -3092,7 +3093,7 @@ def _compute_derivatives(self, derivative_info: DerivativeInfo) -> AutogradField
for freq in freqs:
dJ_deps_complex_f = dJ_deps_complex.sel(f=freq)
vjps_f = self._get_vjps_from_params(
dJ_deps_complex=complex(dJ_deps_complex_f),
dJ_deps_complex=complex(dJ_deps_complex_f.item()),
poles_vals=poles_vals,
omega=2 * np.pi * freq,
requested_paths=derivative_info.paths,
Expand Down