Skip to content

WIP: experiment with first class dim objects #1517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
412 changes: 412 additions & 0 deletions doc/internal/named-dims.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pytensor/xtensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import pytensor.xtensor.rewriting
from pytensor.xtensor import linalg, math, random
from pytensor.xtensor.basic import ones, xtensor_from_tensor, zeros
from pytensor.xtensor.math import dot
from pytensor.xtensor.shape import broadcast, concat, full_like, ones_like, zeros_like
from pytensor.xtensor.type import (
as_xtensor,
dim,
xtensor,
xtensor_constant,
)
Expand Down
82 changes: 59 additions & 23 deletions pytensor/xtensor/basic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from collections.abc import Sequence

from pytensor.compile.ops import TypeCastingOp
from pytensor.graph import Apply, Op
from pytensor.scalar.basic import uint64
from pytensor.tensor.basic import ones as tensor_ones
from pytensor.tensor.basic import zeros as tensor_zeros
from pytensor.tensor.shape import specify_shape
from pytensor.tensor.type import TensorType
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
from pytensor.xtensor.type import DimVariable, XTensorType, as_dim, as_xtensor, xtensor


DIM_LENGTH_SCALAR = uint64


class XOp(Op):
Expand Down Expand Up @@ -32,6 +37,7 @@ def make_node(self, x):
return Apply(self, [x], [output])

def L_op(self, inputs, outs, g_outs):
# TODO fix
[x] = inputs
[g_out] = g_outs
return [xtensor_from_tensor(g_out, dims=x.type.dims)]
Expand All @@ -41,46 +47,50 @@ def L_op(self, inputs, outs, g_outs):


class XTensorFromTensor(XTypeCastOp):
__props__ = ("dims",)

def __init__(self, dims: Sequence[str]):
super().__init__()
self.dims = tuple(dims)
__props__ = ()

def make_node(self, x):
def make_node(self, x, *dims):
if not isinstance(x.type, TensorType):
raise TypeError(f"x must be an TensorType type, got {type(x.type)}")
output = xtensor(dtype=x.type.dtype, dims=self.dims, shape=x.type.shape)
return Apply(self, [x], [output])
output = xtensor(dtype=x.type.dtype, dims=dims)
return Apply(self, [x, *dims], [output])

def L_op(self, inputs, outs, g_outs):
# TODO fix
[g_out] = g_outs
return [tensor_from_xtensor(g_out)]


def xtensor_from_tensor(x, dims, name=None):
return XTensorFromTensor(dims=dims)(x, name=name)
def xtensor_from_tensor(x, dims, name=None, check: bool = True):
if check:
x = specify_shape(x, [dim.size for dim in dims])
dims = [as_dim(dim) for dim in dims]
return XTensorFromTensor()(x, *dims, name=name)


class Rename(XTypeCastOp):
__props__ = ("new_dims",)
class MapDims(XTypeCastOp):
__props__ = ("new_dim_indices",)

def __init__(self, new_dims: tuple[str, ...]):
super().__init__()
self.new_dims = new_dims
def __init__(self, new_dim_indices: tuple[int, ...]):
self.new_dims_indices = new_dim_indices

def make_node(self, x):
def make_node(self, x, *new_dims):
x = as_xtensor(x)
output = x.type.clone(dims=self.new_dims)()
new_dims = list(x.dims)
for i, idx in enumerate(self.new_dims_indices):
new_dims[idx] = new_dims[i]

output = x.type.clone(dims=new_dims)()
return Apply(self, [x], [output])

def L_op(self, inputs, outs, g_outs):
# TODO fix
[x] = inputs
[g_out] = g_outs
return [rename(g_out, dims=x.type.dims)]
return [map_dims(g_out, dims=x.type.dims)]


def rename(x, name_dict: dict[str, str] | None = None, **names: str):
def map_dims(x, name_dict: dict[DimVariable, DimVariable] | None = None, **names):
if name_dict is not None:
if names:
raise ValueError("Cannot use both positional and keyword names in rename")
Expand All @@ -97,4 +107,30 @@ def rename(x, name_dict: dict[str, str] | None = None, **names: str):
f"Cannot rename {old_name} to {new_name}: {old_name} not in {old_names}"
)

return Rename(tuple(new_names))(x)
return MapDims(tuple(new_names))(x)


def zeros(*dims, dtype=None, name=None):
"""Create a new XTensor filled with zeros."""
if not dims:
raise ValueError("At least one dimension must be specified")

return xtensor_from_tensor(
tensor_zeros(shape=[dim.size for dim in dims], dtype=dtype),
dims=dims,
name=name,
check=False,
)


def ones(*dims, dtype=None, name=None):
"""Create a new XTensor filled with zeros."""
if not dims:
raise ValueError("At least one dimension must be specified")

return xtensor_from_tensor(
tensor_ones(shape=[dim.size for dim in dims], dtype=dtype),
dims=dims,
name=name,
check=False,
)
189 changes: 189 additions & 0 deletions pytensor/xtensor/dims.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
from __future__ import annotations

from collections.abc import Iterable
from uuid import uuid4

import numpy as np

from pytensor.graph.basic import Apply
from pytensor.graph.op import Op, Variable
from pytensor.xtensor.type import (
DIM_LENGTH_TYPE,
DIM_LENGTH_VARIABLE,
BasicDim,
CloneDim,
DimType,
DimVariable,
XTensorVariable,
)


class DimOp(Op):
def perform(self, node, inputs, outputs):
raise NotImplementedError(
f"xtensor operation {self} must be lowered to equivalent tensor operations"
)


# Not a dim op, because it doesn't return a DimVariable
class Length(Op):
__props__ = ()

def make_node(self, *inputs: Variable) -> Apply:
(x,) = inputs
if not isinstance(x, DimVariable):
raise TypeError(f"x must be a DimVariable, got {type(x.type)}")
return Apply(self, [x], [DIM_LENGTH_TYPE()])

def perform(self, node, inputs, outputs):
# outputs[0][0] = np.int64(inputs[0])
outputs[0][0] = np.array(inputs[0], dtype=DIM_LENGTH_TYPE.dtype)


def _dim_size(dim: DimVariable) -> DIM_LENGTH_VARIABLE:
if dim.type.size is not None:
return DIM_LENGTH_TYPE.filter_variable(dim.type.size)
return Length()(dim)


class FromLength(DimOp):
__props__ = ("dim_type",)

def __init__(self, dim_type: DimType):
super().__init__()
self.dim_type = dim_type

def make_node(self, *inputs: Variable) -> Apply:
(length,) = inputs
if not isinstance(length, DIM_LENGTH_VARIABLE):
raise TypeError(
f"length must be a DIM_LENGTH_VARIABLE, got {type(length.type)}"
)
if length.type != DIM_LENGTH_TYPE:
raise TypeError(
f"length must be of dtype 'DIM_LENGTH_SCALAR', got {length.type.dtype}"
)
return Apply(self, [length], [self.dim_type()])

def perform(self, node, inputs, outputs):
"""Convert the length to a list of lengths."""
outputs[0][0] = inputs[0]


def from_length(length: DIM_LENGTH_VARIABLE, name: str | None = None) -> DimVariable:
# TODO add check for dtype
if not isinstance(length, DIM_LENGTH_VARIABLE):
raise TypeError(
f"length must be a DIM_LENGTH_VARIABLE, got {type(length.type)}"
)
if length.type != DIM_LENGTH_TYPE:
raise TypeError(
f"length must be of dtype 'DIM_LENGTH_SCALAR', got {length.type.dtype}"
)

uuid = uuid4()
dim_type = BasicDim(uuid=uuid, name=name)
op = FromLength(dim_type)
return op(length, name=name)


class DimFromTensor(Op):
__props__ = ("dim_type",)

def __init__(self, dim_type: DimType):
super().__init__()
self.dim_type = dim_type

def make_node(self, *inputs: Variable) -> Apply:
(x,) = inputs
if not isinstance(x, XTensorVariable):
raise TypeError(f"x must be an XTensorVariable, got {type(x.type)}")
return Apply(self, [x], [self.dim_type()])

def perform(self, node, inputs, outputs):
"""Convert the tensor to a dimension variable."""
(x,) = inputs
(x_var,) = node.inputs
for i, dim in enumerate(x_var.type.dims):
if dim == self.dim_type:
# outputs[0][0] = np.int64(x.shape[i])
outputs[0][0] = np.array(x.shape[i], dtype=DIM_LENGTH_TYPE.dtype)
return
raise ValueError(f"Dimension {self.dim_type} not found in tensor {x.type.dims}")


def _dim_from_tensor(x: XTensorVariable, idx: int) -> DimVariable:
op = DimFromTensor(dim_type=x.type.dims[idx])
return op(x, name=x.type.dims[idx].name)


class Clone(Op):
__props__ = ("dim_type",)

def __init__(self, dim_type):
super().__init__()
self.dim_type = dim_type

def make_node(self, *inputs: Variable) -> Apply:
(x,) = inputs
if not isinstance(x, DimVariable):
raise TypeError(f"x must be a DimVariable, got {type(x.type)}")
return Apply(self, [x], [self.dim_type()])

def perform(self, node, inputs, outputs):
outputs[0][0] = inputs[0]


def _clone_dim(dim: DimVariable, *, name: str | None = None) -> DimVariable:
"""Rename a dimension variable.

Args:
name: The new name for the dimension.

Returns:
A new DimVariable with the updated name.
"""
dim_type = CloneDim(uuid=uuid4(), base=dim.type, name=name)
return Clone(dim_type)(dim, name=name)


class Product(Op):
__props__ = ()

def make_node(self, *dims: Variable) -> Apply:
if not all(isinstance(dim, DimVariable) for dim in dims):
raise TypeError("All inputs must be DimVariables.")
out = dim_type()
return Apply(self, list(dims), [out])

def perform(self, node, inputs, outputs):
outputs[0][0] = np.prod(inputs, dtype=DIM_LENGTH_TYPE.dtype).item()


def product_dim(*dims: DimVariable, name: str | None = None) -> DimVariable:
return Product()(*dims, name=name)


def rebase_dim(dim: DimVariable | DimType, *tensors: XTensorVariable) -> DimVariable:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of rebase_dim?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Create a dim from an existing xtensor / get the length at runtime?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a helper for rewrites to avoid infinite loops:

For instance in Elemwise:

@register_lower_xtensor
@node_rewriter(tracks=[XElemwise])
def lower_elemwise(fgraph, node):
    assert len(node.outputs) == 1
    out_dims = node.outputs[0].dims
    out_dims = [rebase_dim(dim, *node.inputs) for dim in out_dims]

    # Convert input XTensors to Tensors and align batch dimensions
    tensor_inputs = [lower_aligned(inp, out_dims) for inp in node.inputs]

    tensor_outs = Elemwise(scalar_op=node.op.scalar_op)(
        *tensor_inputs, return_list=True
    )

    # Convert output Tensors to XTensors
    new_outs = [
        xtensor_from_tensor(tensor_out, dims=out_dims, check=False)
        for tensor_out in tensor_outs
    ]
    return new_outs

The final XTensorFromTensor op takes the dim variables as inputs. And if we were to use node.outputs[0].dims for those, the returned graph would still return a reference to the XElemwise we want to replace, because those dims are variables that use DimFromTensor(XElemwise) to get the a reference to the dimension length.

if not isinstance(dim, DimVariable | DimType):
raise TypeError(f"dim must be a DimVariable, got {type(dim)}")

if not tensors:
raise ValueError("At least one tensor must be provided for rebasing.")

if isinstance(dim, DimVariable):
dim_type = dim.type
else:
dim_type = dim

for tensor in tensors:
for i, tensor_dim in enumerate(tensor.type.dims):
if dim_type == tensor_dim:
return _dim_from_tensor(tensor, idx=i)
raise ValueError(f"Dimension {dim} not found in any of the provided tensors.")


def rebase_dims(
dims: Iterable[DimVariable | DimType], *tensors: XTensorVariable
) -> list[DimVariable]:
return [rebase_dim(dim, *tensors) for dim in dims]
Loading
Loading