Transforming Array Expression Graphs#

class pytato.transform.Mapper[source]#

A class that when called with a pytato.Array recursively iterates over the DAG, calling the _mapper_method of each node. Users of this class are expected to override the methods of this class or create a subclass.


This class might visit a node multiple times. Use a CachedMapper if this is not desired.

handle_unsupported_array(expr: MappedT, *args: Any, **kwargs: Any) Any[source]#

Mapper method that is invoked for pytato.Array subclasses for which a mapper method does not exist in this mapper.

map_foreign(expr: Any, *args: Any, **kwargs: Any) Any[source]#

Mapper method that is invoked for an object of class for which a mapper method does not exist in this mapper.

rec(expr: MappedT, *args: Any, **kwargs: Any) Any[source]#

Call the mapper method of expr and return the result.

__call__(expr: MappedT, *args: Any, **kwargs: Any) Any[source]#

Handle the mapping of expr.

class pytato.transform.CachedMapper[source]#

Mapper class that maps each node in the DAG exactly once. This loses some information compared to Mapper as a node is visited only from one of its predecessors.

class pytato.transform.CopyMapper[source]#

Performs a deep copy of a pytato.array.Array. The typical use of this mapper is to override individual map_ methods in subclasses to permit term rewriting on an expression graph.


This does not copy the data of a pytato.array.DataWrapper.

class pytato.transform.CopyMapperWithExtraArgs[source]#

Similar to CopyMapper, but each mapper method takes extra *args, **kwargs that are propagated along a path by default.

The logic in CopyMapper purposely does not take the extra arguments to keep the cost of its each call frame low.

class pytato.transform.CombineMapper[source]#

Abstract mapper that recursively combines the results of user nodes of a given expression.

combine(*args: CombineT) CombineT[source]#

Combine the arguments.

class pytato.transform.DependencyMapper[source]#

Maps a pytato.array.Array to a frozenset of pytato.array.Array’s it depends on.


This returns every node in the graph! Consider a custom CombineMapper or a SubsetDependencyMapper instead.

class pytato.transform.InputGatherer[source]#

Mapper to combine all instances of pytato.array.InputArgumentBase that an array expression depends on.

class pytato.transform.SizeParamGatherer[source]#

Mapper to combine all instances of pytato.array.SizeParam that an array expression depends on.

class pytato.transform.SubsetDependencyMapper(universe: FrozenSet[Array])[source]#

Mapper to combine the dependencies of an expression that are a subset of universe.

class pytato.transform.WalkMapper[source]#

A mapper that walks over all the arrays in a pytato.Array.

Users may override the specific mapper methods in a derived class or override WalkMapper.visit() and WalkMapper.post_visit().

visit(expr: Any, *args: Any, **kwargs: Any) bool[source]#

If this method returns True, expr is traversed during the walk. If this method returns False, expr is not traversed as a part of the walk.

post_visit(expr: Any, *args: Any, **kwargs: Any) None[source]#

Callback after expr has been traversed.

class pytato.transform.CachedWalkMapper[source]#

WalkMapper that visits each node in the DAG exactly once. This loses some information compared to WalkMapper as a node is visited only from one of its predecessors.

class pytato.transform.TopoSortMapper#
class pytato.transform.CachedMapAndCopyMapper(map_fn: Callable[[Array | AbstractResultWithNamedArrays], Array | AbstractResultWithNamedArrays])[source]#

Mapper that applies map_fn to each node and copies it. Results of traversals are memoized i.e. each node is mapped via map_fn exactly once.

class pytato.transform.EdgeCachedMapper[source]#

Mapper class to execute a rewriting method (handle_edge()) on each edge in the graph.

abstract handle_edge(expr: Array | AbstractResultWithNamedArrays, child: Array | AbstractResultWithNamedArrays) Any[source]#
pytato.transform.copy_dict_of_named_arrays(source_dict: DictOfNamedArrays, copy_mapper: CopyMapper) DictOfNamedArrays[source]#

Copy the elements of a DictOfNamedArrays into a DictOfNamedArrays.

  • source_dict – The DictOfNamedArrays to copy

  • copy_mapper – A mapper that performs copies different array types


A new DictOfNamedArrays containing copies of the items in source_dict

pytato.transform.get_dependencies(expr: DictOfNamedArrays) Dict[str, FrozenSet[Array]][source]#

Returns the dependencies of each named array in expr.

pytato.transform.map_and_copy(expr: Array | AbstractResultWithNamedArrays, map_fn: Callable[[Array | AbstractResultWithNamedArrays], Array | AbstractResultWithNamedArrays]) Array | AbstractResultWithNamedArrays[source]#

Returns a copy of expr with every array expression reachable from expr mapped via map_fn.


Uses CachedMapAndCopyMapper under the hood and because of its caching nature each node is mapped exactly once.

pytato.transform.materialize_with_mpms(expr: DictOfNamedArrays) DictOfNamedArrays[source]#

Materialize nodes in expr with MPMS materialization strategy. MPMS stands for Multiple-Predecessors, Multiple-Successors.


  • MPMS materialization strategy is a greedy materialization algorithm in which any node with more than 1 materialized predecessors and more than 1 successors is materialized.

  • Materializing here corresponds to tagging a node with ImplStored.

  • Does not attempt to materialize sub-expressions in pytato.Array.shape.


This is a greedy materialization algorithm and thereby this algorithm might be too eager to materialize. Consider the graph below:

I1          I2
 \         /
  \       /
   \     /
    👦   👧
     / \
    /   \
   /     \
  👧       👦
 O1        O2

where, β€˜I1’, β€˜I2’ correspond to instances of pytato.array.InputArgumentBase, and, β€˜O1’ and β€˜O2’ are the outputs required to be evaluated in the computation graph. MPMS materialization algorithm will materialize the intermediate node β€˜T’ as it has 2 predecessors and 2 successors. However, the total number of memory accesses after applying MPMS goes up as shown by the table below.












pytato.transform.deduplicate_data_wrappers(array_or_names: Array | AbstractResultWithNamedArrays) Array | AbstractResultWithNamedArrays[source]#

For the expression graph given as array_or_names, replace all pytato.array.DataWrapper instances containing identical data with a single instance.


Currently only supports numpy.ndarray and pyopencl.array.Array.


This function currently uses addresses of memory buffers to detect duplicate data, and so it may fail to deduplicate some instances of identical-but-separately-stored data. User code must tolerate this, but it must also tolerate this function doing a more thorough job of deduplication.

pytato.transform.lower_to_index_lambda.to_index_lambda(expr: Array) IndexLambda[source]#

Lowers expr to IndexLambda if possible, otherwise raises a pytato.diagnostic.CannotBeLoweredToIndexLambda.


The lowered IndexLambda.

pytato.transform.remove_broadcasts_einsum.rewrite_einsums_with_no_broadcasts(expr: MappedT) MappedT[source]#

Rewrites all instances of Einsum in expr such that the einsum expressions avoid broadcasting axes of its operands. We do so by updating the pytato.array.Einsum.access_descriptors and slicing the operands.

>>> a = pt.make_placeholder("a", (10, 4, 1), np.float64)
>>> b = pt.make_placeholder("b", (10, 1, 4), np.float64)
>>> expr = pt.einsum("ijk,ijk->i", a, b)
>>> new_expr = pt.rewrite_einsums_with_no_broadcasts(expr)
>>> pt.analysis.is_einsum_similar_to_subscript(new_expr, "ij,ik->i")


This transformation preserves the semantics of the expression i.e. does not alter its value.

Dict representation of DAGs#

class pytato.transform.UsersCollector[source]#

Maps a graph to a dictionary representation mapping a node to its users, i.e. all the nodes using its value.


Mapping of each node in the graph to its users.

__init__() None[source]#
pytato.transform.tag_user_nodes(graph: Mapping[Array | AbstractResultWithNamedArrays, Set[Array | AbstractResultWithNamedArrays]], tag: Any, starting_point: Array | AbstractResultWithNamedArrays, node_to_tags: Dict[Array | AbstractResultWithNamedArrays, Set[Array | AbstractResultWithNamedArrays]] | None = None) Dict[Array | AbstractResultWithNamedArrays, Set[Any]][source]#

Tags all nodes reachable from starting_point with tag.

  • graph – A dict representation of a directed graph, mapping each node to other nodes to which it is connected by edges. A possible use case for this function is the graph in UsersCollector.node_to_users.

  • tag – The value to tag the nodes with.

  • starting_point – A starting point in graph.

  • node_to_tags – The resulting mapping of nodes to tags.

pytato.transform.rec_get_user_nodes(expr: Array | AbstractResultWithNamedArrays, node: Array | AbstractResultWithNamedArrays) FrozenSet[Array | AbstractResultWithNamedArrays][source]#

Returns all direct and indirect users of node in expr.

Internal stuff that is only here because the documentation tool wants it#

class pytato.transform.MappedT#

A type variable representing the input type of a Mapper.

class pytato.transform.CombineT#

A type variable representing the type of a CombineMapper.

Analyzing Array Expression Graphs#

pytato.analysis.get_nusers(outputs: Array | DictOfNamedArrays) Mapping[Array, int][source]#

For the DAG outputs, returns the mapping from each node to the number of nodes using its value within the DAG given by outputs.

pytato.analysis.is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) bool[source]#

Returns True if and only if an einsum with the subscript descriptor string subscripts operated on expr’s pytato.array.Einsum.args would compute the same result as expr.

pytato.analysis.get_num_nodes(outputs: Array | DictOfNamedArrays) int[source]#

Returns the number of nodes in DAG outputs.

class pytato.analysis.DirectPredecessorsGetter[source]#

Mapper to get the direct predecessors of a node.


We only consider the predecessors of a nodes in a data-flow sense.

Visualizing Array Expression Graphs#

pytato.get_dot_graph(result: Array | DictOfNamedArrays) str[source]#

Return a string in the dot language depicting the graph of the computation of result.


result – Outputs of the computation (cf. pytato.generate_loopy()).

pytato.get_dot_graph_from_partition(partition: GraphPartition) str[source]#

Return a string in the dot language depicting the graph of the partitioned computation of partition.


partition – Outputs of find_partition().

pytato.show_dot_graph(result: str | Array | DictOfNamedArrays | GraphPartition, **kwargs: Any) None[source]#

Show a graph representing the computation of result in a browser.

pytato.get_ascii_graph(result: Array | DictOfNamedArrays, use_color: bool = True) str[source]#

Return a string representing the computation of result using the asciidag package.

pytato.show_ascii_graph(result: Array | DictOfNamedArrays) None[source]#

Print a graph representing the computation of result to stdout using the asciidag package.


result – Outputs of the computation (cf. pytato.generate_loopy()) or the output of get_dot_graph().

Comparing two expression Graphs#

class pytato.equality.EqualityComparer[source]#

A pytato.array.Array visitor to check equality between two expression DAGs.


  • Compares two expression graphs expr1, expr2 in \(O(N)\) comparisons, where \(N\) is the number of nodes in expr1.

  • This visitor was introduced to memoize the sub-expression comparisons of the expressions to be compared. Not memoizing the sub-expression comparisons results in \(O(2^N)\) complexity for the comparison operation, where \(N\) is the number of nodes in expressions. See GH-Issue-163 <> for more on this.

Stringifying Expression Graphs#

class pytato.stringifier.Reprifier(truncation_depth: int = 3, truncation_string: str = '(...)')[source]#

Stringifies pytato-types to closely resemble CPython’s implementation of repr() for its builtin datatypes.

Partitioning Array Expression Graphs#

Partitioning of graphs in pytato currently mainly serves to enable distributed computation, i.e. sending and receiving data as part of graph evaluation.

However, as implemented, it is completely general and not specific to this use case. Partitioning of expression graphs is based on a few assumptions:

  • We must be able to execute parts in any dependency-respecting order.

  • Parts are compiled at partitioning time, so what inputs they take from memory vs. what they compute is decided at that time.

  • No part may depend on its own outputs as inputs. (cf. PartitionInducedCycleError)

class pytato.partition.GraphPart(pid: Hashable, needed_pids: FrozenSet[Hashable], user_input_names: FrozenSet[str], partition_input_names: FrozenSet[str], output_names: FrozenSet[str])[source]#

An identifier for this part of the graph.


The IDs of parts that are required to be evaluated before this part can be evaluated.


A frozenset of names representing input to the computational graph, i.e. which were not introduced by partitioning.


A frozenset of names of placeholders the part requires as input from other parts in the partition.


Names of placeholders this part provides as output.

all_input_names() FrozenSet[str][source]#
class pytato.partition.GraphPartition(parts: Mapping[Hashable, GraphPart], var_name_to_result: Mapping[str, Array], toposorted_part_ids: List[Hashable])[source]#

Store information about a partitioning of an expression graph.


Mapping from part IDs to instances of GraphPart.


Mapping of placeholder names to the respective pytato.array.Array they represent.


One possible topologically sorted ordering of part IDs that is admissible under GraphPart.needed_pids.


This attribute could be recomputed for those dependencies. Since it is computed as part of find_partition() anyway, it is preserved here.

class pytato.partition.GraphPartitioner(get_part_id: Callable[[Array | AbstractResultWithNamedArrays], Hashable])[source]#

Given a function get_part_id, produces subgraphs representing the computation. Users should not use this class directly, but use find_partition() instead.

__init__(get_part_id: Callable[[Array | AbstractResultWithNamedArrays], Hashable]) None[source]#
__call__(expr: T, *args: Any, **kwargs: Any) Any[source]#

Handle the mapping of expr.

make_partition(outputs: DictOfNamedArrays) GraphPartition[source]#
exception pytato.partition.PartitionInducedCycleError[source]#

Raised by find_partition() if the partitioning induced a cycle in the graph of partitions.

pytato.partition.find_partition(outputs: ~pytato.array.DictOfNamedArrays, part_func: ~typing.Callable[[~pytato.array.Array | ~pytato.array.AbstractResultWithNamedArrays], ~typing.Hashable], partitioner_class: ~typing.Type[~pytato.partition.GraphPartitioner] = <class 'pytato.partition.GraphPartitioner'>) GraphPartition[source]#

Partitions the expr according to part_func and generates code for each partition. Raises PartitionInducedCycleError if the partitioning induces a cycle, e.g. for a graph like the following:

β”Œβ”€β”€β”€ A β”œβ”€β”€β”
β”‚  β””β”€β”€β”€β”˜  β”‚
β”‚       β”Œβ”€β–Όβ”€β”
β”‚       β”‚ B β”‚
β”‚       β””β”€β”¬β”€β”˜
β”‚  β”Œβ”€β”€β”€β”  β”‚
└─►│ C β”‚β—„β”€β”˜

where A and C are in partition 1, and B is in partition 2.

  • outputs – The outputs to partition.

  • part_func – A callable that returns an instance of Hashable for a node.

  • partitioner_class – A GraphPartitioner to guide the partitioning.


An instance of GraphPartition that contains the partition.

pytato.partition.execute_partition(partition: GraphPartition, prg_per_partition: Dict[Hashable, BoundProgram], queue: Any, input_args: Dict[str, Any] | None = None) Dict[str, Any][source]#

Executes a set of partitions on a pyopencl.CommandQueue.


A dictionary of variable names mapped to their values.

Internal stuff that is only here because the documentation tool wants it#

class pytato.partition.T#

A type variable for AbstractResultWithNamedArrays.

Support for Distributed-Memory/Message Passing#

Distributed-memory evaluation of expression graphs is accomplished by partitioning the graph to reveal communication-free pieces of the computation. Communication (i.e. sending/receiving data) is then accomplished at the boundaries of the parts of the resulting graph partitioning.

Recall the requirement for partitioning that, β€œno part may depend on its own outputs as inputs”. That sounds obvious, but in the distributed-memory case, this is harder to decide than it looks, since we do not have full knowledge of the computation graph. Edges go off to other nodes and then come back.

As a first step towards making this tractable, we currently strengthen the requirement to create partition boundaries on every edge that goes between nodes that are/are not a dependency of a receive or that feed/do not feed a send.

class pytato.DistributedSend(data: Array, dest_rank: int, comm_tag: Hashable, tags: FrozenSet[Tag] = frozenset({}))[source]#

Class representing a distributed send operation.


The Array to be sent.


An int. The rank to which data is to be sent.


A hashable, picklable object to serve as a β€˜tag’ for the communication. Only a DistributedRecv with the same tag will be able to receive the data being sent here.

class pytato.DistributedSendRefHolder(send: DistributedSend, passthrough_data: Array, tags: FrozenSet[Tag] = frozenset({}))[source]#

A node acting as an identity on passthrough_data while also holding a reference to a DistributedSend in send. Since pytato represents data flow, and since no data flows β€˜out’ of a DistributedSend, no node in all of pytato has a good reason to hold a reference to a send node, since there is no useful result of a send (at least of an Array type).

This is where this node type comes in. Its value is the same as that of passthrough_data, and it holds a reference to the send node.


This all seems a wee bit inelegant, but nobody who has written or reviewed this code so far had a better idea. If you do, please speak up!


The DistributedSend to which a reference is to be held.


A Array. The value of this node.


It is the user’s responsibility to ensure matching sends and receives are part of the computation graph on all ranks. If this rule is not heeded, undefined behavior (in particular deadlock) may result. Notably, by the nature of the data flow graph built by pytato, unused results do not appear in the graph. It is thus possible for a DistributedSendRefHolder to be constructed and yet to not become part of the graph constructed by the user.

class pytato.DistributedRecv(src_rank: int, comm_tag: CommTagType, shape: ShapeType, dtype: Any, *, axes: AxesT, tags: FrozenSet[Tag])[source]#

Class representing a distributed receive operation.


An int. The rank from which an array is to be received.


A hashable, picklable object to serve as a β€˜tag’ for the communication. Only a DistributedSend with the same tag will be able to send the data being received here.



It is the user’s responsibility to ensure matching sends and receives are part of the computation graph on all ranks. If this rule is not heeded, undefined behavior (in particular deadlock) may result. Notably, by the nature of the data flow graph built by pytato, unused results do not appear in the graph. It is thus possible for a DistributedRecv to be constructed and yet to not become part of the graph constructed by the user.

pytato.make_distributed_send(sent_data: Array, dest_rank: int, comm_tag: Hashable, send_tags: FrozenSet[Tag] = frozenset({})) DistributedSend[source]#

Make a DistributedSend object.

pytato.staple_distributed_send(sent_data: Array, dest_rank: int, comm_tag: Hashable, stapled_to: Array, *, send_tags: FrozenSet[Tag] = frozenset({}), ref_holder_tags: FrozenSet[Tag] = frozenset({})) DistributedSendRefHolder[source]#

Make a DistributedSend object wrapped in a DistributedSendRefHolder object.

pytato.make_distributed_recv(src_rank: int, comm_tag: CommTagType, shape: ConvertibleToShape, dtype: Any, axes: AxesT | None = None, tags: FrozenSet[Tag] = frozenset({})) DistributedRecv[source]#

Make a DistributedRecv object.

class pytato.DistributedGraphPart(pid: Hashable, needed_pids: FrozenSet[Hashable], user_input_names: FrozenSet[str], partition_input_names: FrozenSet[str], output_names: FrozenSet[str], input_name_to_recv_node: Dict[str, DistributedRecv], output_name_to_send_node: Dict[str, DistributedSend], distributed_sends: List[DistributedSend])[source]#

For one graph partition, record send/receive information for input/ output names.

class pytato.DistributedGraphPartition(var_name_to_result: Mapping[str, Array], toposorted_part_ids: List[Hashable], parts: Dict[Hashable, DistributedGraphPart])[source]#

Store information about distributed graph partitions. This has the same attributes as GraphPartition, however parts now maps to instances of DistributedGraphPart.

pytato.find_distributed_partition(outputs: DictOfNamedArrays) DistributedGraphPartition[source]#

Partitions outputs into parts. Between two parts communication statements (sends/receives) are scheduled.


The partitioning of a DAG generally does not have a unique solution. The heuristic employed by this partitioner is as follows:

  1. The data contained in DistributedSend are marked as mandatory part outputs.

  2. Based on the dependencies in outputs, a DAG is constructed with only the mandatory part outputs as the nodes.

  3. Using a topological sort the mandatory part outputs are assigned a β€œtime” (an integer) such that evaluating these outputs at that time would preserve dependencies. We maximize the number of part outputs scheduled at a each β€œtime”. This requirement ensures our topological sort is deterministic.

  4. We then turn our attention to the other arrays that are allocated to a buffer. These are the materialized arrays and belong to one of the following classes: - An Array tagged with pytato.tags.ImplStored. - The expressions in a DictOfNamedArrays.

  5. Based on outputs, we compute the predecessors of a materialized that were a part of the mandatory part outputs. A materialized array is scheduled to be evaluated in a part as soon as all of its inputs are available. Note that certain inputs (like DistributedRecv) might not be available until certain mandatory part outputs have been evaluated.

  6. From outputs, we can construct a DAG comprising only of mandatory part outputs and materialized arrays. We mark all materialized arrays that are being used by nodes in a part that’s not the one in which the materialized array itself was evaluated. Such materialized arrays are also realized as part outputs. This is done to avoid recomputations.

Knobs to tweak the partition:

  1. By removing dependencies between the mandatory part outputs, the resulting DAG would lead to fewer number of parts and parts with more number of nodes in them. Similarly, adding dependencies between the part outputs would lead to smaller parts.

  2. Tagging nodes with :class:~pytato.tags.ImplStored` would help in avoiding re-computations.

pytato.verify_distributed_partition(mpi_communicator: mpi4py.MPI.Comm, partition: DistributedGraphPartition) None[source]#

Verify that

  • a feasible execution order exists among graph parts across the global, partitioned, distributed data flow graph, consisting of all values of partition across all ranks.

  • sends and receives for a given triple of (source rank, destination rank, tag) are unique.

  • there is a one-to-one mapping between instances of DistributedRecv and DistributedSend


This is an MPI-collective operation.

pytato.execute_distributed_partition(partition: DistributedGraphPartition, prg_per_partition: Dict[Hashable, BoundProgram], queue: Any, mpi_communicator: Any, *, allocator: Any | None = None, input_args: Dict[str, Any] | None = None) Dict[str, Any][source]#

Internal stuff that is only here because the documentation tool wants it#

class pytato.Tag#

See pytools.tag.Tag.

class pytato.CommTagType#

A type representing a communication tag.

class pytato.ShapeType#

A type representing a shape.

class pytato.AxesT#

A tuple of Axis objects.

Utilities and Diagnostics#

Helper routines#

pytato.utils.are_shape_components_equal(dim1: int | integer | Array, dim2: int | integer | Array) bool[source]#

Returns True iff dim1 and dim2 are have equal SizeParam coefficients in their expressions.

pytato.utils.are_shapes_equal(shape1: Tuple[int | integer | Array, ...], shape2: Tuple[int | integer | Array, ...]) bool[source]#

Returns True iff shape1 and shape2 have the same dimensionality and the correpsonding components are equal as defined by are_shape_components_equal().

pytato.utils.get_shape_after_broadcasting(exprs: Iterable[Array | number | int | bool_ | bool | float | Expression]) Tuple[int | integer | Array, ...][source]#

Returns the shape after broadcasting exprs in an operation.

pytato.utils.dim_to_index_lambda_components(expr: int | integer | Array, vng: UniqueNameGenerator | None = None) Tuple[number | int | bool_ | bool | float | Expression, Dict[str, SizeParam]][source]#

Returns the scalar expressions and bindings to use the shape component within an index lambda.

>>> n = pt.make_size_param("n")
>>> expr, bnds = dim_to_index_lambda_components(3*n+8, UniqueNameGenerator())
>>> print(expr)
3*_in + 8
>>> bnds
{'_in': SizeParam(name='n')}
pytato.utils.get_common_dtype_of_ary_or_scalars(ary_or_scalars: Sequence[Array | number | int | bool_ | bool | float]) dtype[source]#

Pytato-specific exceptions#

class pytato.diagnostic.NameClashError[source]#

Raised when 2 non-identical InputArgumentBase’s are reachable in an Array’s DAG and share the same name. Here, we refer to 2 objects a and b as being identical iff a is b.

class pytato.diagnostic.CannotBroadcastError[source]#
class pytato.diagnostic.UnknownIndexLambdaExpr[source]#

Raised when the structure pytato.array.IndexLambda could not be inferred.

class pytato.diagnostic.CannotBeLoweredToIndexLambda[source]#

Raised when a pytato.Array was expected to be lowered to an IndexLambda, but it cannot be. For ex. a pytato.loopy.LoopyCallResult cannot be lowered to an IndexLambda.