Package hedge :: Module optemplate
[hide private]
[frames] | no frames]

Source Code for Module hedge.optemplate

   1  """Building blocks and mappers for operator expression trees.""" 
   2   
   3  from __future__ import division 
   4   
   5  __copyright__ = "Copyright (C) 2008 Andreas Kloeckner" 
   6   
   7  __license__ = """ 
   8  This program is free software: you can redistribute it and/or modify 
   9  it under the terms of the GNU General Public License as published by 
  10  the Free Software Foundation, either version 3 of the License, or 
  11  (at your option) any later version. 
  12   
  13  This program is distributed in the hope that it will be useful, 
  14  but WITHOUT ANY WARRANTY; without even the implied warranty of 
  15  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
  16  GNU General Public License for more details. 
  17   
  18  You should have received a copy of the GNU General Public License 
  19  along with this program.  If not, see U{http://www.gnu.org/licenses/}. 
  20  """ 
  21   
  22   
  23   
  24   
  25  import numpy 
  26  import pymbolic.primitives 
  27  import pymbolic.mapper.stringifier 
  28  import pymbolic.mapper.evaluator 
  29  import pymbolic.mapper.dependency 
  30  import pymbolic.mapper.substitutor 
  31  import pymbolic.mapper.constant_folder 
  32  import pymbolic.mapper.flop_counter 
  33  import hedge.mesh 
  34  from pymbolic.mapper import CSECachingMapperMixin 
35 36 37 38 39 -def make_common_subexpression(fields):
40 from hedge.tools import with_object_array_or_scalar 41 from pymbolic.primitives import CommonSubexpression 42 return with_object_array_or_scalar(CommonSubexpression, fields) 43 44 45 46 47 Field = pymbolic.primitives.Variable
48 49 -def make_field(var_or_string):
50 if not isinstance(var_or_string, pymbolic.primitives.Expression): 51 return Field(var_or_string) 52 else: 53 return var_or_string
54
55 56 57 58 -class ScalarParameter(pymbolic.primitives.Variable):
59 - def stringifier(self):
60 return StringifyMapper
61
62 - def get_mapper_method(self, mapper):
63 return mapper.map_scalar_parameter
64
65 66 67 68 -class BoundaryNormalComponent(pymbolic.primitives.AlgebraicLeaf):
69 - def __init__(self, tag, axis):
70 self.tag = tag 71 self.axis = axis
72
73 - def stringifier(self):
74 return StringifyMapper
75
76 - def get_hash(self):
77 return hash((self.__class__, self.tag, self.axis))
78
79 - def is_equal(self, other):
80 return (other.__class__ == self.__class__ 81 and other.tag == self.tag 82 and other.axis == self.axis)
83
84 - def get_mapper_method(self, mapper):
85 return mapper.map_normal_component
86
87 88 89 90 -def make_normal(tag, dimensions):
91 return numpy.array([BoundaryNormalComponent(tag, i) 92 for i in range(dimensions)], dtype=object)
93
94 95 96 97 98 -class PrioritizedSubexpression(pymbolic.primitives.CommonSubexpression):
99 """When the optemplate-to-code transformation is performed, 100 prioritized subexpressions work like common subexpression in 101 that they are assigned their own separate identifier/register 102 location. In addition to this behavior, prioritized subexpressions 103 are evaluated with a settable priority, allowing the user to 104 expedite or delay the evaluation of the subexpression. 105 """ 106
107 - def __init__(self, child, priority=0):
108 pymbolic.primitives.CommonSubexpression.__init__(self, child) 109 self.priority = priority
110
111 - def __getinitargs__(self):
112 return (self.child, self.priority)
113
114 - def get_extra_properties(self):
115 return {"priority": self.priority}
116
117 118 119 120 # operators ------------------------------------------------------------------- 121 -class Operator(pymbolic.primitives.Leaf):
122 - def stringifier(self):
123 return StringifyMapper
124
125 - def __call__(self, *args, **kwargs):
126 # prevent lazy-eval semantics from kicking in 127 raise RuntimeError, "symbolic operators are not callable"
128
129 - def apply(self, discr, field):
130 return discr.compile(self * Field("f"))(f=field)
131
132 133 134 135 -class StatelessOperator(Operator):
136 - def __getinitargs__(self):
137 return ()
138
139 - def get_hash(self):
140 return hash(self.__class__)
141
142 - def is_equal(self, other):
143 return other.__class__ == self.__class__
144
145 146 147 148 -class OperatorBinding(pymbolic.primitives.AlgebraicLeaf):
149 - def __init__(self, op, field):
150 self.op = op 151 self.field = field
152
153 - def stringifier(self):
154 return StringifyMapper
155
156 - def get_mapper_method(self, mapper):
157 return mapper.map_operator_binding
158
159 - def __getinitargs__(self):
160 return self.op, self.field
161
162 - def is_equal(self, other):
163 from hedge.tools import field_equal 164 return (other.__class__ == self.__class__ 165 and other.op == self.op 166 and field_equal(other.field, self.field))
167
168 - def get_hash(self):
169 from hedge.tools import hashable_field 170 return hash((self.__class__, self.op, hashable_field(self.field)))
171
172 173 174 175 # diff operators -------------------------------------------------------------- 176 -class DiffOperatorBase(Operator):
177 - def __init__(self, xyz_axis):
178 Operator.__init__(self) 179 180 self.xyz_axis = xyz_axis
181
182 - def __getinitargs__(self):
183 return (self.xyz_axis,)
184
185 - def get_hash(self):
186 return hash((self.__class__, self.xyz_axis))
187
188 - def is_equal(self, other):
189 return (other.__class__ == self.__class__ 190 and other.xyz_axis == self.xyz_axis)
191
192 -class DifferentiationOperator(DiffOperatorBase):
193 @staticmethod
194 - def matrices(element_group):
195 return element_group.differentiation_matrices
196 197 @staticmethod
198 - def coefficients(element_group):
199 return element_group.diff_coefficients 200
201 - def get_mapper_method(self, mapper):
202 return mapper.map_diff 203
204 -class MInvSTOperator(DiffOperatorBase):
205 @staticmethod
206 - def matrices(element_group):
207 return element_group.minv_st
208 209 @staticmethod
210 - def coefficients(element_group):
211 return element_group.diff_coefficients 212
213 - def get_mapper_method(self, mapper):
214 return mapper.map_minv_st 215
216 -class StiffnessOperator(DiffOperatorBase):
217 @staticmethod
218 - def matrices(element_group):
219 return element_group.stiffness_matrices
220 221 @staticmethod
222 - def coefficients(element_group):
223 return element_group.stiffness_coefficients 224
225 - def get_mapper_method(self, mapper):
226 return mapper.map_stiffness 227
228 -class StiffnessTOperator(DiffOperatorBase):
229 @staticmethod
230 - def matrices(element_group):
231 return element_group.stiffness_t_matrices
232 233 @staticmethod
234 - def coefficients(element_group):
235 return element_group.stiffness_coefficients 236
237 - def get_mapper_method(self, mapper):
238 return mapper.map_stiffness_t 239
240 241 242 243 244 -def DiffOperatorVector(els):
245 from hedge.tools import join_fields 246 return join_fields(*els)
247
248 249 250 251 # mass operators -------------------------------------------------------------- 252 -class MassOperatorBase(StatelessOperator):
253 pass
254
255 256 257 258 -class MassOperator(MassOperatorBase):
259 @staticmethod
260 - def matrix(element_group):
261 return element_group.mass_matrix
262 263 @staticmethod
264 - def coefficients(element_group):
265 return element_group.jacobians 266
267 - def get_mapper_method(self, mapper):
268 return mapper.map_mass 269
270 -class InverseMassOperator(MassOperatorBase):
271 @staticmethod
272 - def matrix(element_group):
273 return element_group.inverse_mass_matrix
274 275 @staticmethod
276 - def coefficients(element_group):
277 return element_group.inverse_jacobians 278
279 - def get_mapper_method(self, mapper):
280 return mapper.map_inverse_mass 281
282 283 284 285 286 # misc operators -------------------------------------------------------------- 287 -class ElementwiseMaxOperator(StatelessOperator):
288 - def get_mapper_method(self, mapper):
289 return mapper.map_elementwise_max
290
291 292 293 294 -class BoundarizeOperator(Operator):
295 - def __init__(self, tag):
296 self.tag = tag
297
298 - def get_hash(self):
299 return hash((self.__class__, self.tag))
300
301 - def is_equal(self, other):
302 return (other.__class__ == self.__class__ 303 and other.tag == self.tag)
304
305 - def get_mapper_method(self, mapper):
306 return mapper.map_boundarize
307
308 - def __getinitargs__(self):
309 return (self.tag,)
310
311 312 313 314 -class FluxExchangeOperator(Operator):
315 """An operator that results in the sending and receiving of 316 boundary information for its argument fields. 317 """ 318
319 - def __init__(self, idx, rank):
320 self.index = idx 321 self.rank = rank
322
323 - def __getinitargs__(self):
324 return (self.index, self.rank)
325
326 - def get_hash(self):
327 return hash((self.__class__, self.index, self.rank))
328
329 - def is_equal(self, other):
330 return (other.__class__ == self.__class__ 331 and other.index == self.index 332 and other.rank == self.rank)
333
334 - def get_mapper_method(self, mapper):
335 return mapper.map_flux_exchange
336
337 338 339 340 # other parts of an operator template ----------------------------------------- 341 -class BoundaryPair(pymbolic.primitives.AlgebraicLeaf):
342 """Represents a pairing of a volume and a boundary field, used for the 343 application of boundary fluxes. 344 """ 345
346 - def __init__(self, field, bfield, tag=hedge.mesh.TAG_ALL):
347 self.field = field 348 self.bfield = bfield 349 self.tag = tag
350
351 - def get_mapper_method(self, mapper):
352 return mapper.map_boundary_pair
353
354 - def stringifier(self):
355 return StringifyMapper
356
357 - def __getinitargs__(self):
358 return (self.field, self.bfield, self.tag)
359
360 - def get_hash(self):
361 from hedge.tools import hashable_field 362 363 return hash((self.__class__, 364 hashable_field(self.field), 365 hashable_field(self.bfield), 366 self.tag))
367
368 - def is_equal(self, other):
369 from hedge.tools import field_equal 370 return (self.__class__ == other.__class__ 371 and field_equal(other.field, self.field) 372 and field_equal(other.bfield, self.bfield) 373 and other.tag == self.tag)
374
375 376 377 378 379 -def pair_with_boundary(field, bfield, tag=hedge.mesh.TAG_ALL):
380 if tag is hedge.mesh.TAG_NONE: 381 return 0 382 else: 383 return BoundaryPair(field, bfield, tag)
384
385 386 387 388 # flux-like operators --------------------------------------------------------- 389 -class FluxOperatorBase(Operator):
390 - def __init__(self, flux):
391 Operator.__init__(self) 392 self.flux = flux
393
394 - def __getinitargs__(self):
395 return (self.flux, )
396
397 - def get_hash(self):
398 return hash((self.__class__, self.flux))
399
400 - def is_equal(self, other):
401 return (self.__class__ == other.__class__ 402 and self.flux == other.flux)
403
404 - def __mul__(self, arg):
405 from hedge.tools import is_obj_array 406 if isinstance(arg, Field) or is_obj_array(arg): 407 return OperatorBinding(self, arg) 408 else: 409 return Operator.__mul__(self, arg)
410
411 412 413 414 -class FluxOperator(FluxOperatorBase):
415 - def get_mapper_method(self, mapper):
416 return mapper.map_flux
417
418 419 420 -class LiftingFluxOperator(FluxOperatorBase):
421 - def get_mapper_method(self, mapper):
422 return mapper.map_lift
423
424 425 426 -class VectorFluxOperator(object):
427 - def __init__(self, fluxes):
428 self.fluxes = fluxes
429
430 - def __mul__(self, arg):
431 if isinstance(arg, int) and arg == 0: 432 return 0 433 from hedge.tools import make_obj_array 434 return make_obj_array( 435 [OperatorBinding(FluxOperator(f), arg) 436 for f in self.fluxes])
437
438 439 440 441 442 # convenience functions ------------------------------------------------------- 443 -def make_vector_field(name, components):
444 if isinstance(components, int): 445 components = range(components) 446 447 from hedge.tools import join_fields 448 vfld = pymbolic.primitives.Variable(name) 449 return join_fields(*[vfld[i] for i in components])
450
451 452 453 454 -def get_flux_operator(flux):
455 """Return a flux operator that can be multiplied with 456 a volume field to obtain the lifted interior fluxes 457 or with a boundary pair to obtain the lifted boundary 458 flux. 459 """ 460 from hedge.tools import is_obj_array 461 462 if is_obj_array(flux): 463 return VectorFluxOperator(flux) 464 else: 465 return FluxOperator(flux)
466
467 468 469 470 -def make_nabla(dim):
471 from hedge.tools import make_obj_array 472 return make_obj_array( 473 [DifferentiationOperator(i) for i in range(dim)])
474
475 -def make_minv_stiffness_t(dim):
476 from hedge.tools import make_obj_array 477 return make_obj_array( 478 [MInvSTOperator(i) for i in range(dim)])
479
480 -def make_stiffness(dim):
481 from hedge.tools import make_obj_array 482 return make_obj_array( 483 [StiffnessOperator(i) for i in range(dim)])
484
485 -def make_stiffness_t(dim):
486 from hedge.tools import make_obj_array 487 return make_obj_array( 488 [StiffnessTOperator(i) for i in range(dim)])
489
490 491 492 493 # mappers --------------------------------------------------------------------- 494 -class LocalOpReducerMixin(object):
495 """Reduces calls to mapper methods for all local differentiation 496 operators to a single mapper method, and likewise for mass 497 operators. 498 """
499 - def map_diff(self, expr, *args, **kwargs):
500 return self.map_diff_base(expr, *args, **kwargs)
501
502 - def map_minv_st(self, expr, *args, **kwargs):
503 return self.map_diff_base(expr, *args, **kwargs)
504
505 - def map_stiffness(self, expr, *args, **kwargs):
506 return self.map_diff_base(expr, *args, **kwargs)
507
508 - def map_stiffness_t(self, expr, *args, **kwargs):
509 return self.map_diff_base(expr, *args, **kwargs)
510
511 - def map_mass(self, expr, *args, **kwargs):
512 return self.map_mass_base(expr, *args, **kwargs)
513
514 - def map_inverse_mass(self, expr, *args, **kwargs):
515 return self.map_mass_base(expr, *args, **kwargs)
516
517 518 519 520 -class FluxOpReducerMixin(object):
521 """Reduces calls to mapper methods for all flux 522 operators to a smaller number of mapper methods. 523 """
524 - def map_flux(self, expr, *args, **kwargs):
525 return self.map_flux_base(expr, *args, **kwargs)
526
527 - def map_lift(self, expr, *args, **kwargs):
528 return self.map_flux_base(expr, *args, **kwargs)
529
530 531 532 533 -class OperatorReducerMixin(LocalOpReducerMixin, FluxOpReducerMixin):
534 """Reduces calls to *any* operator mapping function to just one."""
535 - def map_diff_base(self, expr, *args, **kwargs):
536 return self.map_operator(expr, *args, **kwargs)
537 538 map_mass_base = map_diff_base 539 map_flux_base = map_diff_base 540 map_elementwise_max = map_diff_base 541 map_boundarize = map_diff_base 542 map_flux_exchange = map_diff_base
543
544 545 546 547 -class CombineMapperMixin(object):
548 - def map_operator_binding(self, expr):
549 return self.combine([self.rec(expr.op), self.rec(expr.field)])
550
551 - def map_boundary_pair(self, expr):
552 return self.combine([self.rec(expr.field), self.rec(expr.bfield)])
553
554 555 556 557 -class CombineMapper(CombineMapperMixin, pymbolic.mapper.CombineMapper):
558 pass
559
560 561 562 563 -class IdentityMapperMixin(LocalOpReducerMixin, FluxOpReducerMixin):
564 - def map_operator_binding(self, expr, *args, **kwargs):
565 assert not isinstance(self, BoundOpMapperMixin), \ 566 "IdentityMapper instances cannot be combined with " \ 567 "the BoundOpMapperMixin" 568 569 return expr.__class__( 570 self.rec(expr.op, *args, **kwargs), 571 self.rec(expr.field, *args, **kwargs))
572
573 - def map_boundary_pair(self, expr, *args, **kwargs):
574 assert not isinstance(self, BoundOpMapperMixin), \ 575 "IdentityMapper instances cannot be combined with " \ 576 "the BoundOpMapperMixin" 577 578 return expr.__class__( 579 self.rec(expr.field, *args, **kwargs), 580 self.rec(expr.bfield, *args, **kwargs), 581 expr.tag)
582
583 - def map_mass_base(self, expr, *args, **kwargs):
584 assert not isinstance(self, BoundOpMapperMixin), \ 585 "IdentityMapper instances cannot be combined with " \ 586 "the BoundOpMapperMixin" 587 588 # it's a leaf--no changing children 589 return expr
590
591 - def map_scalar_parameter(self, expr, *args, **kwargs):
592 # it's a leaf--no changing children 593 return expr
594 595 map_diff_base = map_mass_base 596 map_flux_base = map_mass_base 597 map_elementwise_max = map_mass_base 598 map_boundarize = map_mass_base 599 map_flux_exchange = map_mass_base 600 601 map_normal_component = map_mass_base
602
603 604 605 606 -class DependencyMapper( 607 CombineMapperMixin, 608 pymbolic.mapper.dependency.DependencyMapper, 609 OperatorReducerMixin):
610 - def __init__(self, 611 include_operator_bindings=True, 612 composite_leaves=None, 613 **kwargs):
614 if composite_leaves == False: 615 include_operator_bindings = False 616 if composite_leaves == True: 617 include_operator_bindings = True 618 619 pymbolic.mapper.dependency.DependencyMapper.__init__(self, 620 composite_leaves=composite_leaves, **kwargs) 621 622 self.include_operator_bindings = include_operator_bindings
623
624 - def map_operator_binding(self, expr):
625 if self.include_operator_bindings: 626 return set([expr]) 627 else: 628 return CombineMapperMixin.map_operator_binding(self, expr)
629
630 - def map_operator(self, expr):
631 return set()
632
633 - def map_scalar_parameter(self, expr):
634 return set([expr])
635
636 - def map_normal_component(self, expr):
637 return set()
638
639 640 641 -class FlopCounter( 642 CombineMapperMixin, 643 pymbolic.mapper.flop_counter.FlopCounter):
644 - def map_operator_binding(self, expr):
645 return self.rec(expr.field)
646
647 - def map_scalar_parameter(self, expr):
648 return 0
649
650 651 652 653 -class CommutativeConstantFoldingMapper( 654 pymbolic.mapper.constant_folder.CommutativeConstantFoldingMapper, 655 IdentityMapperMixin):
656
657 - def __init__(self):
658 pymbolic.mapper.constant_folder.CommutativeConstantFoldingMapper.__init__(self) 659 self.dep_mapper = DependencyMapper()
660
661 - def is_constant(self, expr):
662 return not bool(self.dep_mapper(expr))
663
664 665 666 667 -class IdentityMapper( 668 IdentityMapperMixin, 669 pymbolic.mapper.IdentityMapper):
670 pass
671
672 673 674 675 676 -class SubstitutionMapper(pymbolic.mapper.substitutor.SubstitutionMapper, 677 IdentityMapperMixin):
678 pass
679
680 681 682 683 -class StringifyMapper(pymbolic.mapper.stringifier.StringifyMapper):
684 - def map_boundary_pair(self, expr, enclosing_prec):
685 return "BPair(%s, %s, %s)" % (expr.field, expr.bfield, repr(expr.tag))
686
687 - def map_diff(self, expr, enclosing_prec):
688 return "Diff%d" % expr.xyz_axis
689
690 - def map_minv_st(self, expr, enclosing_prec):
691 return "MInvST%d" % expr.xyz_axis
692
693 - def map_stiffness(self, expr, enclosing_prec):
694 return "Stiff%d" % expr.xyz_axis
695
696 - def map_stiffness_t(self, expr, enclosing_prec):
697 return "StiffT%d" % expr.xyz_axis
698
699 - def map_mass(self, expr, enclosing_prec):
700 return "M"
701
702 - def map_inverse_mass(self, expr, enclosing_prec):
703 return "InvM"
704
705 - def map_flux(self, expr, enclosing_prec):
706 return "Flux(%s)" % expr.flux
707
708 - def map_lift(self, expr, enclosing_prec):
709 return "Lift(%s)" % expr.flux
710
711 - def map_elementwise_max(self, expr, enclosing_prec):
712 return "ElWMax"
713
714 - def map_boundarize(self, expr, enclosing_prec):
715 return "Boundarize<tag=%s>" % expr.tag
716
717 - def map_flux_exchange(self, expr, enclosing_prec):
718 return "FExch<idx=%d,rank=%d>" % (expr.index, expr.rank)
719
720 - def map_normal_component(self, expr, enclosing_prec):
721 return "Normal<tag=%s>[%d]" % (expr.tag, expr.axis)
722
723 - def map_operator_binding(self, expr, enclosing_prec):
724 return "<%s>(%s)" % (expr.op, expr.field)
725
726 - def map_scalar_parameter(self, expr, enclosing_prec):
727 return "ScalarPar[%s]" % expr.name
728
729 730 731 732 733 734 -class NoCSEStringifyMapper(StringifyMapper):
735 - def map_common_subexpression(self, expr, enclosing_prec):
736 return self.rec(expr.child, enclosing_prec)
737
738 739 740 741 -class BoundOpMapperMixin(object):
742 - def map_operator_binding(self, expr, *args, **kwargs):
743 return expr.op.get_mapper_method(self)(expr.op, expr.field, *args, **kwargs)
744
745 746 747 748 -class EmptyFluxKiller(CSECachingMapperMixin, IdentityMapper):
749 - def __init__(self, discr):
750 IdentityMapper.__init__(self) 751 self.discr = discr
752 753 map_common_subexpression_uncached = \ 754 IdentityMapper.map_common_subexpression 755
756 - def map_operator_binding(self, expr):
757 if (isinstance(expr.op, ( 758 FluxOperatorBase, 759 LiftingFluxOperator)) 760 and 761 isinstance(expr.field, BoundaryPair) 762 and 763 len(self.discr.get_boundary(expr.field.tag).nodes) == 0): 764 return 0 765 else: 766 return IdentityMapper.map_operator_binding(self, expr)
767
768 769 770 771 -class OperatorBinder(CSECachingMapperMixin, IdentityMapper):
772 map_common_subexpression_uncached = \ 773 IdentityMapper.map_common_subexpression 774
775 - def map_product(self, expr):
776 if len(expr.children) == 0: 777 return expr 778 779 from pymbolic.primitives import flattened_product 780 first = expr.children[0] 781 if isinstance(first, Operator): 782 return OperatorBinding(first, 783 self.rec(flattened_product(expr.children[1:]))) 784 else: 785 return first * self.rec(flattened_product(expr.children[1:]))
786
787 788 789 790 -class _InnerInverseMassContractor(pymbolic.mapper.RecursiveMapper):
791 - def map_constant(self, expr):
792 return OperatorBinding( 793 InverseMassOperator(), 794 expr)
795
796 - def map_algebraic_leaf(self, expr):
797 return OperatorBinding( 798 InverseMassOperator(), 799 expr)
800
801 - def map_operator_binding(self, binding):
802 if isinstance(binding.op, MassOperator): 803 return binding.field 804 elif isinstance(binding.op, StiffnessOperator): 805 return OperatorBinding( 806 DifferentiationOperator(binding.op.xyz_axis), 807 binding.field) 808 elif isinstance(binding.op, StiffnessTOperator): 809 return OperatorBinding( 810 MInvSTOperator(binding.op.xyz_axis), 811 binding.field) 812 elif isinstance(binding.op, FluxOperator): 813 return OperatorBinding( 814 LiftingFluxOperator(binding.op.flux), 815 binding.field) 816 else: 817 return OperatorBinding( 818 InverseMassOperator(), 819 binding)
820
821 - def map_sum(self, expr):
822 return expr.__class__(tuple(self.rec(child) for child in expr.children))
823
824 - def map_product(self, expr):
825 def is_scalar(expr): 826 return isinstance(expr, (int, float, complex))
827 828 from pytools import len_iterable 829 nonscalar_count = len_iterable(ch 830 for ch in expr.children 831 if not is_scalar(ch)) 832 833 if nonscalar_count > 1: 834 # too complicated, don't touch it 835 return expr 836 else: 837 def do_map(expr): 838 if is_scalar(expr): 839 return expr 840 else: 841 return self.rec(expr)
842 return expr.__class__(tuple( 843 do_map(child) for child in expr.children)) 844
845 846 847 848 849 -class InverseMassContractor(CSECachingMapperMixin, IdentityMapper):
850 # assumes all operators to be bound 851 map_common_subexpression_uncached = \ 852 IdentityMapper.map_common_subexpression 853
854 - def map_boundary_pair(self, bp):
855 return BoundaryPair(self.rec(bp.field), self.rec(bp.bfield), bp.tag)
856
857 - def map_operator_binding(self, binding):
858 # we only care about bindings of inverse mass operators 859 if not isinstance(binding.op, InverseMassOperator): 860 return binding.__class__(binding.op, 861 self.rec(binding.field)) 862 else: 863 return _InnerInverseMassContractor()(binding.field)
864
865 866 867 868 # BC-to-flux rewriting -------------------------------------------------------- 869 -class BCToFluxRewriter(CSECachingMapperMixin, IdentityMapper):
870 """Operates on L{FluxOperator} instances bound to L{BoundaryPair}s. If the 871 boundary pair's C{bfield} is an expression of what's available in the 872 C{field}, we can avoid fetching the data for the explicit boundary 873 condition and just substitute the C{bfield} expression into the flux. This 874 mapper does exactly that. 875 """ 876 877 map_common_subexpression_uncached = \ 878 IdentityMapper.map_common_subexpression 879
880 - def map_operator_binding(self, expr):
881 if not (isinstance(expr.op, FluxOperator) 882 and isinstance(expr.field, BoundaryPair)): 883 return IdentityMapper.map_operator_binding(self, expr) 884 885 bpair = expr.field 886 vol_field = bpair.field 887 bdry_field = bpair.bfield 888 flux = expr.op.flux 889 890 bdry_dependencies = DependencyMapper( 891 include_calls="descend_args", 892 include_operator_bindings=True)(bdry_field) 893 894 vol_dependencies = DependencyMapper( 895 include_operator_bindings=True)(vol_field) 896 897 vol_bdry_intersection = bdry_dependencies & vol_dependencies 898 if vol_bdry_intersection: 899 raise RuntimeError("Variables are being used as both " 900 "boundary and volume quantities: %s" 901 % ", ".join(str(v) for v in vol_bdry_intersection)) 902 903 # Step 1: Find maximal flux-evaluable subexpression of bounary field 904 # in given BoundaryPair. 905 906 class MaxBoundaryFluxEvaluableExpressionFinder( 907 IdentityMapper, OperatorReducerMixin): 908 def __init__(self, vol_expr_list): 909 self.vol_expr_list = vol_expr_list 910 self.vol_expr_to_idx = dict((vol_expr, idx) 911 for idx, vol_expr in enumerate(vol_expr_list)) 912 913 self.bdry_expr_list = [] 914 self.bdry_expr_to_idx = {}
915 916 def register_boundary_expr(self, expr): 917 try: 918 return self.bdry_expr_to_idx[expr] 919 except KeyError: 920 idx = len(self.bdry_expr_to_idx) 921 self.bdry_expr_to_idx[expr] = idx 922 self.bdry_expr_list.append(expr) 923 return idx
924 925 def register_volume_expr(self, expr): 926 try: 927 return self.vol_expr_to_idx[expr] 928 except KeyError: 929 idx = len(self.vol_expr_to_idx) 930 self.vol_expr_to_idx[expr] = idx 931 self.vol_expr_list.append(expr) 932 return idx 933 934 def map_normal_component(self, expr): 935 if expr.tag != bpair.tag: 936 raise RuntimeError("BoundaryNormalComponent and BoundaryPair " 937 "do not agree about boundary tag: %s vs %s" 938 % (expr.tag, bpair.tag)) 939 940 from hedge.flux import Normal 941 return Normal(expr.axis) 942 943 def map_variable(self, expr): 944 from hedge.flux import FieldComponent 945 return FieldComponent( 946 self.register_boundary_expr(expr), 947 is_local=False) 948 949 map_subscript = map_variable 950 951 def map_operator_binding(self, expr): 952 from hedge.flux import FieldComponent 953 if isinstance(expr.op, BoundarizeOperator): 954 if expr.op.tag != bpair.tag: 955 raise RuntimeError("BoundarizeOperator and BoundaryPair " 956 "do not agree about boundary tag: %s vs %s" 957 % (expr.op.tag, bpair.tag)) 958 959 return FieldComponent( 960 self.register_volume_expr(expr.field), 961 is_local=True) 962 elif isinstance(expr.op, FluxExchangeOperator): 963 from hedge.mesh import TAG_RANK_BOUNDARY 964 op_tag = TAG_RANK_BOUNDARY(expr.op.rank) 965 if bpair.tag != op_tag: 966 raise RuntimeError("BoundarizeOperator and FluxExchangeOperator " 967 "do not agree about boundary tag: %s vs %s" 968 % (op_tag, bpair.tag)) 969 return FieldComponent( 970 self.register_boundary_expr(expr), 971 is_local=False) 972 else: 973 raise RuntimeError("Found '%s' in a boundary term. " 974 "To the best of my knowledge, no hedge operator applies " 975 "directly to boundary data, so this is likely in error." 976 % expr.op) 977 978 from hedge.tools import is_obj_array 979 if not is_obj_array(vol_field): 980 vol_field = [vol_field] 981 982 mbfeef = MaxBoundaryFluxEvaluableExpressionFinder(vol_field) 983 new_bdry_field = mbfeef(bdry_field) 984 985 # Step II: Substitute the new_bdry_field into the flux. 986 from hedge.flux import FluxSubstitutionMapper, FieldComponent 987 988 def sub_bdry_into_flux(expr): 989 if isinstance(expr, FieldComponent) and not expr.is_local: 990 if expr.index == 0 and not is_obj_array(bdry_field): 991 return new_bdry_field 992 else: 993 return new_bdry_field[expr.index] 994 else: 995 return None 996 997 new_flux = FluxSubstitutionMapper( 998 sub_bdry_into_flux)(flux) 999 1000 from hedge.tools import is_zero 1001 if is_zero(new_flux): 1002 return 0 1003 else: 1004 return OperatorBinding( 1005 FluxOperator(new_flux), BoundaryPair( 1006 numpy.array(mbfeef.vol_expr_list, dtype=object), 1007 numpy.array(mbfeef.bdry_expr_list, dtype=object), 1008 bpair.tag)) 1009
1010 1011 1012 1013 # collecting ------------------------------------------------------------------ 1014 -class CollectorMixin(LocalOpReducerMixin, FluxOpReducerMixin):
1015 - def combine(self, values):
1016 from pytools import flatten 1017 return set(flatten(values))
1018
1019 - def map_constant(self, bpair):
1020 return set()
1021
1022 - def map_mass_base(self, expr):
1023 return set()
1024
1025 - def map_diff_base(self, expr):
1026 return set()
1027
1028 - def map_flux_base(self, expr):
1029 return set()
1030
1031 - def map_variable(self, expr):
1032 return set()
1033
1034 - def map_normal_component(self, expr):
1035 return set()
1036
1037 - def map_scalar_parameter(self, expr):
1038 return set()
1039
1040 1041 1042 1043 -class FluxCollector(CSECachingMapperMixin, CollectorMixin, CombineMapper):
1044 map_common_subexpression_uncached = \ 1045 CombineMapper.map_common_subexpression 1046
1047 - def map_operator_binding(self, expr):
1048 if isinstance(expr.op, ( 1049 FluxOperatorBase)): 1050 result = set([expr]) 1051 else: 1052 result = set() 1053 1054 return result | self.rec(expr.field)
1055
1056 1057 1058 1059 -class BoundaryTagCollector(CollectorMixin, CombineMapper):
1060 - def map_boundary_pair(self, bpair):
1061 return set([bpair.tag])
1062
1063 1064 1065 1066 -class BoundOperatorCollector(CSECachingMapperMixin, CollectorMixin, CombineMapper):
1067 - def __init__(self, op_class):
1068 self.op_class = op_class
1069 1070 map_common_subexpression_uncached = \ 1071 CombineMapper.map_common_subexpression 1072
1073 - def map_operator_binding(self, expr):
1074 if isinstance(expr.op, self.op_class): 1075 result = set([expr]) 1076 else: 1077 result = set() 1078 1079 return result | self.rec(expr.field)
1080
1081 1082 1083 # evaluation ------------------------------------------------------------------ 1084 -class Evaluator(pymbolic.mapper.evaluator.EvaluationMapper):
1085 - def map_boundary_pair(self, bp):
1086 return BoundaryPair(self.rec(bp.field), self.rec(bp.bfield), bp.tag)
1087
1088 1089 1090 1091 # optemplate tools ------------------------------------------------------------ 1092 -def split_optemplate_for_multirate(state_vector, op_template, 1093 index_groups):
1094 class IndexGroupKillerSubstMap: 1095 def __init__(self, kill_set): 1096 self.kill_set = kill_set
1097 1098 def __call__(self, expr): 1099 if expr in kill_set: 1100 return 0 1101 else: 1102 return None 1103 1104 # make IndexGroupKillerSubstMap that kill everything 1105 # *except* what's in that index group 1106 killers = [] 1107 for i in range(len(index_groups)): 1108 kill_set = set() 1109 for j in range(len(index_groups)): 1110 if i != j: 1111 kill_set |= set(index_groups[j]) 1112 1113 killers.append(IndexGroupKillerSubstMap(kill_set)) 1114 1115 from hedge.optemplate import \ 1116 SubstitutionMapper, \ 1117 CommutativeConstantFoldingMapper 1118 1119 return [ 1120 CommutativeConstantFoldingMapper()( 1121 SubstitutionMapper(killer)( 1122 op_template[ig])) 1123 for ig in index_groups 1124 for killer in killers] 1125