Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions doc/dev_start_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ You can now build the documentation from the root of the project with:

.. code-block:: bash

python -m sphinx -b html ./doc ./html
# -j for parallel and faster doc build
sphinx-build -b html ./doc ./html -j auto


Afterward, you can go to `html/index.html` and navigate the changes in a browser. One way to do this is to go to the `html` directory and run:
Expand All @@ -219,7 +220,7 @@ Afterward, you can go to `html/index.html` and navigate the changes in a browser

python -m http.server

**Do not commit the `html` directory. The documentation is built automatically.**
**Do not commit the `html` directory.**
For more documentation customizations such as different formats e.g., PDF, refer to the `Sphinx documentation <https://www.sphinx-doc.org/en/master/usage/builders/index.html>`_.

Other tools that might help
Expand Down
38 changes: 9 additions & 29 deletions doc/extending/creating_a_numba_jax_op.rst
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
Adding JAX, Numba and Pytorch support for `Op`\s
=======================================
================================================

PyTensor is able to convert its graphs into JAX, Numba and Pytorch compiled functions. In order to do
this, each :class:`Op` in an PyTensor graph must have an equivalent JAX/Numba/Pytorch implementation function.

This tutorial will explain how JAX, Numba and Pytorch implementations are created for an :class:`Op`.

Step 1: Identify the PyTensor :class:`Op` you'd like to implement
------------------------------------------------------------------------
-----------------------------------------------------------------

Find the source for the PyTensor :class:`Op` you'd like to be supported and
identify the function signature and return values. These can be determined by
Expand Down Expand Up @@ -97,8 +97,8 @@ Next, we look at the :meth:`Op.perform` implementation to see exactly
how the inputs and outputs are used to compute the outputs for an :class:`Op`
in Python. This method is effectively what needs to be implemented.

Step 2: Find the relevant method in JAX/Numba/Pytorch (or something close)
---------------------------------------------------------
Step 2: Find the relevant or close method in JAX/Numba/Pytorch
--------------------------------------------------------------

With a precise idea of what the PyTensor :class:`Op` does we need to figure out how
to implement it in JAX, Numba or Pytorch. In the best case scenario, there is a similarly named
Expand Down Expand Up @@ -269,7 +269,7 @@ and :func:`torch.cumprod`
z[0] = np.cumprod(x, axis=self.axis)

Step 3: Register the function with the respective dispatcher
---------------------------------------------------------------
------------------------------------------------------------

With the PyTensor `Op` replicated, we'll need to register the
function with the backends `Linker`. This is done through the use of
Expand Down Expand Up @@ -626,28 +626,8 @@ Step 4: Write tests

Note
----
In out previous example of extending JAX, :class:`Eye`\ :class:`Op` was used with the test function as follows:

.. code:: python

def test_jax_Eye():
"""Test JAX conversion of the `Eye` `Op`."""

# Create a symbolic input for `Eye`
x_at = pt.scalar()

# Create a variable that is the output of an `Eye` `Op`
eye_var = pt.eye(x_at)

# Create an PyTensor `FunctionGraph`
out_fg = FunctionGraph(outputs=[eye_var])

# Pass the graph and any inputs to the testing function
compare_jax_and_py(out_fg, [3])

This one nowadays leads to a test failure due to new restrictions in JAX + JIT,
as reported in issue `#654 <https://github.com/pymc-devs/pytensor/issues/654>`_.
All jitted functions now must have constant shape, which means a graph like the
Due to new restrictions in JAX JIT as reported in issue `#654 <https://github.com/pymc-devs/pytensor/issues/654>`_,
all jitted functions now must have constant shape. In other words, only PyTensor graphs with static shapes
can be translated to JAX at the moment. It means a graph like the
one of :class:`Eye` can never be translated to JAX, since it's fundamentally a
function with dynamic shapes. In other words, only PyTensor graphs with static shapes
can be translated to JAX at the moment.
function with dynamic shapes.
28 changes: 1 addition & 27 deletions doc/internal/metadocumentation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,7 @@ Documentation Documentation AKA Meta-Documentation
How to build documentation
--------------------------

Let's say you are writing documentation, and want to see the `sphinx
<http://sphinx.pocoo.org/>`__ output before you push it.
The documentation will be generated in the ``html`` directory.

.. code-block:: bash

cd PyTensor/
python ./doc/scripts/docgen.py

If you don't want to generate the pdf, do the following:

.. code-block:: bash

cd PyTensor/
python ./doc/scripts/docgen.py --nopdf


For more details:

.. code-block:: bash

$ python doc/scripts/docgen.py --help
Usage: doc/scripts/docgen.py [OPTIONS]
-o <dir>: output the html files in the specified dir
--rst: only compile the doc (requires sphinx)
--nopdf: do not produce a PDF file from the doc, only HTML
--help: this help
Refer to `relevant section of Developer Start Guide <https://pytensor.readthedocs.io/en/latest/dev_start_guide.html#contributing-to-the-documentation>`_.

Use ReST for documentation
--------------------------
Expand Down