Source code for pytato.codegen

from __future__ import annotations

__copyright__ = """Copyright (C) 2020 Matt Wala"""

__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""

import dataclasses
from functools import partialmethod
from typing import Union, Dict, Tuple, Callable, List, Any

import pymbolic.primitives as prim
from pymbolic import var

from pytato.array import (Array, DictOfNamedArrays, IndexLambda,
                          DataWrapper, Roll, AxisPermutation, Slice,
                          IndexRemappingBase, Stack, Placeholder, Reshape,
                          Concatenate, DataInterface, SizeParam,
                          InputArgumentBase, MatrixProduct, Einsum)

from pytato.scalar_expr import ScalarExpression, IntegralScalarExpression
from pytato.transform import CopyMapper, CachedWalkMapper, SubsetDependencyMapper
from pytato.target import Target
from pytato.loopy import LoopyCall
from pytools import UniqueNameGenerator
import loopy as lp
SymbolicIndex = Tuple[IntegralScalarExpression, ...]


__doc__ = """
.. currentmodule:: pytato.codegen

.. autoclass:: CodeGenPreprocessor
.. autoclass:: PreprocessResult

.. autofunction:: preprocess
.. autofunction:: normalize_outputs
"""


# {{{ preprocessing for codegen

[docs]class CodeGenPreprocessor(CopyMapper): """A mapper that preprocesses graphs to simplify code generation. The following node simplifications are performed: ====================================== ===================================== Source Node Type Target Node Type ====================================== ===================================== :class:`~pytato.array.DataWrapper` :class:`~pytato.array.Placeholder` :class:`~pytato.array.Roll` :class:`~pytato.array.IndexLambda` :class:`~pytato.array.AxisPermutation` :class:`~pytato.array.IndexLambda` :class:`~pytato.array.Slice` :class:`~pytato.array.IndexLambda` :class:`~pytato.array.Reshape` :class:`~pytato.array.IndexLambda` :class:`~pytato.array.Concatenate` :class:`~pytato.array.IndexLambda` :class:`~pytato.array.MatrixProduct` :class:`~pytato.array.IndexLambda` :class:`~pytato.array.Einsum` :class:`~pytato.array.IndexLambda` ====================================== ===================================== """ # TODO: # Stack -> IndexLambda # MatrixProduct -> Einsum def __init__(self, target: Target) -> None: super().__init__() self.bound_arguments: Dict = {} self.var_name_gen: UniqueNameGenerator = UniqueNameGenerator() self.target = target self.kernels_seen: Dict[str, lp.LoopKernel] = {} def map_size_param(self, expr: SizeParam) -> Array: name = expr.name assert name is not None return SizeParam(name=name, tags=expr.tags) def map_placeholder(self, expr: Placeholder) -> Array: name = expr.name if name is None: name = self.var_name_gen("_pt_in") return Placeholder(name=name, shape=tuple(self.rec(s) if isinstance(s, Array) else s for s in expr.shape), dtype=expr.dtype, tags=expr.tags) def map_loopy_call(self, expr: LoopyCall) -> LoopyCall: from pytato.target.loopy import LoopyTarget if not isinstance(self.target, LoopyTarget): raise ValueError("Got a LoopyCall for a non-loopy target.") translation_unit = expr.translation_unit.copy( target=self.target.get_loopy_target()) namegen = UniqueNameGenerator(set(self.kernels_seen)) entrypoint = expr.entrypoint # {{{ eliminate callable name collision for name, clbl in translation_unit.callables_table.items(): if isinstance(clbl, lp.kernel.function_interface.CallableKernel): if name in self.kernels_seen and ( translation_unit[name] != self.kernels_seen[name]): # callee name collision => must rename # {{{ see if it's one of the other kernels for other_knl in self.kernels_seen.values(): if other_knl.copy(name=name) == translation_unit[name]: new_name = other_knl.name break else: # didn't find any other equivalent kernel, rename to # something unique new_name = namegen(name) # }}} if name == entrypoint: # if the colliding name is the entrypoint, then rename the # entrypoint as well. entrypoint = new_name translation_unit = lp.rename_callable( translation_unit, name, new_name) name = new_name self.kernels_seen[name] = translation_unit[name] # }}} bindings = {name: (self.rec(subexpr) if isinstance(subexpr, Array) else subexpr) for name, subexpr in expr.bindings.items()} return LoopyCall(translation_unit=translation_unit, bindings=bindings, # type: ignore entrypoint=entrypoint) def map_data_wrapper(self, expr: DataWrapper) -> Array: name = expr.name if name is None: name = self.var_name_gen("_pt_in") self.bound_arguments[name] = expr.data return Placeholder(name=name, shape=tuple(self.rec(s) if isinstance(s, Array) else s for s in expr.shape), dtype=expr.dtype, tags=expr.tags) def map_stack(self, expr: Stack) -> Array: def get_subscript(array_index: int) -> SymbolicIndex: result = [] for i in range(expr.ndim): if i != expr.axis: result.append(var(f"_{i}")) return tuple(result) # I = axis index # # => If(_I == 0, # _in0[_0, _1, ...], # If(_I == 1, # _in1[_0, _1, ...], # ... # _inNm1[_0, _1, ...] ...)) for i in range(len(expr.arrays) - 1, -1, -1): subarray_expr = var(f"_in{i}")[get_subscript(i)] if i == len(expr.arrays) - 1: stack_expr = subarray_expr else: from pymbolic.primitives import If, Comparison stack_expr = If(Comparison(var(f"_{expr.axis}"), "==", i), subarray_expr, stack_expr) bindings = {f"_in{i}": self.rec(array) for i, array in enumerate(expr.arrays)} return IndexLambda(expr=stack_expr, shape=tuple(self.rec(s) if isinstance(s, Array) else s for s in expr.shape), dtype=expr.dtype, bindings=bindings) def map_concatenate(self, expr: Concatenate) -> Array: from pymbolic.primitives import If, Comparison, Subscript def get_subscript(array_index: int, offset: ScalarExpression) -> Subscript: aggregate = var(f"_in{array_index}") index = [var(f"_{i}") if i != expr.axis else (var(f"_{i}") - offset) for i in range(len(expr.shape))] return Subscript(aggregate, tuple(index)) lbounds: List[Any] = [0] ubounds: List[Any] = [expr.arrays[0].shape[expr.axis]] for i, array in enumerate(expr.arrays[1:], start=1): ubounds.append(ubounds[i-1]+array.shape[expr.axis]) lbounds.append(ubounds[i-1]) # I = axis index # # => If(0<=_I < arrays[0].shape[axis], # _in0[_0, _1, ..., _I, ...], # If(arrays[0].shape[axis]<= _I < (arrays[1].shape[axis] # +arrays[0].shape[axis]), # _in1[_0, _1, ..., _I-arrays[0].shape[axis], ...], # ... # _inNm1[_0, _1, ...] ...)) for i in range(len(expr.arrays) - 1, -1, -1): lbound, ubound = lbounds[i], ubounds[i] subarray_expr = get_subscript(i, lbound) if i == len(expr.arrays) - 1: stack_expr = subarray_expr else: stack_expr = If(Comparison(var(f"_{expr.axis}"), ">=", lbound) and Comparison(var(f"_{expr.axis}"), "<", ubound), subarray_expr, stack_expr) bindings = {f"_in{i}": self.rec(array) for i, array in enumerate(expr.arrays)} return IndexLambda(expr=stack_expr, shape=tuple(self.rec(s) if isinstance(s, Array) else s for s in expr.shape), dtype=expr.dtype, bindings=bindings) def map_roll(self, expr: Roll) -> Array: from pytato.utils import dim_to_index_lambda_components index_expr = var("_in0") indices = [var(f"_{d}") for d in range(expr.ndim)] axis = expr.axis axis_len_expr, bindings = dim_to_index_lambda_components( expr.shape[axis], UniqueNameGenerator({"_in0"})) indices[axis] = (indices[axis] - expr.shift) % axis_len_expr if indices: index_expr = index_expr[tuple(indices)] bindings["_in0"] = expr.array # type: ignore return IndexLambda(expr=index_expr, shape=tuple(self.rec(s) if isinstance(s, Array) else s for s in expr.shape), dtype=expr.dtype, bindings={name: self.rec(bnd) for name, bnd in bindings.items()}) def map_matrix_product(self, expr: MatrixProduct) -> Array: from pytato.utils import dim_to_index_lambda_components from pytato.scalar_expr import Reduce x1 = prim.Subscript(prim.Variable("in0"), (tuple(prim.Variable(f"_{i}") for i in range(len(expr.x1.shape)-1)) + (prim.Variable("_r0"),)) ) x2_i_start = len(expr.x1.shape) - 1 x2 = prim.Subscript(prim.Variable("in1"), (prim.Variable("_r0"),) + tuple(prim.Variable(f"_{i+x2_i_start}") for i in range(len(expr.x2.shape)-1))) namegen = UniqueNameGenerator({"in0", "in1"}) redn_bound, redn_bound_bindings = dim_to_index_lambda_components( expr.x1.shape[-1], namegen) bindings = {k: self.rec(v) for k, v in redn_bound_bindings.items()} bindings["in0"] = self.rec(expr.x1) bindings["in1"] = self.rec(expr.x2) inner_expr = Reduce( x1*x2, "sum", {"_r0": (0, redn_bound)}) return IndexLambda( expr=inner_expr, shape=tuple(self.rec(s) if isinstance(s, Array) else s for s in expr.shape), dtype=expr.dtype, bindings=bindings) def map_einsum(self, expr: Einsum) -> Array: import operator from functools import reduce from pytato.scalar_expr import Reduce from pytato.utils import dim_to_index_lambda_components from pytato.array import ElementwiseAxis, ReductionAxis bindings = {f"in{k}": self.rec(arg) for k, arg in enumerate(expr.args)} redn_bounds: Dict[str, Tuple[ScalarExpression, ScalarExpression]] = {} args_as_pym_expr: List[prim.Subscript] = [] namegen = UniqueNameGenerator(set(bindings)) # {{{ add bindings coming from the shape expressions for access_descr, (iarg, arg) in zip(expr.access_descriptors, enumerate(expr.args)): subscript_indices = [] for iaxis, axis in enumerate(access_descr): if isinstance(axis, ElementwiseAxis): subscript_indices.append(prim.Variable(f"_{axis.dim}")) else: assert isinstance(axis, ReductionAxis) redn_idx_name = f"_r{axis.dim}" if redn_idx_name not in redn_bounds: # convert the ShapeComponent to a ScalarExpression redn_bound, redn_bound_bindings = ( dim_to_index_lambda_components( arg.shape[iaxis], namegen)) redn_bounds[redn_idx_name] = (0, redn_bound) bindings.update({k: self.rec(v) for k, v in redn_bound_bindings.items()}) subscript_indices.append(prim.Variable(redn_idx_name)) args_as_pym_expr.append(prim.Subscript(prim.Variable(f"in{iarg}"), tuple(subscript_indices))) # }}} inner_expr = reduce(operator.mul, args_as_pym_expr[1:], args_as_pym_expr[0]) if redn_bounds: inner_expr = Reduce(inner_expr, "sum", redn_bounds) return IndexLambda(expr=inner_expr, shape=tuple(self.rec(s) if isinstance(s, Array) else s for s in expr.shape), dtype=expr.dtype, bindings=bindings) # {{{ index remapping (roll, axis permutation, slice) def handle_index_remapping(self, indices_getter: Callable[[CodeGenPreprocessor, Array], SymbolicIndex], expr: IndexRemappingBase) -> Array: indices = indices_getter(self, expr) index_expr = var("_in0") if indices: index_expr = index_expr[indices] array = self.rec(expr.array) return IndexLambda(expr=index_expr, shape=tuple(self.rec(s) if isinstance(s, Array) else s for s in expr.shape), dtype=expr.dtype, bindings=dict(_in0=array)) def _indices_for_axis_permutation(self, expr: AxisPermutation) -> SymbolicIndex: indices = [None] * expr.ndim for from_index, to_index in enumerate(expr.axes): indices[to_index] = var(f"_{from_index}") return tuple(indices) def _indices_for_slice(self, expr: Slice) -> SymbolicIndex: return tuple(var(f"_{d}") + expr.starts[d] for d in range(expr.ndim)) def _indices_for_reshape(self, expr: Reshape) -> SymbolicIndex: newstrides = [1] # reshaped array strides for new_axis_len in reversed(expr.shape[1:]): assert isinstance(new_axis_len, int) newstrides.insert(0, newstrides[0]*new_axis_len) flattened_idx = sum(prim.Variable(f"_{i}")*stride for i, stride in enumerate(newstrides)) oldstrides = [1] # input array strides for axis_len in reversed(expr.array.shape[1:]): assert isinstance(axis_len, int) oldstrides.insert(0, oldstrides[0]*axis_len) assert isinstance(expr.array.shape[-1], int) oldsizetills = [expr.array.shape[-1]] # input array size till for axes idx for old_axis_len in reversed(expr.array.shape[:-1]): assert isinstance(old_axis_len, int) oldsizetills.insert(0, oldsizetills[0]*old_axis_len) return tuple(((flattened_idx % sizetill) // stride) for stride, sizetill in zip(oldstrides, oldsizetills)) # https://github.com/python/mypy/issues/8619 map_axis_permutation = ( partialmethod(handle_index_remapping, _indices_for_axis_permutation)) # type: ignore # noqa map_slice = partialmethod(handle_index_remapping, _indices_for_slice) # type: ignore # noqa map_reshape = partialmethod(handle_index_remapping, _indices_for_reshape) # noqa
# }}} # }}}
[docs]def normalize_outputs(result: Union[Array, DictOfNamedArrays, Dict[str, Array]]) -> DictOfNamedArrays: """Convert outputs of a computation to the canonical form. Performs a conversion to :class:`~pytato.DictOfNamedArrays` if necessary. :param result: Outputs of the computation. """ if not isinstance(result, (Array, DictOfNamedArrays, dict)): raise TypeError("outputs of the computation should be " "either an Array or a DictOfNamedArrays") if isinstance(result, Array): outputs = DictOfNamedArrays({"_pt_out": result}) elif isinstance(result, dict): outputs = DictOfNamedArrays(result) else: assert isinstance(result, DictOfNamedArrays) outputs = result return outputs
# {{{ input naming check class NamesValidityChecker(CachedWalkMapper): def __init__(self) -> None: self.name_to_input: Dict[str, InputArgumentBase] = {} super().__init__() def post_visit(self, expr: Any) -> None: if isinstance(expr, InputArgumentBase): if expr.name is None: # Name to be automatically assigned return try: ary = self.name_to_input[expr.name] except KeyError: self.name_to_input[expr.name] = expr else: if ary is not expr: from pytato.diagnostic import NameClashError raise NameClashError("Received two separate instances of inputs " f"named '{expr.name}'.") def check_validity_of_outputs(exprs: DictOfNamedArrays) -> None: name_validation_mapper = NamesValidityChecker() for ary in exprs.values(): name_validation_mapper(ary) # }}}
[docs]@dataclasses.dataclass(init=True, repr=False, eq=False) class PreprocessResult: outputs: DictOfNamedArrays compute_order: Tuple[str, ...] bound_arguments: Dict
[docs]def preprocess(outputs: DictOfNamedArrays, target: Target) -> PreprocessResult: """Preprocess a computation for code generation.""" from pytato.transform import copy_dict_of_named_arrays check_validity_of_outputs(outputs) # {{{ compute the order in which the outputs must be computed # semantically order does not matter, but doing a toposort ordering of the # outputs leads to a FLOP optimal choice from pytools.graph import compute_topological_order get_deps = SubsetDependencyMapper(frozenset(out.expr for out in outputs.values())) # only look for dependencies between the outputs deps = {name: get_deps(output.expr) for name, output in outputs.items()} # represent deps in terms of output names output_expr_to_name = {output.expr: name for name, output in outputs.items()} dag = {name: (frozenset([output_expr_to_name[output] for output in val]) - frozenset([name])) for name, val in deps.items()} output_order: List[str] = compute_topological_order(dag)[::-1] # }}} mapper = CodeGenPreprocessor(target) new_outputs = copy_dict_of_named_arrays(outputs, mapper) return PreprocessResult(outputs=new_outputs, compute_order=tuple(output_order), bound_arguments=mapper.bound_arguments)
# vim: fdm=marker