Skip to content

Commit 5024d54

Browse files
committed
Add docstrings to more XTensorVariable methods
Also remove broadcast which is not a method in Xarray
1 parent fdb4087 commit 5024d54

File tree

1 file changed

+152
-23
lines changed

1 file changed

+152
-23
lines changed

pytensor/xtensor/type.py

Lines changed: 152 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -366,21 +366,25 @@ def __trunc__(self):
366366
# https://docs.xarray.dev/en/latest/api.html#id1
367367
@property
368368
def values(self) -> TensorVariable:
369+
"""Convert to a TensorVariable with the same data."""
369370
return typing.cast(TensorVariable, px.basic.tensor_from_xtensor(self))
370371

371372
# Can't provide property data because that's already taken by Constants!
372373
# data = values
373374

374375
@property
375376
def coords(self):
377+
"""Not implemented."""
376378
raise NotImplementedError("coords not implemented for XTensorVariable")
377379

378380
@property
379381
def dims(self) -> tuple[str, ...]:
382+
"""The names of the dimensions of the variable."""
380383
return self.type.dims
381384

382385
@property
383386
def sizes(self) -> dict[str, TensorVariable]:
387+
"""The sizes of the dimensions of the variable."""
384388
return dict(zip(self.dims, self.shape))
385389

386390
@property
@@ -392,18 +396,22 @@ def as_numpy(self):
392396
# https://docs.xarray.dev/en/latest/api.html#ndarray-attributes
393397
@property
394398
def ndim(self) -> int:
399+
"""The number of dimensions of the variable."""
395400
return self.type.ndim
396401

397402
@property
398403
def shape(self) -> tuple[TensorVariable, ...]:
404+
"""The shape of the variable."""
399405
return tuple(px.basic.tensor_from_xtensor(self).shape) # type: ignore
400406

401407
@property
402408
def size(self) -> TensorVariable:
409+
"""The total number of elements in the variable."""
403410
return typing.cast(TensorVariable, variadic_mul(*self.shape))
404411

405412
@property
406-
def dtype(self):
413+
def dtype(self) -> str:
414+
"""The data type of the variable."""
407415
return self.type.dtype
408416

409417
@property
@@ -414,6 +422,7 @@ def broadcastable(self):
414422
# DataArray contents
415423
# https://docs.xarray.dev/en/latest/api.html#dataarray-contents
416424
def rename(self, new_name_or_name_dict=None, **names):
425+
"""Rename the variable or its dimension(s)."""
417426
if isinstance(new_name_or_name_dict, str):
418427
new_name = new_name_or_name_dict
419428
name_dict = None
@@ -425,31 +434,41 @@ def rename(self, new_name_or_name_dict=None, **names):
425434
return new_out
426435

427436
def copy(self, name: str | None = None):
437+
"""Create a copy of the variable.
438+
439+
This is just an identity operation, as XTensorVariables are immutable.
440+
"""
428441
out = px.math.identity(self)
429442
out.name = name
430443
return out
431444

432445
def astype(self, dtype):
446+
"""Convert the variable to a different data type."""
433447
return px.math.cast(self, dtype)
434448

435449
def item(self):
450+
"""Not implemented."""
436451
raise NotImplementedError("item not implemented for XTensorVariable")
437452

438453
# Indexing
439454
# https://docs.xarray.dev/en/latest/api.html#id2
440455
def __setitem__(self, idx, value):
456+
"""Not implemented. Use `x[idx].set(value)` or `x[idx].inc(value)` instead."""
441457
raise TypeError(
442458
"XTensorVariable does not support item assignment. Use the output of `x[idx].set` or `x[idx].inc` instead."
443459
)
444460

445461
@property
446462
def loc(self):
463+
"""Not implemented."""
447464
raise NotImplementedError("loc not implemented for XTensorVariable")
448465

449466
def sel(self, *args, **kwargs):
467+
"""Not implemented."""
450468
raise NotImplementedError("sel not implemented for XTensorVariable")
451469

452470
def __getitem__(self, idx):
471+
"""Index the variable positionally."""
453472
if isinstance(idx, dict):
454473
return self.isel(idx)
455474

@@ -465,6 +484,7 @@ def isel(
465484
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
466485
**indexers_kwargs,
467486
):
487+
"""Index the variable along the specified dimension(s)."""
468488
if indexers_kwargs:
469489
if indexers is not None:
470490
raise ValueError(
@@ -505,6 +525,48 @@ def isel(
505525
return px.indexing.index(self, *indices)
506526

507527
def set(self, value):
528+
"""Return a copy of the variable indexed by self with the indexed values set to y.
529+
530+
The original variable is not modified.
531+
532+
Raises
533+
------
534+
ValueError
535+
If self is not the result of an index operation
536+
537+
Examples
538+
--------
539+
540+
.. testcode::
541+
542+
import pytensor.xtensor as ptx
543+
544+
x = ptx.as_xtensor([[0, 0], [0, 0]], dims=("a", "b"))
545+
idx = ptx.as_xtensor([0, 1], dims=("a",))
546+
out = x[:, idx].set(1)
547+
print(out.eval())
548+
549+
.. testoutput::
550+
551+
[[1 0]
552+
[0 1]]
553+
554+
555+
.. testcode::
556+
557+
import pytensor.xtensor as ptx
558+
559+
x = ptx.as_xtensor([[0, 0], [0, 0]], dims=("a", "b"))
560+
idx = ptx.as_xtensor([0, 1], dims=("a",))
561+
out = x.isel({"b": idx}).set(-1)
562+
print(out.eval())
563+
564+
.. testoutput::
565+
566+
[[-1 0]
567+
[ 0 -1]]
568+
569+
"""
508570
if not (
509571
self.owner is not None and isinstance(self.owner.op, px.indexing.Index)
510572
):
@@ -516,6 +578,48 @@ def set(self, value):
516578
return px.indexing.index_assignment(x, value, *idxs)
517579

518580
def inc(self, value):
581+
"""Return a copy of the variable indexed by self with the indexed values incremented by value.
582+
583+
The original variable is not modified.
584+
585+
Raises
586+
------
587+
ValueError
588+
If self is not the result of an index operation
589+
590+
Examples
591+
--------
592+
593+
.. testcode::
594+
595+
import pytensor.xtensor as ptx
596+
597+
x = ptx.as_xtensor([[1, 1], [1, 1]], dims=("a", "b"))
598+
idx = ptx.as_xtensor([0, 1], dims=("a",))
599+
out = x[:, idx].inc(1)
600+
print(out.eval())
601+
602+
.. testoutput::
603+
604+
[[2 1]
605+
[1 2]]
606+
607+
608+
.. testcode::
609+
610+
import pytensor.xtensor as ptx
611+
612+
x = ptx.as_xtensor([[1, 1], [1, 1]], dims=("a", "b"))
613+
idx = ptx.as_xtensor([0, 1], dims=("a",))
614+
out = x.isel({"b": idx}).inc(-1)
615+
print(out.eval())
616+
617+
.. testoutput::
618+
619+
[[0 1]
620+
[1 0]]
621+
622+
"""
519623
if not (
520624
self.owner is not None and isinstance(self.owner.op, px.indexing.Index)
521625
):
@@ -579,7 +683,7 @@ def squeeze(
579683
drop=None,
580684
axis: int | Sequence[int] | None = None,
581685
):
582-
"""Remove dimensions of size 1 from an XTensorVariable.
686+
"""Remove dimensions of size 1.
583687
584688
Parameters
585689
----------
@@ -606,24 +710,21 @@ def expand_dims(
606710
axis: int | Sequence[int] | None = None,
607711
**dim_kwargs,
608712
):
609-
"""Add one or more new dimensions to the tensor.
713+
"""Add one or more new dimensions to the variable.
610714
611715
Parameters
612716
----------
613717
dim : str | Sequence[str] | dict[str, int | Sequence] | None
614718
If str or sequence of str, new dimensions with size 1.
615719
If dict, keys are dimension names and values are either:
616-
- int: the new size
617-
- sequence: coordinates (length determines size)
720+
721+
- int: the new size
722+
- sequence: coordinates (length determines size)
618723
create_index_for_new_dim : bool, default: True
619-
Currently ignored. Reserved for future coordinate support.
620-
In xarray, when True (default), creates a coordinate index for the new dimension
621-
with values from 0 to size-1. When False, no coordinate index is created.
724+
Ignored by PyTensor
622725
axis : int | Sequence[int] | None, default: None
623726
Not implemented yet. In xarray, specifies where to insert the new dimension(s).
624727
By default (None), new dimensions are inserted at the beginning (axis=0).
625-
Symbolic axis is not supported yet.
626-
Negative values count from the end.
627728
**dim_kwargs : int | Sequence
628729
Alternative to `dim` dict. Only used if `dim` is None.
629730
@@ -643,65 +744,75 @@ def expand_dims(
643744
# ndarray methods
644745
# https://docs.xarray.dev/en/latest/api.html#id7
645746
def clip(self, min, max):
747+
"""Clip the values of the variable to a specified range."""
646748
return px.math.clip(self, min, max)
647749

648750
def conj(self):
751+
"""Return the complex conjugate of the variable."""
649752
return px.math.conj(self)
650753

651754
@property
652755
def imag(self):
756+
"""Return the imaginary part of the variable."""
653757
return px.math.imag(self)
654758

655759
@property
656760
def real(self):
761+
"""Return the real part of the variable."""
657762
return px.math.real(self)
658763

659764
@property
660765
def T(self):
661-
"""Return the full transpose of the tensor.
766+
"""Return the full transpose of the variable.
662767
663768
This is equivalent to calling transpose() with no arguments.
664-
665-
Returns
666-
-------
667-
XTensorVariable
668-
Fully transposed tensor.
669769
"""
670770
return self.transpose()
671771

672772
# Aggregation
673773
# https://docs.xarray.dev/en/latest/api.html#id6
674774
def all(self, dim=None):
775+
"""Reduce the variable by applying `all` along some dimension(s)."""
675776
return px.reduction.all(self, dim)
676777

677778
def any(self, dim=None):
779+
"""Reduce the variable by applying `any` along some dimension(s)."""
678780
return px.reduction.any(self, dim)
679781

680782
def max(self, dim=None):
783+
"""Compute the maximum along the given dimension(s)."""
681784
return px.reduction.max(self, dim)
682785

683786
def min(self, dim=None):
787+
"""Compute the minimum along the given dimension(s)."""
684788
return px.reduction.min(self, dim)
685789

686790
def mean(self, dim=None):
791+
"""Compute the mean along the given dimension(s)."""
687792
return px.reduction.mean(self, dim)
688793

689794
def prod(self, dim=None):
795+
"""Compute the product along the given dimension(s)."""
690796
return px.reduction.prod(self, dim)
691797

692798
def sum(self, dim=None):
799+
"""Compute the sum along the given dimension(s)."""
693800
return px.reduction.sum(self, dim)
694801

695802
def std(self, dim=None, ddof=0):
803+
"""Compute the standard deviation along the given dimension(s)."""
696804
return px.reduction.std(self, dim, ddof=ddof)
697805

698806
def var(self, dim=None, ddof=0):
807+
"""Compute the variance along the given dimension(s)."""
699808
return px.reduction.var(self, dim, ddof=ddof)
700809

701810
def cumsum(self, dim=None):
811+
"""Compute the cumulative sum along the given dimension(s)."""
702812
return px.reduction.cumsum(self, dim)
703813

704814
def cumprod(self, dim=None):
815+
"""Compute the cumulative product along the given dimension(s)."""
705816
return px.reduction.cumprod(self, dim)
706817

707818
def diff(self, dim, n=1):
@@ -720,7 +831,7 @@ def transpose(
720831
*dim: str | EllipsisType,
721832
missing_dims: Literal["raise", "warn", "ignore"] = "raise",
722833
):
723-
"""Transpose dimensions of the tensor.
834+
"""Transpose the dimensions of the variable.
724835
725836
Parameters
726837
----------
@@ -729,6 +840,7 @@ def transpose(
729840
Can use ellipsis (...) to represent remaining dimensions.
730841
missing_dims : {"raise", "warn", "ignore"}, default="raise"
731842
How to handle dimensions that don't exist in the tensor:
843+
732844
- "raise": Raise an error if any dimensions don't exist
733845
- "warn": Warn if any dimensions don't exist
734846
- "ignore": Silently ignore any dimensions that don't exist
@@ -747,21 +859,38 @@ def transpose(
747859
return px.shape.transpose(self, *dim, missing_dims=missing_dims)
748860

749861
def stack(self, dim, **dims):
862+
"""Stack existing dimensions into a single new dimension."""
750863
return px.shape.stack(self, dim, **dims)
751864

752865
def unstack(self, dim, **dims):
866+
"""Unstack a dimension into multiple dimensions of a given size.
867+
868+
Because XTensorVariables don't have coords, this operation requires the sizes of each unstacked dimension to be specified.
869+
Also, unstacked dims will follow a C-style order, regardless of the order of the original dimensions.
870+
871+
.. testcode::
872+
873+
import pytensor.xtensor as ptx
874+
875+
x = ptx.as_xtensor([[1, 2], [3, 4]], dims=("a", "b"))
876+
stacked_cumsum = x.stack({"c": ["a", "b"]}).cumsum("c")
877+
unstacked_cumsum = stacked_cumsum.unstack({"c": x.sizes})
878+
print(unstacked_cumsum.eval())
879+
880+
.. testoutput::
881+
882+
[[ 1 3]
883+
[ 6 10]]
884+
885+
"""
753886
return px.shape.unstack(self, dim, **dims)
754887

755888
def dot(self, other, dim=None):
756-
"""Matrix multiplication with another XTensorVariable, contracting over matching or specified dims."""
889+
"""Generalized dot product with another XTensorVariable."""
757890
return px.math.dot(self, other, dim=dim)
758891

759-
def broadcast(self, *others, exclude=None):
760-
"""Broadcast this tensor against other XTensorVariables."""
761-
return px.shape.broadcast(self, *others, exclude=exclude)
762-
763892
def broadcast_like(self, other, exclude=None):
764-
"""Broadcast this tensor against another XTensorVariable."""
893+
"""Broadcast against another XTensorVariable."""
765894
_, self_bcast = px.shape.broadcast(other, self, exclude=exclude)
766895
return self_bcast
767896

0 commit comments

Comments
 (0)