Implementations of the Array Context Abstraction¶

Array context based on numpy¶

Array context based on pyopencl.array¶

class arraycontext.PyOpenCLArrayContext(queue: pyopencl.CommandQueue, allocator: pyopencl.tools.AllocatorBase | None = None, wait_event_queue_length: int | None = None, force_device_scalars: bool | None = None)[source]¶

A ArrayContext that uses pyopencl.array.Array instances for its base array class.

context¶

A pyopencl.Context.

queue¶

A pyopencl.CommandQueue.

allocator¶

A PyOpenCL memory allocator. Can also be None (default) or False to use the default allocator. Please note that running with the default allocator allocates and deallocates OpenCL buffers directly. If lots of arrays are created (e.g. as results of computation), the associated cost may become significant. Using e.g. pyopencl.tools.MemoryPool as the allocator can help avoid this cost.

transform_loopy_program(t_unit: lp.TranslationUnit) lp.TranslationUnit[source]¶
class arraycontext.impl.pyopencl.taggable_cl_array.TaggableCLArray(cq, shape, dtype, order='C', allocator=None, data=None, offset=0, strides=None, events=None, _flags=None, _fast=False, _size=None, _context=None, _queue=None, axes=None, tags=frozenset({}))[source]¶

A pyopencl.array.Array with additional metadata. This is used by PytatoPyOpenCLArrayContext to preserve tags for data while frozen, and also in a similar capacity by PyOpenCLArrayContext.

axes¶

A tuple of instances of Axis, with one Axis for each dimension of the array.

tags¶

A frozenset of pytools.tag.Tag. Typically intended to record application-specific metadata to drive the optimizations in arraycontext.PyOpenCLArrayContext.transform_loopy_program().

class arraycontext.impl.pyopencl.taggable_cl_array.Axis(tags: FrozenSet[Tag])[source]¶

Records the tags corresponding to a dimension of TaggableCLArray.

arraycontext.impl.pyopencl.taggable_cl_array.to_tagged_cl_array(ary: Array, axes: Tuple[Axis, ...] | None = None, tags: FrozenSet[Tag] = frozenset({})) TaggableCLArray[source]¶

Returns a TaggableCLArray that is constructed from the data in ary along with the metadata from axes and tags. If ary is already a TaggableCLArray, the new tags and axes are added to the existing ones.

Parameters:

axes – An instance of Axis for each dimension of the array. If passed None, then initialized to a pytato.Axis with no tags attached for each dimension.

Lazy/Deferred evaluation array context based on pytato¶

A pytato-based array context defers the evaluation of an array until its frozen. The execution contexts for the evaluations are specific to an ArrayContext type. For ex. PytatoPyOpenCLArrayContext uses pyopencl to JIT-compile and execute the array expressions.

Following pytato-based array context are provided:

class arraycontext.PytatoPyOpenCLArrayContext(queue: CommandQueue, allocator=None, *, use_memory_pool: bool | None = None, compile_trace_callback: Callable[[Any, str, Any], None] | None = None, _force_svm_arg_limit: int | None = None)[source]¶

A ArrayContext that uses pytato data types to represent the arrays targeting OpenCL for offloading operations.

queue¶

A pyopencl.CommandQueue.

allocator¶

A pyopencl memory allocator. Can also be None (default) or False to use the default allocator.

__init__(queue: CommandQueue, allocator=None, *, use_memory_pool: bool | None = None, compile_trace_callback: Callable[[Any, str, Any], None] | None = None, _force_svm_arg_limit: int | None = None) None[source]¶
Parameters:

compile_trace_callback – A function of three arguments (what, stage, ir), where what identifies the object being compiled, stage is a string describing the compilation pass, and ir is an object containing the intermediate representation. This interface should be considered unstable.

transform_dag(dag: pytato.DictOfNamedArrays) pytato.DictOfNamedArrays[source]¶

Returns a transformed version of dag. Sub-classes are supposed to override this method to implement context-specific transformations on dag (most likely to perform domain-specific optimizations). Every pytato DAG that is compiled to a GPU-kernel is passed through this routine.

Parameters:

dag – An instance of pytato.DictOfNamedArrays

Returns:

A transformed version of dag.

compile(f: Callable[[...], Any]) Callable[[...], Any][source]¶

Compiles f for repeated use on this array context. f is expected to be a pure function performing an array computation.

Control flow statements (if, while) that might take different paths depending on the data lead to undefined behavior and are illegal. Any data-dependent control flow must be expressed via array functions, such as actx.np.where.

f may be called on placeholder data, to obtain a representation of the computation performed, or it may be called as part of the actual computation, on actual data. If f is called on placeholder data, it may be called only once (or a few times).

Parameters:

f – the function executing the computation.

Returns:

a function with the same signature as f.

class arraycontext.PytatoJAXArrayContext(*, compile_trace_callback: Callable[[Any, str, Any], None] | None = None)[source]¶

An arraycontext that uses pytato to represent the thawed state of the arrays and compiles the expressions using pytato.target.python.JAXPythonTarget.

Compiling a Python callable (Internal)¶

class arraycontext.impl.pytato.compile.BaseLazilyCompilingFunctionCaller(actx: ~arraycontext.impl.pytato._BasePytatoArrayContext, f: ~typing.Callable[[...], ~typing.Any], program_cache: ~typing.Dict[~typing.Mapping[~typing.Tuple[~typing.Any, ...], ~arraycontext.impl.pytato.compile.AbstractInputDescriptor], ~arraycontext.impl.pytato.compile.CompiledFunction] = <factory>)[source]¶

Records a side-effect-free callable f that can be specialized for the input types with which __call__() is invoked.

f¶

The callable that will be called to obtain pytato DAGs.

__call__(*args: Any, **kwargs: Any) Any[source]¶

Returns the result of f’s function application on args.

Before applying f, it is compiled to a pytato DAG that would apply f with args in a lazy-sense. The intermediary pytato DAG for args is memoized in self.

class arraycontext.impl.pytato.compile.LazilyPyOpenCLCompilingFunctionCaller(actx: ~arraycontext.impl.pytato._BasePytatoArrayContext, f: ~typing.Callable[[...], ~typing.Any], program_cache: ~typing.Dict[~typing.Mapping[~typing.Tuple[~typing.Any, ...], ~arraycontext.impl.pytato.compile.AbstractInputDescriptor], ~arraycontext.impl.pytato.compile.CompiledFunction] = <factory>)[source]¶
class arraycontext.impl.pytato.compile.LazilyJAXCompilingFunctionCaller(actx: ~arraycontext.impl.pytato._BasePytatoArrayContext, f: ~typing.Callable[[...], ~typing.Any], program_cache: ~typing.Dict[~typing.Mapping[~typing.Tuple[~typing.Any, ...], ~arraycontext.impl.pytato.compile.AbstractInputDescriptor], ~arraycontext.impl.pytato.compile.CompiledFunction] = <factory>)[source]¶
class arraycontext.impl.pytato.compile.CompiledFunction[source]¶

A callable which captures the pytato.target.BoundProgram resulting from calling f with a given set of input types, and generating loopy IR from it.

pytato_program¶
input_id_to_name_in_program¶

A mapping from input id to the placeholder name in CompiledFunction.pytato_program. Input id is represented as the position of f’s argument augmented with the leaf array’s key if the argument is an array container.

abstract __call__(arg_id_to_arg) Any[source]¶
Parameters:

arg_id_to_arg – Mapping from input id to the passed argument. See CompiledFunction.input_id_to_name_in_program for input id’s representation.

class arraycontext.impl.pytato.compile.FromArrayContextCompile[source]¶

Tagged to the entrypoint kernel of every translation unit that is generated by compile().

Typically this tag serves as a branch condition in implementing a specialized transform strategy for kernels compiled by compile().

Array context based on jax.numpy¶

class arraycontext.EagerJAXArrayContext[source]¶

A ArrayContext that uses jax.Array instances for its base array class and performs all array operations eagerly. See PytatoJAXArrayContext for a lazier version.

Note

JAX stores a global configuration state in jax.config. Callers are expected to maintain those. Most important for scientific computing workloads being jax_enable_x64.

numpy coverage¶

This is a list of functionality implemented by arraycontext.ArrayContext.np.

Note

Only functions and methods that have at least one implementation are listed.

Array creation routines¶

Array manipulation routines¶

Linear algebra¶

Logic Functions¶

Function

PyOpenCLArrayContext

EagerJAXArrayContext

PytatoPyOpenCLArrayContext

PytatoJAXArrayContext

numpy.all()

Yes

Yes

Yes

Yes

numpy.any()

Yes

Yes

Yes

Yes

numpy.greater

Yes

Yes

Yes

Yes

numpy.greater_equal

Yes

Yes

Yes

Yes

numpy.less

Yes

Yes

Yes

Yes

numpy.less_equal

Yes

Yes

Yes

Yes

numpy.equal

Yes

Yes

Yes

Yes

numpy.not_equal

Yes

Yes

Yes

Yes

Mathematical functions¶

Function

PyOpenCLArrayContext

EagerJAXArrayContext

PytatoPyOpenCLArrayContext

PytatoJAXArrayContext

numpy.sin

Yes

Yes

Yes

Yes

numpy.cos

Yes

Yes

Yes

Yes

numpy.tan

Yes

Yes

Yes

Yes

numpy.arcsin

Yes

Yes

Yes

Yes

numpy.arccos

Yes

Yes

Yes

Yes

numpy.arctan

Yes

Yes

Yes

Yes

numpy.arctan2

Yes

Yes

Yes

Yes

numpy.sinh

Yes

Yes

Yes

Yes

numpy.cosh

Yes

Yes

Yes

Yes

numpy.tanh

Yes

Yes

Yes

Yes

numpy.floor

Yes

Yes

Yes

Yes

numpy.ceil

Yes

Yes

Yes

Yes

numpy.sum()

Yes

Yes

Yes

Yes

numpy.exp

Yes

Yes

Yes

Yes

numpy.log

Yes

Yes

Yes

Yes

numpy.log10

Yes

Yes

Yes

Yes

numpy.real()

Yes

Yes

Yes

Yes

numpy.imag()

Yes

Yes

Yes

Yes

numpy.conjugate

Yes

Yes

Yes

Yes

numpy.maximum

Yes

Yes

Yes

Yes

numpy.amax()

Yes

Yes

Yes

Yes

numpy.minimum

Yes

Yes

Yes

Yes

numpy.amin()

Yes

Yes

Yes

Yes

numpy.sqrt

Yes

Yes

Yes

Yes

numpy.absolute

Yes

Yes

Yes

Yes

numpy.fabs

Yes

Yes

Yes

Yes