Source code for pytato.transform

from __future__ import annotations

__copyright__ = """
Copyright (C) 2020 Matt Wala
"""

__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""

from typing import Any, Callable, Dict, FrozenSet, Union, TypeVar, Set, Generic

from pytato.array import (
        Array, IndexLambda, Placeholder, MatrixProduct, Stack, Roll,
        AxisPermutation, Slice, DataWrapper, SizeParam, DictOfNamedArrays,
        AbstractResultWithNamedArrays, Reshape, Concatenate, NamedArray,
        IndexRemappingBase, Einsum, InputArgumentBase)
from pytato.loopy import LoopyCall

T = TypeVar("T", Array, AbstractResultWithNamedArrays)
CombineT = TypeVar("CombineT")  # used in CombineMapper
ArrayOrNames = Union[Array, AbstractResultWithNamedArrays]
R = FrozenSet[Array]

__doc__ = """
.. currentmodule:: pytato.transform

.. autoclass:: CopyMapper
.. autoclass:: DependencyMapper
.. autoclass:: SubsetDependencyMapper
.. autoclass:: WalkMapper
.. autoclass:: CachedWalkMapper
.. autofunction:: copy_dict_of_named_arrays
.. autofunction:: get_dependencies

"""


# {{{ mapper classes

class UnsupportedArrayError(ValueError):
    pass


class Mapper:
    def handle_unsupported_array(self, expr: T, *args: Any, **kwargs: Any) -> Any:
        """Mapper method that is invoked for
        :class:`pytato.Array` subclasses for which a mapper
        method does not exist in this mapper.
        """
        raise UnsupportedArrayError("%s cannot handle expressions of type %s"
                % (type(self).__name__, type(expr)))

    def map_foreign(self, expr: Any, *args: Any, **kwargs: Any) -> Any:
        raise ValueError("%s encountered invalid foreign object: %s"
                % (type(self).__name__, repr(expr)))

    def rec(self, expr: T, *args: Any, **kwargs: Any) -> Any:
        method: Callable[..., Array]

        try:
            method = getattr(self, expr._mapper_method)
        except AttributeError:
            if isinstance(expr, Array):
                return self.handle_unsupported_array(expr, *args, **kwargs)
            else:
                return self.map_foreign(expr, *args, **kwargs)

        return method(expr, *args, **kwargs)

    def __call__(self, expr: Array, *args: Any, **kwargs: Any) -> Any:
        return self.rec(expr, *args, **kwargs)


[docs]class CopyMapper(Mapper): """Performs a deep copy of a :class:`pytato.array.Array`. The typical use of this mapper is to override individual ``map_`` methods in subclasses to permit term rewriting on an expression graph. """ def __init__(self) -> None: self.cache: Dict[ArrayOrNames, ArrayOrNames] = {} def rec(self, expr: T) -> T: # type: ignore if expr in self.cache: return self.cache[expr] # type: ignore result: T = super().rec(expr) self.cache[expr] = result return result def map_index_lambda(self, expr: IndexLambda) -> Array: bindings: Dict[str, Array] = { name: self.rec(subexpr) for name, subexpr in expr.bindings.items()} return IndexLambda(expr=expr.expr, shape=tuple(self.rec(s) if isinstance(s, Array) else s for s in expr.shape), dtype=expr.dtype, bindings=bindings, tags=expr.tags) def map_placeholder(self, expr: Placeholder) -> Array: return Placeholder(name=expr.name, shape=tuple(self.rec(s) if isinstance(s, Array) else s for s in expr.shape), dtype=expr.dtype, tags=expr.tags) def map_matrix_product(self, expr: MatrixProduct) -> Array: return MatrixProduct(x1=self.rec(expr.x1), x2=self.rec(expr.x2), tags=expr.tags) def map_stack(self, expr: Stack) -> Array: arrays = tuple(self.rec(arr) for arr in expr.arrays) return Stack(arrays=arrays, axis=expr.axis, tags=expr.tags) def map_roll(self, expr: Roll) -> Array: return Roll(array=self.rec(expr.array), shift=expr.shift, axis=expr.axis, tags=expr.tags) def map_axis_permutation(self, expr: AxisPermutation) -> Array: return AxisPermutation(array=self.rec(expr.array), axes=expr.axes, tags=expr.tags) def map_slice(self, expr: Slice) -> Array: return Slice(array=self.rec(expr.array), starts=expr.starts, stops=expr.stops, tags=expr.tags) def map_data_wrapper(self, expr: DataWrapper) -> Array: return DataWrapper(name=expr.name, data=expr.data, shape=tuple(self.rec(s) if isinstance(s, Array) else s for s in expr.shape), tags=expr.tags) def map_size_param(self, expr: SizeParam) -> Array: return SizeParam(name=expr.name, tags=expr.tags) def map_einsum(self, expr: Einsum) -> Array: return Einsum(expr.access_descriptors, tuple(self.rec(arg) for arg in expr.args)) def map_named_array(self, expr: NamedArray) -> NamedArray: return type(expr)(self.rec(expr._container), expr.name) def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> DictOfNamedArrays: return DictOfNamedArrays({key: self.rec(val.expr) for key, val in expr.items()})
class CombineMapper(Mapper, Generic[CombineT]): def __init__(self) -> None: self.cache: Dict[ArrayOrNames, CombineT] = {} def rec(self, expr: ArrayOrNames) -> CombineT: # type: ignore if expr in self.cache: return self.cache[expr] # type-ignore reason: type not compatible with super.rec() type result: CombineT = super().rec(expr) # type: ignore self.cache[expr] = result return result # type-ignore reason: incompatible ret. type with super class def __call__(self, expr: ArrayOrNames) -> CombineT: # type: ignore return self.rec(expr) def combine(self, *args: CombineT) -> CombineT: raise NotImplementedError def map_index_lambda(self, expr: IndexLambda) -> CombineT: return self.combine(*(self.rec(bnd) for bnd in expr.bindings.values()), *(self.rec(s) for s in expr.shape if isinstance(s, Array))) def map_placeholder(self, expr: Placeholder) -> CombineT: return self.combine(*(self.rec(s) for s in expr.shape if isinstance(s, Array))) def map_data_wrapper(self, expr: DataWrapper) -> CombineT: return self.combine(*(self.rec(s) for s in expr.shape if isinstance(s, Array))) def map_matrix_product(self, expr: MatrixProduct) -> CombineT: return self.combine(self.rec(expr.x1), self.rec(expr.x2)) def map_stack(self, expr: Stack) -> CombineT: return self.combine(*(self.rec(ary) for ary in expr.arrays)) def map_roll(self, expr: Roll) -> CombineT: return self.combine(self.rec(expr.array)) def map_axis_permutation(self, expr: AxisPermutation) -> CombineT: return self.combine(self.rec(expr.array)) def map_slice(self, expr: Slice) -> CombineT: return self.combine(self.rec(expr.array)) def map_reshape(self, expr: Reshape) -> CombineT: return self.combine(self.rec(expr.array)) def map_concatenate(self, expr: Concatenate) -> CombineT: return self.combine(*(self.rec(ary) for ary in expr.arrays)) def map_einsum(self, expr: Einsum) -> CombineT: return self.combine(*(self.rec(ary) for ary in expr.args)) def map_named_array(self, expr: NamedArray) -> CombineT: return self.combine(self.rec(expr._container)) def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> CombineT: return self.combine(*(self.rec(ary.expr) for ary in expr.values())) def map_loopy_call(self, expr: LoopyCall) -> CombineT: return self.combine(*(self.rec(ary) for ary in expr.bindings.values() if isinstance(ary, Array)))
[docs]class DependencyMapper(CombineMapper[R]): """ Maps a :class:`pytato.array.Array` to a :class:`frozenset` of :class:`pytato.array.Array`'s it depends on. .. warning:: This returns every node in the graph! Consider a custom :class:`CombineMapper` or a :class:`SubsetDependencyMapper` instead. """ def combine(self, *args: R) -> R: from functools import reduce return reduce(lambda a, b: a | b, args, frozenset()) def map_index_lambda(self, expr: IndexLambda) -> R: return self.combine(frozenset([expr]), super().map_index_lambda(expr)) def map_placeholder(self, expr: Placeholder) -> R: return self.combine(frozenset([expr]), super().map_placeholder(expr)) def map_data_wrapper(self, expr: DataWrapper) -> R: return self.combine(frozenset([expr]), super().map_data_wrapper(expr)) def map_size_param(self, expr: SizeParam) -> R: return frozenset([expr]) def map_matrix_product(self, expr: MatrixProduct) -> R: return self.combine(frozenset([expr]), super().map_matrix_product(expr)) def map_stack(self, expr: Stack) -> R: return self.combine(frozenset([expr]), super().map_stack(expr)) def map_roll(self, expr: Roll) -> R: return self.combine(frozenset([expr]), super().map_roll(expr)) def map_axis_permutation(self, expr: AxisPermutation) -> R: return self.combine(frozenset([expr]), super().map_axis_permutation(expr)) def map_slice(self, expr: Slice) -> R: return self.combine(frozenset([expr]), super().map_slice(expr)) def map_reshape(self, expr: Reshape) -> R: return self.combine(frozenset([expr]), super().map_reshape(expr)) def map_concatenate(self, expr: Concatenate) -> R: return self.combine(frozenset([expr]), super().map_concatenate(expr)) def map_einsum(self, expr: Einsum) -> R: return self.combine(frozenset([expr]), super().map_einsum(expr)) def map_named_array(self, expr: NamedArray) -> R: return self.combine(frozenset([expr]), super().map_named_array(expr))
[docs]class SubsetDependencyMapper(DependencyMapper): """ Mapper to combine the dependencies of an expression that are a subset of *universe*. """ def __init__(self, universe: FrozenSet[Array]): self.universe = universe super().__init__() def combine(self, *args: FrozenSet[Array]) -> FrozenSet[Array]: from functools import reduce return reduce(lambda acc, arg: acc | (arg & self.universe), args, frozenset())
class InputGatherer(CombineMapper[FrozenSet[InputArgumentBase]]): """ Mapper to combine all instances of :class:`pytato.InputArgumentBase` that an array expression depends on. """ def combine(self, *args: FrozenSet[InputArgumentBase] ) -> FrozenSet[InputArgumentBase]: from functools import reduce return reduce(lambda a, b: a | b, args, frozenset()) def map_placeholder(self, expr: Placeholder) -> FrozenSet[InputArgumentBase]: return self.combine(frozenset([expr]), super().map_placeholder(expr)) def map_data_wrapper(self, expr: DataWrapper) -> FrozenSet[InputArgumentBase]: return self.combine(frozenset([expr]), super().map_data_wrapper(expr)) def map_size_param(self, expr: SizeParam) -> FrozenSet[SizeParam]: return frozenset([expr]) # }}}
[docs]class WalkMapper(Mapper): """ A mapper that walks over all the arrays in a :class:`pytato.array.Array`. Users may override the specific mapper methods in a derived class or override :meth:`WalkMapper.visit` and :meth:`WalkMapper.post_visit`. .. automethod:: visit .. automethod:: post_visit """
[docs] def visit(self, expr: Any) -> bool: """ If this method returns *True*, *expr* is traversed during the walk. If this method returns *False*, *expr* is not traversed as a part of the walk. """ return True
[docs] def post_visit(self, expr: Any) -> None: """ Callback after *expr* has been traversed. """ pass
def map_index_lambda(self, expr: IndexLambda) -> None: if not self.visit(expr): return for child in expr.bindings.values(): self.rec(child) for dim in expr.shape: if isinstance(dim, Array): self.rec(dim) self.post_visit(expr) def map_placeholder(self, expr: Placeholder) -> None: if not self.visit(expr): return for dim in expr.shape: if isinstance(dim, Array): self.rec(dim) self.post_visit(expr) map_data_wrapper = map_placeholder map_size_param = map_placeholder def map_matrix_product(self, expr: MatrixProduct) -> None: if not self.visit(expr): return self.rec(expr.x1) self.rec(expr.x2) self.post_visit(expr) def _map_index_remapping_base(self, expr: IndexRemappingBase) -> None: if not self.visit(expr): return self.rec(expr.array) self.post_visit(expr) map_roll = _map_index_remapping_base map_axis_permutation = _map_index_remapping_base map_slice = _map_index_remapping_base map_reshape = _map_index_remapping_base def map_stack(self, expr: Stack) -> None: if not self.visit(expr): return for child in expr.arrays: self.rec(child) self.post_visit(expr) map_concatenate = map_stack def map_einsum(self, expr: Einsum) -> None: if not self.visit(expr): return for child in expr.args: self.rec(child) for dim in expr.shape: if isinstance(dim, Array): self.rec(dim) def map_loopy_call(self, expr: LoopyCall) -> None: if not self.visit(expr): return for child in expr.bindings.values(): if isinstance(child, Array): self.rec(child) self.post_visit(expr) def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None: if not self.visit(expr): return for child in expr._data.values(): self.rec(child) self.post_visit(expr) def map_named_array(self, expr: NamedArray) -> None: if not self.visit(expr): return self.rec(expr._container) self.post_visit(expr)
[docs]class CachedWalkMapper(WalkMapper): """ WalkMapper that visits each node in the DAG exactly once. This loses some information compared to :class:`WalkMapper` as a node is visited only from one of its predecessors. """ def __init__(self) -> None: self._visited_nodes: Set[ArrayOrNames] = set() # type-ignore reason: CachedWalkMapper.rec's type does not match # WalkMapper.rec's type def rec(self, expr: ArrayOrNames) -> None: # type: ignore if expr in self._visited_nodes: return # type-ignore reason: super().rec expects either 'Array' or # 'AbstractResultWithNamedArrays', passed 'ArrayOrNames' super().rec(expr) # type: ignore self._visited_nodes.add(expr)
# }}} # {{{ mapper frontends
[docs]def copy_dict_of_named_arrays(source_dict: DictOfNamedArrays, copy_mapper: CopyMapper) -> DictOfNamedArrays: """Copy the elements of a :class:`~pytato.DictOfNamedArrays` into a :class:`~pytato.DictOfNamedArrays`. :param source_dict: The :class:`~pytato.DictOfNamedArrays` to copy :param copy_mapper: A mapper that performs copies different array types :returns: A new :class:`~pytato.DictOfNamedArrays` containing copies of the items in *source_dict* """ if not source_dict: return DictOfNamedArrays({}) data = {name: copy_mapper(val.expr) for name, val in source_dict.items()} return DictOfNamedArrays(data)
[docs]def get_dependencies(expr: DictOfNamedArrays) -> Dict[str, FrozenSet[Array]]: """Returns the dependencies of each named array in *expr*. """ dep_mapper = DependencyMapper() return {name: dep_mapper(val.expr) for name, val in expr.items()}
# }}} # vim: foldmethod=marker