Skip to content

Commit d84e9c6

Browse files
committed
Small tweaks to XRV Ops
* Fix core_dims_needed calculation * Handle lazy dtype * Nicer __str__ with use of `name`
1 parent 5761704 commit d84e9c6

File tree

7 files changed

+51
-6
lines changed

7 files changed

+51
-6
lines changed

pytensor/tensor/random/op.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,13 @@ def make_node(self, rng, size, *dist_params):
392392
out_type = TensorType(dtype=self.dtype, shape=static_shape)
393393
outputs = (rng.type(), out_type())
394394

395+
if self.dtype == "floatX":
396+
# Commit to a specific float type if the Op is still using "floatX"
397+
dtype = config.floatX
398+
props = self._props_dict()
399+
props["dtype"] = dtype
400+
self = type(self)(**props)
401+
395402
return Apply(self, inputs, outputs)
396403

397404
def batch_ndim(self, node: Apply) -> int:

pytensor/xtensor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import warnings
22

33
import pytensor.xtensor.rewriting
4-
from pytensor.xtensor import linalg
4+
from pytensor.xtensor import linalg, random
55
from pytensor.xtensor.math import dot
66
from pytensor.xtensor.shape import concat
77
from pytensor.xtensor.type import (

pytensor/xtensor/random.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,16 @@
55
import pytensor.tensor.random.basic as ptr
66
from pytensor.graph.basic import Variable
77
from pytensor.tensor.random.op import RandomVariable
8-
from pytensor.xtensor import as_xtensor
98
from pytensor.xtensor.math import sqrt
9+
from pytensor.xtensor.type import as_xtensor
1010
from pytensor.xtensor.vectorization import XRV
1111

1212

1313
def _as_xrv(
1414
core_op: RandomVariable,
1515
core_inps_dims_map: Sequence[Sequence[int]] | None = None,
1616
core_out_dims_map: Sequence[int] | None = None,
17+
name: str | None = None,
1718
):
1819
"""Helper function to define an XRV constructor.
1920
@@ -41,7 +42,14 @@ def _as_xrv(
4142
core_out_dims_map = tuple(range(core_op.ndim_supp))
4243

4344
core_dims_needed = max(
44-
(*(len(i) for i in core_inps_dims_map), len(core_out_dims_map)), default=0
45+
max(
46+
(
47+
max((entry + 1 for entry in dims_map), default=0)
48+
for dims_map in core_inps_dims_map
49+
),
50+
default=0,
51+
),
52+
max((entry + 1 for entry in core_out_dims_map), default=0),
4553
)
4654

4755
@wraps(core_op)
@@ -76,7 +84,10 @@ def xrv_constructor(
7684
extra_dims = {}
7785

7886
return XRV(
79-
core_op, core_dims=full_core_dims, extra_dims=tuple(extra_dims.keys())
87+
core_op,
88+
core_dims=full_core_dims,
89+
extra_dims=tuple(extra_dims.keys()),
90+
name=name,
8091
)(rng, *extra_dims.values(), *params)
8192

8293
return xrv_constructor

pytensor/xtensor/rewriting/vectorization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def lower_rv(fgraph, node):
116116
size = [*extra_dim_lengths, *param_batch_shape]
117117

118118
# RVs are their own core Op
119-
new_next_rng, tensor_out = core_op(*tensor_params, rng=rng, size=size).owner.outputs
119+
new_next_rng, tensor_out = core_op.make_node(rng, size, *tensor_params).outputs
120120

121121
# Convert output Tensors to XTensors
122122
new_out = xtensor_from_tensor(tensor_out, dims=old_out.type.dims)

pytensor/xtensor/type.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def __init__(
6767
raise ValueError(
6868
f"Shape {self.shape} must have the same length as dims {self.dims}"
6969
)
70+
self.broadcastable = tuple(s == 1 for s in self.shape)
7071
self.ndim = len(self.dims)
7172
self.name = name
7273
self.numpy_dtype = np.dtype(self.dtype)

pytensor/xtensor/vectorization.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,12 @@ def __init__(
142142
core_op,
143143
core_dims: tuple[tuple[tuple[str, ...], ...], tuple[str, ...]],
144144
extra_dims: tuple[str, ...],
145+
name: str | None = None,
145146
):
146147
super().__init__()
148+
if name is None:
149+
name = getattr(core_op, "name", None)
150+
self.name = name
147151
self.core_op = core_op
148152
inps_core_dims, out_core_dims = core_dims
149153
for operand_dims in (*inps_core_dims, out_core_dims):
@@ -154,6 +158,15 @@ def __init__(
154158
raise ValueError("size_dims must be unique")
155159
self.extra_dims = tuple(extra_dims)
156160

161+
def __str__(self):
162+
if self.name is not None:
163+
name = self.name
164+
attrs = f"(core_dims={self.core_dims}, extra_dims={self.extra_dims})"
165+
else:
166+
name = self.__class__.__name__
167+
attrs = f"(core_op={self.core_op}, core_dims={self.core_dims}, extra_dims={self.extra_dims})"
168+
return f"{name}({attrs})"
169+
157170
def update(self, node):
158171
# RNG input and update are the first input and output respectively
159172
return {node.inputs[0]: node.outputs[0]}

tests/xtensor/test_random.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import pytensor.tensor.random as ptr
99
import pytensor.xtensor.random as pxr
10-
from pytensor import function, shared
10+
from pytensor import config, function, shared
1111
from pytensor.graph import rewrite_graph
1212
from pytensor.graph.basic import equal_computations
1313
from pytensor.tensor import broadcast_arrays, tensor
@@ -112,6 +112,19 @@ def _supp_shape_from_params(self, dist_params, param_shapes=None):
112112
)
113113

114114

115+
def test_dtype():
116+
x = normal(0, 1)
117+
assert x.type.dtype == config.floatX
118+
119+
with config.change_flags(floatX="float64"):
120+
x = normal(0, 1)
121+
assert x.type.dtype == "float64"
122+
123+
with config.change_flags(floatX="float32"):
124+
x = normal(0, 1)
125+
assert x.type.dtype == "float32"
126+
127+
115128
def test_normal():
116129
rng = random_generator_type("rng")
117130
c_size = tensor("c_size", shape=(), dtype=int)

0 commit comments

Comments
 (0)