Transforming Array Expression Graphs

class pytato.transform.Mapper[source]

A class that when called with a pytato.Array recursively iterates over the DAG, calling the _mapper_method of each node. Users of this class are expected to override the methods of this class or create a subclass.


This class might visit a node multiple times. Use a CachedMapper if this is not desired.

handle_unsupported_array(expr: MappedT, *args: Any, **kwargs: Any) Any[source]

Mapper method that is invoked for pytato.Array subclasses for which a mapper method does not exist in this mapper.

map_foreign(expr: Any, *args: Any, **kwargs: Any) Any[source]

Mapper method that is invoked for an object of class for which a mapper method does not exist in this mapper.

rec(expr: MappedT, *args: Any, **kwargs: Any) Any[source]

Call the mapper method of expr and return the result.

__call__(expr: MappedT, *args: Any, **kwargs: Any) Any[source]

Handle the mapping of expr.

class pytato.transform.CachedMapper[source]

Mapper class that maps each node in the DAG exactly once. This loses some information compared to Mapper as a node is visited only from one of its predecessors.

class pytato.transform.CopyMapper[source]

Performs a deep copy of a pytato.array.Array. The typical use of this mapper is to override individual map_ methods in subclasses to permit term rewriting on an expression graph.

clone_for_callee(function: FunctionDefinition) _SelfMapper[source]

Called to clone self before starting traversal of a pytato.function.FunctionDefinition.


This does not copy the data of a pytato.array.DataWrapper.

class pytato.transform.CopyMapperWithExtraArgs[source]

Similar to CopyMapper, but each mapper method takes extra *args, **kwargs that are propagated along a path by default.

The logic in CopyMapper purposely does not take the extra arguments to keep the cost of its each call frame low.

class pytato.transform.CombineMapper[source]

Abstract mapper that recursively combines the results of user nodes of a given expression.

combine(*args: CombineT) CombineT[source]

Combine the arguments.

class pytato.transform.DependencyMapper[source]

Maps a pytato.array.Array to a frozenset of pytato.array.Array’s it depends on.


This returns every node in the graph! Consider a custom CombineMapper or a SubsetDependencyMapper instead.

class pytato.transform.InputGatherer[source]

Mapper to combine all instances of pytato.array.InputArgumentBase that an array expression depends on.

class pytato.transform.SizeParamGatherer[source]

Mapper to combine all instances of pytato.array.SizeParam that an array expression depends on.

class pytato.transform.SubsetDependencyMapper(universe: FrozenSet[Array])[source]

Mapper to combine the dependencies of an expression that are a subset of universe.

class pytato.transform.WalkMapper[source]

A mapper that walks over all the arrays in a pytato.Array.

Users may override the specific mapper methods in a derived class or override WalkMapper.visit() and WalkMapper.post_visit().

visit(expr: Any, *args: Any, **kwargs: Any) bool[source]

If this method returns True, expr is traversed during the walk. If this method returns False, expr is not traversed as a part of the walk.

post_visit(expr: Any, *args: Any, **kwargs: Any) None[source]

Callback after expr has been traversed.

class pytato.transform.CachedWalkMapper[source]

WalkMapper that visits each node in the DAG exactly once. This loses some information compared to WalkMapper as a node is visited only from one of its predecessors.

class pytato.transform.TopoSortMapper
class pytato.transform.CachedMapAndCopyMapper(map_fn: Callable[[Array | AbstractResultWithNamedArrays], Array | AbstractResultWithNamedArrays])[source]

Mapper that applies map_fn to each node and copies it. Results of traversals are memoized i.e. each node is mapped via map_fn exactly once.

pytato.transform.copy_dict_of_named_arrays(source_dict: DictOfNamedArrays, copy_mapper: CopyMapper) DictOfNamedArrays[source]

Copy the elements of a DictOfNamedArrays into a DictOfNamedArrays.

  • source_dict – The DictOfNamedArrays to copy

  • copy_mapper – A mapper that performs copies different array types


A new DictOfNamedArrays containing copies of the items in source_dict

pytato.transform.get_dependencies(expr: DictOfNamedArrays) Dict[str, FrozenSet[Array]][source]

Returns the dependencies of each named array in expr.

pytato.transform.map_and_copy(expr: MappedT, map_fn: Callable[[Array | AbstractResultWithNamedArrays], Array | AbstractResultWithNamedArrays]) MappedT[source]

Returns a copy of expr with every array expression reachable from expr mapped via map_fn.


Uses CachedMapAndCopyMapper under the hood and because of its caching nature each node is mapped exactly once.

pytato.transform.materialize_with_mpms(expr: DictOfNamedArrays) DictOfNamedArrays[source]

Materialize nodes in expr with MPMS materialization strategy. MPMS stands for Multiple-Predecessors, Multiple-Successors.


  • MPMS materialization strategy is a greedy materialization algorithm in which any node with more than 1 materialized predecessor and more than 1 successor is materialized.

  • Materializing here corresponds to tagging a node with ImplStored.

  • Does not attempt to materialize sub-expressions in pytato.Array.shape.


This is a greedy materialization algorithm and thereby this algorithm might be too eager to materialize. Consider the graph below:

I1          I2
 \         /
  \       /
   \     /
    🡦   🡧
     / \
    /   \
   /     \
  🡧       🡦
 O1        O2

where, ‘I1’, ‘I2’ correspond to instances of pytato.array.InputArgumentBase, and, ‘O1’ and ‘O2’ are the outputs required to be evaluated in the computation graph. MPMS materialization algorithm will materialize the intermediate node ‘T’ as it has 2 predecessors and 2 successors. However, the total number of memory accesses after applying MPMS goes up as shown by the table below.












pytato.transform.deduplicate_data_wrappers(array_or_names: Array | AbstractResultWithNamedArrays) Array | AbstractResultWithNamedArrays[source]

For the expression graph given as array_or_names, replace all pytato.array.DataWrapper instances containing identical data with a single instance.


Currently only supports numpy.ndarray and pyopencl.array.Array.


This function currently uses addresses of memory buffers to detect duplicate data, and so it may fail to deduplicate some instances of identical-but-separately-stored data. User code must tolerate this, but it must also tolerate this function doing a more thorough job of deduplication.

pytato.transform.lower_to_index_lambda.to_index_lambda(expr: Array) IndexLambda[source]

Lowers expr to IndexLambda if possible, otherwise raises a pytato.diagnostic.CannotBeLoweredToIndexLambda.


The lowered IndexLambda.

pytato.transform.remove_broadcasts_einsum.rewrite_einsums_with_no_broadcasts(expr: MappedT) MappedT[source]

Rewrites all instances of Einsum in expr such that the einsum expressions avoid broadcasting axes of its operands. We do so by updating the pytato.array.Einsum.access_descriptors and slicing the operands.

>>> a = pt.make_placeholder("a", (10, 4, 1), np.float64)
>>> b = pt.make_placeholder("b", (10, 1, 4), np.float64)
>>> expr = pt.einsum("ijk,ijk->i", a, b)
>>> new_expr = pt.rewrite_einsums_with_no_broadcasts(expr)
>>> pt.analysis.is_einsum_similar_to_subscript(new_expr, "ij,ik->i")


This transformation preserves the semantics of the expression i.e. does not alter its value.

class pytato.transform.einsum_distributive_law.EinsumDistributiveLawDescriptor[source]

Abstract-type that informs apply_distributive_property_to_einsums() how should the distributive law be applied along an einsum’s operands.

class pytato.transform.einsum_distributive_law.DoNotDistribute[source]

Tells apply_distributive_property_to_einsums() to not apply distributive law along any operands of the Einsum.

class pytato.transform.einsum_distributive_law.DoDistribute(ioperand: int)[source]

Tells apply_distributive_property_to_einsums() to apply distributive law along the ioperand-th operands of the Einsum.

pytato.transform.einsum_distributive_law.apply_distributive_property_to_einsums(expr: MappedT, how_to_distribute: Callable[[Array], EinsumDistributiveLawDescriptor]) MappedT[source]

Returns a copy of expr after applying distributive law for einstein summation nodes in the expression graph.

>>> import pytato as pt
>>> x1 = pt.make_placeholder("x1", 4, np.float64)
>>> x2 = pt.make_placeholder("x2", 4, np.float64)
>>> A = pt.make_placeholder("A", (10, 4), np.float64)
>>> y = A @ (x1 + x2)

>>> def how_to_distribute(expr):
...     if pt.analysis.is_einsum_similar_to_subscript(
...         expr, "ij,j->i"):
...         return DoDistribute(ioperand=1)
...     else:
...         return DoNotDistribute()

>>> y_transformed = apply_distributive_property_to_einsums(y,
...                     how_to_distribute)

>>> y_transformed == A @ x1 + A @ x2

Dict representation of DAGs

class pytato.transform.UsersCollector[source]

Maps a graph to a dictionary representation mapping a node to its users, i.e. all the nodes using its value.


Mapping of each node in the graph to its users.

__init__() None[source]
pytato.transform.tag_user_nodes(graph: Mapping[Array | AbstractResultWithNamedArrays, Set[Array | AbstractResultWithNamedArrays]], tag: Any, starting_point: Array | AbstractResultWithNamedArrays, node_to_tags: Dict[Array | AbstractResultWithNamedArrays, Set[Array | AbstractResultWithNamedArrays]] | None = None) Dict[Array | AbstractResultWithNamedArrays, Set[Any]][source]

Tags all nodes reachable from starting_point with tag.

  • graph – A dict representation of a directed graph, mapping each node to other nodes to which it is connected by edges. A possible use case for this function is the graph in UsersCollector.node_to_users.

  • tag – The value to tag the nodes with.

  • starting_point – A starting point in graph.

  • node_to_tags – The resulting mapping of nodes to tags.

pytato.transform.rec_get_user_nodes(expr: Array | AbstractResultWithNamedArrays, node: Array | AbstractResultWithNamedArrays) FrozenSet[Array | AbstractResultWithNamedArrays][source]

Returns all direct and indirect users of node in expr.

Transforming call sites

pytato.transform.calls.inline_calls(expr: Array | AbstractResultWithNamedArrays) Array | AbstractResultWithNamedArrays[source]

Returns a copy of expr with call sites tagged with pytato.tags.InlineCallTag inlined into the expression graph.

pytato.transform.calls.tag_all_calls_to_be_inlined(expr: Array | AbstractResultWithNamedArrays) Array | AbstractResultWithNamedArrays[source]

Returns a copy of expr with all reachable instances of pytato.function.Call nodes tagged with pytato.tags.InlineCallTag.


This routine does NOT inline calls, to inline the calls use tag_all_calls_to_be_inlined() on this routine’s output.

Internal stuff that is only here because the documentation tool wants it

class pytato.transform.MappedT

A type variable representing the input type of a Mapper.

class pytato.transform.CombineT

A type variable representing the type of a CombineMapper.

class pytato.transform._SelfMapper

A type variable used to represent the type of a mapper in CopyMapper.clone_for_callee().

Analyzing Array Expression Graphs

pytato.analysis.get_nusers(outputs: Array | DictOfNamedArrays) Mapping[Array, int][source]

For the DAG outputs, returns the mapping from each node to the number of nodes using its value within the DAG given by outputs.

pytato.analysis.is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) bool[source]

Returns True if and only if an einsum with the subscript descriptor string subscripts operated on expr’s pytato.array.Einsum.args would compute the same result as expr.

pytato.analysis.get_num_nodes(outputs: Array | DictOfNamedArrays) int[source]

Returns the number of nodes in DAG outputs.

pytato.analysis.get_num_call_sites(outputs: Array | DictOfNamedArrays) int[source]

Returns the number of nodes in DAG outputs.

class pytato.analysis.DirectPredecessorsGetter[source]

Mapper to get the direct predecessors of a node.


We only consider the predecessors of a nodes in a data-flow sense.

Visualizing Array Expression Graphs

pytato.get_dot_graph(result: Array | DictOfNamedArrays) str[source]

Return a string in the dot language depicting the graph of the computation of result.


result – Outputs of the computation (cf. pytato.generate_loopy()).

pytato.get_dot_graph_from_partition(partition: DistributedGraphPartition) str[source]

Return a string in the dot language depicting the graph of the partitioned computation of partition.


partition – Outputs of find_distributed_partition().

pytato.show_dot_graph(result: str | Array | DictOfNamedArrays | DistributedGraphPartition, **kwargs: Any) None[source]

Show a graph representing the computation of result in a browser.

pytato.show_fancy_placeholder_data_flow(dag: Array | DictOfNamedArrays, **kwargs: Any) None[source]

Visualizes the data-flow from the placeholders into outputs.



This is a heavily opinionated visualization of data-flow graph in dag. Displaying all the information about the node is not the priority. See pytato.show_dot_graph() that aims to be more verbose.

Comparing two expression Graphs

class pytato.equality.EqualityComparer[source]

A pytato.array.Array visitor to check equality between two expression DAGs.


  • Compares two expression graphs expr1, expr2 in \(O(N)\) comparisons, where \(N\) is the number of nodes in expr1.

  • This visitor was introduced to memoize the sub-expression comparisons of the expressions to be compared. Not memoizing the sub-expression comparisons results in \(O(2^N)\) complexity for the comparison operation, where \(N\) is the number of nodes in expressions. See GH-Issue-163 <> for more on this.

Stringifying Expression Graphs

class pytato.stringifier.Reprifier(truncation_depth: int = 3, truncation_string: str = '(...)')[source]

Stringifies pytato-types to closely resemble CPython’s implementation of repr() for its builtin datatypes.

Utilities and Diagnostics

Helper routines

pytato.utils.are_shape_components_equal(dim1: int | integer[Any] | Array, dim2: int | integer[Any] | Array) bool[source]

Returns True iff dim1 and dim2 are have equal SizeParam coefficients in their expressions.

pytato.utils.are_shapes_equal(shape1: Tuple[int | integer[Any] | Array, ...], shape2: Tuple[int | integer[Any] | Array, ...]) bool[source]

Returns True iff shape1 and shape2 have the same dimensionality and the correpsonding components are equal as defined by are_shape_components_equal().

pytato.utils.get_shape_after_broadcasting(exprs: Iterable[Array | number[Any] | int | bool_ | bool | float | complex | Expression]) Tuple[int | integer[Any] | Array, ...][source]

Returns the shape after broadcasting exprs in an operation.

pytato.utils.dim_to_index_lambda_components(expr: int | integer[Any] | Array, vng: UniqueNameGenerator | None = None) Tuple[number[Any] | int | bool_ | bool | float | complex | Expression, Dict[str, SizeParam]][source]

Returns the scalar expressions and bindings to use the shape component within an index lambda.

>>> n = pt.make_size_param("n")
>>> expr, bnds = dim_to_index_lambda_components(3*n+8, UniqueNameGenerator())
>>> print(expr)
3*_in + 8
>>> bnds
{'_in': SizeParam(name='n')}
pytato.utils.get_common_dtype_of_ary_or_scalars(ary_or_scalars: Sequence[Array | number[Any] | int | bool_ | bool | float | complex]) dtype[source]
pytato.utils.get_einsum_subscript_str(expr: Einsum) str[source]

Returns the index subscript expression that was used in constructing expr using the pytato.einsum() routine.

>>> A = pt.make_placeholder("A", (10, 6), np.float64)
>>> B = pt.make_placeholder("B", (6, 5), np.float64)
>>> C = pt.make_placeholder("B", (5, 4), np.float64)
>>> ABC = pt.einsum("ij,jk,kl->il", A, B, C)
>>> get_einsum_subscript_str(ABC)

Pytato-specific exceptions

class pytato.diagnostic.NameClashError[source]

Raised when 2 non-identical InputArgumentBase’s are reachable in an Array’s DAG and share the same name. Here, we refer to 2 objects a and b as being identical iff a is b.

class pytato.diagnostic.CannotBroadcastError[source]
class pytato.diagnostic.UnknownIndexLambdaExpr[source]

Raised when the structure pytato.array.IndexLambda could not be inferred.

class pytato.diagnostic.CannotBeLoweredToIndexLambda[source]

Raised when a pytato.Array was expected to be lowered to an IndexLambda, but it cannot be. For ex. a pytato.loopy.LoopyCallResult cannot be lowered to an IndexLambda.