Transforming Array Expression Graphs¶
- class pytato.transform.Mapper[source]¶
A class that when called with a
pytato.Array
recursively iterates over the DAG, calling the _mapper_method of each node. Users of this class are expected to override the methods of this class or create a subclass.Note
This class might visit a node multiple times. Use a
CachedMapper
if this is not desired.- handle_unsupported_array(expr: ~pytato.transform.MappedT, *args: ~typing.~P, **kwargs: ~typing.~P) ResultT [source]¶
Mapper method that is invoked for
pytato.Array
subclasses for which a mapper method does not exist in this mapper.
- map_foreign(expr: ~typing.Any, *args: ~typing.~P, **kwargs: ~typing.~P) ResultT [source]¶
Mapper method that is invoked for an object of class for which a mapper method does not exist in this mapper.
- class pytato.transform.CachedMapper[source]¶
Mapper class that maps each node in the DAG exactly once. This loses some information compared to
Mapper
as a node is visited only from one of its predecessors.- get_cache_key(expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) Hashable [source]¶
- class pytato.transform.TransformMapper[source]¶
Base class for mappers that transform
pytato.array.Array
s into otherpytato.array.Array
s.Enables certain operations that can only be done if the mapping results are also arrays (e.g., calling
get_cache_key()
on them). Does not implement default mapper methods; for that, seeCopyMapper
.- clone_for_callee(function: FunctionDefinition) Self [source]¶
Called to clone self before starting traversal of a
pytato.function.FunctionDefinition
.
- class pytato.transform.TransformMapperWithExtraArgs[source]¶
Similar to
TransformMapper
, but each mapper method takes extra*args
,**kwargs
that are propagated along a path by default.The logic in
TransformMapper
purposely does not take the extra arguments to keep the cost of its each call frame low.- clone_for_callee(function: FunctionDefinition) Self [source]¶
Called to clone self before starting traversal of a
pytato.function.FunctionDefinition
.
- class pytato.transform.CopyMapper[source]¶
Performs a deep copy of a
pytato.array.Array
. The typical use of this mapper is to override individualmap_
methods in subclasses to permit term rewriting on an expression graph.Note
This does not copy the data of a
pytato.array.DataWrapper
.
- class pytato.transform.CopyMapperWithExtraArgs[source]¶
Similar to
CopyMapper
, but each mapper method takes extra*args
,**kwargs
that are propagated along a path by default.The logic in
CopyMapper
purposely does not take the extra arguments to keep the cost of its each call frame low.
- class pytato.transform.CombineMapper[source]¶
Abstract mapper that recursively combines the results of user nodes of a given expression.
- class pytato.transform.DependencyMapper[source]¶
Maps a
pytato.array.Array
to afrozenset
ofpytato.array.Array
’s it depends on.Warning
This returns every node in the graph! Consider a custom
CombineMapper
or aSubsetDependencyMapper
instead.
- class pytato.transform.InputGatherer[source]¶
Mapper to combine all instances of
pytato.array.InputArgumentBase
that an array expression depends on.
- class pytato.transform.SizeParamGatherer[source]¶
Mapper to combine all instances of
pytato.array.SizeParam
that an array expression depends on.
- class pytato.transform.SubsetDependencyMapper(universe: frozenset[Array])[source]¶
Mapper to combine the dependencies of an expression that are a subset of universe.
- class pytato.transform.WalkMapper[source]¶
A mapper that walks over all the arrays in a
pytato.Array
.Users may override the specific mapper methods in a derived class or override
WalkMapper.visit()
andWalkMapper.post_visit()
.
- class pytato.transform.CachedWalkMapper[source]¶
WalkMapper that visits each node in the DAG exactly once. This loses some information compared to
WalkMapper
as a node is visited only from one of its predecessors.
- class pytato.transform.TopoSortMapper¶
- class pytato.transform.CachedMapAndCopyMapper(map_fn: Callable[[ArrayOrNames], ArrayOrNames])[source]¶
Mapper that applies map_fn to each node and copies it. Results of traversals are memoized i.e. each node is mapped via map_fn exactly once.
- pytato.transform.copy_dict_of_named_arrays(source_dict: DictOfNamedArrays, copy_mapper: CopyMapper) DictOfNamedArrays [source]¶
Copy the elements of a
DictOfNamedArrays
into aDictOfNamedArrays
.- Parameters:
source_dict – The
DictOfNamedArrays
to copycopy_mapper – A mapper that performs copies different array types
- Returns:
A new
DictOfNamedArrays
containing copies of the items in source_dict
- pytato.transform.get_dependencies(expr: DictOfNamedArrays) dict[str, frozenset[Array]] [source]¶
Returns the dependencies of each named array in expr.
- pytato.transform.map_and_copy(expr: MappedT, map_fn: Callable[[ArrayOrNames], ArrayOrNames]) MappedT [source]¶
Returns a copy of expr with every array expression reachable from expr mapped via map_fn.
Note
Uses
CachedMapAndCopyMapper
under the hood and because of its caching nature each node is mapped exactly once.
- pytato.transform.materialize_with_mpms(expr: DictOfNamedArrays) DictOfNamedArrays [source]¶
Materialize nodes in expr with MPMS materialization strategy. MPMS stands for Multiple-Predecessors, Multiple-Successors.
Note
MPMS materialization strategy is a greedy materialization algorithm in which any node with more than 1 materialized predecessor and more than 1 successor is materialized.
Materializing here corresponds to tagging a node with
ImplStored
.Does not attempt to materialize sub-expressions in
pytato.Array.shape
.
Warning
This is a greedy materialization algorithm and thereby this algorithm might be too eager to materialize. Consider the graph below:
I1 I2 \ / \ / \ / 🡦 🡧 T / \ / \ / \ 🡧 🡦 O1 O2
where, ‘I1’, ‘I2’ correspond to instances of
pytato.array.InputArgumentBase
, and, ‘O1’ and ‘O2’ are the outputs required to be evaluated in the computation graph. MPMS materialization algorithm will materialize the intermediate node ‘T’ as it has 2 predecessors and 2 successors. However, the total number of memory accesses after applying MPMS goes up as shown by the table below.Before
After
Reads
4
4
Writes
2
3
Total
6
7
- pytato.transform.deduplicate_data_wrappers(array_or_names: Array | AbstractResultWithNamedArrays) Array | AbstractResultWithNamedArrays [source]¶
For the expression graph given as array_or_names, replace all
pytato.array.DataWrapper
instances containing identical data with a single instance.Note
Currently only supports
numpy.ndarray
andpyopencl.array.Array
.Note
This function currently uses addresses of memory buffers to detect duplicate data, and so it may fail to deduplicate some instances of identical-but-separately-stored data. User code must tolerate this, but it must also tolerate this function doing a more thorough job of deduplication.
- pytato.transform.lower_to_index_lambda.to_index_lambda(expr: Array) IndexLambda [source]¶
Lowers expr to
IndexLambda
if possible, otherwise raises apytato.diagnostic.CannotBeLoweredToIndexLambda
.- Returns:
The lowered
IndexLambda
.
- pytato.transform.remove_broadcasts_einsum.rewrite_einsums_with_no_broadcasts(expr: MappedT) MappedT [source]¶
Rewrites all instances of
Einsum
in expr such that the einsum expressions avoid broadcasting axes of its operands. We do so by updating thepytato.array.Einsum.access_descriptors
and slicing the operands.>>> a = pt.make_placeholder("a", (10, 4, 1), np.float64) >>> b = pt.make_placeholder("b", (10, 1, 4), np.float64) >>> expr = pt.einsum("ijk,ijk->i", a, b) >>> new_expr = pt.rewrite_einsums_with_no_broadcasts(expr) >>> pt.analysis.is_einsum_similar_to_subscript(new_expr, "ij,ik->i") True
Note
This transformation preserves the semantics of the expression i.e. does not alter its value.
- class pytato.transform.einsum_distributive_law.EinsumDistributiveLawDescriptor[source]¶
Abstract-type that informs
apply_distributive_property_to_einsums()
how should the distributive law be applied along an einsum’s operands.
- class pytato.transform.einsum_distributive_law.DoNotDistribute[source]¶
Tells
apply_distributive_property_to_einsums()
to not apply distributive law along any operands of theEinsum
.
- class pytato.transform.einsum_distributive_law.DoDistribute(ioperand: int)[source]¶
Tells
apply_distributive_property_to_einsums()
to apply distributive law along the ioperand-th operands of theEinsum
.
- pytato.transform.einsum_distributive_law.apply_distributive_property_to_einsums(expr: MappedT, how_to_distribute: Callable[[Array], EinsumDistributiveLawDescriptor]) MappedT [source]¶
Returns a copy of expr after applying distributive law for einstein summation nodes in the expression graph.
>>> import pytato as pt >>> x1 = pt.make_placeholder("x1", 4, np.float64) >>> x2 = pt.make_placeholder("x2", 4, np.float64) >>> A = pt.make_placeholder("A", (10, 4), np.float64) >>> y = A @ (x1 + x2) >>> def how_to_distribute(expr): ... if pt.analysis.is_einsum_similar_to_subscript( ... expr, "ij,j->i"): ... return DoDistribute(ioperand=1) ... else: ... return DoNotDistribute() >>> y_transformed = apply_distributive_property_to_einsums(y, ... how_to_distribute) >>> y_transformed == A @ x1 + A @ x2 True
Dict representation of DAGs¶
- class pytato.transform.UsersCollector[source]¶
Maps a graph to a dictionary representation mapping a node to its users, i.e. all the nodes using its value.
- node_to_users¶
Mapping of each node in the graph to its users.
- pytato.transform.rec_get_user_nodes(expr: Array | AbstractResultWithNamedArrays, node: Array | AbstractResultWithNamedArrays) frozenset[Array | AbstractResultWithNamedArrays] [source]¶
Returns all direct and indirect users of node in expr.
Transforming call sites¶
- pytato.transform.calls.inline_calls(expr: Array | AbstractResultWithNamedArrays) Array | AbstractResultWithNamedArrays [source]¶
Returns a copy of expr with call sites tagged with
pytato.tags.InlineCallTag
inlined into the expression graph.
- pytato.transform.calls.tag_all_calls_to_be_inlined(expr: Array | AbstractResultWithNamedArrays) Array | AbstractResultWithNamedArrays [source]¶
Returns a copy of expr with all reachable instances of
pytato.function.Call
nodes tagged withpytato.tags.InlineCallTag
.Note
This routine does NOT inline calls, to inline the calls use
tag_all_calls_to_be_inlined()
on this routine’s output.
Internal stuff that is only here because the documentation tool wants it¶
- class pytato.transform.ArrayOrNames¶
- class pytato.transform.CombineT¶
A type variable representing the type of a
CombineMapper
.
- class pytato.transform.Scalar¶
See
pymbolic.Scalar
.
Analyzing Array Expression Graphs¶
- pytato.analysis.get_nusers(outputs: Array | DictOfNamedArrays) Mapping[Array, int] [source]¶
For the DAG outputs, returns the mapping from each node to the number of nodes using its value within the DAG given by outputs.
- pytato.analysis.is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) bool [source]¶
Returns True if and only if an einsum with the subscript descriptor string subscripts operated on expr’s
pytato.array.Einsum.args
would compute the same result as expr.
- pytato.analysis.get_num_nodes(outputs: Array | DictOfNamedArrays, count_duplicates: bool | None = None) int [source]¶
Returns the number of nodes in DAG outputs. Instances of DictOfNamedArrays are excluded from counting.
- pytato.analysis.get_node_type_counts(outputs: Array | DictOfNamedArrays, count_duplicates: bool = False) dict[type[Any], int] [source]¶
Returns a dictionary mapping node types to node count for that type in DAG outputs.
Instances of DictOfNamedArrays are excluded from counting.
- pytato.analysis.get_node_multiplicities(outputs: Array | DictOfNamedArrays) dict[Array, int] [source]¶
Returns the multiplicity per expr.
- pytato.analysis.get_num_call_sites(outputs: Array | DictOfNamedArrays) int [source]¶
Returns the number of nodes in DAG outputs.
- class pytato.analysis.DirectPredecessorsGetter[source]¶
Mapper to get the direct predecessors of a node.
Note
We only consider the predecessors of a nodes in a data-flow sense.
Visualizing Array Expression Graphs¶
- pytato.get_dot_graph(result: Array | DictOfNamedArrays) str [source]¶
Return a string in the dot language depicting the graph of the computation of result.
- Parameters:
result – Outputs of the computation (cf.
pytato.generate_loopy()
).
- pytato.get_dot_graph_from_partition(partition: DistributedGraphPartition) str [source]¶
Return a string in the dot language depicting the graph of the partitioned computation of partition.
- Parameters:
partition – Outputs of
find_distributed_partition()
.
- pytato.show_dot_graph(result: str | Array | DictOfNamedArrays | DistributedGraphPartition, **kwargs: Any) None [source]¶
Show a graph representing the computation of result in a browser.
- Parameters:
result – Outputs of the computation (cf.
pytato.generate_loopy()
) or the output ofget_dot_graph()
, or the output offind_distributed_partition()
.kwargs – Passed on to
pytools.graphviz.show_dot()
unmodified.
- pytato.show_fancy_placeholder_data_flow(dag: Array | DictOfNamedArrays, **kwargs: Any) None [source]¶
Visualizes the data-flow from the placeholders into outputs.
- Parameters:
dag – The expression to be plotted.
kwargs – Graphviz visualization options to be passed to
pytools.graphviz.show_dot()
.
Note
This is a heavily opinionated visualization of data-flow graph in dag. Displaying all the information about the node is not the priority. See
pytato.show_dot_graph()
that aims to be more verbose.
Comparing two expression Graphs¶
- class pytato.equality.EqualityComparer[source]¶
A
pytato.array.Array
visitor to check equality between two expression DAGs.Note
Compares two expression graphs
expr1
,expr2
in \(O(N)\) comparisons, where \(N\) is the number of nodes inexpr1
.This visitor was introduced to memoize the sub-expression comparisons of the expressions to be compared. Not memoizing the sub-expression comparisons results in \(O(2^N)\) complexity for the comparison operation, where \(N\) is the number of nodes in expressions. See GH-Issue-163 <https://github.com/inducer/pytato/issues/163> for more on this.
Stringifying Expression Graphs¶
Utilities and Diagnostics¶
Helper routines¶
- pytato.utils.are_shape_components_equal(dim1: int | integer | Array, dim2: int | integer | Array) bool [source]¶
Returns True iff dim1 and dim2 are have equal
SizeParam
coefficients in their expressions.
- pytato.utils.are_shapes_equal(shape1: tuple[int | integer | Array, ...], shape2: tuple[int | integer | Array, ...]) bool [source]¶
Returns True iff shape1 and shape2 have the same dimensionality and the corresponding components are equal as defined by
are_shape_components_equal()
.
- pytato.utils.get_shape_after_broadcasting(exprs: Iterable[Array | Scalar]) ShapeType [source]¶
Returns the shape after broadcasting exprs in an operation.
- pytato.utils.dim_to_index_lambda_components(expr: ShapeComponent, vng: UniqueNameGenerator | None = None) tuple[ScalarExpression, dict[str, SizeParam]] [source]¶
Returns the scalar expressions and bindings to use the shape component within an index lambda.
>>> n = pt.make_size_param("n") >>> expr, bnds = dim_to_index_lambda_components(3*n+8, UniqueNameGenerator()) >>> print(expr) 3*_in + 8 >>> bnds {'_in': SizeParam(name='n')}
- pytato.utils.get_common_dtype_of_ary_or_scalars(ary_or_scalars: Sequence[ArrayOrScalar]) _dtype_any [source]¶
- pytato.utils.get_einsum_subscript_str(expr: Einsum) str [source]¶
Returns the index subscript expression that can be used in constructing expr using the
pytato.einsum()
routine.Deprecated: use get_einsum_specification_str instead.
>>> A = pt.make_placeholder("A", (10, 6), np.float64) >>> B = pt.make_placeholder("B", (6, 5), np.float64) >>> C = pt.make_placeholder("C", (5, 4), np.float64) >>> ABC = pt.einsum("ij,jk,kl->il", A, B, C) >>> get_einsum_subscript_str(ABC) 'ij,jk,kl->il'
References¶
Pytato-specific exceptions¶
- class pytato.diagnostic.NameClashError[source]¶
Raised when 2 non-identical
InputArgumentBase
’s are reachable in anArray
’s DAG and share the same name. Here, we refer to 2 objectsa
andb
as being identical iffa is b
.
- class pytato.diagnostic.UnknownIndexLambdaExpr[source]¶
Raised when the structure
pytato.array.IndexLambda
could not be inferred.
- class pytato.diagnostic.CannotBeLoweredToIndexLambda[source]¶
Raised when a
pytato.Array
was expected to be lowered to anIndexLambda
, but it cannot be. For ex. apytato.loopy.LoopyCallResult
cannot be lowered to anIndexLambda
.