Package hedge :: Module partition
[hide private]
[frames] | no frames]

Source Code for Module hedge.partition

  1  """Mesh partitioning subsystem. 
  2   
  3  This is used by parallel execution (MPI) and local timestepping. 
  4  """ 
  5   
  6  from __future__ import division 
  7   
  8  __copyright__ = "Copyright (C) 2009 Andreas Kloeckner" 
  9   
 10  __license__ = """ 
 11  This program is free software: you can redistribute it and/or modify 
 12  it under the terms of the GNU General Public License as published by 
 13  the Free Software Foundation, either version 3 of the License, or 
 14  (at your option) any later version. 
 15   
 16  This program is distributed in the hope that it will be useful, 
 17  but WITHOUT ANY WARRANTY; without even the implied warranty of 
 18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
 19  GNU General Public License for more details. 
 20   
 21  You should have received a copy of the GNU General Public License 
 22  along with this program.  If not, see U{http://www.gnu.org/licenses/}. 
 23  """ 
 24   
 25   
 26   
 27   
 28  import numpy 
 29  import numpy.linalg as la 
 30  import pytools 
 31  from pytools import memoize_method 
 32  import hedge.mesh 
 33  import hedge.optemplate 
34 35 36 37 38 -class PartitionData(pytools.Record):
39 - def __init__(self, 40 part_nr, 41 mesh, 42 global2local_elements, 43 global2local_vertex_indices, 44 neighbor_parts, 45 global_periodic_opposite_faces, 46 part_boundary_tags 47 ):
48 pytools.Record.__init__(self, locals())
49
50 51 52 53 -def partition_from_tags(mesh, tag_to_number):
54 partition = numpy.zeros((len(mesh.elements),), dtype=numpy.int32) 55 56 for tag, number in tag_to_number.iteritems(): 57 for el in mesh.tag_to_elements[tag]: 58 partition[el.id] += number 59 60 return partition
61
62 63 64 65 -def partition_mesh(mesh, partition, part_bdry_tag_factory):
66 """*partition* is a mapping that maps element id to 67 integers that represent different pieces of the mesh. 68 69 For historical reasons, the values in partition are called 70 'parts'. 71 """ 72 73 # Find parts to which we need to distribute. 74 all_parts = list(set( 75 partition[el.id] for el in mesh.elements)) 76 77 # Prepare a mapping of elements to tags to speed up 78 # copy_el_tagger, below. 79 el2tags = {} 80 for tag, elements in mesh.tag_to_elements.iteritems(): 81 if tag == hedge.mesh.TAG_ALL: 82 continue 83 for el in elements: 84 el2tags.setdefault(el, []).append(tag) 85 86 # prepare a mapping of (el, face_nr) to boundary_tags 87 # to speed up partition_bdry_tagger, below 88 elface2tags = {} 89 for tag, elfaces in mesh.tag_to_boundary.iteritems(): 90 if tag == hedge.mesh.TAG_ALL: 91 continue 92 for el, fn in elfaces: 93 elface2tags.setdefault((el, fn), []).append(tag) 94 95 # prepare a mapping from (el, face_nr) to the part 96 # at the other end of the interface, if different from 97 # current. concurrently, prepare a mapping 98 # part -> set([parts that border me]) 99 elface2part = {} 100 neighboring_parts = {} 101 102 for elface1, elface2 in mesh.interfaces: 103 e1, f1 = elface1 104 e2, f2 = elface2 105 r1 = partition[e1.id] 106 r2 = partition[e2.id] 107 108 if r1 != r2: 109 neighboring_parts.setdefault(r1, set()).add(r2) 110 neighboring_parts.setdefault(r2, set()).add(r1) 111 112 elface2part[elface1] = r2 113 elface2part[elface2] = r1 114 115 # prepare a new mesh for each part and send it 116 from hedge.mesh import TAG_NO_BOUNDARY 117 118 for part in all_parts: 119 part_global_elements = [el 120 for el in mesh.elements 121 if partition [el.id] == part] 122 123 # pick out this part's vertices 124 from pytools import flatten 125 part_global_vertex_indices = set(flatten( 126 el.vertex_indices for el in part_global_elements)) 127 128 part_local_vertices = [mesh.points[vi] 129 for vi in part_global_vertex_indices] 130 131 # find global-to-local maps 132 part_global2local_vertex_indices = dict( 133 (gvi, lvi) for lvi, gvi in 134 enumerate(part_global_vertex_indices)) 135 136 part_global2local_elements = dict( 137 (el.id, i) for i, el in 138 enumerate(part_global_elements)) 139 140 # find elements in local numbering 141 part_local_elements = [ 142 [part_global2local_vertex_indices[vi] 143 for vi in el.vertex_indices] 144 for el in part_global_elements] 145 146 # make new local Mesh object, including 147 # boundary and element tagging 148 def partition_bdry_tagger(fvi, local_el, fn, all_vertices): 149 el = part_global_elements[local_el.id] 150 151 result = elface2tags.get((el, fn), []) 152 try: 153 opp_part = elface2part[el, fn] 154 result.append(part_bdry_tag_factory(opp_part)) 155 156 # keeps this part of the boundary from falling 157 # under TAG_ALL. 158 result.append(TAG_NO_BOUNDARY) 159 160 except KeyError: 161 pass 162 163 return result
164 165 def copy_el_tagger(local_el, all_vertices): 166 return el2tags.get(part_global_elements[local_el.id], []) 167 168 def is_partbdry_face((local_el, face_nr)): 169 return (part_global_elements[local_el.id], face_nr) in elface2part 170 171 from hedge.mesh import make_conformal_mesh 172 part_mesh = make_conformal_mesh( 173 part_local_vertices, 174 part_local_elements, 175 partition_bdry_tagger, copy_el_tagger, 176 mesh.periodicity, 177 is_partbdry_face) 178 179 # assemble per-part data 180 181 my_nb_parts = neighboring_parts.get(part, []) 182 yield PartitionData( 183 part, 184 part_mesh, 185 part_global2local_elements, 186 part_global2local_vertex_indices, 187 my_nb_parts, 188 mesh.periodic_opposite_faces, 189 part_boundary_tags=dict( 190 (nb_part, part_bdry_tag_factory(nb_part)) 191 for nb_part in my_nb_parts), 192 ) 193
194 195 196 197 198 -def find_neighbor_vol_indices( 199 my_discr, my_part_data, 200 nb_discr, nb_part_data, 201 debug=False):
202 203 from pytools import reverse_dictionary 204 l2g_vertex_indices = \ 205 reverse_dictionary(my_part_data.global2local_vertex_indices) 206 nb_l2g_vertex_indices = \ 207 reverse_dictionary(nb_part_data.global2local_vertex_indices) 208 209 my_bdry_tag = my_part_data.part_boundary_tags[nb_part_data.part_nr] 210 nb_bdry_tag = nb_part_data.part_boundary_tags[my_part_data.part_nr] 211 212 my_mesh_bdry = my_part_data.mesh.tag_to_boundary[my_bdry_tag] 213 nb_mesh_bdry = nb_part_data.mesh.tag_to_boundary[nb_bdry_tag] 214 215 my_discr_bdry = my_discr.get_boundary(my_bdry_tag) 216 nb_discr_bdry = nb_discr.get_boundary(nb_bdry_tag) 217 218 nb_vertices_to_face = dict( 219 (frozenset(el.faces[face_nr]), (el, face_nr)) 220 for el, face_nr 221 in nb_mesh_bdry) 222 223 from_indices = [] 224 225 shuffled_indices_cache = {} 226 227 def get_shuffled_indices(face_node_count, shuffle_op): 228 try: 229 return shuffled_indices_cache[shuffle_op] 230 except KeyError: 231 unshuffled_indices = range(face_node_count) 232 result = shuffled_indices_cache[shuffle_op] = \ 233 shuffle_op(unshuffled_indices) 234 return result
235 236 for my_el, my_face_nr in my_mesh_bdry: 237 eslice, ldis = my_discr.find_el_data(my_el.id) 238 239 my_vertices = my_el.faces[my_face_nr] 240 my_global_vertices = tuple(l2g_vertex_indices[vi] 241 for vi in my_vertices) 242 243 face_node_count = ldis.face_node_count() 244 try: 245 nb_vertices = frozenset( 246 nb_part_data.global2local_vertex_indices[vi] 247 for vi in my_global_vertices) 248 # continue below in else part 249 except KeyError: 250 # this happens if my_global_vertices is not a permutation 251 # of the neighbor's face vertices. Periodicity is the only 252 # reason why that would be so. 253 my_global_vertices_there, axis = my_part_data.global_periodic_opposite_faces[ 254 my_global_vertices] 255 256 nb_vertices = frozenset( 257 nb_part_data.global2local_vertex_indices[vi] 258 for vi in my_global_vertices_there) 259 260 nb_el, nb_face_nr = nb_vertices_to_face[nb_vertices] 261 nb_global_vertices_there = tuple( 262 nb_l2g_vertex_indices[vi] 263 for vi in nb_el.faces[nb_face_nr]) 264 265 nb_global_vertices, axis2 = nb_part_data.global_periodic_opposite_faces[ 266 nb_global_vertices_there] 267 268 assert axis == axis2 269 270 nb_face_start = nb_discr_bdry \ 271 .find_facepair((nb_el, nb_face_nr)) \ 272 .opp.el_base_index 273 274 shuffle_op = \ 275 ldis.get_face_index_shuffle_to_match( 276 my_global_vertices, 277 nb_global_vertices) 278 279 shuffled_nb_node_indices = [nb_face_start+i 280 for i in get_shuffled_indices(face_node_count, shuffle_op)] 281 282 from_indices.extend(shuffled_nb_node_indices) 283 284 # check if the nodes really match up 285 if debug and ldis.has_facial_nodes: 286 my_node_indices = [eslice.start+i for i in ldis.face_indices()[my_face_nr]] 287 288 for my_i, nb_i in zip(my_node_indices, shuffled_nb_node_indices): 289 dist = my_discr.nodes[my_i]-nb_discr_bdry.nodes[nb_i] 290 dist[axis] = 0 291 assert la.norm(dist) < 1e-14 292 else: 293 # continue handling of nonperiodic case 294 nb_el, nb_face_nr = nb_vertices_to_face[nb_vertices] 295 nb_global_vertices = tuple( 296 nb_l2g_vertex_indices[vi] 297 for vi in nb_el.faces[nb_face_nr]) 298 299 nb_face_start = nb_discr_bdry \ 300 .find_facepair((nb_el, nb_face_nr)) \ 301 .opp.el_base_index 302 303 shuffle_op = \ 304 ldis.get_face_index_shuffle_to_match( 305 my_global_vertices, 306 nb_global_vertices) 307 308 shuffled_nb_node_indices = [nb_face_start+i 309 for i in get_shuffled_indices(face_node_count, shuffle_op)] 310 311 from_indices.extend(shuffled_nb_node_indices) 312 313 # Check if the nodes really match up 314 if debug and ldis.has_facial_nodes: 315 my_node_indices = [eslice.start+i 316 for i in ldis.face_indices()[my_face_nr]] 317 318 for my_i, nb_i in zip(my_node_indices, shuffled_nb_node_indices): 319 dist = my_discr.nodes[my_i]-nb_discr_bdry.nodes[nb_i] 320 assert la.norm(dist) < 1e-14 321 322 # Finally, unify FluxFace.h values across boundary. 323 my_flux_face = my_discr_bdry.find_facepair_side((my_el, my_face_nr)) 324 nb_flux_face = nb_discr_bdry.find_facepair_side((nb_el, nb_face_nr)) 325 my_flux_face.h = nb_flux_face.h = max(my_flux_face.h, nb_flux_face.h) 326 327 assert len(from_indices) \ 328 == len(my_discr_bdry.nodes) \ 329 == len(nb_discr_bdry.nodes) 330 331 # Convert nb's boundary indices to nb's volume indices. 332 return nb_discr_bdry.vol_indices[ 333 numpy.asarray(from_indices, dtype=numpy.intp)] 334
335 336 337 338 -class StupidInterdomainFluxMapper(hedge.optemplate.IdentityMapper):
339 """Attempts to map a regular optemplate into one that is 340 suitable for inter-domain flux computation. 341 342 Maps everything to zero that is not an interior flux or 343 inverse mass operator. Interior fluxes on the other hand are 344 mapped to boundary fluxes on the specified tag. 345 """ 346
347 - def __init__(self, bdry_tag, vol_var, bdry_val_var):
348 self.bdry_tag = bdry_tag 349 self.vol_var = vol_var 350 self.bdry_val_var = bdry_val_var
351
352 - def map_operator_binding(self, expr):
353 from hedge.optemplate import \ 354 FluxOperatorBase, \ 355 BoundaryPair, \ 356 OperatorBinding, \ 357 IdentityMapperMixin, \ 358 InverseMassOperator 359 360 from pymbolic.mapper.substitutor import SubstitutionMapper 361 362 class FieldIntoBdrySubstitutionMapper( 363 SubstitutionMapper, 364 IdentityMapperMixin): 365 def map_normal(self, expr): 366 return expr
367 368 if isinstance(expr.op, FluxOperatorBase): 369 if isinstance(expr.field, BoundaryPair): 370 return 0 371 else: 372 # Finally, an interior flux. Rewrite it. 373 374 def subst_func(expr): 375 if expr == self.vol_var: 376 return self.bdry_val_var 377 else: 378 return None 379 380 return OperatorBinding(expr.op, 381 BoundaryPair( 382 expr.field, 383 SubstitutionMapper(subst_func)(expr.field), 384 self.bdry_tag)) 385 elif isinstance(expr.op, InverseMassOperator): 386 return OperatorBinding(expr.op, self.rec(expr.field)) 387 else: 388 return 0 389
390 391 392 393 -def compile_interdomain_flux(optemplate, vol_var, bdry_var, 394 my_discr, my_part_data, 395 nb_discr, nb_part_data, 396 use_stupid_substitution=False):
397 """ 398 `use_stupid_substitution` uses `StupidInterdomainFluxMapper` to 399 try to pare down a full optemplate to one that is suitable for 400 interdomain flux computation. While technique is stupid, it 401 will work for many common DG operators. See the description of 402 `StupidInterdomainFluxMapper` to see what exactly is done. 403 """ 404 405 from hedge.optemplate import make_field 406 407 neighbor_indices = find_neighbor_vol_indices( 408 my_discr, my_part_data, 409 nb_discr, nb_part_data, 410 debug="node_permutation" in my_discr.debug | nb_discr.debug) 411 412 my_bdry_tag = my_part_data.part_boundary_tags[nb_part_data.part_nr] 413 414 kwargs = {} 415 if use_stupid_substitution: 416 kwargs = {"post_bind_mapper": StupidInterdomainFluxMapper( 417 my_bdry_tag, make_field(vol_var), make_field(bdry_var))} 418 419 return my_discr.compile(optemplate, **kwargs), neighbor_indices
420
421 422 423 424 -class Transformer:
425 - def __init__(self, whole_discr, parts_data, parts_discr):
426 self.whole_discr = whole_discr 427 self.parts_data = parts_data 428 self.parts_discr = parts_discr
429 430 @memoize_method
431 - def _embeddings(self):
432 result = [] 433 for part_data, part_discr in zip(self.parts_data, self.parts_discr): 434 part_emb = numpy.zeros((len(part_discr),), dtype=numpy.intp) 435 result.append(part_emb) 436 437 for g_el, l_el in part_data.global2local_elements.iteritems(): 438 g_slice = self.whole_discr.find_el_range(g_el) 439 part_emb[part_discr.find_el_range(l_el)] = \ 440 numpy.arange(g_slice.start, g_slice.stop) 441 return result
442
443 - def reassemble(self, parts_vol_vectors):
444 from pytools import single_valued, indices_in_shape 445 from hedge.tools import log_shape 446 ls = single_valued(log_shape(pvv) for pvv in parts_vol_vectors) 447 448 def remap_scalar_field(idx): 449 result = self.whole_discr.volume_zeros() 450 for part_emb, part_vol_vector in zip( 451 self._embeddings(), parts_vol_vectors): 452 result[part_emb] = part_vol_vector[idx] 453 454 return result
455 456 if ls != (): 457 result = numpy.zeros(ls, dtype=object) 458 for i in indices_in_shape(ls): 459 result[i] = remap_scalar_field(i) 460 return result 461 else: 462 return remap_scalar_field(())
463
464 - def split(self, whole_vol_vector):
465 from pytools import indices_in_shape 466 from hedge.tools import log_shape 467 468 ls = log_shape(whole_vol_vector) 469 470 if ls != (): 471 result = [numpy.zeros(ls, dtype=object) 472 for part_emb in self._embeddings()] 473 for p, part_emb in enumerate(self._embeddings()): 474 for i in indices_in_shape(ls): 475 result[p][i] = whole_vol_vector[part_emb] 476 return result 477 else: 478 return [whole_vol_vector[part_emb] 479 for part_emb in self._embeddings()]
480