1 """Base facility for C generation from vector expressions."""
2
3 from __future__ import division
4
5 __copyright__ = "Copyright (C) 2009 Andreas Kloeckner"
6
7 __license__ = """
8 This program is free software: you can redistribute it and/or modify
9 it under the terms of the GNU General Public License as published by
10 the Free Software Foundation, either version 3 of the License, or
11 (at your option) any later version.
12
13 This program is distributed in the hope that it will be useful,
14 but WITHOUT ANY WARRANTY; without even the implied warranty of
15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 GNU General Public License for more details.
17
18 You should have received a copy of the GNU General Public License
19 along with this program. If not, see U{http://www.gnu.org/licenses/}.
20 """
21
22
23
24
25 import numpy
26 import pymbolic.mapper.substitutor
27 import hedge.optemplate
28 from pytools import memoize_method, Record
37 result = self.subst_func(expr)
38 if result is None:
39 raise ValueError("operator binding may not survive "
40 "vector expression subsitution")
41
42 return result
43
51
55
56
57
58
59
60 -class ConstantGatherMapper(
61 hedge.optemplate.CombineMapper,
62 hedge.optemplate.CollectorMixin,
63 hedge.optemplate.OperatorReducerMixin):
72
78
83 __slots__ = ["name", "expr", "do_not_return"]
84
89 from pytools import common_dtype, match_precision
90
91 result = common_dtype(vector_dtype_map.values())
92
93 scalar_dtypes = scalar_dtype_map.values() + const_dtypes
94 if scalar_dtypes:
95 prec_matched_scalar_dtype = match_precision(
96 common_dtype(scalar_dtypes),
97 dtype_to_match=result)
98 result = common_dtype([result, prec_matched_scalar_dtype])
99
100 return result
101
106 - def __init__(self, vec_expr_info_list, result_dtype_getter):
107 self.result_dtype_getter = result_dtype_getter
108
109 from hedge.optemplate import \
110 DependencyMapper, ScalarParameter
111 from operator import or_
112 from pymbolic import var
113
114 dep_mapper = DependencyMapper(
115 include_subscripts=True,
116 include_lookups=True,
117 include_calls="descend_args")
118 deps = reduce(or_, (dep_mapper(vei.expr) for vei in vec_expr_info_list))
119
120 deps -= set(var(vei.name) for vei in vec_expr_info_list)
121
122 from pytools import partition
123
124 def is_vector_pred(dep):
125 return not isinstance(dep, ScalarParameter)
126
127 vdeps, sdeps = partition(is_vector_pred, deps)
128
129 vdeps = [(str(vdep), vdep) for vdep in vdeps]
130 sdeps = [(str(sdep), sdep) for sdep in sdeps]
131 vdeps.sort()
132 sdeps.sort()
133 self.vector_deps = [vdep for key, vdep in vdeps]
134 self.scalar_deps = [sdep for key, sdep in sdeps]
135
136 self.vector_dep_names = ["v%d" % i for i in range(len(self.vector_deps))]
137 self.scalar_dep_names = ["s%d" % i for i in range(len(self.scalar_deps))]
138
139 self.constant_dtypes = [
140 numpy.array(const).dtype
141 for vei in vec_expr_info_list
142 for const in ConstantGatherMapper()(vei.expr)]
143
144 var_i = var("i")
145 subst_map = dict(
146 list(zip(self.vector_deps, [var(vecname)[var_i]
147 for vecname in self.vector_dep_names]))
148 +list(zip(self.scalar_deps,
149 [var(scaname) for scaname in self.scalar_dep_names]))
150 +[(var(vei.name), var(vei.name)[var_i])
151 for vei in vec_expr_info_list
152 if not vei.do_not_return]
153 )
154
155 def subst_func(expr):
156 try:
157 return subst_map[expr]
158 except KeyError:
159 return None
160
161 self.vec_expr_info_list = [
162 vei.copy(expr=DefaultingSubstitutionMapper(subst_func)(vei.expr))
163 for vei in vec_expr_info_list]
164
165 self.result_vec_expr_info_list = [
166 vei for vei in vec_expr_info_list if not vei.do_not_return]
167
168 @memoize_method
170 return [rvei.name for rvei in self.result_vec_expr_info_list]
171
172 @memoize_method
173 - def get_kernel(self, vector_dtypes, scalar_dtypes):
174 from pymbolic.mapper.stringifier import PREC_NONE
175 from pymbolic.mapper.c_code import CCodeMapper
176
177 elwise = self.elementwise_mod
178
179 result_dtype = self.result_dtype_getter(
180 dict(zip(self.vector_deps, vector_dtypes)),
181 dict(zip(self.scalar_deps, scalar_dtypes)),
182 self.constant_dtypes)
183
184 from hedge.tools import is_obj_array
185 args = [elwise.VectorArg(result_dtype, vei.name)
186 for vei in self.vec_expr_info_list
187 if not vei.do_not_return]
188
189 def real_const_mapper(num):
190 r = repr(num)
191 if "." not in r:
192 return "double(%s)" % r
193 else:
194 return r
195
196 code_mapper = CCodeMapper(constant_mapper=real_const_mapper)
197
198 code_lines = []
199 for vei in self.vec_expr_info_list:
200 expr_code = code_mapper(vei.expr, PREC_NONE)
201 if vei.do_not_return:
202 from codepy.cgen import dtype_to_ctype
203 code_lines.append(
204 "%s %s = %s;" % (
205 dtype_to_ctype(result_dtype), vei.name, expr_code))
206 else:
207 code_lines.append(
208 "%s[i] = %s;" % (vei.name, expr_code))
209
210
211 assert not code_mapper.cses
212
213 args.extend(
214 elwise.VectorArg(dtype, name)
215 for dtype, name in zip(vector_dtypes, self.vector_dep_names))
216 args.extend(
217 elwise.ScalarArg(dtype, name)
218 for dtype, name in zip(scalar_dtypes, self.scalar_dep_names))
219
220 return KernelRecord(
221 kernel=self.make_kernel_internal(args, "\n".join(code_lines)),
222 result_dtype=result_dtype)
223