Skip to content

Commit c24e0ad

Browse files
committed
Fix mean, var and std of XTensorVariables
1 parent b4522d2 commit c24e0ad

File tree

3 files changed

+34
-7
lines changed

3 files changed

+34
-7
lines changed

pytensor/xtensor/reduction.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def bool_reduce(x, dim: REDUCE_DIM = None, *, binary_op):
8181
def _infer_reduced_size(original_var, reduced_var):
8282
reduced_dims = reduced_var.dims
8383
return variadic_mul(
84-
*[size for dim, size in original_var.sizes if dim not in reduced_dims]
84+
*[size for dim, size in original_var.sizes.items() if dim not in reduced_dims]
8585
)
8686

8787

@@ -96,7 +96,7 @@ def var(x, dim: REDUCE_DIM, *, ddof: int = 0):
9696
x = as_xtensor(x)
9797
x_mean = mean(x, dim)
9898
n = _infer_reduced_size(x, x_mean)
99-
return square(x - x_mean) / (n - ddof)
99+
return square(x - x_mean).sum(dim) / (n - ddof)
100100

101101

102102
def std(x, dim: REDUCE_DIM, *, ddof: int = 0):

pytensor/xtensor/type.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -692,11 +692,11 @@ def prod(self, dim=None):
692692
def sum(self, dim=None):
693693
return px.reduction.sum(self, dim)
694694

695-
def std(self, dim=None):
696-
return px.reduction.std(self, dim)
695+
def std(self, dim=None, ddof=0):
696+
return px.reduction.std(self, dim, ddof=ddof)
697697

698-
def var(self, dim=None):
699-
return px.reduction.var(self, dim)
698+
def var(self, dim=None, ddof=0):
699+
return px.reduction.var(self, dim, ddof=ddof)
700700

701701
def cumsum(self, dim=None):
702702
return px.reduction.cumsum(self, dim)

tests/xtensor/test_reduction.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
"dim", [..., None, "a", ("c", "a")], ids=["Ellipsis", "None", "a", "(a, c)"]
1313
)
1414
@pytest.mark.parametrize(
15-
"method", ["sum", "prod", "all", "any", "max", "min", "cumsum", "cumprod"][2:]
15+
"method",
16+
["sum", "prod", "all", "any", "max", "min", "mean", "cumsum", "cumprod"],
1617
)
1718
def test_reduction(method, dim):
1819
x = xtensor("x", dims=("a", "b", "c"), shape=(3, 5, 7))
@@ -25,3 +26,29 @@ def test_reduction(method, dim):
2526
fn(x_test),
2627
getattr(x_test, method)(dim=dim),
2728
)
29+
30+
31+
@pytest.mark.parametrize(
32+
"dim", [..., None, "a", ("c", "a")], ids=["Ellipsis", "None", "a", "(a, c)"]
33+
)
34+
@pytest.mark.parametrize("method", ["std", "var"])
35+
def test_std_var(method, dim):
36+
x = xtensor("x", dims=("a", "b", "c"), shape=(3, 5, 7))
37+
out = [
38+
getattr(x, method)(dim=dim),
39+
getattr(x, method)(dim=dim, ddof=2),
40+
]
41+
42+
fn = xr_function([x], out)
43+
x_test = xr_arange_like(x)
44+
results = fn(x_test)
45+
46+
xr_assert_allclose(
47+
results[0],
48+
getattr(x_test, method)(dim=dim),
49+
)
50+
51+
xr_assert_allclose(
52+
results[1],
53+
getattr(x_test, method)(dim=dim, ddof=2),
54+
)

0 commit comments

Comments
 (0)