Distributed-Memory/Message Passing#

Distributed-memory evaluation of expression graphs is accomplished by partitioning the graph to reveal communication-free pieces of the computation. Communication (i.e. sending/receiving data) is then accomplished at the boundaries of the parts of the resulting graph partitioning.

Recall the requirement for partitioning that, “no part may depend on its own outputs as inputs”. That sounds obvious, but in the distributed-memory case, this is harder to decide than it looks, since we do not have full knowledge of the computation graph. Edges go off to other nodes and then come back.


The following nodes represent communication in the DAG:

class pytato.DistributedSend(data: Array, dest_rank: int, comm_tag: Hashable, *, tags: FrozenSet[Tag] = frozenset({}))[source]#

Class representing a distributed send operation. See DistributedSendRefHolder for a way to ensure that nodes of this type remain part of a DAG.


The Array to be sent.


An int. The rank to which data is to be sent.


A hashable, picklable object to serve as a ‘tag’ for the communication. Only a DistributedRecv with the same tag will be able to receive the data being sent here.

class pytato.DistributedSendRefHolder(send: DistributedSend, passthrough_data: Array, tags: FrozenSet[Tag] = frozenset({}))[source]#

A node acting as an identity on passthrough_data while also holding a reference to a DistributedSend in send. Since pytato represents data flow, and since no data flows ‘out’ of a DistributedSend, no node in all of pytato has a good reason to hold a reference to a send node, since there is no useful result of a send (at least of an Array type).

This is where this node type comes in. Its value is the same as that of passthrough_data, and it holds a reference to the send node.


This all seems a wee bit inelegant, but nobody who has written or reviewed this code so far had a better idea. If you do, please speak up!


The DistributedSend to which a reference is to be held.


A Array. The value of this node.


It is the user’s responsibility to ensure matching sends and receives are part of the computation graph on all ranks. If this rule is not heeded, undefined behavior (in particular deadlock) may result. Notably, by the nature of the data flow graph built by pytato, unused results do not appear in the graph. It is thus possible for a DistributedSendRefHolder to be constructed and yet to not become part of the graph constructed by the user.

class pytato.DistributedRecv(shape: ShapeType, dtype: np.dtype[Any], src_rank: int, comm_tag: CommTagType, *, axes: AxesT, tags: FrozenSet[Tag])[source]#

Class representing a distributed receive operation.


An int. The rank from which an array is to be received.


A hashable, picklable object to serve as a ‘tag’ for the communication. Only a DistributedSend with the same tag will be able to send the data being received here.



It is the user’s responsibility to ensure matching sends and receives are part of the computation graph on all ranks. If this rule is not heeded, undefined behavior (in particular deadlock) may result. Notably, by the nature of the data flow graph built by pytato, unused results do not appear in the graph. It is thus possible for a DistributedRecv to be constructed and yet to not become part of the graph constructed by the user.

These functions aid in creating communication nodes:

pytato.staple_distributed_send(sent_data: Array, dest_rank: int, comm_tag: Hashable, stapled_to: Array, *, send_tags: FrozenSet[Tag] = frozenset({}), ref_holder_tags: FrozenSet[Tag] = frozenset({})) DistributedSendRefHolder[source]#

Make a DistributedSend object wrapped in a DistributedSendRefHolder object.

pytato.make_distributed_recv(src_rank: int, comm_tag: CommTagType, shape: ConvertibleToShape, dtype: Any, axes: AxesT | None = None, tags: FrozenSet[Tag] = frozenset({})) DistributedRecv[source]#

Make a DistributedRecv object.

For completeness, individual (non-held/”stapled”) DistributedSend nodes can be made via this function:

pytato.make_distributed_send(sent_data: Array, dest_rank: int, comm_tag: Hashable, send_tags: FrozenSet[Tag] = frozenset({})) DistributedSend[source]#

Make a DistributedSend object.

Redirections for the documentation tool#

class np.dtype#

See numpy.dtype.


Partitioning of graphs in pytato serves to enable distributed computation, i.e. sending and receiving data as part of graph evaluation.

Partitioning of expression graphs is based on a few assumptions:

  • We must be able to execute parts in any dependency-respecting order.

  • Parts are compiled at partitioning time, so what inputs they take from memory vs. what they compute is decided at that time.

  • No part may depend on its own outputs as inputs.

class pytato.DistributedGraphPart(pid: Hashable, needed_pids: FrozenSet[Hashable], user_input_names: FrozenSet[str], partition_input_names: FrozenSet[str], output_names: FrozenSet[str], name_to_recv_node: Mapping[str, DistributedRecv], name_to_send_nodes: Mapping[str, Sequence[DistributedSend]])[source]#

For one graph part, record send/receive information for input/ output names.

Names that occur as keys in name_to_recv_node and name_to_send_nodes are usable as input names by other parts, or in the result of the computation.


An identifier for this part of the graph.


The IDs of parts that are required to be evaluated before this part can be evaluated.


A frozenset of names representing input to the computational graph, i.e. which were not introduced by partitioning.


A frozenset of names of placeholders the part requires as input from other parts in the partition.


Names of placeholders this part provides as output.

all_input_names() FrozenSet[str][source]#
class pytato.DistributedGraphPartition(parts: Mapping[Hashable, DistributedGraphPart], name_to_output: Mapping[str, Array], overall_output_names: Sequence[str])[source]#

Mapping from part IDs to instances of DistributedGraphPart.


Mapping of placeholder names to the respective pytato.array.Array they represent. This is where the actual expressions are stored, for all parts. Observe that the DistributedGraphPart, for the most part, only stores names. These “outputs” may be ‘part outputs’ (i.e. data computed in one part for use by another, effectively tempoarary variables), or ‘overall outputs’ of the comutation.


The names of the outputs (in name_to_output) that were given to find_distributed_partition() to specify the overall computaiton.

pytato.find_distributed_partition(mpi_communicator: mpi4py.MPI.Comm, outputs: DictOfNamedArrays) DistributedGraphPartition[source]#

Compute a :class:DistributedGraphPartition` (for use with execute_distributed_partition()) that evaluates the same result as outputs, such that:

  • communication only happens at the beginning and end of each DistributedGraphPart, and

  • the partition introduces no circular dependencies between parts, mediated by either local data flow or off-rank communication.


This is an MPI-collective operation.

The following sections describe the (non-binding, as far as documentation is concerned) algorithm behind the partitioner.


We identify a communication operation (consisting of a pair of a send and a receive) by a CommunicationOpIdentifier. We keep graphs of these in CommunicationDepGraph.

If graph is a CommunicationDepGraph, then b in graph[a] means that, in order to initiate the communication operation identified by CommunicationOpIdentifier a, the communication operation identified by CommunicationOpIdentifier b must be completed. I.e. the nodes are “communication operations”, i.e. pairs of send/receive. Edges represent (rank-local) data flow between them.

Step 1: Build a global graph of data flow between communication operations

As a first step, each rank receives a copy of global CommunicationDepGraph, as described above. This becomes comm_ids_to_needed_comm_ids.

Step 2: Obtain a “schedule” of “communication batches”

On rank 0, compute and broadcast a topological order of comm_ids_to_needed_comm_ids. The result of this is comm_batches, a sequence of sets of CommunicationOpIdentifier instances, identifying sets of communication operations expected to complete between parts of the computation. (I.e. computation will occur before the first communication batch, then between the first and second, and so on.)


An important restriction of this scheme is that a linear order of communication batches is obtained, meaning that, typically, no overlap of computation and communication occurs.

Step 3: Create rank-local part descriptors

On each rank, we next rewrite the communication batches into computation parts, each identified by a _PartCommIDs structure, which gathers receives that need to complete before the computation on a part can begin and sends that can begin once computation on a part is complete.

Step 4: Assign materialized arrays to parts

“Stored” arrays are those whose value will be computed and stored in memory. This includes the following:

  • Arrays tagged ImplStored by prior processing of the DAG,

  • arrays being sent (because we need to hand a buffer to MPI),

  • arrays being received (because MPI puts the received data in memory)

  • Overall outputs of the computation.

By contrast, the code below uses the word “materialized” only for arrays of the first type (tagged ImplStored), so that ‘stored’ is a superset of ‘materialized’.

In addition, data computed by one part (in the above sense) of the computation and used by another must be in memory. Evaluating and storing temporary arrays is expensive, and so we try to minimize the number of times that that this occurs as part of the partitioning. This is done by relying on already-stored arrays as much as possible and recomputing any intermediate results needed in, say, an originating and a consuming part.

We begin this process by assigning each materialized array to a part in which it is computed, based on the part in which data depending on such arrays is sent. This choice implies that these computations occur as late as possible.

Step 5: Promote stored arrays to part outputs if needed

In DistributedGraphPart, our description of the partitioned computation, each part can declare named ‘outputs’ that can be used by subsequent parts. Stored arrays are promoted to part outputs if they have users in parts other than the one in which they are computed.

Step 6:: Rewrite the DAG into its parts

In the final step, we traverse the DAG to apply the following changes:

Internal stuff that is only here because the documentation tool wants it#

class pytato.distributed.partition.T#

A type variable for AbstractResultWithNamedArrays.

class pytato.distributed.partition.CommunicationOpIdentifier(src_rank: int, dest_rank: int, comm_tag: Hashable)[source]#

Identifies a communication operation (consisting of a pair of a send and a receive).



In find_distributed_partition(), we use instances of this type as though they identify sends or receives, i.e. just a single end of the communication. Realize that this is only true given the additional context of which rank is the local rank.

class pytato.distributed.partition.CommunicationDepGraph#

An alias for Mapping[CommunicationOpIdentifier, AbstractSet[CommunicationOpIdentifier]].


exception pytato.distributed.verify.PartitionInducedCycleError[source]#

Raised by if the partitioning (e.g. via find_distributed_partition()) erroneously induced a cycle in the graph of partitions.

pytato.verify_distributed_partition(mpi_communicator: mpi4py.MPI.Comm, partition: DistributedGraphPartition) None[source]#

Verify that

  • a feasible execution order exists among graph parts across the global, partitioned, distributed data flow graph, consisting of all values of partition across all ranks.

  • sends and receives for a given triple of (source rank, destination rank, tag) are unique.

  • there is a one-to-one mapping between instances of DistributedRecv and DistributedSend


This is an MPI-collective operation.


pytato.execute_distributed_partition(partition: DistributedGraphPartition, prg_per_partition: Dict[Hashable, BoundProgram], queue: Any, mpi_communicator: Any, *, allocator: Any | None = None, input_args: Dict[str, Any] | None = None) Dict[str, Any][source]#

Internal stuff that is only here because the documentation tool wants it#

class pytato.Tag#

See pytools.tag.Tag.

class pytato.CommTagType#

A type representing a communication tag.

class pytato.ShapeType#

A type representing a shape.

class pytato.AxesT#

A tuple of Axis objects.