from __future__ import annotations
__copyright__ = """
Copyright (C) 2020 Andreas Kloeckner
"""
__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
from numbers import Number
from typing import Any, Union, Mapping, FrozenSet, Set, Tuple, Optional
from pymbolic.mapper import (WalkMapper as WalkMapperBase, IdentityMapper as
IdentityMapperBase)
from pymbolic.mapper.substitutor import (SubstitutionMapper as
SubstitutionMapperBase)
from pymbolic.mapper.dependency import (DependencyMapper as
DependencyMapperBase)
from pymbolic.mapper.evaluator import (EvaluationMapper as
EvaluationMapperBase)
from pymbolic.mapper.distributor import (DistributeMapper as
DistributeMapperBase)
from pymbolic.mapper.stringifier import (StringifyMapper as
StringifyMapperBase)
from pymbolic.mapper.collector import TermCollector as TermCollectorBase
import pymbolic.primitives as prim
import numpy as np
import re
__doc__ = """
.. currentmodule:: pytato.scalar_expr
.. data:: ScalarExpression
A :class:`type` for scalar-valued symbolic expressions. Expressions are
composable and manipulable via :mod:`pymbolic`.
Concretely, this is an alias for
``Union[Number, np.bool_, bool, pymbolic.primitives.Expression]``.
.. autofunction:: parse
.. autofunction:: get_dependencies
.. autofunction:: substitute
"""
# {{{ scalar expressions
IntegralScalarExpression = Union[int, prim.Expression]
ScalarType = Union[Number, np.bool_, bool]
ScalarExpression = Union[ScalarType, prim.Expression]
SCALAR_CLASSES = prim.VALID_CONSTANT_CLASSES
[docs]def parse(s: str) -> ScalarExpression:
from pymbolic.parser import Parser
return Parser()(s)
# }}}
# {{{ mapper classes
class WalkMapper(WalkMapperBase):
def map_reduce(self, expr: Reduce) -> None:
if not self.visit(expr):
return
self.rec(expr.inner_expr)
self.post_visit(expr)
class IdentityMapper(IdentityMapperBase):
pass
class SubstitutionMapper(SubstitutionMapperBase):
pass
IDX_LAMBDA_RE = re.compile("_r?(0|([1-9][0-9]*))")
class DependencyMapper(DependencyMapperBase):
def __init__(self, *,
include_idx_lambda_indices: bool = True,
include_subscripts: bool = True,
include_lookups: bool = True,
include_calls: bool = True,
include_cses: bool = False,
composite_leaves: Optional[bool] = None) -> None:
super().__init__(include_subscripts=include_subscripts,
include_lookups=include_lookups,
include_calls=include_calls,
include_cses=include_cses,
composite_leaves=composite_leaves)
self.include_idx_lambda_indices = include_idx_lambda_indices
def map_variable(self, expr: prim.Variable) -> Set[prim.Variable]:
if ((not self.include_idx_lambda_indices)
and IDX_LAMBDA_RE.fullmatch(str(expr))):
return set()
else:
return super().map_variable(expr) # type: ignore
def map_reduce(self, expr: Reduce,
*args: Any, **kwargs: Any) -> Set[prim.Variable]:
return self.combine([ # type: ignore
self.rec(expr.inner_expr),
set().union(*(self.rec((lb, ub)) for (lb, ub) in expr.bounds.values()))])
class EvaluationMapper(EvaluationMapperBase):
def map_reduce(self, expr: Reduce, *args: Any, **kwargs: Any) -> None:
# TODO: not trivial to evaluate symbolic reduction nodes
raise NotImplementedError()
class DistributeMapper(DistributeMapperBase):
def map_reduce(self, expr: Reduce, *args: Any, **kwargs: Any) -> None:
# TODO: not trivial to distribute symbolic reduction nodes
raise NotImplementedError()
class TermCollector(TermCollectorBase):
def map_reduce(self, expr: Reduce, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError()
class StringifyMapper(StringifyMapperBase):
def map_reduce(self, expr: Any, enclosing_prec: Any, *args: Any) -> str:
from pymbolic.mapper.stringifier import (
PREC_COMPARISON as PC,
PREC_NONE as PN)
bounds_expr = " and ".join(
f"{self.rec(lb, PC)}<={name}<{self.rec(ub, PC)}"
for name, (lb, ub) in expr.bounds.items())
bounds_expr = "{" + bounds_expr + "}"
return (f"{expr.op}({bounds_expr}, {self.rec(expr.inner_expr, PN)})")
# }}}
# {{{ mapper frontends
[docs]def get_dependencies(expression: Any,
include_idx_lambda_indices: bool = True) -> FrozenSet[str]:
"""Return the set of variable names in an expression.
:param expression: A scalar expression, or an expression derived from such
(e.g., a tuple of scalar expressions)
"""
mapper = DependencyMapper(
composite_leaves=False,
include_idx_lambda_indices=include_idx_lambda_indices)
return frozenset(dep.name for dep in mapper(expression))
[docs]def substitute(expression: Any,
variable_assigments: Optional[Mapping[str, Any]]) -> Any:
"""Perform variable substitution in an expression.
:param expression: A scalar expression, or an expression derived from such
(e.g., a tuple of scalar expressions)
:param variable_assigments: A mapping from variable names to substitutions
"""
if variable_assigments is None:
variable_assigments = {}
from pymbolic.mapper.substitutor import make_subst_func
return SubstitutionMapper(make_subst_func(variable_assigments))(expression)
def evaluate(expression: Any, context: Optional[Mapping[str, Any]] = None) -> Any:
"""
Evaluates *expression* by substituting the variable values as provided in
*context*.
"""
if context is None:
context = {}
return EvaluationMapper(context)(expression)
def distribute(expr: Any, parameters: FrozenSet[Any] = frozenset(),
commutative: bool = True) -> Any:
if commutative:
return DistributeMapper(TermCollector(parameters))(expr)
else:
return DistributeMapper(lambda x: x)(expr)
# }}}
# {{{ custom scalar expression nodes
class ExpressionBase(prim.Expression):
def make_stringifier(self, originating_stringifier: Any = None) -> str:
return StringifyMapper()
class Reduce(ExpressionBase):
"""
.. attribute:: inner_expr
A :class:`ScalarExpression` to be reduced over.
.. attribute:: op
One of ``"sum"``, ``"product"``, ``"max"``, ``"min"``.
.. attribute:: bounds
A mapping from reduction inames to tuples ``(lower_bound, upper_bound)``
identifying half-open bounds intervals. Must be hashable.
"""
inner_expr: ScalarExpression
op: str
bounds: Mapping[str, Tuple[ScalarExpression, ScalarExpression]]
def __init__(self, inner_expr: ScalarExpression, op: str, bounds: Any) -> None:
self.inner_expr = inner_expr
if op not in ["sum", "product", "max", "min"]:
raise ValueError(f"unsupported op: {op}")
self.op = op
self.bounds = bounds
def __hash__(self) -> int:
return hash((self.inner_expr,
self.op,
tuple(self.bounds.keys()),
tuple(self.bounds.values())))
def __getinitargs__(self) -> Tuple[ScalarExpression, str, Any]:
return (self.inner_expr, self.op, self.bounds)
mapper_method = "map_reduce"
# }}}
# vim: foldmethod=marker