Source code for pytato.utils

from __future__ import annotations

__copyright__ = "Copyright (C) 2021 Kaushik Kulkarni"

__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""

import numpy as np
import pymbolic.primitives as prim

from numbers import Number
from typing import Tuple, List, Union, Callable, Any, Sequence, Dict
from pytato.array import (Array, ShapeType, IndexLambda, SizeParam, ShapeComponent,
                          DtypeOrScalar, ArrayOrScalar)
from pytato.scalar_expr import (ScalarExpression, IntegralScalarExpression,
                                SCALAR_CLASSES)
from pytools import UniqueNameGenerator
from pytato.transform import Mapper


__doc__ = """
Helper routines
---------------

.. autofunction:: are_shape_components_equal
.. autofunction:: are_shapes_equal
.. autofunction:: get_shape_after_broadcasting
.. autofunction:: dim_to_index_lambda_components
"""


[docs]def get_shape_after_broadcasting( exprs: List[Union[Array, ScalarExpression]]) -> ShapeType: """ Returns the shape after broadcasting *exprs* in an operation. """ shapes = [expr.shape if isinstance(expr, Array) else () for expr in exprs] result_dim = max(len(s) for s in shapes) # append leading dimensions of all the shapes with 1's to match result_dim. augmented_shapes = [((1,)*(result_dim-len(s)) + s) for s in shapes] def _get_result_axis_length(axis_lengths: List[IntegralScalarExpression] ) -> IntegralScalarExpression: result_axis_len = axis_lengths[0] for axis_len in axis_lengths[1:]: if are_shape_components_equal(axis_len, result_axis_len): pass elif are_shape_components_equal(axis_len, 1): pass elif are_shape_components_equal(result_axis_len, 1): result_axis_len = axis_len else: raise ValueError("operands could not be broadcasted together with " f"shapes {' '.join(str(s) for s in shapes)}.") return result_axis_len return tuple(_get_result_axis_length([s[i] for s in augmented_shapes]) for i in range(result_dim))
def get_indexing_expression(shape: ShapeType, result_shape: ShapeType) -> Tuple[ScalarExpression, ...]: """ Returns the indices while broadcasting an array of shape *shape* into one of shape *result_shape*. """ assert len(shape) <= len(result_shape) i_start = len(result_shape) - len(shape) indices = [] for i, (dim1, dim2) in enumerate(zip(shape, result_shape[i_start:])): if not are_shape_components_equal(dim1, dim2): assert are_shape_components_equal(dim1, 1) indices.append(0) else: indices.append(prim.Variable(f"_{i+i_start}")) return tuple(indices) def with_indices_for_broadcasted_shape(val: prim.Variable, shape: ShapeType, result_shape: ShapeType) -> prim.Expression: if len(shape) == 0: # scalar expr => do not index return val else: return val[get_indexing_expression(shape, result_shape)] def extract_dtypes_or_scalars( exprs: Sequence[ArrayOrScalar]) -> List[DtypeOrScalar]: dtypes: List[DtypeOrScalar] = [] for expr in exprs: if isinstance(expr, Array): dtypes.append(expr.dtype) else: assert isinstance(expr, SCALAR_CLASSES) dtypes.append(expr) return dtypes def update_bindings_and_get_broadcasted_expr(arr: ArrayOrScalar, bnd_name: str, bindings: Dict[str, Array], result_shape: ShapeType ) -> ScalarExpression: """ Returns an instance of :class:`~pytato.scalar_expr.ScalarExpression` to address *arr* in a :class:`pytato.array.IndexLambda` of shape *result_shape*. """ if isinstance(arr, SCALAR_CLASSES): return arr assert isinstance(arr, Array) bindings[bnd_name] = arr return with_indices_for_broadcasted_shape(prim.Variable(bnd_name), arr.shape, result_shape) def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], # noqa:E501 get_result_type: Callable[[DtypeOrScalar, DtypeOrScalar], np.dtype[Any]], # noqa:E501 ) -> ArrayOrScalar: if isinstance(a1, Number) and isinstance(a2, Number): from pytato.scalar_expr import evaluate return evaluate(op(a1, a2)) # type: ignore result_shape = get_shape_after_broadcasting([a1, a2]) dtypes = extract_dtypes_or_scalars([a1, a2]) result_dtype = get_result_type(*dtypes) bindings: Dict[str, Array] = {} expr1 = update_bindings_and_get_broadcasted_expr(a1, "_in0", bindings, result_shape) expr2 = update_bindings_and_get_broadcasted_expr(a2, "_in1", bindings, result_shape) return IndexLambda(op(expr1, expr2), shape=result_shape, dtype=result_dtype, bindings=bindings) # {{{ dim_to_index_lambda_components class ShapeExpressionMapper(Mapper): """ Mapper that takes a shape component and returns it as a scalar expression. """ def __init__(self, var_name_gen: UniqueNameGenerator): self.cache: Dict[Array, ScalarExpression] = {} self.var_name_gen = var_name_gen self.bindings: Dict[str, SizeParam] = {} def rec(self, expr: Array) -> ScalarExpression: # type: ignore if expr in self.cache: return self.cache[expr] result: Array = super().rec(expr) self.cache[expr] = result return result def map_index_lambda(self, expr: IndexLambda) -> ScalarExpression: from pytato.scalar_expr import substitute return substitute(expr.expr, {name: self.rec(val) for name, val in expr.bindings.items()}) def map_size_param(self, expr: SizeParam) -> ScalarExpression: name = self.var_name_gen("_in") self.bindings[name] = expr return prim.Variable(name)
[docs]def dim_to_index_lambda_components(expr: ShapeComponent, vng: UniqueNameGenerator, ) -> Tuple[ScalarExpression, Dict[str, SizeParam]]: """ Returns the scalar expressions and bindings to use the shape component within an index lambda. .. testsetup:: >>> import pytato as pt >>> from pytato.utils import dim_to_index_lambda_components >>> from pytools import UniqueNameGenerator .. doctest:: >>> n = pt.make_size_param("n") >>> expr, bnds = dim_to_index_lambda_components(3*n+8, UniqueNameGenerator()) >>> print(expr) 3*_in + 8 >>> bnds # doctest: +ELLIPSIS {'_in': <pytato.array.SizeParam ...>} """ if isinstance(expr, int): return expr, {} assert isinstance(expr, Array) mapper = ShapeExpressionMapper(vng) result = mapper(expr) return result, mapper.bindings
# }}}
[docs]def are_shape_components_equal(dim1: ShapeComponent, dim2: ShapeComponent) -> bool: """ Returns *True* iff *dim1* and *dim2* are have equal :class:`~pytato.array.SizeParam` coefficients in their expressions. """ from pytato.scalar_expr import substitute, distribute def to_expr(dim: ShapeComponent) -> ScalarExpression: expr, bnds = dim_to_index_lambda_components(dim, UniqueNameGenerator()) return substitute(expr, {name: prim.Variable(bnd.name) for name, bnd in bnds.items()}) dim1_expr = to_expr(dim1) dim2_expr = to_expr(dim2) # ScalarExpression.__eq__ returns Any return (distribute(dim1_expr-dim2_expr) == 0) # type: ignore
[docs]def are_shapes_equal(shape1: ShapeType, shape2: ShapeType) -> bool: """ Returns *True* iff *shape1* and *shape2* have the same dimensionality and the correpsonding components are equal as defined by :func:`~pytato.utils.are_shape_components_equal`. """ return ((len(shape1) == len(shape2)) and all(are_shape_components_equal(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)))