@@ -366,21 +366,25 @@ def __trunc__(self):
366
366
# https://docs.xarray.dev/en/latest/api.html#id1
367
367
@property
368
368
def values (self ) -> TensorVariable :
369
+ """Convert to a TensorVariable with the same data."""
369
370
return typing .cast (TensorVariable , px .basic .tensor_from_xtensor (self ))
370
371
371
372
# Can't provide property data because that's already taken by Constants!
372
373
# data = values
373
374
374
375
@property
375
376
def coords (self ):
377
+ """Not implemented."""
376
378
raise NotImplementedError ("coords not implemented for XTensorVariable" )
377
379
378
380
@property
379
381
def dims (self ) -> tuple [str , ...]:
382
+ """The names of the dimensions of the variable."""
380
383
return self .type .dims
381
384
382
385
@property
383
386
def sizes (self ) -> dict [str , TensorVariable ]:
387
+ """The sizes of the dimensions of the variable."""
384
388
return dict (zip (self .dims , self .shape ))
385
389
386
390
@property
@@ -392,18 +396,22 @@ def as_numpy(self):
392
396
# https://docs.xarray.dev/en/latest/api.html#ndarray-attributes
393
397
@property
394
398
def ndim (self ) -> int :
399
+ """The number of dimensions of the variable."""
395
400
return self .type .ndim
396
401
397
402
@property
398
403
def shape (self ) -> tuple [TensorVariable , ...]:
404
+ """The shape of the variable."""
399
405
return tuple (px .basic .tensor_from_xtensor (self ).shape ) # type: ignore
400
406
401
407
@property
402
408
def size (self ) -> TensorVariable :
409
+ """The total number of elements in the variable."""
403
410
return typing .cast (TensorVariable , variadic_mul (* self .shape ))
404
411
405
412
@property
406
- def dtype (self ):
413
+ def dtype (self ) -> str :
414
+ """The data type of the variable."""
407
415
return self .type .dtype
408
416
409
417
@property
@@ -414,6 +422,7 @@ def broadcastable(self):
414
422
# DataArray contents
415
423
# https://docs.xarray.dev/en/latest/api.html#dataarray-contents
416
424
def rename (self , new_name_or_name_dict = None , ** names ):
425
+ """Rename the variable or its dimension(s)."""
417
426
if isinstance (new_name_or_name_dict , str ):
418
427
new_name = new_name_or_name_dict
419
428
name_dict = None
@@ -425,31 +434,41 @@ def rename(self, new_name_or_name_dict=None, **names):
425
434
return new_out
426
435
427
436
def copy (self , name : str | None = None ):
437
+ """Create a copy of the variable.
438
+
439
+ This is just an identity operation, as XTensorVariables are immutable.
440
+ """
428
441
out = px .math .identity (self )
429
442
out .name = name
430
443
return out
431
444
432
445
def astype (self , dtype ):
446
+ """Convert the variable to a different data type."""
433
447
return px .math .cast (self , dtype )
434
448
435
449
def item (self ):
450
+ """Not implemented."""
436
451
raise NotImplementedError ("item not implemented for XTensorVariable" )
437
452
438
453
# Indexing
439
454
# https://docs.xarray.dev/en/latest/api.html#id2
440
455
def __setitem__ (self , idx , value ):
456
+ """Not implemented. Use `x[idx].set(value)` or `x[idx].inc(value)` instead."""
441
457
raise TypeError (
442
458
"XTensorVariable does not support item assignment. Use the output of `x[idx].set` or `x[idx].inc` instead."
443
459
)
444
460
445
461
@property
446
462
def loc (self ):
463
+ """Not implemented."""
447
464
raise NotImplementedError ("loc not implemented for XTensorVariable" )
448
465
449
466
def sel (self , * args , ** kwargs ):
467
+ """Not implemented."""
450
468
raise NotImplementedError ("sel not implemented for XTensorVariable" )
451
469
452
470
def __getitem__ (self , idx ):
471
+ """Index the variable positionally."""
453
472
if isinstance (idx , dict ):
454
473
return self .isel (idx )
455
474
@@ -465,6 +484,7 @@ def isel(
465
484
missing_dims : Literal ["raise" , "warn" , "ignore" ] = "raise" ,
466
485
** indexers_kwargs ,
467
486
):
487
+ """Index the variable along the specified dimension(s)."""
468
488
if indexers_kwargs :
469
489
if indexers is not None :
470
490
raise ValueError (
@@ -505,6 +525,48 @@ def isel(
505
525
return px .indexing .index (self , * indices )
506
526
507
527
def set (self , value ):
528
+ """Return a copy of the variable indexed by self with the indexed values set to y.
529
+
530
+ The original variable is not modified.
531
+
532
+ Raises
533
+ ------
534
+ ValueError
535
+ If self is not the result of an index operation
536
+
537
+ Examples
538
+ --------
539
+
540
+ .. testcode::
541
+
542
+ import pytensor.xtensor as ptx
543
+
544
+ x = ptx.as_xtensor([[0, 0], [0, 0]], dims=("a", "b"))
545
+ idx = ptx.as_xtensor([0, 1], dims=("a",))
546
+ out = x[:, idx].set(1)
547
+ print(out.eval())
548
+
549
+ .. testoutput::
550
+
551
+ [[1 0]
552
+ [0 1]]
553
+
554
+
555
+ .. testcode::
556
+
557
+ import pytensor.xtensor as ptx
558
+
559
+ x = ptx.as_xtensor([[0, 0], [0, 0]], dims=("a", "b"))
560
+ idx = ptx.as_xtensor([0, 1], dims=("a",))
561
+ out = x.isel({"b": idx}).set(-1)
562
+ print(out.eval())
563
+
564
+ .. testoutput::
565
+
566
+ [[-1 0]
567
+ [ 0 -1]]
568
+
569
+ """
508
570
if not (
509
571
self .owner is not None and isinstance (self .owner .op , px .indexing .Index )
510
572
):
@@ -516,6 +578,48 @@ def set(self, value):
516
578
return px .indexing .index_assignment (x , value , * idxs )
517
579
518
580
def inc (self , value ):
581
+ """Return a copy of the variable indexed by self with the indexed values incremented by value.
582
+
583
+ The original variable is not modified.
584
+
585
+ Raises
586
+ ------
587
+ ValueError
588
+ If self is not the result of an index operation
589
+
590
+ Examples
591
+ --------
592
+
593
+ .. testcode::
594
+
595
+ import pytensor.xtensor as ptx
596
+
597
+ x = ptx.as_xtensor([[1, 1], [1, 1]], dims=("a", "b"))
598
+ idx = ptx.as_xtensor([0, 1], dims=("a",))
599
+ out = x[:, idx].inc(1)
600
+ print(out.eval())
601
+
602
+ .. testoutput::
603
+
604
+ [[2 1]
605
+ [1 2]]
606
+
607
+
608
+ .. testcode::
609
+
610
+ import pytensor.xtensor as ptx
611
+
612
+ x = ptx.as_xtensor([[1, 1], [1, 1]], dims=("a", "b"))
613
+ idx = ptx.as_xtensor([0, 1], dims=("a",))
614
+ out = x.isel({"b": idx}).inc(-1)
615
+ print(out.eval())
616
+
617
+ .. testoutput::
618
+
619
+ [[0 1]
620
+ [1 0]]
621
+
622
+ """
519
623
if not (
520
624
self .owner is not None and isinstance (self .owner .op , px .indexing .Index )
521
625
):
@@ -579,7 +683,7 @@ def squeeze(
579
683
drop = None ,
580
684
axis : int | Sequence [int ] | None = None ,
581
685
):
582
- """Remove dimensions of size 1 from an XTensorVariable .
686
+ """Remove dimensions of size 1.
583
687
584
688
Parameters
585
689
----------
@@ -606,24 +710,21 @@ def expand_dims(
606
710
axis : int | Sequence [int ] | None = None ,
607
711
** dim_kwargs ,
608
712
):
609
- """Add one or more new dimensions to the tensor .
713
+ """Add one or more new dimensions to the variable .
610
714
611
715
Parameters
612
716
----------
613
717
dim : str | Sequence[str] | dict[str, int | Sequence] | None
614
718
If str or sequence of str, new dimensions with size 1.
615
719
If dict, keys are dimension names and values are either:
616
- - int: the new size
617
- - sequence: coordinates (length determines size)
720
+
721
+ - int: the new size
722
+ - sequence: coordinates (length determines size)
618
723
create_index_for_new_dim : bool, default: True
619
- Currently ignored. Reserved for future coordinate support.
620
- In xarray, when True (default), creates a coordinate index for the new dimension
621
- with values from 0 to size-1. When False, no coordinate index is created.
724
+ Ignored by PyTensor
622
725
axis : int | Sequence[int] | None, default: None
623
726
Not implemented yet. In xarray, specifies where to insert the new dimension(s).
624
727
By default (None), new dimensions are inserted at the beginning (axis=0).
625
- Symbolic axis is not supported yet.
626
- Negative values count from the end.
627
728
**dim_kwargs : int | Sequence
628
729
Alternative to `dim` dict. Only used if `dim` is None.
629
730
@@ -643,65 +744,75 @@ def expand_dims(
643
744
# ndarray methods
644
745
# https://docs.xarray.dev/en/latest/api.html#id7
645
746
def clip (self , min , max ):
747
+ """Clip the values of the variable to a specified range."""
646
748
return px .math .clip (self , min , max )
647
749
648
750
def conj (self ):
751
+ """Return the complex conjugate of the variable."""
649
752
return px .math .conj (self )
650
753
651
754
@property
652
755
def imag (self ):
756
+ """Return the imaginary part of the variable."""
653
757
return px .math .imag (self )
654
758
655
759
@property
656
760
def real (self ):
761
+ """Return the real part of the variable."""
657
762
return px .math .real (self )
658
763
659
764
@property
660
765
def T (self ):
661
- """Return the full transpose of the tensor .
766
+ """Return the full transpose of the variable .
662
767
663
768
This is equivalent to calling transpose() with no arguments.
664
-
665
- Returns
666
- -------
667
- XTensorVariable
668
- Fully transposed tensor.
669
769
"""
670
770
return self .transpose ()
671
771
672
772
# Aggregation
673
773
# https://docs.xarray.dev/en/latest/api.html#id6
674
774
def all (self , dim = None ):
775
+ """Reduce the variable by applying `all` along some dimension(s)."""
675
776
return px .reduction .all (self , dim )
676
777
677
778
def any (self , dim = None ):
779
+ """Reduce the variable by applying `any` along some dimension(s)."""
678
780
return px .reduction .any (self , dim )
679
781
680
782
def max (self , dim = None ):
783
+ """Compute the maximum along the given dimension(s)."""
681
784
return px .reduction .max (self , dim )
682
785
683
786
def min (self , dim = None ):
787
+ """Compute the minimum along the given dimension(s)."""
684
788
return px .reduction .min (self , dim )
685
789
686
790
def mean (self , dim = None ):
791
+ """Compute the mean along the given dimension(s)."""
687
792
return px .reduction .mean (self , dim )
688
793
689
794
def prod (self , dim = None ):
795
+ """Compute the product along the given dimension(s)."""
690
796
return px .reduction .prod (self , dim )
691
797
692
798
def sum (self , dim = None ):
799
+ """Compute the sum along the given dimension(s)."""
693
800
return px .reduction .sum (self , dim )
694
801
695
802
def std (self , dim = None , ddof = 0 ):
803
+ """Compute the standard deviation along the given dimension(s)."""
696
804
return px .reduction .std (self , dim , ddof = ddof )
697
805
698
806
def var (self , dim = None , ddof = 0 ):
807
+ """Compute the variance along the given dimension(s)."""
699
808
return px .reduction .var (self , dim , ddof = ddof )
700
809
701
810
def cumsum (self , dim = None ):
811
+ """Compute the cumulative sum along the given dimension(s)."""
702
812
return px .reduction .cumsum (self , dim )
703
813
704
814
def cumprod (self , dim = None ):
815
+ """Compute the cumulative product along the given dimension(s)."""
705
816
return px .reduction .cumprod (self , dim )
706
817
707
818
def diff (self , dim , n = 1 ):
@@ -720,7 +831,7 @@ def transpose(
720
831
* dim : str | EllipsisType ,
721
832
missing_dims : Literal ["raise" , "warn" , "ignore" ] = "raise" ,
722
833
):
723
- """Transpose dimensions of the tensor .
834
+ """Transpose the dimensions of the variable .
724
835
725
836
Parameters
726
837
----------
@@ -729,6 +840,7 @@ def transpose(
729
840
Can use ellipsis (...) to represent remaining dimensions.
730
841
missing_dims : {"raise", "warn", "ignore"}, default="raise"
731
842
How to handle dimensions that don't exist in the tensor:
843
+
732
844
- "raise": Raise an error if any dimensions don't exist
733
845
- "warn": Warn if any dimensions don't exist
734
846
- "ignore": Silently ignore any dimensions that don't exist
@@ -747,21 +859,38 @@ def transpose(
747
859
return px .shape .transpose (self , * dim , missing_dims = missing_dims )
748
860
749
861
def stack (self , dim , ** dims ):
862
+ """Stack existing dimensions into a single new dimension."""
750
863
return px .shape .stack (self , dim , ** dims )
751
864
752
865
def unstack (self , dim , ** dims ):
866
+ """Unstack a dimension into multiple dimensions of a given size.
867
+
868
+ Because XTensorVariables don't have coords, this operation requires the sizes of each unstacked dimension to be specified.
869
+ Also, unstacked dims will follow a C-style order, regardless of the order of the original dimensions.
870
+
871
+ .. testcode::
872
+
873
+ import pytensor.xtensor as ptx
874
+
875
+ x = ptx.as_xtensor([[1, 2], [3, 4]], dims=("a", "b"))
876
+ stacked_cumsum = x.stack({"c": ["a", "b"]}).cumsum("c")
877
+ unstacked_cumsum = stacked_cumsum.unstack({"c": x.sizes})
878
+ print(unstacked_cumsum.eval())
879
+
880
+ .. testoutput::
881
+
882
+ [[ 1 3]
883
+ [ 6 10]]
884
+
885
+ """
753
886
return px .shape .unstack (self , dim , ** dims )
754
887
755
888
def dot (self , other , dim = None ):
756
- """Matrix multiplication with another XTensorVariable, contracting over matching or specified dims ."""
889
+ """Generalized dot product with another XTensorVariable."""
757
890
return px .math .dot (self , other , dim = dim )
758
891
759
- def broadcast (self , * others , exclude = None ):
760
- """Broadcast this tensor against other XTensorVariables."""
761
- return px .shape .broadcast (self , * others , exclude = exclude )
762
-
763
892
def broadcast_like (self , other , exclude = None ):
764
- """Broadcast this tensor against another XTensorVariable."""
893
+ """Broadcast against another XTensorVariable."""
765
894
_ , self_bcast = px .shape .broadcast (other , self , exclude = exclude )
766
895
return self_bcast
767
896
0 commit comments