Overview of the compilation pipeline#

Once one has an PyTensor graph, they can use pytensor.function() to compile a function that will perform the computations modeled by the graph in Python, C, Numba, or JAX.

More specifically, pytensor.function() takes a list of input and output Variables that define the precise sub-graphs that correspond to the desired computations.

Here is an overview of the various steps that are taken during the compilation performed by pytensor.function().

Step 1 - Clone the graph and collect shared variables#

pytensor.function() first validates the user-supplied inputs and resolves profiling settings. It then calls construct_function_ins_and_outs which clones the computation graph, discovers shared variables that appear in the graph, applies givens substitutions, and wires up updates. The result is a list of In (SymbolicInput) objects and cloned output variables, ready for the next stage.

Step 2 - Create a FunctionGraph and rewrite it#

pytensor.function() resolves the Mode and obtains the FunctionMaker class via mode.function_maker (this is overridable—for example, DebugMode substitutes its own maker that performs additional validation).

FunctionMaker.__init__ then:

  • Wraps raw inputs/outputs into SymbolicInput / SymbolicOutput.

  • Builds a FunctionGraph via FunctionMaker.create_fgraph, which also extracts update outputs from the input specs and sets up the update_mapping. If an existing fgraph is passed (as Scan does for its inner loop), FunctionMaker.create_fgraph augments it with update outputs instead of creating a new graph.

  • Applies the rewriter produced by the mode to the FunctionGraph (via prepare_fgraph). The rewriter is typically obtained through a query on optdb.

  • Configures the linker with the rewritten graph.

Some relevant Features are added to the FunctionGraph during this stage—for instance, features that prevent rewrites from operating in-place on inputs declared as immutable.

Step 4 - Wrap everything in a Function#

The VM, input containers, and output containers are wrapped in a Function object that presents a normal Python callable interface. When called, Function.__call__() places user-provided values into the input containers, runs the VM, copies update outputs back into shared-variable containers, and returns the results.