Package hedge :: Package backends :: Module vector_expr
[hide private]
[frames] | no frames]

Source Code for Module hedge.backends.vector_expr

  1  """Base facility for C generation from vector expressions.""" 
  2   
  3  from __future__ import division 
  4   
  5  __copyright__ = "Copyright (C) 2009 Andreas Kloeckner" 
  6   
  7  __license__ = """ 
  8  This program is free software: you can redistribute it and/or modify 
  9  it under the terms of the GNU General Public License as published by 
 10  the Free Software Foundation, either version 3 of the License, or 
 11  (at your option) any later version. 
 12   
 13  This program is distributed in the hope that it will be useful, 
 14  but WITHOUT ANY WARRANTY; without even the implied warranty of 
 15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
 16  GNU General Public License for more details. 
 17   
 18  You should have received a copy of the GNU General Public License 
 19  along with this program.  If not, see U{http://www.gnu.org/licenses/}. 
 20  """ 
 21   
 22   
 23   
 24   
 25  import numpy 
 26  import pymbolic.mapper.substitutor 
 27  import hedge.optemplate 
 28  from pytools import memoize_method, Record 
29 30 31 32 33 -class DefaultingSubstitutionMapper( 34 pymbolic.mapper.substitutor.SubstitutionMapper, 35 hedge.optemplate.IdentityMapperMixin):
36 - def map_operator_binding(self, expr):
37 result = self.subst_func(expr) 38 if result is None: 39 raise ValueError("operator binding may not survive " 40 "vector expression subsitution") 41 42 return result
43
44 - def handle_unsupported_expression(self, expr):
45 result = self.subst_func(expr) 46 if result is not None: 47 return result 48 else: 49 pymbolic.mapper.substitutor.SubstitutionMapper.handle_unsupported_expression( 50 self, expr)
51
52 - def map_scalar_parameter(self, expr):
53 return pymbolic.mapper.substitutor.SubstitutionMapper.map_variable( 54 self, expr)
55
56 57 58 59 60 -class ConstantGatherMapper( 61 hedge.optemplate.CombineMapper, 62 hedge.optemplate.CollectorMixin, 63 hedge.optemplate.OperatorReducerMixin):
64 - def map_algebraic_leaf(self, expr):
65 return set()
66
67 - def map_constant(self, expr):
68 return set([expr])
69
70 - def map_operator(self, expr):
71 return set()
72
73 74 75 76 -class KernelRecord(Record):
77 pass
78
79 80 81 82 -class VectorExpressionInfo(Record):
83 __slots__ = ["name", "expr", "do_not_return"]
84
85 86 87 88 -def simple_result_dtype_getter(vector_dtype_map, scalar_dtype_map, const_dtypes):
89 from pytools import common_dtype, match_precision 90 91 result = common_dtype(vector_dtype_map.values()) 92 93 scalar_dtypes = scalar_dtype_map.values() + const_dtypes 94 if scalar_dtypes: 95 prec_matched_scalar_dtype = match_precision( 96 common_dtype(scalar_dtypes), 97 dtype_to_match=result) 98 result = common_dtype([result, prec_matched_scalar_dtype]) 99 100 return result
101
102 103 104 105 -class CompiledVectorExpressionBase(object):
106 - def __init__(self, vec_expr_info_list, result_dtype_getter):
107 self.result_dtype_getter = result_dtype_getter 108 109 from hedge.optemplate import \ 110 DependencyMapper, ScalarParameter 111 from operator import or_ 112 from pymbolic import var 113 114 dep_mapper = DependencyMapper( 115 include_subscripts=True, 116 include_lookups=True, 117 include_calls="descend_args") 118 deps = reduce(or_, (dep_mapper(vei.expr) for vei in vec_expr_info_list)) 119 120 deps -= set(var(vei.name) for vei in vec_expr_info_list) 121 122 from pytools import partition 123 124 def is_vector_pred(dep): 125 return not isinstance(dep, ScalarParameter)
126 127 vdeps, sdeps = partition(is_vector_pred, deps) 128 129 vdeps = [(str(vdep), vdep) for vdep in vdeps] 130 sdeps = [(str(sdep), sdep) for sdep in sdeps] 131 vdeps.sort() 132 sdeps.sort() 133 self.vector_deps = [vdep for key, vdep in vdeps] 134 self.scalar_deps = [sdep for key, sdep in sdeps] 135 136 self.vector_dep_names = ["v%d" % i for i in range(len(self.vector_deps))] 137 self.scalar_dep_names = ["s%d" % i for i in range(len(self.scalar_deps))] 138 139 self.constant_dtypes = [ 140 numpy.array(const).dtype 141 for vei in vec_expr_info_list 142 for const in ConstantGatherMapper()(vei.expr)] 143 144 var_i = var("i") 145 subst_map = dict( 146 list(zip(self.vector_deps, [var(vecname)[var_i] 147 for vecname in self.vector_dep_names])) 148 +list(zip(self.scalar_deps, 149 [var(scaname) for scaname in self.scalar_dep_names])) 150 +[(var(vei.name), var(vei.name)[var_i]) 151 for vei in vec_expr_info_list 152 if not vei.do_not_return] 153 ) 154 155 def subst_func(expr): 156 try: 157 return subst_map[expr] 158 except KeyError: 159 return None
160 161 self.vec_expr_info_list = [ 162 vei.copy(expr=DefaultingSubstitutionMapper(subst_func)(vei.expr)) 163 for vei in vec_expr_info_list] 164 165 self.result_vec_expr_info_list = [ 166 vei for vei in vec_expr_info_list if not vei.do_not_return] 167 168 @memoize_method
169 - def result_names(self):
170 return [rvei.name for rvei in self.result_vec_expr_info_list]
171 172 @memoize_method
173 - def get_kernel(self, vector_dtypes, scalar_dtypes):
174 from pymbolic.mapper.stringifier import PREC_NONE 175 from pymbolic.mapper.c_code import CCodeMapper 176 177 elwise = self.elementwise_mod 178 179 result_dtype = self.result_dtype_getter( 180 dict(zip(self.vector_deps, vector_dtypes)), 181 dict(zip(self.scalar_deps, scalar_dtypes)), 182 self.constant_dtypes) 183 184 from hedge.tools import is_obj_array 185 args = [elwise.VectorArg(result_dtype, vei.name) 186 for vei in self.vec_expr_info_list 187 if not vei.do_not_return] 188 189 def real_const_mapper(num): 190 r = repr(num) 191 if "." not in r: 192 return "double(%s)" % r 193 else: 194 return r
195 196 code_mapper = CCodeMapper(constant_mapper=real_const_mapper) 197 198 code_lines = [] 199 for vei in self.vec_expr_info_list: 200 expr_code = code_mapper(vei.expr, PREC_NONE) 201 if vei.do_not_return: 202 from codepy.cgen import dtype_to_ctype 203 code_lines.append( 204 "%s %s = %s;" % ( 205 dtype_to_ctype(result_dtype), vei.name, expr_code)) 206 else: 207 code_lines.append( 208 "%s[i] = %s;" % (vei.name, expr_code)) 209 210 # common subexpressions have been taken care of by the compiler 211 assert not code_mapper.cses 212 213 args.extend( 214 elwise.VectorArg(dtype, name) 215 for dtype, name in zip(vector_dtypes, self.vector_dep_names)) 216 args.extend( 217 elwise.ScalarArg(dtype, name) 218 for dtype, name in zip(scalar_dtypes, self.scalar_dep_names)) 219 220 return KernelRecord( 221 kernel=self.make_kernel_internal(args, "\n".join(code_lines)), 222 result_dtype=result_dtype) 223