@@ -366,21 +366,25 @@ def __trunc__(self):
366
366
# https://docs.xarray.dev/en/latest/api.html#id1
367
367
@property
368
368
def values (self ) -> TensorVariable :
369
+ """Convert to a TensorVariable with the same data."""
369
370
return typing .cast (TensorVariable , px .basic .tensor_from_xtensor (self ))
370
371
371
372
# Can't provide property data because that's already taken by Constants!
372
373
# data = values
373
374
374
375
@property
375
376
def coords (self ):
377
+ """Not implemented."""
376
378
raise NotImplementedError ("coords not implemented for XTensorVariable" )
377
379
378
380
@property
379
381
def dims (self ) -> tuple [str , ...]:
382
+ """The names of the dimensions of the variable."""
380
383
return self .type .dims
381
384
382
385
@property
383
386
def sizes (self ) -> dict [str , TensorVariable ]:
387
+ """The sizes of the dimensions of the variable."""
384
388
return dict (zip (self .dims , self .shape ))
385
389
386
390
@property
@@ -392,18 +396,22 @@ def as_numpy(self):
392
396
# https://docs.xarray.dev/en/latest/api.html#ndarray-attributes
393
397
@property
394
398
def ndim (self ) -> int :
399
+ """The number of dimensions of the variable."""
395
400
return self .type .ndim
396
401
397
402
@property
398
403
def shape (self ) -> tuple [TensorVariable , ...]:
404
+ """The shape of the variable."""
399
405
return tuple (px .basic .tensor_from_xtensor (self ).shape ) # type: ignore
400
406
401
407
@property
402
408
def size (self ) -> TensorVariable :
409
+ """The total number of elements in the variable."""
403
410
return typing .cast (TensorVariable , variadic_mul (* self .shape ))
404
411
405
412
@property
406
- def dtype (self ):
413
+ def dtype (self ) -> str :
414
+ """The data type of the variable."""
407
415
return self .type .dtype
408
416
409
417
@property
@@ -414,6 +422,7 @@ def broadcastable(self):
414
422
# DataArray contents
415
423
# https://docs.xarray.dev/en/latest/api.html#dataarray-contents
416
424
def rename (self , new_name_or_name_dict = None , ** names ):
425
+ """Rename the variable or its dimension(s)."""
417
426
if isinstance (new_name_or_name_dict , str ):
418
427
new_name = new_name_or_name_dict
419
428
name_dict = None
@@ -425,31 +434,41 @@ def rename(self, new_name_or_name_dict=None, **names):
425
434
return new_out
426
435
427
436
def copy (self , name : str | None = None ):
437
+ """Create a copy of the variable.
438
+
439
+ This is just an identity operation, as XTensorVariables are immutable.
440
+ """
428
441
out = px .math .identity (self )
429
442
out .name = name
430
443
return out
431
444
432
445
def astype (self , dtype ):
446
+ """Convert the variable to a different data type."""
433
447
return px .math .cast (self , dtype )
434
448
435
449
def item (self ):
450
+ """Not implemented."""
436
451
raise NotImplementedError ("item not implemented for XTensorVariable" )
437
452
438
453
# Indexing
439
454
# https://docs.xarray.dev/en/latest/api.html#id2
440
455
def __setitem__ (self , idx , value ):
456
+ """Not implemented. Use `x[idx].set(value)` or `x[idx].inc(value)` instead."""
441
457
raise TypeError (
442
458
"XTensorVariable does not support item assignment. Use the output of `x[idx].set` or `x[idx].inc` instead."
443
459
)
444
460
445
461
@property
446
462
def loc (self ):
463
+ """Not implemented."""
447
464
raise NotImplementedError ("loc not implemented for XTensorVariable" )
448
465
449
466
def sel (self , * args , ** kwargs ):
467
+ """Not implemented."""
450
468
raise NotImplementedError ("sel not implemented for XTensorVariable" )
451
469
452
470
def __getitem__ (self , idx ):
471
+ """Index the variable positionally."""
453
472
if isinstance (idx , dict ):
454
473
return self .isel (idx )
455
474
@@ -465,6 +484,7 @@ def isel(
465
484
missing_dims : Literal ["raise" , "warn" , "ignore" ] = "raise" ,
466
485
** indexers_kwargs ,
467
486
):
487
+ """Index the variable along the specified dimension(s)."""
468
488
if indexers_kwargs :
469
489
if indexers is not None :
470
490
raise ValueError (
@@ -505,6 +525,48 @@ def isel(
505
525
return px .indexing .index (self , * indices )
506
526
507
527
def set (self , value ):
528
+ """Return a copy of the variable indexed by self with the indexed values set to y.
529
+
530
+ The original variable is not modified.
531
+
532
+ Raises
533
+ ------
534
+ Value:
535
+ If self is not the result of an index operation
536
+
537
+ Examples
538
+ --------
539
+
540
+ .. test-code::
541
+
542
+ import pytensor.xtensor as ptx
543
+
544
+ x = ptx.as_xtensor([[0, 0], [0, 0]], dims=("a", "b"))
545
+ idx = ptx.as_xtensor([0, 1], dims=("a",))
546
+ out = x[:, idx].set(1)
547
+ out.eval()
548
+
549
+ .. test-output::
550
+
551
+ array([[1, 0],
552
+ [0, 1]])
553
+
554
+
555
+ .. test-code::
556
+
557
+ import pytensor.xtensor as ptx
558
+
559
+ x = ptx.as_xtensor([[0, 0], [0, 0]], dims=("a", "b"))
560
+ idx = ptx.as_xtensor([0, 1], dims=("a",))
561
+ out = x.isel({"b": idx}).set(-1)
562
+ out.eval()
563
+
564
+ .. test-output::
565
+
566
+ array([[-1, 0],
567
+ [0, -1]])
568
+
569
+ """
508
570
if not (
509
571
self .owner is not None and isinstance (self .owner .op , px .indexing .Index )
510
572
):
@@ -516,6 +578,48 @@ def set(self, value):
516
578
return px .indexing .index_assignment (x , value , * idxs )
517
579
518
580
def inc (self , value ):
581
+ """Return a copy of the variable indexed by self with the indexed values incremented by value.
582
+
583
+ The original variable is not modified.
584
+
585
+ Raises
586
+ ------
587
+ Value:
588
+ If self is not the result of an index operation
589
+
590
+ Examples
591
+ --------
592
+
593
+ .. test-code::
594
+
595
+ import pytensor.xtensor as ptx
596
+
597
+ x = ptx.as_xtensor([[1, 1], [1, 1]], dims=("a", "b"))
598
+ idx = ptx.as_xtensor([0, 1], dims=("a",))
599
+ out = x[:, idx].inc(1)
600
+ out.eval()
601
+
602
+ .. test-output::
603
+
604
+ array([[2, 1],
605
+ [1, 2]])
606
+
607
+
608
+ .. test-code::
609
+
610
+ import pytensor.xtensor as ptx
611
+
612
+ x = ptx.as_xtensor([[1, 1], [1, 1]], dims=("a", "b"))
613
+ idx = ptx.as_xtensor([0, 1], dims=("a",))
614
+ out = x.isel({"b": idx}).inc(-1)
615
+ out.eval()
616
+
617
+ .. test-output::
618
+
619
+ array([[0, 1],
620
+ [1, 0]])
621
+
622
+ """
519
623
if not (
520
624
self .owner is not None and isinstance (self .owner .op , px .indexing .Index )
521
625
):
@@ -579,7 +683,7 @@ def squeeze(
579
683
drop = None ,
580
684
axis : int | Sequence [int ] | None = None ,
581
685
):
582
- """Remove dimensions of size 1 from an XTensorVariable .
686
+ """Remove dimensions of size 1.
583
687
584
688
Parameters
585
689
----------
@@ -606,7 +710,7 @@ def expand_dims(
606
710
axis : int | Sequence [int ] | None = None ,
607
711
** dim_kwargs ,
608
712
):
609
- """Add one or more new dimensions to the tensor .
713
+ """Add one or more new dimensions to the variable .
610
714
611
715
Parameters
612
716
----------
@@ -616,14 +720,10 @@ def expand_dims(
616
720
- int: the new size
617
721
- sequence: coordinates (length determines size)
618
722
create_index_for_new_dim : bool, default: True
619
- Currently ignored. Reserved for future coordinate support.
620
- In xarray, when True (default), creates a coordinate index for the new dimension
621
- with values from 0 to size-1. When False, no coordinate index is created.
723
+ Ignored by PyTensor
622
724
axis : int | Sequence[int] | None, default: None
623
725
Not implemented yet. In xarray, specifies where to insert the new dimension(s).
624
726
By default (None), new dimensions are inserted at the beginning (axis=0).
625
- Symbolic axis is not supported yet.
626
- Negative values count from the end.
627
727
**dim_kwargs : int | Sequence
628
728
Alternative to `dim` dict. Only used if `dim` is None.
629
729
@@ -643,65 +743,75 @@ def expand_dims(
643
743
# ndarray methods
644
744
# https://docs.xarray.dev/en/latest/api.html#id7
645
745
def clip (self , min , max ):
746
+ """Clip the values of the variable to a specified range."""
646
747
return px .math .clip (self , min , max )
647
748
648
749
def conj (self ):
750
+ """Return the complex conjugate of the variable."""
649
751
return px .math .conj (self )
650
752
651
753
@property
652
754
def imag (self ):
755
+ """Return the imaginary part of the variable."""
653
756
return px .math .imag (self )
654
757
655
758
@property
656
759
def real (self ):
760
+ """Return the real part of the variable."""
657
761
return px .math .real (self )
658
762
659
763
@property
660
764
def T (self ):
661
- """Return the full transpose of the tensor .
765
+ """Return the full transpose of the variable .
662
766
663
767
This is equivalent to calling transpose() with no arguments.
664
-
665
- Returns
666
- -------
667
- XTensorVariable
668
- Fully transposed tensor.
669
768
"""
670
769
return self .transpose ()
671
770
672
771
# Aggregation
673
772
# https://docs.xarray.dev/en/latest/api.html#id6
674
773
def all (self , dim = None ):
774
+ """Reduce the variable by applying `all` along some dimension(s)."""
675
775
return px .reduction .all (self , dim )
676
776
677
777
def any (self , dim = None ):
778
+ """Reduce the variable by applying `any` along some dimension(s)."""
678
779
return px .reduction .any (self , dim )
679
780
680
781
def max (self , dim = None ):
782
+ """Compute the maximum along the given dimension(s)."""
681
783
return px .reduction .max (self , dim )
682
784
683
785
def min (self , dim = None ):
786
+ """Compute the minimum along the given dimension(s)."""
684
787
return px .reduction .min (self , dim )
685
788
686
789
def mean (self , dim = None ):
790
+ """Compute the mean along the given dimension(s)."""
687
791
return px .reduction .mean (self , dim )
688
792
689
793
def prod (self , dim = None ):
794
+ """Compute the product along the given dimension(s)."""
690
795
return px .reduction .prod (self , dim )
691
796
692
797
def sum (self , dim = None ):
798
+ """Compute the sum along the given dimension(s)."""
693
799
return px .reduction .sum (self , dim )
694
800
695
801
def std (self , dim = None , ddof = 0 ):
802
+ """Compute the standard deviation along the given dimension(s)."""
696
803
return px .reduction .std (self , dim , ddof = ddof )
697
804
698
805
def var (self , dim = None , ddof = 0 ):
806
+ """Compute the variance along the given dimension(s)."""
699
807
return px .reduction .var (self , dim , ddof = ddof )
700
808
701
809
def cumsum (self , dim = None ):
810
+ """Compute the cumulative sum along the given dimension(s)."""
702
811
return px .reduction .cumsum (self , dim )
703
812
704
813
def cumprod (self , dim = None ):
814
+ """Compute the cumulative product along the given dimension(s)."""
705
815
return px .reduction .cumprod (self , dim )
706
816
707
817
def diff (self , dim , n = 1 ):
@@ -720,7 +830,7 @@ def transpose(
720
830
* dim : str | EllipsisType ,
721
831
missing_dims : Literal ["raise" , "warn" , "ignore" ] = "raise" ,
722
832
):
723
- """Transpose dimensions of the tensor .
833
+ """Transpose the dimensions of the variable .
724
834
725
835
Parameters
726
836
----------
@@ -747,21 +857,38 @@ def transpose(
747
857
return px .shape .transpose (self , * dim , missing_dims = missing_dims )
748
858
749
859
def stack (self , dim , ** dims ):
860
+ """Stack existing dimensions into a single new dimension."""
750
861
return px .shape .stack (self , dim , ** dims )
751
862
752
863
def unstack (self , dim , ** dims ):
864
+ """Unstack a dimension into multiple dimensions of a given size.
865
+
866
+ Because XTensorVariables don't have coords, this operation requires the sizes of each unstacked dimension to be specified.
867
+ Also, unstacked dims will follow a C-style order, regardless of the order of the original dimensions.
868
+
869
+ .. test-code::
870
+
871
+ import pytensor.xtensor as ptx
872
+
873
+ x = ptx.as_xtensor([[1, 2], [3, 4]], dims=("a", "b"))
874
+ stacked_cumsum = x.stack({"c": ["a", "b"]}).cumsum("c")
875
+ unstacked_cumsum = stacked_x.unstack({"c": x.sizes})
876
+ unstacked_cumsum.eval()
877
+
878
+ .. test-output::
879
+
880
+ array([[ 1, 3],
881
+ [ 6, 10]])
882
+
883
+ """
753
884
return px .shape .unstack (self , dim , ** dims )
754
885
755
886
def dot (self , other , dim = None ):
756
- """Matrix multiplication with another XTensorVariable, contracting over matching or specified dims ."""
887
+ """Generalized dot product with another XTensorVariable."""
757
888
return px .math .dot (self , other , dim = dim )
758
889
759
- def broadcast (self , * others , exclude = None ):
760
- """Broadcast this tensor against other XTensorVariables."""
761
- return px .shape .broadcast (self , * others , exclude = exclude )
762
-
763
890
def broadcast_like (self , other , exclude = None ):
764
- """Broadcast this tensor against another XTensorVariable."""
891
+ """Broadcast against another XTensorVariable."""
765
892
_ , self_bcast = px .shape .broadcast (other , self , exclude = exclude )
766
893
return self_bcast
767
894
0 commit comments