Transforming Array Expression Graphs

class pytato.transform.Mapper[source]

A class that when called with a pytato.Array recursively iterates over the DAG, calling the _mapper_method of each node. Users of this class are expected to override the methods of this class or create a subclass.

Note

This class might visit a node multiple times. Use a CachedMapper if this is not desired.

handle_unsupported_array(expr: ~pytato.array.Array, *args: ~typing.~P, **kwargs: ~typing.~P) ResultT[source]

Mapper method that is invoked for pytato.Array subclasses for which a mapper method does not exist in this mapper.

rec(expr: ~pytato.array.Array | ~pytato.array.AbstractResultWithNamedArrays, *args: ~typing.~P, **kwargs: ~typing.~P) ResultT[source]

Call the mapper method of expr and return the result.

__call__(expr: ~pytato.array.Array | ~pytato.array.AbstractResultWithNamedArrays, *args: ~typing.~P, **kwargs: ~typing.~P) ResultT[source]
__call__(expr: ~pytato.function.FunctionDefinition, *args: ~typing.~P, **kwargs: ~typing.~P) FunctionResultT

Handle the mapping of expr.

class pytato.transform.CacheInputsWithKey(expr: ~pytato.transform.CacheExprT, key: ~collections.abc.Hashable, *args: ~typing.~P, **kwargs: ~typing.~P)[source]

Data structure for inputs to CachedMapperCache.

expr

The input expression being mapped.

args

A tuple of extra positional arguments.

kwargs

A dict of extra keyword arguments.

key

The cache key corresponding to expr and any additional inputs that were passed.

class pytato.transform.CachedMapperCache(err_on_collision: bool)[source]

Cache for mappers.

__init__(err_on_collision: bool) None[source]

Initialize the cache.

Parameters:

err_on_collision – Raise an exception if two distinct input expression instances have the same key.

Compute the key for an input expression.

add(inputs: CacheInputsWithKey[CacheExprT, P], result: CacheResultT) CacheResultT[source]

Cache a mapping result.

retrieve(inputs: CacheInputsWithKey[CacheExprT, P]) CacheResultT[source]

Retrieve the cached mapping result.

clear() None[source]

Reset the cache.

class pytato.transform.CachedMapper(err_on_collision: bool = False, _cache: CachedMapperCache[Array | AbstractResultWithNamedArrays, ResultT, P] | None = None, _function_cache: CachedMapperCache[FunctionDefinition, FunctionResultT, P] | None = None)[source]

Mapper class that maps each node in the DAG exactly once. This loses some information compared to Mapper as a node is visited only from one of its predecessors.

get_cache_key(expr: ~pytato.array.Array | ~pytato.array.AbstractResultWithNamedArrays, *args: ~typing.~P, **kwargs: ~typing.~P) Hashable[source]
get_function_definition_cache_key(expr: ~pytato.function.FunctionDefinition, *args: ~typing.~P, **kwargs: ~typing.~P) Hashable[source]
clone_for_callee(function: FunctionDefinition) Self[source]

Called to clone self before starting traversal of a pytato.function.FunctionDefinition.

class pytato.transform.TransformMapperCache(err_on_collision: bool, err_on_created_duplicate: bool)[source]

Cache for TransformMapper and TransformMapperWithExtraArgs.

__init__(err_on_collision: bool, err_on_created_duplicate: bool) None[source]

Initialize the cache.

Parameters:
  • err_on_collision – Raise an exception if two distinct input expression instances have the same key.

  • err_on_created_duplicate – Raise an exception if mapping produces a new array instance that has the same key as the input array.

add(inputs: CacheInputsWithKey[CacheExprT, P], result: CacheExprT) CacheExprT[source]

Cache a mapping result.

Returns the cached result (which may not be identical to result if a result was already cached with the same result key).

class pytato.transform.TransformMapper(err_on_collision: bool = False, err_on_created_duplicate: bool = False, _cache: TransformMapperCache[Array | AbstractResultWithNamedArrays, ()] | None = None, _function_cache: TransformMapperCache[FunctionDefinition, ()] | None = None)[source]

Base class for mappers that transform pytato.array.Arrays into other pytato.array.Arrays.

Enables certain operations that can only be done if the mapping results are also arrays (e.g., computing a cache key from them). Does not implement default mapper methods; for that, see CopyMapper.

__init__(err_on_collision: bool = False, err_on_created_duplicate: bool = False, _cache: TransformMapperCache[Array | AbstractResultWithNamedArrays, ()] | None = None, _function_cache: TransformMapperCache[FunctionDefinition, ()] | None = None) None[source]
Parameters:
  • err_on_collision – Raise an exception if two distinct input array instances have the same key.

  • err_on_created_duplicate – Raise an exception if mapping produces a new array instance that has the same key as the input array.

clone_for_callee(function: FunctionDefinition) Self[source]

Called to clone self before starting traversal of a pytato.function.FunctionDefinition.

class pytato.transform.TransformMapperWithExtraArgs(err_on_collision: bool = False, err_on_created_duplicate: bool = False, _cache: TransformMapperCache[Array | AbstractResultWithNamedArrays, P] | None = None, _function_cache: TransformMapperCache[FunctionDefinition, P] | None = None)[source]

Similar to TransformMapper, but each mapper method takes extra *args, **kwargs that are propagated along a path by default.

The logic in TransformMapper purposely does not take the extra arguments to keep the cost of its each call frame low.

__init__(err_on_collision: bool = False, err_on_created_duplicate: bool = False, _cache: TransformMapperCache[Array | AbstractResultWithNamedArrays, P] | None = None, _function_cache: TransformMapperCache[FunctionDefinition, P] | None = None) None[source]
Parameters:
  • err_on_collision – Raise an exception if two distinct input array instances have the same key.

  • err_on_created_duplicate – Raise an exception if mapping produces a new array instance that has the same key as the input array.

clone_for_callee(function: FunctionDefinition) Self[source]

Called to clone self before starting traversal of a pytato.function.FunctionDefinition.

class pytato.transform.CopyMapper(err_on_collision: bool = False, err_on_created_duplicate: bool = False, _cache: TransformMapperCache[Array | AbstractResultWithNamedArrays, ()] | None = None, _function_cache: TransformMapperCache[FunctionDefinition, ()] | None = None)[source]

Performs a deep copy of a pytato.array.Array. The typical use of this mapper is to override individual map_ methods in subclasses to permit term rewriting on an expression graph.

Note

This does not copy the data of a pytato.array.DataWrapper.

class pytato.transform.CopyMapperWithExtraArgs(err_on_collision: bool = False, err_on_created_duplicate: bool = False, _cache: TransformMapperCache[Array | AbstractResultWithNamedArrays, P] | None = None, _function_cache: TransformMapperCache[FunctionDefinition, P] | None = None)[source]

Similar to CopyMapper, but each mapper method takes extra *args, **kwargs that are propagated along a path by default.

The logic in CopyMapper purposely does not take the extra arguments to keep the cost of its each call frame low.

class pytato.transform.Deduplicator(_cache: TransformMapperCache[Array | AbstractResultWithNamedArrays, ()] | None = None, _function_cache: TransformMapperCache[FunctionDefinition, ()] | None = None)[source]

Removes duplicate nodes from an expression.

class pytato.transform.CombineMapper(err_on_collision: bool = False, _cache: CachedMapperCache[Array | AbstractResultWithNamedArrays, ResultT, P] | None = None, _function_cache: CachedMapperCache[FunctionDefinition, FunctionResultT, P] | None = None)[source]

Abstract mapper that recursively combines the results of user nodes of a given expression.

combine(*args: ResultT) ResultT[source]

Combine the arguments.

class pytato.transform.DependencyMapper(err_on_collision: bool = False, _cache: CachedMapperCache[Array | AbstractResultWithNamedArrays, ResultT, P] | None = None, _function_cache: CachedMapperCache[FunctionDefinition, FunctionResultT, P] | None = None)[source]

Maps a pytato.array.Array to a frozenset of pytato.array.Array’s it depends on.

Warning

This returns every node in the graph! Consider a custom CombineMapper or a SubsetDependencyMapper instead.

class pytato.transform.InputGatherer(err_on_collision: bool = False, _cache: CachedMapperCache[Array | AbstractResultWithNamedArrays, ResultT, P] | None = None, _function_cache: CachedMapperCache[FunctionDefinition, FunctionResultT, P] | None = None)[source]

Mapper to combine all instances of pytato.array.InputArgumentBase that an array expression depends on.

class pytato.transform.SizeParamGatherer(err_on_collision: bool = False, _cache: CachedMapperCache[Array | AbstractResultWithNamedArrays, ResultT, P] | None = None, _function_cache: CachedMapperCache[FunctionDefinition, FunctionResultT, P] | None = None)[source]

Mapper to combine all instances of pytato.array.SizeParam that an array expression depends on.

class pytato.transform.SubsetDependencyMapper(universe: frozenset[Array])[source]

Mapper to combine the dependencies of an expression that are a subset of universe.

class pytato.transform.WalkMapper[source]

A mapper that walks over all the arrays in a pytato.Array.

Users may override the specific mapper methods in a derived class or override WalkMapper.visit() and WalkMapper.post_visit().

visit(expr: ~typing.Any, *args: ~typing.~P, **kwargs: ~typing.~P) bool[source]

If this method returns True, expr is traversed during the walk. If this method returns False, expr is not traversed as a part of the walk.

post_visit(expr: ~typing.Any, *args: ~typing.~P, **kwargs: ~typing.~P) None[source]

Callback after expr has been traversed.

class pytato.transform.CachedWalkMapper(_visited_functions: set[Hashable] | None = None)[source]

WalkMapper that visits each node in the DAG exactly once. This loses some information compared to WalkMapper as a node is visited only from one of its predecessors.

class pytato.transform.TopoSortMapper(_visited_functions: set[Hashable] | None = None)
class pytato.transform.CachedMapAndCopyMapper(map_fn: Callable[[ArrayOrNames], ArrayOrNames], _cache: TransformMapperCache[ArrayOrNames, []] | None = None, _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None)[source]

Mapper that applies map_fn to each node and copies it. Results of traversals are memoized i.e. each node is mapped via map_fn exactly once.

pytato.transform.copy_dict_of_named_arrays(source_dict: DictOfNamedArrays, copy_mapper: CopyMapper) DictOfNamedArrays[source]

Copy the elements of a DictOfNamedArrays into a DictOfNamedArrays.

Parameters:
  • source_dict – The DictOfNamedArrays to copy

  • copy_mapper – A mapper that performs copies different array types

Returns:

A new DictOfNamedArrays containing copies of the items in source_dict

pytato.transform.deduplicate(expr: ArrayOrNamesOrFunctionDefTc) ArrayOrNamesOrFunctionDefTc[source]

Remove duplicate nodes from an expression.

Note

Does not remove distinct instances of data wrappers that point to the same data (as they will not hash the same). For a utility that does that, see deduplicate_data_wrappers().

pytato.transform.get_dependencies(expr: DictOfNamedArrays) dict[str, frozenset[Array]][source]

Returns the dependencies of each named array in expr.

pytato.transform.map_and_copy(expr: ArrayOrNamesTc, map_fn: Callable[[ArrayOrNames], ArrayOrNames]) ArrayOrNamesTc[source]

Returns a copy of expr with every array expression reachable from expr mapped via map_fn.

Note

Uses CachedMapAndCopyMapper under the hood and because of its caching nature each node is mapped exactly once.

pytato.transform.materialize_with_mpms(expr: ArrayOrNamesTc) ArrayOrNamesTc[source]

Materialize nodes in expr with MPMS materialization strategy. MPMS stands for Multiple-Predecessors, Multiple-Successors.

Note

  • MPMS materialization strategy is a greedy materialization algorithm in which any node with more than 1 materialized predecessor and more than 1 successor is materialized.

  • Materializing here corresponds to tagging a node with ImplStored.

  • Does not attempt to materialize sub-expressions in pytato.Array.shape.

Warning

This is a greedy materialization algorithm and thereby this algorithm might be too eager to materialize. Consider the graph below:

I1          I2
 \         /
  \       /
   \     /
    🡦   🡧
      T
     / \
    /   \
   /     \
  🡧       🡦
 O1        O2

where, ‘I1’, ‘I2’ correspond to instances of pytato.array.InputArgumentBase, and, ‘O1’ and ‘O2’ are the outputs required to be evaluated in the computation graph. MPMS materialization algorithm will materialize the intermediate node ‘T’ as it has 2 predecessors and 2 successors. However, the total number of memory accesses after applying MPMS goes up as shown by the table below.

Before

After

Reads

4

4

Writes

2

3

Total

6

7

pytato.transform.deduplicate_data_wrappers(array_or_names: ArrayOrNamesOrFunctionDefTc) ArrayOrNamesOrFunctionDefTc[source]

For the expression graph given as array_or_names, replace all pytato.array.DataWrapper instances containing identical data with a single instance.

Note

Currently only supports numpy.ndarray and pyopencl.array.Array.

Note

This function currently uses addresses of memory buffers to detect duplicate data, and so it may fail to deduplicate some instances of identical-but-separately-stored data. User code must tolerate this, but it must also tolerate this function doing a more thorough job of deduplication.

pytato.transform.lower_to_index_lambda.to_index_lambda(expr: Array) IndexLambda[source]

Lowers expr to IndexLambda if possible, otherwise raises a pytato.diagnostic.CannotBeLoweredToIndexLambda.

Returns:

The lowered IndexLambda.

pytato.transform.remove_broadcasts_einsum.rewrite_einsums_with_no_broadcasts(expr: ArrayOrNamesTc) ArrayOrNamesTc[source]

Rewrites all instances of Einsum in expr such that the einsum expressions avoid broadcasting axes of its operands. We do so by updating the pytato.Einsum.access_descriptors and slicing the operands.

>>> a = pt.make_placeholder("a", (10, 4, 1), np.float64)
>>> b = pt.make_placeholder("b", (10, 1, 4), np.float64)
>>> expr = pt.einsum("ijk,ijk->i", a, b)
>>> new_expr = pt.rewrite_einsums_with_no_broadcasts(expr)
>>> pt.analysis.is_einsum_similar_to_subscript(new_expr, "ij,ik->i")
True

Note

This transformation preserves the semantics of the expression i.e. does not alter its value.

class pytato.transform.einsum_distributive_law.EinsumDistributiveLawDescriptor[source]

Abstract-type that informs apply_distributive_property_to_einsums() how should the distributive law be applied along an einsum’s operands.

class pytato.transform.einsum_distributive_law.DoNotDistribute[source]

Tells apply_distributive_property_to_einsums() to not apply distributive law along any operands of the Einsum.

class pytato.transform.einsum_distributive_law.DoDistribute(ioperand: int)[source]

Tells apply_distributive_property_to_einsums() to apply distributive law along the ioperand-th operands of the Einsum.

pytato.transform.einsum_distributive_law.apply_distributive_property_to_einsums(expr: ArrayOrNamesTc, how_to_distribute: Callable[[Array], EinsumDistributiveLawDescriptor]) ArrayOrNamesTc[source]

Returns a copy of expr after applying distributive law for einstein summation nodes in the expression graph.

>>> import pytato as pt
>>> x1 = pt.make_placeholder("x1", 4, np.float64)
>>> x2 = pt.make_placeholder("x2", 4, np.float64)
>>> A = pt.make_placeholder("A", (10, 4), np.float64)
>>> y = A @ (x1 + x2)

>>> def how_to_distribute(expr):
...     if pt.analysis.is_einsum_similar_to_subscript(
...         expr, "ij,j->i"):
...         return DoDistribute(ioperand=1)
...     else:
...         return DoNotDistribute()

>>> y_transformed = apply_distributive_property_to_einsums(y,
...                     how_to_distribute)

>>> y_transformed == A @ x1 + A @ x2
True
pytato.unify_axes_tags(expr: ~pytato.transform.ArrayOrNamesOrFunctionDefTc, *, tag_t: type[~pytools.tag.Tag] = <class 'pytools.tag.Tag'>, equations_collector_t: type[~pytato.transform.metadata.AxesTagsEquationCollector] = <class 'pytato.transform.metadata.AxesTagsEquationCollector'>, unify_redn_descrs: bool = True) ArrayOrNamesOrFunctionDefTc[source]

Returns a copy of expr with tags of type tag_t propagated along the array operations with the tags propagation semantics implemented in equations_collector_t. By propagation, we mean that we solve the unification equations assembled in equations to obtain a mapping from an Array‘s axis to the new tags it is to be tagged with. We use this mapping to add more tags to an array’s axis.

Note

  • This routine by itself does not raise if a particular array’s axis is tagged with multiple tags of type tag_t. If such behavior is not expected, ensure that tag_t is a subclass of UniqueTag.

class pytato.transform.metadata.AxisTagAttacher(axis_to_tags: Mapping[tuple[Array, int | str], Collection[Tag]], tag_corresponding_redn_descr: bool, _cache: TransformMapperCache[ArrayOrNames, []] | None = None, _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None)[source]

A mapper that tags the axes in a DAG as prescribed by axis_to_tags.

class pytato.transform.metadata.AxisIgnoredForPropagationTag[source]

Disallows tags from propagating across axes equipped with this tag. Effectively removes an edge from the undirected graph whose edges represent propagation pathways. The default tag propagation behavior in the case of an einsum is to propagate all tags across non-reduction axes. Since this is not always desirable, this tag can be used to disable the default behavior at axis-granularity.

class pytato.transform.metadata.AxesTagsEquationCollector(tag_t: type[Tag])[source]

Records equations arising from operand/output axes equivalence for an array operation. An equation is recorded for “straight-through” axes in expressions, i.e. ones that pass through an operation unmodified.

tag_t: type[Tag]

The type of the tags that are to be propagated.

equations: list[tuple[str, str]]

A list of equations. Each equation is represented by 2-tuple as ("u", "v") that is mathematically interpreted as \(\{u \doteq v\}\).

known_tag_to_var: dict[Tag, str]

A mapping from an instance of pytools.tag.Tag to a str by which it will be referenced in equations.

axis_to_var: bidict[tuple[Array, int | str], str]

A bidict from a tuple of the form (array, iaxis) to the str by which it will be referenced in equations.

map_index_lambda(expr: IndexLambda) None[source]
map_placeholder(expr: InputArgumentBase) None[source]

A pytato.InputArgumentBase does not have any operands i.e. no propagation equations are recorded.

map_data_wrapper(expr: InputArgumentBase) None[source]

A pytato.InputArgumentBase does not have any operands i.e. no propagation equations are recorded.

map_size_param(expr: InputArgumentBase) None[source]

A pytato.InputArgumentBase does not have any operands i.e. no propagation equations are recorded.

map_reshape(expr: Reshape) None[source]

Reshaping generally does not preserve the axis between its input and output and so no constraints are enforced except when the pytato.Reshape has come from a pytato.expand_dims().

map_basic_index(expr: BasicIndex) None[source]
map_contiguous_advanced_index(expr: AdvancedIndexInContiguousAxes) None[source]
map_stack(expr: Stack) None[source]
map_concatenate(expr: Concatenate) None[source]

Note

Users may subclass this mapper to implement domain-specific axis tags propagation semantics.

class pytato.transform.metadata.Tag[source]

See pytools.tag.Tag.

class pytato.transform.dead_code_elimination.DeadCodeEliminator(err_on_collision: bool = False, err_on_created_duplicate: bool = False, _cache: TransformMapperCache[Array | AbstractResultWithNamedArrays, ()] | None = None, _function_cache: TransformMapperCache[FunctionDefinition, ()] | None = None)[source]
class pytato.eliminate_dead_code(expr: ArrayOrNamesTc)[source]

Removes dead subexpressions from expr.

Note

Currently the following sub-expressions are eliminated:

  • Arguments in calls to pt.zero.

Dict representation of DAGs

class pytato.transform.UsersCollector[source]

Maps a graph to a dictionary representation mapping a node to its users, i.e. all the nodes using its value.

node_to_users

Mapping of each node in the graph to its users.

__init__() None[source]
pytato.transform.rec_get_user_nodes(expr: Array | AbstractResultWithNamedArrays, node: Array | AbstractResultWithNamedArrays) frozenset[Array | AbstractResultWithNamedArrays][source]

Returns all direct and indirect users of node in expr.

Transforming call sites

pytato.transform.calls.inline_calls(expr: ArrayOrNamesTc) ArrayOrNamesTc[source]

Returns a copy of expr with call sites tagged with pytato.tags.InlineCallTag inlined into the expression graph.

pytato.transform.calls.tag_all_calls_to_be_inlined(expr: ArrayOrNamesTc) ArrayOrNamesTc[source]

Returns a copy of expr with all reachable instances of pytato.function.Call nodes tagged with pytato.tags.InlineCallTag.

Note

This routine does NOT inline calls, to inline the calls use tag_all_calls_to_be_inlined() on this routine’s output.

Internal stuff that is only here because the documentation tool wants it

class pytato.transform.ArrayOrNames
class pytato.transform.ArrayOrNamesTc

A type variable representing the input type of a Mapper, excluding functions.

class pytato.transform.ArrayOrNamesOrFunctionDefTc

A type variable representing the input type of a Mapper, including functions.

class pytato.transform.ResultT

A type variable representing the result type of a Mapper when mapping a pytato.Array or pytato.AbstractResultWithNamedArrays.

class pytato.transform.FunctionResultT

A type variable representing the result type of a Mapper when mapping a pytato.function.FunctionDefinition.

class pytato.transform.CacheExprT

A type variable representing an input from which to compute a cache key in order to cache a result.

class pytato.transform.CacheKeyT[source]

A type variable representing a key computed from an input expression.

class pytato.transform.CacheResultT

A type variable representing a result to be cached.

class pytato.transform.Scalar

See pymbolic.Scalar.

class pytato.transform.P

A typing.ParamSpec used to annotate *args and **kwargs.

Analyzing Array Expression Graphs

pytato.analysis.get_nusers(outputs: ArrayOrNames) Mapping[Array, int][source]

For the DAG outputs, returns the mapping from each array node to the number of nodes using its value within the DAG given by outputs.

pytato.analysis.is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) bool[source]

Returns True if and only if an einsum with the subscript descriptor string subscripts operated on expr’s pytato.Einsum.args would compute the same result as expr.

pytato.analysis.get_num_nodes(outputs: Array | AbstractResultWithNamedArrays, count_duplicates: bool | None = None) int[source]

Returns the number of nodes in DAG outputs. Instances of DictOfNamedArrays are excluded from counting.

pytato.analysis.get_node_type_counts(outputs: Array | AbstractResultWithNamedArrays, count_duplicates: bool = False) dict[type[Any], int][source]

Returns a dictionary mapping node types to node count for that type in DAG outputs.

Instances of DictOfNamedArrays are excluded from counting.

pytato.analysis.get_node_multiplicities(outputs: Array | AbstractResultWithNamedArrays) dict[Array, int][source]

Returns the multiplicity per expr.

pytato.analysis.get_num_call_sites(outputs: Array | AbstractResultWithNamedArrays) int[source]

Returns the number of nodes in DAG outputs.

class pytato.analysis.DirectPredecessorsGetter(*, include_functions: bool = False)[source]

Helper to get the direct predecessors of a node.

Note

We only consider the predecessors of a node in a data-flow sense.

class pytato.analysis.ListOfDirectPredecessorsGetter(*, include_functions: bool = False)[source]

Helper to get the direct predecessors of a node.

Note

We only consider the predecessors of a node in a data-flow sense.

class pytato.analysis.TagCountMapper(tag_types: type[pytools.tag.Tag] | Iterable[type[pytools.tag.Tag]])[source]

Returns the number of nodes in a DAG that are tagged with all the tag types in tag_types.

pytato.analysis.get_num_tags_of_type(outputs: ArrayOrNames, tag_types: type[pytools.tag.Tag] | Iterable[type[pytools.tag.Tag]]) int[source]

Returns the number of nodes in DAG outputs that are tagged with all the tag types in tag_types.

Visualizing Array Expression Graphs

pytato.get_dot_graph(result: Array | AbstractResultWithNamedArrays) str[source]

Return a string in the dot language depicting the graph of the computation of result.

Parameters:

result – Outputs of the computation (cf. pytato.generate_loopy()).

pytato.get_dot_graph_from_partition(partition: DistributedGraphPartition) str[source]

Return a string in the dot language depicting the graph of the partitioned computation of partition.

Parameters:

partition – Outputs of find_distributed_partition().

pytato.show_dot_graph(result: str | Array | DictOfNamedArrays | DistributedGraphPartition, **kwargs: Any) None[source]

Show a graph representing the computation of result in a browser.

Parameters:
pytato.show_fancy_placeholder_data_flow(dag: Array | DictOfNamedArrays, **kwargs: Any) None[source]

Visualizes the data-flow from the placeholders into outputs.

Parameters:

Note

This is a heavily opinionated visualization of data-flow graph in dag. Displaying all the information about the node is not the priority. See pytato.show_dot_graph() that aims to be more verbose.

Comparing two expression Graphs

class pytato.equality.EqualityComparer[source]

A pytato.array.Array visitor to check equality between two expression DAGs.

Note

  • Compares two expression graphs expr1, expr2 in \(O(N)\) comparisons, where \(N\) is the number of nodes in expr1.

  • This visitor was introduced to memoize the sub-expression comparisons of the expressions to be compared. Not memoizing the sub-expression comparisons results in \(O(2^N)\) complexity for the comparison operation, where \(N\) is the number of nodes in expressions. See GH-Issue-163 <https://github.com/inducer/pytato/issues/163> for more on this.

Stringifying Expression Graphs

class pytato.stringifier.Reprifier(truncation_depth: int = 3, truncation_string: str = '(...)')[source]

Stringifies pytato-types to closely resemble CPython’s implementation of repr() for its builtin datatypes.

Utilities and Diagnostics

Helper routines

pytato.utils.are_shape_components_equal(dim1: ShapeComponent, dim2: ShapeComponent) bool[source]

Returns True iff dim1 and dim2 are have equal SizeParam coefficients in their expressions.

pytato.utils.are_shapes_equal(shape1: ShapeType, shape2: ShapeType) bool[source]

Returns True iff shape1 and shape2 have the same dimensionality and the corresponding components are equal as defined by are_shape_components_equal().

pytato.utils.get_shape_after_broadcasting(exprs: Iterable[Array | Scalar]) ShapeType[source]

Returns the shape after broadcasting exprs in an operation.

pytato.utils.dim_to_index_lambda_components(expr: ShapeComponent, vng: UniqueNameGenerator | None = None) tuple[ScalarExpression, dict[str, SizeParam]][source]

Returns the scalar expressions and bindings to use the shape component within an index lambda.

>>> n = pt.make_size_param("n")
>>> expr, bnds = dim_to_index_lambda_components(3*n+8, UniqueNameGenerator())
>>> print(expr)
3*_in + 8
>>> bnds
{'_in': SizeParam(name='n')}
pytato.utils.get_common_dtype_of_ary_or_scalars(ary_or_scalars: Sequence[ArrayOrScalar]) _dtype_any[source]
pytato.utils.get_einsum_subscript_str(expr: Einsum) str[source]

Returns the index subscript expression that can be used in constructing expr using the pytato.einsum() routine.

Deprecated: use get_einsum_specification_str instead.

>>> A = pt.make_placeholder("A", (10, 6), np.float64)
>>> B = pt.make_placeholder("B", (6, 5), np.float64)
>>> C = pt.make_placeholder("C", (5, 4), np.float64)
>>> ABC = pt.einsum("ij,jk,kl->il", A, B, C)
>>> get_einsum_subscript_str(ABC)
'ij,jk,kl->il'

References

class pytato.utils.UniqueNameGenerator[source]

See pytools.UniqueNameGenerator.

Pytato-specific exceptions

class pytato.diagnostic.NameClashError[source]

Raised when 2 non-identical InputArgumentBase’s are reachable in an Array’s DAG and share the same name. Here, we refer to 2 objects a and b as being identical iff a is b.

class pytato.diagnostic.CannotBroadcastError[source]
class pytato.diagnostic.UnknownIndexLambdaExpr[source]

Raised when the structure pytato.array.IndexLambda could not be inferred.

class pytato.diagnostic.CannotBeLoweredToIndexLambda[source]

Raised when a pytato.Array was expected to be lowered to an IndexLambda, but it cannot be. For ex. a pytato.loopy.LoopyCallResult cannot be lowered to an IndexLambda.