Package hedge :: Module compiler
[hide private]
[frames] | no frames]

Source Code for Module hedge.compiler

  1  """Compiler to turn operator expression tree into (imperative) bytecode.""" 
  2   
  3  from __future__ import division 
  4   
  5  __copyright__ = "Copyright (C) 2008 Andreas Kloeckner" 
  6   
  7  __license__ = """ 
  8  Permission is hereby granted, free of charge, to any person 
  9  obtaining a copy of this software and associated documentation 
 10  files (the "Software"), to deal in the Software without 
 11  restriction, including without limitation the rights to use, 
 12  copy, modify, merge, publish, distribute, sublicense, and/or sell 
 13  copies of the Software, and to permit persons to whom the 
 14  Software is furnished to do so, subject to the following 
 15  conditions: 
 16   
 17  The above copyright notice and this permission notice shall be 
 18  included in all copies or substantial portions of the Software. 
 19   
 20  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 
 21  EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 
 22  OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 
 23  NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 
 24  HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 
 25  WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 
 26  FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 
 27  OTHER DEALINGS IN THE SOFTWARE. 
 28  """ 
 29   
 30   
 31   
 32   
 33  from pytools import Record, memoize_method 
 34  from hedge.optemplate import IdentityMapper 
35 36 37 38 39 # instructions ---------------------------------------------------------------- 40 -class Instruction(Record):
41 __slots__ = ["dep_mapper_factory"] 42 priority = 0 43
44 - def get_assignees(self):
45 raise NotImplementedError("no get_assignees in %s" % self.__class__)
46
47 - def get_dependencies(self):
48 raise NotImplementedError("no get_dependencies in %s" % self.__class__)
49
50 - def __str__(self):
51 raise NotImplementedError
52
53 - def get_executor_method(self, executor):
54 raise NotImplementedError
55
56 -class Assign(Instruction):
57 # attributes: names, exprs, do_not_return, priority 58 # 59 # do_not_return is a list of bools indicating whether the corresponding 60 # entry in names and exprs describes an expression that is not needed 61 # beyond this assignment 62 63 comment = "" 64
65 - def __init__(self, names, exprs, **kwargs):
66 Instruction.__init__(self, names=names, exprs=exprs, **kwargs) 67 68 if not hasattr(self, "do_not_return"): 69 self.do_not_return = [False] * len(names)
70 71 @memoize_method
72 - def flop_count(self):
73 from hedge.optemplate import FlopCounter 74 return sum(FlopCounter()(expr) for expr in self.exprs)
75
76 - def get_assignees(self):
77 return set(self.names)
78
79 - def get_dependencies(self, each_vector=False):
80 try: 81 if each_vector: 82 raise AttributeError 83 else: 84 return self._dependencies 85 except: 86 # arg is include_subscripts 87 dep_mapper = self.dep_mapper_factory(each_vector) 88 89 from operator import or_ 90 deps = reduce( 91 or_, (dep_mapper(expr) 92 for expr in self.exprs)) 93 94 from pymbolic.primitives import Variable 95 deps -= set(Variable(name) for name in self.names) 96 97 if not each_vector: 98 self._dependencies = deps 99 100 return deps
101
102 - def __str__(self):
103 comment = self.comment 104 if len(self.names) == 1: 105 if comment: 106 comment = "/* %s */ " % comment 107 108 return "%s <- %s%s" % (self.names[0], comment, self.exprs[0]) 109 else: 110 if comment: 111 comment = " /* %s */" % comment 112 113 lines = [] 114 lines.append("{"+comment) 115 for n, e, dnr in zip(self.names, self.exprs, self.do_not_return): 116 if dnr: 117 dnr_indicator = "-i" 118 else: 119 dnr_indicator = "" 120 121 lines.append(" %s <%s- %s" % (n, dnr_indicator, e)) 122 lines.append("}") 123 return "\n".join(lines)
124
125 - def get_executor_method(self, executor):
126 return executor.exec_assign
127
128 -class FluxBatchAssign(Instruction):
129 __slots__ = ["names", "fluxes", "kind"] 130
131 - def get_assignees(self):
132 return set(self.names)
133
134 - def __str__(self):
135 lines = [] 136 lines.append("{ /* %s */" % self.kind) 137 for n, f in zip(self.names, self.fluxes): 138 lines.append(" %s <- %s" % (n, f)) 139 lines.append("}") 140 return "\n".join(lines)
141
142 - def get_executor_method(self, executor):
143 return executor.exec_flux_batch_assign
144
145 -class DiffBatchAssign(Instruction):
146 # attributes: names, op_class, operators, field 147
148 - def get_assignees(self):
149 return set(self.names)
150 151 @memoize_method
152 - def get_dependencies(self):
153 return self.dep_mapper_factory()(self.field)
154
155 - def __str__(self):
156 lines = [] 157 158 if len(self.names) > 1: 159 lines.append("{") 160 for n, d in zip(self.names, self.operators): 161 lines.append(" %s <- %s * %s" % (n, d, self.field)) 162 lines.append("}") 163 else: 164 for n, d in zip(self.names, self.operators): 165 lines.append("%s <- %s * %s" % (n, d, self.field)) 166 167 return "\n".join(lines)
168
169 - def get_executor_method(self, executor):
170 return executor.exec_diff_batch_assign
171
172 -class MassAssign(Instruction):
173 __slots__ = ["name", "op_class", "field"] 174
175 - def get_assignees(self):
176 return set([self.name])
177
178 - def get_dependencies(self):
179 return self.dep_mapper_factory()(self.field)
180
181 - def __str__(self):
182 return "%s <- %s * %s" % ( 183 self.name, 184 str(self.op_class()), 185 self.field)
186
187 - def get_executor_method(self, executor):
188 return executor.exec_mass_assign
189
190 191 -class FluxExchangeBatchAssign(Instruction):
192 __slots__ = ["names", "indices_and_ranks", "rank_to_index_and_name", "field"] 193 priority = 1 194
195 - def __init__(self, names, indices_and_ranks, field, dep_mapper_factory):
208
209 - def get_assignees(self):
210 return set(self.names)
211
212 - def get_dependencies(self):
213 return self.dep_mapper_factory()(self.field)
214
215 - def __str__(self):
216 lines = [] 217 218 lines.append("{") 219 for n, (index, rank) in zip(self.names, self.indices_and_ranks): 220 lines.append(" %s <- receive index %s from rank %d [%s]" % ( 221 n, index, rank, self.field)) 222 lines.append("}") 223 224 return "\n".join(lines)
225
226 - def get_executor_method(self, executor):
227 return executor.exec_flux_exchange_batch_assign
228
229 230 231 232 233 -def dot_dataflow_graph(code, max_node_label_length=30):
234 origins = {} 235 node_names = {} 236 237 result = [ 238 "initial [label=\"initial\"]" 239 "result [label=\"result\"]" 240 ] 241 242 for num, insn in enumerate(code.instructions): 243 node_name = "node%d" % num 244 node_names[insn] = node_name 245 node_label = str(insn).replace("\n", "\\l")+"\\l" 246 247 if max_node_label_length is not None: 248 node_label = node_label[:max_node_label_length] 249 250 result.append("%s [ label=\"p%d: %s\" shape=box ];" % ( 251 node_name, insn.priority, node_label)) 252 253 for assignee in insn.get_assignees(): 254 origins[assignee] = node_name 255 256 def get_orig_node(expr): 257 from pymbolic.primitives import Variable 258 if isinstance(expr, Variable): 259 return origins.get(expr.name, "initial") 260 else: 261 return "initial"
262 263 def gen_expr_arrow(expr, target_node): 264 result.append("%s -> %s [label=\"%s\"];" 265 % (get_orig_node(expr), target_node, expr)) 266 267 for insn in code.instructions: 268 for dep in insn.get_dependencies(): 269 gen_expr_arrow(dep, node_names[insn]) 270 271 from hedge.tools import is_obj_array 272 273 if is_obj_array(code.result): 274 for subexp in code.result: 275 gen_expr_arrow(subexp, "result") 276 else: 277 gen_expr_arrow(code.result, "result") 278 279 return "digraph dataflow {\n%s\n}\n" % "\n".join(result) 280
281 282 283 284 285 # code ------------------------------------------------------------------------ 286 -class Code(object):
287 - def __init__(self, instructions, result):
288 self.instructions = instructions 289 self.result = result
290
291 - def dump_dataflow_graph(self):
292 from hedge.tools import get_rank 293 from hedge.compiler import dot_dataflow_graph 294 i = 0 295 while True: 296 dot_name = "dataflow-%d.dot" % i 297 from os.path import exists 298 if exists(dot_name): 299 i += 1 300 continue 301 302 open(dot_name, "w").write( 303 dot_dataflow_graph(self, max_node_label_length=None)) 304 break
305 306
307 - class NoInstructionAvailable(Exception):
308 pass
309 310 @memoize_method
311 - def get_next_step(self, available_names, done_insns):
312 from pytools import all, argmax2 313 available_insns = [ 314 (insn, insn.priority) for insn in self.instructions 315 if insn not in done_insns 316 and all(dep.name in available_names 317 for dep in insn.get_dependencies())] 318 319 if not available_insns: 320 raise self.NoInstructionAvailable 321 322 from pytools import flatten 323 discardable_vars = set(available_names) - set(flatten( 324 [dep.name for dep in insn.get_dependencies()] 325 for insn in self.instructions 326 if insn not in done_insns )) 327 328 from hedge.tools import with_object_array_or_scalar 329 with_object_array_or_scalar( 330 lambda var: discardable_vars.discard(var.name), 331 self.result) 332 333 return argmax2(available_insns), discardable_vars
334
335 - def __str__(self):
336 lines = [] 337 for insn in self.instructions: 338 lines.extend(str(insn).split("\n")) 339 lines.append(str(self.result)) 340 341 return "\n".join(lines)
342
343 - def execute(self, exec_mapper):
344 context = exec_mapper.context 345 346 futures = [] 347 done_insns = set() 348 349 quit_flag = False 350 force_future = False 351 while not quit_flag: 352 # check futures for completion 353 i = 0 354 while i < len(futures): 355 future = futures[i] 356 if force_future or future.is_ready(): 357 assignments, new_futures = future() 358 for target, value in assignments: 359 context[target] = value 360 futures.extend(new_futures) 361 futures.pop(i) 362 force_future = False 363 else: 364 i += 1 365 366 del future 367 368 # pick the next insn 369 try: 370 insn, discardable_vars = self.get_next_step( 371 frozenset(context.keys()), 372 frozenset(done_insns)) 373 except self.NoInstructionAvailable: 374 if futures: 375 # no insn ready: we need a future to complete to continue 376 force_future = True 377 else: 378 # no futures, no available instructions: we're done 379 quit_flag = True 380 else: 381 for name in discardable_vars: 382 del context[name] 383 384 done_insns.add(insn) 385 assignments, new_futures = \ 386 insn.get_executor_method(exec_mapper)(insn) 387 for target, value in assignments: 388 context[target] = value 389 390 futures.extend(new_futures) 391 392 if len(done_insns) < len(self.instructions): 393 print "Unreachable instructions:" 394 for insn in set(self.instructions) - done_insns: 395 print " ", insn 396 397 raise RuntimeError("not all instructions are reachable" 398 "--did you forget to pass a value for a placeholder?") 399 400 from hedge.tools import with_object_array_or_scalar 401 return with_object_array_or_scalar(exec_mapper, self.result)
402
403 404 405 406 # compiler -------------------------------------------------------------------- 407 -class OperatorCompilerBase(IdentityMapper):
408 from hedge.optemplate import BoundOperatorCollector \ 409 as bound_op_collector_class 410
411 - class FluxRecord(Record):
412 __slots__ = ["flux_expr", "dependencies", "kind"]
413
414 - class FluxBatch(Record):
415 __slots__ = ["flux_exprs", "kind"]
416
417 - def __init__(self, prefix="_expr", max_vectors_in_batch_expr=None):
418 IdentityMapper.__init__(self) 419 self.prefix = prefix 420 421 self.max_vectors_in_batch_expr = max_vectors_in_batch_expr 422 423 self.code = [] 424 self.assigned_var_count = 0 425 self.expr_to_var = {}
426 427 @memoize_method
428 - def dep_mapper_factory(self, include_subscripts=False):
429 from hedge.optemplate import DependencyMapper 430 self.dep_mapper = DependencyMapper( 431 include_operator_bindings=False, 432 include_subscripts=include_subscripts, 433 include_calls="descend_args") 434 435 return self.dep_mapper
436
437 - def get_contained_fluxes(self, expr):
438 """Recursively enumerate all flux expressions in the expression tree 439 `expr`. The returned list consists of `ExecutionPlanner.FluxRecord` 440 instances with fields `flux_expr` and `dependencies`. 441 """ 442 443 # overridden by subclasses 444 raise NotImplementedError
445
446 - def collect_diff_ops(self, expr):
447 from hedge.optemplate import DiffOperatorBase 448 return self.bound_op_collector_class(DiffOperatorBase)(expr)
449
450 - def collect_flux_exchange_ops(self, expr):
451 from hedge.optemplate import FluxExchangeOperator 452 return self.bound_op_collector_class(FluxExchangeOperator)(expr)
453
454 - def __call__(self, expr):
455 # Fluxes can be evaluated faster in batches. Here, we find flux batches 456 # that we can evaluate together. 457 458 # For each FluxRecord, find the other fluxes its flux depends on. 459 flux_queue = self.get_contained_fluxes(expr) 460 for fr in flux_queue: 461 fr.dependencies = set() 462 for d in fr.dependencies: 463 fr.dependencies |= set(sf.flux_expr 464 for sf in self.get_contained_fluxes(d)) 465 466 # Then figure out batches of fluxes to evaluate 467 self.flux_batches = [] 468 admissible_deps = set() 469 while flux_queue: 470 present_batch = set() 471 i = 0 472 while i < len(flux_queue): 473 fr = flux_queue[i] 474 if fr.dependencies <= admissible_deps: 475 present_batch.add(fr) 476 flux_queue.pop(i) 477 else: 478 i += 1 479 480 if present_batch: 481 482 batches_by_kind = {} 483 for fr in present_batch: 484 batches_by_kind[fr.kind] = \ 485 batches_by_kind.get(fr.kind, set()) | set([fr.flux_expr]) 486 487 for kind, batch in batches_by_kind.iteritems(): 488 self.flux_batches.append( 489 self.FluxBatch(kind=kind, flux_exprs=list(batch))) 490 491 admissible_deps |= present_batch 492 else: 493 raise RuntimeError, "cannot resolve flux evaluation order" 494 495 # Once flux batching is figured out, we also need to know which 496 # derivatives are going to be needed, because once the 497 # rst-derivatives are available, it's best to calculate the 498 # xyz ones and throw the rst ones out. It's naturally good if 499 # we can avoid computing (or storing) some of the xyz ones. 500 # So figure out which XYZ derivatives of what are needed. 501 502 self.diff_ops = self.collect_diff_ops(expr) 503 504 # Flux exchange also works better when batched. 505 self.flux_exchange_ops = self.collect_flux_exchange_ops(expr) 506 507 # Finally, walk the expression and build the code. 508 result = IdentityMapper.__call__(self, expr) 509 510 # Then, put the toplevel expressions into variables as well. 511 from hedge.tools import with_object_array_or_scalar 512 result = with_object_array_or_scalar(self.assign_to_new_var, result) 513 return Code(self.aggregate_assignments(self.code, result), result)
514
515 - def get_var_name(self):
516 new_name = self.prefix+str(self.assigned_var_count) 517 self.assigned_var_count += 1 518 return new_name
519
520 - def map_common_subexpression(self, expr):
521 try: 522 return self.expr_to_var[expr.child] 523 except KeyError: 524 priority = getattr(expr, "priority", 0) 525 cse_var = self.assign_to_new_var(self.rec(expr.child), 526 priority=priority) 527 self.expr_to_var[expr.child] = cse_var 528 return cse_var
529
530 - def map_operator_binding(self, expr):
531 from hedge.optemplate import \ 532 DiffOperatorBase, \ 533 MassOperatorBase, \ 534 FluxExchangeOperator, \ 535 OperatorBinding, \ 536 Field 537 if isinstance(expr.op, DiffOperatorBase): 538 return self.map_diff_op_binding(expr) 539 elif isinstance(expr.op, MassOperatorBase): 540 return self.map_mass_op_binding(expr) 541 elif isinstance(expr.op, FluxExchangeOperator): 542 return self.map_flux_exchange_op_binding(expr) 543 else: 544 field_var = self.assign_to_new_var( 545 self.rec(expr.field)) 546 result_var = self.assign_to_new_var( 547 OperatorBinding(expr.op, field_var)) 548 return result_var
549
550 - def map_diff_op_binding(self, expr):
551 try: 552 return self.expr_to_var[expr] 553 except KeyError: 554 all_diffs = [diff 555 for diff in self.diff_ops 556 if diff.op.__class__ == expr.op.__class__ 557 and diff.field == expr.field] 558 559 from pytools import single_valued 560 names = [self.get_var_name() for d in all_diffs] 561 self.code.append( 562 DiffBatchAssign( 563 names=names, 564 op_class=single_valued( 565 d.op.__class__ for d in all_diffs), 566 operators=[d.op for d in all_diffs], 567 field=self.rec(single_valued(d.field for d in all_diffs)), 568 dep_mapper_factory=self.dep_mapper_factory)) 569 570 from pymbolic import var 571 for n, d in zip(names, all_diffs): 572 self.expr_to_var[d] = var(n) 573 574 return self.expr_to_var[expr]
575
576 - def map_mass_op_binding(self, expr):
577 try: 578 return self.expr_to_var[expr] 579 except KeyError: 580 ma = MassAssign( 581 name=self.get_var_name(), 582 op_class=expr.op.__class__, 583 field=self.rec(expr.field), 584 dep_mapper_factory=self.dep_mapper_factory) 585 self.code.append(ma) 586 587 from pymbolic import var 588 v = var(ma.name) 589 self.expr_to_var[expr] = v 590 return v
591
592 - def map_flux_exchange_op_binding(self, expr):
593 try: 594 return self.expr_to_var[expr] 595 except KeyError: 596 from hedge.tools import is_field_equal 597 all_flux_xchgs = [fe 598 for fe in self.flux_exchange_ops 599 if is_field_equal(fe.field, expr.field)] 600 601 assert len(all_flux_xchgs) > 0 602 603 from pytools import single_valued 604 names = [self.get_var_name() for d in all_flux_xchgs] 605 self.code.append( 606 FluxExchangeBatchAssign( 607 names=names, 608 indices_and_ranks=[ 609 (fe.op.index, fe.op.rank) for fe in all_flux_xchgs], 610 field=self.rec( 611 single_valued( 612 (fe.field for fe in all_flux_xchgs), 613 equality_pred=is_field_equal)), 614 dep_mapper_factory=self.dep_mapper_factory)) 615 616 from pymbolic import var 617 for n, d in zip(names, all_flux_xchgs): 618 self.expr_to_var[d] = var(n) 619 620 return self.expr_to_var[expr]
621
622 - def map_planned_flux(self, expr):
623 try: 624 return self.expr_to_var[expr] 625 except KeyError: 626 for fb in self.flux_batches: 627 try: 628 idx = fb.flux_exprs.index(expr) 629 except ValueError: 630 pass 631 else: 632 # found at idx 633 mapped_fluxes = [self.internal_map_flux(f) for f in fb.flux_exprs] 634 635 names = [self.get_var_name() for f in mapped_fluxes] 636 self.code.append( 637 self.make_flux_batch_assign(names, mapped_fluxes, fb.kind)) 638 639 from pymbolic import var 640 for n, f in zip(names, fb.flux_exprs): 641 self.expr_to_var[f] = var(n) 642 643 return var(names[idx]) 644 645 raise RuntimeError("flux '%s' not in any flux batch" % expr)
646
647 - def assign_to_new_var(self, expr, priority=0):
648 from pymbolic.primitives import Variable 649 if isinstance(expr, Variable): 650 return expr 651 652 new_name = self.get_var_name() 653 self.code.append(self.make_assign(new_name, expr, priority)) 654 655 return Variable(new_name)
656 657 # instruction producers ---------------------------------------------------
658 - def make_assign(self, name, expr, priority):
662
663 - def make_flux_batch_assign(self, names, fluxes, kind):
665 666 # assignment aggregration pass --------------------------------------------
667 - def aggregate_assignments(self, instructions, result):
668 from pymbolic.primitives import Variable 669 670 # agregation helpers -------------------------------------------------- 671 def get_complete_origins_set(insn, skip_levels=0): 672 if skip_levels < 0: 673 skip_levels = 0 674 675 result = set() 676 for dep in insn.get_dependencies(): 677 if isinstance(dep, Variable): 678 dep_origin = origins_map.get(dep.name, None) 679 if dep_origin is not None: 680 if skip_levels <= 0: 681 result.add(dep_origin) 682 result |= get_complete_origins_set(dep_origin, skip_levels-1) 683 684 return result
685 686 var_assignees_cache = {} 687 def get_var_assignees(insn): 688 try: 689 return var_assignees_cache[insn] 690 except KeyError: 691 result = set(Variable(assignee) 692 for assignee in insn.get_assignees()) 693 var_assignees_cache[insn] = result 694 return result
695 696 def aggregate_two_assignments(ass_1, ass_2): 697 names = ass_1.names+ass_2.names 698 699 from pymbolic.primitives import Variable 700 deps = (ass_1.get_dependencies() | ass_2.get_dependencies()) \ 701 - set(Variable(name) for name in names) 702 703 return Assign( 704 names=names, exprs=ass_1.exprs+ass_2.exprs, 705 _dependencies=deps, 706 dep_mapper_factory=self.dep_mapper_factory, 707 priority=max(ass_1.priority, ass_2.priority)) 708 709 # main aggregation pass ----------------------------------------------- 710 origins_map = dict( 711 (assignee, insn) 712 for insn in instructions 713 for assignee in insn.get_assignees()) 714 715 from pytools import partition 716 unprocessed_assigns, other_insns = partition( 717 lambda insn: isinstance(insn, Assign), 718 instructions) 719 720 # filter out zero-flop-count assigns--no need to bother with those 721 processed_assigns, unprocessed_assigns = partition( 722 lambda ass: ass.flop_count() == 0, 723 unprocessed_assigns) 724 725 # greedy aggregation 726 while unprocessed_assigns: 727 my_assign = unprocessed_assigns.pop() 728 729 my_deps = my_assign.get_dependencies() 730 my_assignees = get_var_assignees(my_assign) 731 732 agg_candidates = [] 733 for i, other_assign in enumerate(unprocessed_assigns): 734 other_deps = other_assign.get_dependencies() 735 other_assignees = get_var_assignees(other_assign) 736 737 if ((my_deps & other_deps 738 or my_deps & other_assignees 739 or other_deps & my_assignees) 740 and my_assign.priority == other_assign.priority): 741 agg_candidates.append((i, other_assign)) 742 743 did_work = False 744 745 if agg_candidates: 746 my_indirect_origins = get_complete_origins_set( 747 my_assign, skip_levels=1) 748 749 for other_assign_index, other_assign in agg_candidates: 750 if self.max_vectors_in_batch_expr is not None: 751 new_assignee_count = len( 752 set(my_assign.get_assignees()) 753 | set(other_assign.get_assignees())) 754 new_dep_count = len( 755 my_assign.get_dependencies(each_vector=True) 756 | other_assign.get_dependencies(each_vector=True)) 757 758 if (new_assignee_count + new_dep_count \ 759 > self.max_vectors_in_batch_expr): 760 continue 761 762 other_indirect_origins = get_complete_origins_set( 763 other_assign, skip_levels=1) 764 765 if (my_assign not in other_indirect_origins and 766 other_assign not in my_indirect_origins): 767 did_work = True 768 769 # aggregate the two assignments 770 new_assignment = aggregate_two_assignments(my_assign, other_assign) 771 del unprocessed_assigns[other_assign_index] 772 unprocessed_assigns.append(new_assignment) 773 for assignee in new_assignment.get_assignees(): 774 origins_map[assignee] = new_assignment 775 776 break 777 778 if not did_work: 779 processed_assigns.append(my_assign) 780 781 externally_used_names = set( 782 expr 783 for insn in processed_assigns+other_insns 784 for expr in insn.get_dependencies()) 785 786 from hedge.tools import is_obj_array 787 if is_obj_array(result): 788 externally_used_names |= set(expr for expr in result) 789 else: 790 externally_used_names |= set([result]) 791 792 def schedule_and_finalize_assignment(ass): 793 dep_mapper = self.dep_mapper_factory() 794 795 names_exprs = zip(ass.names, ass.exprs) 796 797 my_assignees = set(name for name, expr in names_exprs) 798 names_exprs_deps = [ 799 (name, expr, 800 set(dep.name for dep in dep_mapper(expr) if 801 isinstance(dep, Variable)) & my_assignees) 802 for name, expr in names_exprs] 803 804 ordered_names_exprs = [] 805 available_names = set() 806 807 while names_exprs_deps: 808 schedulable = [] 809 810 i = 0 811 while i < len(names_exprs_deps): 812 name, expr, deps = names_exprs_deps[i] 813 814 unsatisfied_deps = deps - available_names 815 816 if not unsatisfied_deps: 817 schedulable.append((str(expr), name, expr)) 818 del names_exprs_deps[i] 819 else: 820 i += 1 821 822 # make sure these come out in a constant order 823 schedulable.sort() 824 825 if schedulable: 826 for key, name, expr in schedulable: 827 ordered_names_exprs.append((name, expr)) 828 available_names.add(name) 829 else: 830 raise RuntimeError("aggregation resulted in an impossible assignment") 831 832 return self.finalize_multi_assign( 833 names=[name for name, expr in ordered_names_exprs], 834 exprs=[expr for name, expr in ordered_names_exprs], 835 do_not_return=[Variable(name) not in externally_used_names 836 for name, expr in ordered_names_exprs], 837 priority=ass.priority) 838 839 return [schedule_and_finalize_assignment(ass) 840 for ass in processed_assigns] + other_insns 841