Transforming Array Expression Graphs#
- class pytato.transform.Mapper[source]#
A class that when called with a
pytato.Arrayrecursively iterates over the DAG, calling the _mapper_method of each node. Users of this class are expected to override the methods of this class or create a subclass.Note
This class might visit a node multiple times. Use a
CachedMapperif this is not desired.- handle_unsupported_array(expr: MappedT, *args: Any, **kwargs: Any) Any[source]#
Mapper method that is invoked for
pytato.Arraysubclasses for which a mapper method does not exist in this mapper.
- map_foreign(expr: Any, *args: Any, **kwargs: Any) Any[source]#
Mapper method that is invoked for an object of class for which a mapper method does not exist in this mapper.
- class pytato.transform.CachedMapper[source]#
Mapper class that maps each node in the DAG exactly once. This loses some information compared to
Mapperas a node is visited only from one of its predecessors.
- class pytato.transform.CopyMapper[source]#
Performs a deep copy of a
pytato.array.Array. The typical use of this mapper is to override individualmap_methods in subclasses to permit term rewriting on an expression graph.- clone_for_callee() _SelfMapper[source]#
Called to clone self before starting traversal of a
pytato.function.FunctionDefinition.
Note
This does not copy the data of a
pytato.array.DataWrapper.
- class pytato.transform.CopyMapperWithExtraArgs[source]#
Similar to
CopyMapper, but each mapper method takes extra*args,**kwargsthat are propagated along a path by default.The logic in
CopyMapperpurposely does not take the extra arguments to keep the cost of its each call frame low.
- class pytato.transform.CombineMapper[source]#
Abstract mapper that recursively combines the results of user nodes of a given expression.
- class pytato.transform.DependencyMapper[source]#
Maps a
pytato.array.Arrayto afrozensetofpytato.array.Arrayβs it depends on.Warning
This returns every node in the graph! Consider a custom
CombineMapperor aSubsetDependencyMapperinstead.
- class pytato.transform.InputGatherer[source]#
Mapper to combine all instances of
pytato.array.InputArgumentBasethat an array expression depends on.
- class pytato.transform.SizeParamGatherer[source]#
Mapper to combine all instances of
pytato.array.SizeParamthat an array expression depends on.
- class pytato.transform.SubsetDependencyMapper(universe: FrozenSet[Array])[source]#
Mapper to combine the dependencies of an expression that are a subset of universe.
- class pytato.transform.WalkMapper[source]#
A mapper that walks over all the arrays in a
pytato.Array.Users may override the specific mapper methods in a derived class or override
WalkMapper.visit()andWalkMapper.post_visit().
- class pytato.transform.CachedWalkMapper[source]#
WalkMapper that visits each node in the DAG exactly once. This loses some information compared to
WalkMapperas a node is visited only from one of its predecessors.
- class pytato.transform.TopoSortMapper#
- class pytato.transform.CachedMapAndCopyMapper(map_fn: Callable[[Array | AbstractResultWithNamedArrays], Array | AbstractResultWithNamedArrays])[source]#
Mapper that applies map_fn to each node and copies it. Results of traversals are memoized i.e. each node is mapped via map_fn exactly once.
- pytato.transform.copy_dict_of_named_arrays(source_dict: DictOfNamedArrays, copy_mapper: CopyMapper) DictOfNamedArrays[source]#
Copy the elements of a
DictOfNamedArraysinto aDictOfNamedArrays.- Parameters:
source_dict β The
DictOfNamedArraysto copycopy_mapper β A mapper that performs copies different array types
- Returns:
A new
DictOfNamedArrayscontaining copies of the items in source_dict
- pytato.transform.get_dependencies(expr: DictOfNamedArrays) Dict[str, FrozenSet[Array]][source]#
Returns the dependencies of each named array in expr.
- pytato.transform.map_and_copy(expr: MappedT, map_fn: Callable[[Array | AbstractResultWithNamedArrays], Array | AbstractResultWithNamedArrays]) MappedT[source]#
Returns a copy of expr with every array expression reachable from expr mapped via map_fn.
Note
Uses
CachedMapAndCopyMapperunder the hood and because of its caching nature each node is mapped exactly once.
- pytato.transform.materialize_with_mpms(expr: DictOfNamedArrays) DictOfNamedArrays[source]#
Materialize nodes in expr with MPMS materialization strategy. MPMS stands for Multiple-Predecessors, Multiple-Successors.
Note
MPMS materialization strategy is a greedy materialization algorithm in which any node with more than 1 materialized predecessors and more than 1 successors is materialized.
Materializing here corresponds to tagging a node with
ImplStored.Does not attempt to materialize sub-expressions in
pytato.Array.shape.
Warning
This is a greedy materialization algorithm and thereby this algorithm might be too eager to materialize. Consider the graph below:
I1 I2 \ / \ / \ / 𑦠𑧠T / \ / \ / \ 𑧠𑦠O1 O2where, βI1β, βI2β correspond to instances of
pytato.array.InputArgumentBase, and, βO1β and βO2β are the outputs required to be evaluated in the computation graph. MPMS materialization algorithm will materialize the intermediate node βTβ as it has 2 predecessors and 2 successors. However, the total number of memory accesses after applying MPMS goes up as shown by the table below.Before
After
Reads
4
4
Writes
2
3
Total
6
7
- pytato.transform.deduplicate_data_wrappers(array_or_names: Array | AbstractResultWithNamedArrays) Array | AbstractResultWithNamedArrays[source]#
For the expression graph given as array_or_names, replace all
pytato.array.DataWrapperinstances containing identical data with a single instance.Note
Currently only supports
numpy.ndarrayandpyopencl.array.Array.Note
This function currently uses addresses of memory buffers to detect duplicate data, and so it may fail to deduplicate some instances of identical-but-separately-stored data. User code must tolerate this, but it must also tolerate this function doing a more thorough job of deduplication.
- pytato.transform.lower_to_index_lambda.to_index_lambda(expr: Array) IndexLambda[source]#
Lowers expr to
IndexLambdaif possible, otherwise raises apytato.diagnostic.CannotBeLoweredToIndexLambda.- Returns:
The lowered
IndexLambda.
- pytato.transform.remove_broadcasts_einsum.rewrite_einsums_with_no_broadcasts(expr: MappedT) MappedT[source]#
Rewrites all instances of
Einsumin expr such that the einsum expressions avoid broadcasting axes of its operands. We do so by updating thepytato.array.Einsum.access_descriptorsand slicing the operands.>>> a = pt.make_placeholder("a", (10, 4, 1), np.float64) >>> b = pt.make_placeholder("b", (10, 1, 4), np.float64) >>> expr = pt.einsum("ijk,ijk->i", a, b) >>> new_expr = pt.rewrite_einsums_with_no_broadcasts(expr) >>> pt.analysis.is_einsum_similar_to_subscript(new_expr, "ij,ik->i") True
Note
This transformation preserves the semantics of the expression i.e. does not alter its value.
- class pytato.transform.einsum_distributive_law.EinsumDistributiveLawDescriptor[source]#
Abstract-type that informs
apply_distributive_property_to_einsums()how should the distributive law be applied along an einsumβs operands.
- class pytato.transform.einsum_distributive_law.DoNotDistribute[source]#
Tells
apply_distributive_property_to_einsums()to not apply distributive law along any operands of theEinsum.
- class pytato.transform.einsum_distributive_law.DoDistribute(ioperand: int)[source]#
Tells
apply_distributive_property_to_einsums()to apply distributive law along the ioperand-th operands of theEinsum.
- pytato.transform.einsum_distributive_law.apply_distributive_property_to_einsums(expr: MappedT, how_to_distribute: Callable[[Array], EinsumDistributiveLawDescriptor]) MappedT[source]#
Returns a copy of expr after applying distributive law for einstein summation nodes in the expression graph.
>>> import pytato as pt >>> x1 = pt.make_placeholder("x1", 4, np.float64) >>> x2 = pt.make_placeholder("x2", 4, np.float64) >>> A = pt.make_placeholder("A", (10, 4), np.float64) >>> y = A @ (x1 + x2) >>> def how_to_distribute(expr): ... if pt.analysis.is_einsum_similar_to_subscript( ... expr, "ij,j->i"): ... return DoDistribute(ioperand=1) ... else: ... return DoNotDistribute() >>> y_transformed = apply_distributive_property_to_einsums(y, ... how_to_distribute) >>> y_transformed == A @ x1 + A @ x2 True
Dict representation of DAGs#
- class pytato.transform.UsersCollector[source]#
Maps a graph to a dictionary representation mapping a node to its users, i.e. all the nodes using its value.
- node_to_users#
Mapping of each node in the graph to its users.
- pytato.transform.tag_user_nodes(graph: Mapping[Array | AbstractResultWithNamedArrays, Set[Array | AbstractResultWithNamedArrays]], tag: Any, starting_point: Array | AbstractResultWithNamedArrays, node_to_tags: Dict[Array | AbstractResultWithNamedArrays, Set[Array | AbstractResultWithNamedArrays]] | None = None) Dict[Array | AbstractResultWithNamedArrays, Set[Any]][source]#
Tags all nodes reachable from starting_point with tag.
- Parameters:
graph β A
dictrepresentation of a directed graph, mapping each node to other nodes to which it is connected by edges. A possible use case for this function is the graph inUsersCollector.node_to_users.tag β The value to tag the nodes with.
starting_point β A starting point in graph.
node_to_tags β The resulting mapping of nodes to tags.
- pytato.transform.rec_get_user_nodes(expr: Array | AbstractResultWithNamedArrays, node: Array | AbstractResultWithNamedArrays) FrozenSet[Array | AbstractResultWithNamedArrays][source]#
Returns all direct and indirect users of node in expr.
Transforming call sites#
- pytato.transform.calls.inline_calls(expr: Array | AbstractResultWithNamedArrays) Array | AbstractResultWithNamedArrays[source]#
Returns a copy of expr with call sites tagged with
pytato.tags.InlineCallTaginlined into the expression graph.
- pytato.transform.calls.tag_all_calls_to_be_inlined(expr: Array | AbstractResultWithNamedArrays) Array | AbstractResultWithNamedArrays[source]#
Returns a copy of expr with all reachable instances of
pytato.function.Callnodes tagged withpytato.tags.InlineCallTag.Note
This routine does NOT inline calls, to inline the calls use
tag_all_calls_to_be_inlined()on this routineβs output.
Internal stuff that is only here because the documentation tool wants it#
- class pytato.transform.CombineT#
A type variable representing the type of a
CombineMapper.
- class pytato.transform._SelfMapper#
A type variable used to represent the type of a mapper in
CopyMapper.clone_for_callee().
Analyzing Array Expression Graphs#
- pytato.analysis.get_nusers(outputs: Array | DictOfNamedArrays) Mapping[Array, int][source]#
For the DAG outputs, returns the mapping from each node to the number of nodes using its value within the DAG given by outputs.
- pytato.analysis.is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) bool[source]#
Returns True if and only if an einsum with the subscript descriptor string subscripts operated on exprβs
pytato.array.Einsum.argswould compute the same result as expr.
- pytato.analysis.get_num_nodes(outputs: Array | DictOfNamedArrays) int[source]#
Returns the number of nodes in DAG outputs.
- pytato.analysis.get_num_call_sites(outputs: Array | DictOfNamedArrays) int[source]#
Returns the number of nodes in DAG outputs.
- class pytato.analysis.DirectPredecessorsGetter[source]#
Mapper to get the direct predecessors of a node.
Note
We only consider the predecessors of a nodes in a data-flow sense.
Visualizing Array Expression Graphs#
- pytato.get_dot_graph(result: Array | DictOfNamedArrays) str[source]#
Return a string in the dot language depicting the graph of the computation of result.
- Parameters:
result β Outputs of the computation (cf.
pytato.generate_loopy()).
- pytato.get_dot_graph_from_partition(partition: DistributedGraphPartition) str[source]#
Return a string in the dot language depicting the graph of the partitioned computation of partition.
- Parameters:
partition β Outputs of
find_distributed_partition().
- pytato.show_dot_graph(result: str | Array | DictOfNamedArrays | DistributedGraphPartition, **kwargs: Any) None[source]#
Show a graph representing the computation of result in a browser.
- Parameters:
result β Outputs of the computation (cf.
pytato.generate_loopy()) or the output ofget_dot_graph(), or the output offind_distributed_partition().kwargs β Passed on to
pytools.graphviz.show_dot()unmodified.
- pytato.show_fancy_placeholder_data_flow(dag: Array | DictOfNamedArrays, **kwargs: Any) None[source]#
Visualizes the data-flow from the placeholders into outputs.
- Parameters:
dag β The expression to be plotted.
kwargs β Graphviz visualization options to be passed to
pytools.graphviz.show_dot().
Note
This is a heavily opinionated visualization of data-flow graph in dag. Displaying all the information about the node is not the priority. See
pytato.show_dot_graph()that aims to be more verbose.
Comparing two expression Graphs#
- class pytato.equality.EqualityComparer[source]#
A
pytato.array.Arrayvisitor to check equality between two expression DAGs.Note
Compares two expression graphs
expr1,expr2in \(O(N)\) comparisons, where \(N\) is the number of nodes inexpr1.This visitor was introduced to memoize the sub-expression comparisons of the expressions to be compared. Not memoizing the sub-expression comparisons results in \(O(2^N)\) complexity for the comparison operation, where \(N\) is the number of nodes in expressions. See GH-Issue-163 <https://github.com/inducer/pytato/issues/163> for more on this.
Stringifying Expression Graphs#
Utilities and Diagnostics#
Helper routines#
- pytato.utils.are_shape_components_equal(dim1: int | integer | Array, dim2: int | integer | Array) bool[source]#
Returns True iff dim1 and dim2 are have equal
SizeParamcoefficients in their expressions.
- pytato.utils.are_shapes_equal(shape1: Tuple[int | integer | Array, ...], shape2: Tuple[int | integer | Array, ...]) bool[source]#
Returns True iff shape1 and shape2 have the same dimensionality and the correpsonding components are equal as defined by
are_shape_components_equal().
- pytato.utils.get_shape_after_broadcasting(exprs: Iterable[Array | number | int | bool_ | bool | float | Expression]) Tuple[int | integer | Array, ...][source]#
Returns the shape after broadcasting exprs in an operation.
- pytato.utils.dim_to_index_lambda_components(expr: int | integer | Array, vng: UniqueNameGenerator | None = None) Tuple[number | int | bool_ | bool | float | Expression, Dict[str, SizeParam]][source]#
Returns the scalar expressions and bindings to use the shape component within an index lambda.
>>> n = pt.make_size_param("n") >>> expr, bnds = dim_to_index_lambda_components(3*n+8, UniqueNameGenerator()) >>> print(expr) 3*_in + 8 >>> bnds {'_in': SizeParam(name='n')}
- pytato.utils.get_common_dtype_of_ary_or_scalars(ary_or_scalars: Sequence[Array | number | int | bool_ | bool | float]) dtype[source]#
- pytato.utils.get_einsum_subscript_str(expr: Einsum) str[source]#
Returns the index subscript expression that was used in constructing expr using the
pytato.einsum()routine.>>> A = pt.make_placeholder("A", (10, 6), np.float64) >>> B = pt.make_placeholder("B", (6, 5), np.float64) >>> C = pt.make_placeholder("B", (5, 4), np.float64) >>> ABC = pt.einsum("ij,jk,kl->il", A, B, C) >>> get_einsum_subscript_str(ABC) 'ij,jk,kl->il'
Pytato-specific exceptions#
- class pytato.diagnostic.NameClashError[source]#
Raised when 2 non-identical
InputArgumentBaseβs are reachable in anArrayβs DAG and share the same name. Here, we refer to 2 objectsaandbas being identical iffa is b.
- class pytato.diagnostic.UnknownIndexLambdaExpr[source]#
Raised when the structure
pytato.array.IndexLambdacould not be inferred.
- class pytato.diagnostic.CannotBeLoweredToIndexLambda[source]#
Raised when a
pytato.Arraywas expected to be lowered to anIndexLambda, but it cannot be. For ex. apytato.loopy.LoopyCallResultcannot be lowered to anIndexLambda.