Package hedge :: Package backends :: Package jit :: Module vector_expr
[hide private]
[frames] | no frames]

Source Code for Module hedge.backends.jit.vector_expr

  1  """C code generation for vector expressions.""" 
  2   
  3  from __future__ import division 
  4   
  5  __copyright__ = "Copyright (C) 2008 Andreas Kloeckner" 
  6   
  7  __license__ = """ 
  8  This program is free software: you can redistribute it and/or modify 
  9  it under the terms of the GNU General Public License as published by 
 10  the Free Software Foundation, either version 3 of the License, or 
 11  (at your option) any later version. 
 12   
 13  This program is distributed in the hope that it will be useful, 
 14  but WITHOUT ANY WARRANTY; without even the implied warranty of 
 15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
 16  GNU General Public License for more details. 
 17   
 18  You should have received a copy of the GNU General Public License 
 19  along with this program.  If not, see U{http://www.gnu.org/licenses/}. 
 20  """ 
 21   
 22   
 23   
 24   
 25  import numpy 
 26  import codepy.elementwise 
 27  from hedge.backends.vector_expr import CompiledVectorExpressionBase 
 28   
 29   
 30   
 31   
32 -class CompiledVectorExpression(CompiledVectorExpressionBase):
33 elementwise_mod = codepy.elementwise 34
35 - def __init__(self, vec_expr_info_list, result_dtype_getter, 36 toolchain=None, wait_on_error=False):
37 CompiledVectorExpressionBase.__init__(self, 38 vec_expr_info_list, result_dtype_getter) 39 40 self.toolchain = toolchain 41 self.wait_on_error = wait_on_error
42
43 - def make_kernel_internal(self, args, instructions):
44 return self.elementwise_mod.ElementwiseKernel( 45 args, instructions, name="vector_expression", 46 toolchain=self.toolchain, 47 wait_on_error=self.wait_on_error)
48
49 - def __call__(self, evaluate_subexpr, stats_callback=None):
50 vectors = [evaluate_subexpr(vec_expr) 51 for vec_expr in self.vector_deps] 52 scalars = [evaluate_subexpr(scal_expr) 53 for scal_expr in self.scalar_deps] 54 55 from pytools import single_valued 56 shape = single_valued(vec.shape for vec in vectors) 57 58 kernel_rec = self.get_kernel( 59 tuple(v.dtype for v in vectors), 60 tuple(s.dtype for s in scalars)) 61 62 from hedge.tools import make_obj_array 63 results = [numpy.empty(shape, kernel_rec.result_dtype) 64 for vei in self.result_vec_expr_info_list] 65 66 size = results[0].size 67 args = (results+vectors+scalars) 68 69 if stats_callback is not None: 70 timer = stats_callback(size, self) 71 sub_timer = timer.start_sub_timer() 72 kernel_rec.kernel(*args) 73 sub_timer.stop().submit() 74 else: 75 kernel_rec.kernel(*args) 76 77 return results
78 79 80 81 82 if __name__ == "__main__": 83 test_dtype = numpy.float32 84 85 import pycuda.autoinit 86 from pymbolic import parse 87 expr = parse("2*x+3*y+4*z") 88 print expr 89 cexpr = CompiledVectorExpression(expr, 90 lambda expr: (True, test_dtype), 91 test_dtype) 92 93 from pymbolic import var 94 ctx = { 95 var("x"): numpy.arange(5, dtype=test_dtype), 96 var("y"): numpy.arange(5, dtype=test_dtype), 97 var("z"): numpy.arange(5, dtype=test_dtype), 98 } 99 100 print cexpr(lambda expr: ctx[expr]) 101