Skip to content

Commit c51ef7d

Browse files
committed
WIP
Signed-off-by: nstarman <[email protected]>
1 parent fbe5a7b commit c51ef7d

File tree

122 files changed

+3278
-15896
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

122 files changed

+3278
-15896
lines changed

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77

88
from datetime import datetime
9+
910
from typing import Any
1011

1112
import pytz

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,8 @@ COORDINAX_ENABLE_RUNTIME_TYPECHECKING = "beartype.beartype"
290290

291291
[tool.ruff.lint.isort]
292292
combine-as-imports = true
293-
extra-standard-library = ["typing_extensions"]
293+
sections = { typing = ["typing", "typing_extensions", "jaxtyping"] }
294+
section-order = ["future", "standard-library", "typing", "third-party", "first-party", "local-folder"]
294295
known-first-party = ["dataclassish", "optional_dependencies", "quaxed", "unxt", "xmmutablemap"]
295296
known-local-folder = ["coordinax"]
296297

@@ -301,3 +302,6 @@ constraint-dependencies = [
301302
"jax<0.7",
302303
"jaxlib<0.7",
303304
]
305+
306+
[tool.uv.sources]
307+
unxt = { path = "../unxt", editable = true }

src/coordinax/__init__.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# modules
66
"angle",
77
"distance",
8+
"r",
89
"vecs",
910
"ops",
1011
"frames",
@@ -13,6 +14,7 @@
1314
# common vecs objects
1415
"vector",
1516
"vconvert",
17+
"Vector",
1618
"CartesianPos3D",
1719
"CartesianVel3D",
1820
"SphericalPos",
@@ -26,20 +28,11 @@
2628
from .setup_package import install_import_hook
2729

2830
with install_import_hook("coordinax"):
29-
from . import angle, distance, frames, ops, vecs
31+
from . import angle, distance, frames, ops, r, vecs
3032
from ._version import version as __version__ # noqa: F401
3133
from .distance import Distance
3234
from .frames import Coordinate
33-
from .vecs import (
34-
CartesianPos3D,
35-
CartesianVel3D,
36-
FourVector,
37-
KinematicSpace,
38-
SphericalPos,
39-
SphericalVel,
40-
vconvert,
41-
vector,
42-
)
35+
from .vecs import KinematicSpace, Vector, vconvert, vector
4336

4437
# isort: split
4538
# Interoperability

src/coordinax/_coordinax_space_frames/frame_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
__all__: tuple[str, ...] = ()
44

55

6+
from jaxtyping import Array, Shaped
67
from typing import TypeAlias
78

8-
from jaxtyping import Array, Shaped
99
from plum import dispatch
1010

1111
import quaxed.numpy as jnp

src/coordinax/_coordinax_space_frames/galactocentric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
__all__ = ("Galactocentric",)
44

55

6+
from jaxtyping import Array, Shaped
67
from typing import ClassVar, TypeAlias, final
78

89
import equinox as eqx
9-
from jaxtyping import Array, Shaped
1010

1111
import unxt as u
1212
from dataclassish.converters import Unless

src/coordinax/_coordinax_space_frames/icrs.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33
__all__ = ("ICRS",)
44

55

6-
from typing import TypeAlias, final
7-
86
from jaxtyping import Array, Shaped
7+
from typing import TypeAlias, final
98

109
import unxt as u
1110

src/coordinax/_interop/coordinax_interop_astropy/converters.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,78 +5,75 @@
55
__all__: tuple[str, ...] = ()
66

77

8+
from jaxtyping import Shaped
89
from typing import cast
910

1011
import astropy.coordinates as apyc
1112
import astropy.units as apyu
12-
from jaxtyping import Shaped
1313
from plum import conversion_method, convert
1414

1515
import unxt as u
1616

1717
import coordinax as cx
18+
import coordinax.r as cxr
19+
import coordinax.vecs as cxv
1820

1921
#####################################################################
2022

2123
# =====================================
2224
# Quantity
2325

2426

25-
@conversion_method(cx.vecs.AbstractPos3D, apyu.Quantity) # type: ignore[arg-type]
26-
def vec_to_q(obj: cx.vecs.AbstractPos3D, /) -> Shaped[apyu.Quantity, "*batch 3"]:
27+
@conversion_method(cxv.Vector, apyu.Quantity) # type: ignore[arg-type]
28+
def vec_to_q(obj: cxv.Vector, /) -> Shaped[apyu.Quantity, "*batch 3"]:
2729
"""`coordinax.AbstractPos3D` -> `astropy.units.Quantity`.
2830
2931
Examples
3032
--------
31-
>>> import coordinax as cx
33+
>>> import coordinax.vecs as cxv
3234
>>> from plum import convert
3335
>>> import astropy.units as apyu
3436
35-
>>> vec = cx.CartesianPos3D.from_([1, 2, 3], "km")
37+
>>> vec = cxv.CartesianPos3D.from_([1, 2, 3], "km")
3638
>>> convert(vec, apyu.Quantity)
3739
<Quantity [1., 2., 3.] km>
3840
39-
>>> vec = cx.SphericalPos(r=apyu.Quantity(1, unit="km"),
41+
>>> vec = cxv.SphericalPos(r=apyu.Quantity(1, unit="km"),
4042
... theta=apyu.Quantity(2, unit="deg"),
4143
... phi=apyu.Quantity(3, unit="deg"))
4244
>>> convert(vec, apyu.Quantity)
4345
<Quantity [0.03485167, 0.0018265 , 0.99939084] km>
4446
45-
>>> vec = cx.vecs.CylindricalPos(rho=apyu.Quantity(1, unit="km"),
47+
>>> vec = cxv.CylindricalPos(rho=apyu.Quantity(1, unit="km"),
4648
... phi=apyu.Quantity(2, unit="deg"),
4749
... z=apyu.Quantity(3, unit="m"))
4850
>>> convert(vec, apyu.Quantity)
4951
<Quantity [0.99939084, 0.0348995 , 0.003 ] km>
5052
51-
"""
52-
return convert(convert(obj, u.Quantity), apyu.Quantity)
53-
54-
55-
@conversion_method(cx.vecs.CartesianAcc3D, apyu.Quantity) # type: ignore[arg-type]
56-
@conversion_method(cx.CartesianVel3D, apyu.Quantity) # type: ignore[arg-type]
57-
def vec_diff_to_q(
58-
obj: cx.CartesianVel3D | cx.vecs.CartesianAcc3D, /
59-
) -> Shaped[apyu.Quantity, "*batch 3"]:
60-
"""`coordinax.CartesianVel3D` -> `astropy.units.Quantity`.
61-
62-
Examples
63-
--------
64-
>>> import coordinax as cx
65-
>>> from plum import convert
66-
>>> from astropy.units import Quantity as AstropyQuantity
67-
68-
>>> dif = cx.CartesianVel3D.from_([1, 2, 3], "km/s")
53+
>>> dif = cxv.CartesianVel3D.from_([1, 2, 3], "km/s")
6954
>>> convert(dif, AstropyQuantity)
7055
<Quantity [1., 2., 3.] km / s>
7156
72-
>>> dif2 = cx.vecs.CartesianAcc3D.from_([1, 2, 3], "km/s2")
57+
>>> dif2 = cxv.CartesianAcc3D.from_([1, 2, 3], "km/s2")
7358
>>> convert(dif2, AstropyQuantity)
7459
<Quantity [1., 2., 3.] km / s2>
7560
7661
"""
7762
return convert(convert(obj, u.Quantity), apyu.Quantity)
7863

7964

65+
# =====================================
66+
67+
68+
@conversion_method(cxv.Vector, apyc.CartesianRepresentation)
69+
def convert_vector_to_astropy(obj: cxv.Vector, /) -> apyc.CartesianRepresentation:
70+
return apyc.CartesianRepresentation(
71+
x=convert(obj["x"], apyu.Quantity),
72+
y=convert(obj["y"], apyu.Quantity),
73+
z=convert(obj["z"], apyu.Quantity),
74+
)
75+
76+
8077
# =====================================
8178
# CartesianPos3D
8279

src/coordinax/_src/custom_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22

33
__all__: tuple[str, ...] = ()
44

5+
from jaxtyping import Real, Shaped
56
from typing import TypeAlias
67

78
from astropy.units import (
89
CompositeUnit as AstropyCompositeUnit,
910
Unit as AstropyUnit,
1011
UnitBase as AstropyUnitBase,
1112
)
12-
from jaxtyping import Real, Shaped
1313

1414
import unxt as u
1515

src/coordinax/_src/distances/funcs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
)
88

99

10+
from jaxtyping import ArrayLike
1011
from typing import Any
1112

12-
from jaxtyping import ArrayLike
1313
from plum import dispatch
1414

1515
import quaxed.numpy as jnp

src/coordinax/_src/distances/measures.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
__all__ = ("Distance", "DistanceModulus", "Parallax")
44

55
from dataclasses import KW_ONLY
6+
7+
from jaxtyping import Array, Shaped
68
from typing import Any, final
79

810
import equinox as eqx
911
import jax.numpy as jnp
1012
import wadler_lindig as wl
11-
from jaxtyping import Array, Shaped
1213

1314
import quaxed.numpy as jnp
1415
import unxt as u

0 commit comments

Comments
 (0)