Skip to content

Commit

Permalink
Add dedicated parfor injection pass for kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
ZzEeKkAa committed Apr 29, 2024
1 parent eb0acbb commit 97ad569
Show file tree
Hide file tree
Showing 10 changed files with 218 additions and 428 deletions.
2 changes: 2 additions & 0 deletions numba_dpex/core/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class DpexTargetOptions(CPUTargetOptions):
no_compile = _option_mapping("no_compile")
inline_threshold = _option_mapping("inline_threshold")
_compilation_mode = _option_mapping("_compilation_mode")
_parfor_args = _option_mapping("_parfor_args")

def finalize(self, flags, options):
super().finalize(flags, options)
Expand All @@ -63,6 +64,7 @@ def finalize(self, flags, options):
_inherit_if_not_set(
flags, options, "_compilation_mode", CompilationMode.KERNEL
)
_inherit_if_not_set(flags, options, "_parfor_args", None)


class DpexKernelTarget(TargetDescriptor):
Expand Down
246 changes: 23 additions & 223 deletions numba_dpex/core/parfors/kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,15 @@
from numba.parfors import parfor

from numba_dpex.core import config
from numba_dpex.core.decorators import kernel
from numba_dpex.core.parfors.parfor_inject_kernel_pass import ParforArguments
from numba_dpex.core.types.kernel_api.index_space_ids import ItemType
from numba_dpex.core.utils.call_kernel_builder import SPIRVKernelModule
from numba_dpex.kernel_api_impl.spirv import spirv_generator
from numba_dpex.kernel_api_impl.spirv.dispatcher import (
SPIRVKernelDispatcher,
_SPIRVKernelCompileResult,
)

from ..descriptor import dpex_kernel_target
from ..types import DpnpNdArray
Expand All @@ -38,79 +45,19 @@
class ParforKernel:
def __init__(
self,
name,
kernel,
signature,
kernel_args,
kernel_arg_types,
queue: dpctl.SyclQueue,
local_accessors=None,
work_group_size=None,
kernel_module=None,
):
self.name = name
self.kernel = kernel
self.signature = signature
self.kernel_args = kernel_args
self.kernel_arg_types = kernel_arg_types
self.queue = queue
self.local_accessors = local_accessors
self.work_group_size = work_group_size


def _print_block(block):
for i, inst in enumerate(block.body):
print(" ", i, inst)


def _print_body(body_dict):
"""Pretty-print a set of IR blocks."""
for label, block in body_dict.items():
print("label: ", label)
_print_block(block)


def _compile_kernel_parfor(
sycl_queue, kernel_name, func_ir, argtypes, debug=False
):
with target_override(dpex_kernel_target.target_context.target_name):
cres = compile_numba_ir_with_dpex(
pyfunc=func_ir,
pyfunc_name=kernel_name,
args=argtypes,
return_type=None,
debug=debug,
is_kernel=True,
typing_context=dpex_kernel_target.typing_context,
target_context=dpex_kernel_target.target_context,
extra_compile_flags=None,
)
cres.library.inline_threshold = config.INLINE_THRESHOLD
cres.library._optimize_final_module()
func = cres.library.get_function(cres.fndesc.llvm_func_name)
kernel = dpex_kernel_target.target_context.prepare_spir_kernel(
func, cres.signature.args
)
spirv_module = spirv_generator.llvm_to_spirv(
dpex_kernel_target.target_context,
kernel.module.__str__(),
kernel.module.as_bitcode(),
)

dpctl_create_program_from_spirv_flags = []
if debug or config.DPEX_OPT == 0:
# if debug is ON we need to pass additional flags to igc.
dpctl_create_program_from_spirv_flags = ["-g", "-cl-opt-disable"]

# create a sycl::kernel_bundle
kernel_bundle = dpctl_prog.create_program_from_spirv(
sycl_queue,
spirv_module,
" ".join(dpctl_create_program_from_spirv_flags),
)
# create a sycl::kernel
sycl_kernel = kernel_bundle.get_sycl_kernel(kernel.name)

return sycl_kernel
self.kernel_module = kernel_module


def _legalize_names_with_typemap(names, typemap):
Expand Down Expand Up @@ -189,76 +136,11 @@ def _replace_var_with_array(vars, loop_body, typemap, calltypes):
typemap[v] = types.npytypes.Array(el_typ, 1, "C")


def _find_setitems_block(setitems, block, typemap):
for inst in block.body:
if isinstance(inst, ir.StaticSetItem) or isinstance(inst, ir.SetItem):
setitems.add(inst.target.name)
elif isinstance(inst, parfor.Parfor):
_find_setitems_block(setitems, inst.init_block, typemap)
_find_setitems_body(setitems, inst.loop_body, typemap)


def _find_setitems_body(setitems, loop_body, typemap):
"""
Find the arrays that are written into (goes into setitems)
"""
for label, block in loop_body.items():
_find_setitems_block(setitems, block, typemap)


def _replace_sentinel_with_parfor_body(kernel_ir, sentinel_name, loop_body):
# new label for splitting sentinel block
new_label = max(loop_body.keys()) + 1

# Search all the block in the kernel function for the sentinel assignment.
for label, block in kernel_ir.blocks.items():
for i, inst in enumerate(block.body):
if (
isinstance(inst, ir.Assign)
and inst.target.name == sentinel_name
):
# We found the sentinel assignment.
loc = inst.loc
scope = block.scope
# split block across __sentinel__
# A new block is allocated for the statements prior to the
# sentinel but the new block maintains the current block label.
prev_block = ir.Block(scope, loc)
prev_block.body = block.body[:i]

# The current block is used for statements after the sentinel.
block.body = block.body[i + 1 :] # noqa: E203
# But the current block gets a new label.
body_first_label = min(loop_body.keys())

# The previous block jumps to the minimum labelled block of the
# parfor body.
prev_block.append(ir.Jump(body_first_label, loc))
# Add all the parfor loop body blocks to the kernel function's
# IR.
for loop, b in loop_body.items():
kernel_ir.blocks[loop] = b
body_last_label = max(loop_body.keys())
kernel_ir.blocks[new_label] = block
kernel_ir.blocks[label] = prev_block
# Add a jump from the last parfor body block to the block
# containing statements after the sentinel.
kernel_ir.blocks[body_last_label].append(
ir.Jump(new_label, loc)
)
break
else:
continue
break


def create_kernel_for_parfor(
lowerer,
parfor_node,
typemap,
flags,
loop_ranges,
has_aliases,
races,
parfor_outputs,
) -> ParforKernel:
Expand Down Expand Up @@ -367,120 +249,38 @@ def create_kernel_for_parfor(
loop_ranges=loop_ranges,
param_dict=param_dict,
)
kernel_ir = kernel_template.kernel_ir

if config.DEBUG_ARRAY_OPT:
print("kernel_ir dump ", type(kernel_ir))
kernel_ir.dump()
print("loop_body dump ", type(loop_body))
_print_body(loop_body)

# rename all variables in kernel_ir afresh
var_table = get_name_var_table(kernel_ir.blocks)
new_var_dict = {}
reserved_names = (
[sentinel_name] + list(param_dict.values()) + legal_loop_indices
kernel_dispatcher: SPIRVKernelDispatcher = kernel(
kernel_template.py_func,
_parfor_args=ParforArguments(
loop_body=loop_body,
param_dict=param_dict,
legal_loop_indices=legal_loop_indices,
),
)
for name, var in var_table.items():
if not (name in reserved_names):
new_var_dict[name] = mk_unique_var(name)
replace_var_names(kernel_ir.blocks, new_var_dict)
if config.DEBUG_ARRAY_OPT:
print("kernel_ir dump after renaming ")
kernel_ir.dump()

kernel_param_types = param_types

if config.DEBUG_ARRAY_OPT:
print(
"kernel_param_types = ",
type(kernel_param_types),
"\n",
kernel_param_types,
)

kernel_stub_last_label = max(kernel_ir.blocks.keys()) + 1

# Add kernel stub last label to each parfor.loop_body label to prevent
# label conflicts.
loop_body = add_offset_to_labels(loop_body, kernel_stub_last_label)

_replace_sentinel_with_parfor_body(kernel_ir, sentinel_name, loop_body)

if config.DEBUG_ARRAY_OPT:
print("kernel_ir last dump before renaming")
kernel_ir.dump()

kernel_ir.blocks = rename_labels(kernel_ir.blocks)
remove_dels(kernel_ir.blocks)

old_alias = flags.noalias
if not has_aliases:
if config.DEBUG_ARRAY_OPT:
print("No aliases found so adding noalias flag.")
flags.noalias = True

remove_dead(kernel_ir.blocks, kernel_ir.arg_names, kernel_ir, typemap)

if config.DEBUG_ARRAY_OPT:
print("kernel_ir after remove dead")
kernel_ir.dump()

# The first argument to a range kernel is a kernel_api.Item object. The
# ``Item`` object is used by the kernel_api.spirv backend to generate the
# The first argument to a range kernel is a kernel_api.NdItem object. The
# ``NdItem`` object is used by the kernel_api.spirv backend to generate the
# correct SPIR-V indexing instructions. Since, the argument is not something
# available originally in the kernel_param_types, we add it at this point to
# make sure the kernel signature matches the actual generated code.
ty_item = ItemType(parfor_dim)
kernel_param_types = (ty_item, *kernel_param_types)
kernel_param_types = (ty_item, *param_types)
kernel_sig = signature(types.none, *kernel_param_types)

if config.DEBUG_ARRAY_OPT:
sys.stdout.flush()

if config.DEBUG_ARRAY_OPT:
print("after DUFunc inline".center(80, "-"))
kernel_ir.dump()

# The ParforLegalizeCFD pass has already ensured that the LHS and RHS
# arrays are on same device. We can take the queue from the first input
# array and use that to compile the kernel.

exec_queue: dpctl.SyclQueue = None

for arg in parfor_args:
obj = typemap[arg]
if isinstance(obj, DpnpNdArray):
filter_string = obj.queue.sycl_device
# FIXME: A better design is required so that we do not have to
# create a queue every time.
exec_queue = dpctl.get_device_cached_queue(filter_string)

if not exec_queue:
raise AssertionError(
"No execution found for parfor. No way to compile the kernel!"
)

sycl_kernel = _compile_kernel_parfor(
exec_queue,
kernel_name,
kernel_ir,
kernel_param_types,
debug=flags.debuginfo,
kcres: _SPIRVKernelCompileResult = kernel_dispatcher.get_compile_result(
types.void(*kernel_param_types) # kernel signature
)

flags.noalias = old_alias
kernel_module: SPIRVKernelModule = kcres.kernel_device_ir_module

if config.DEBUG_ARRAY_OPT:
print("kernel_sig = ", kernel_sig)

return ParforKernel(
name=kernel_name,
kernel=sycl_kernel,
signature=kernel_sig,
kernel_args=parfor_args,
kernel_arg_types=func_arg_types,
queue=exec_queue,
kernel_module=kernel_module,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,9 @@ def _generate_kernel_ir(self):
def dump_kernel_string(self):
raise NotImplementedError

@abc.abstractmethod
def dump_kernel_ir(self):
raise NotImplementedError

@property
@abc.abstractmethod
def kernel_ir(self):
def py_func(self):
raise NotImplementedError

@property
Expand Down
21 changes: 7 additions & 14 deletions numba_dpex/core/parfors/kernel_templates/range_kernel_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import sys

import dpnp
from numba.core import compiler

import numba_dpex as dpex

Expand Down Expand Up @@ -51,7 +50,7 @@ def __init__(
self._param_dict = param_dict

self._kernel_txt = self._generate_kernel_stub_as_string()
self._kernel_ir = self._generate_kernel_ir()
self._py_func = self._generate_kernel_ir()

def _generate_kernel_stub_as_string(self):
"""Generates a stub dpex kernel for the parfor as a string.
Expand Down Expand Up @@ -109,17 +108,15 @@ def _generate_kernel_ir(self):
globls = {"dpnp": dpnp, "dpex": dpex}
locls = {}
exec(self._kernel_txt, globls, locls)
kernel_fn = locls[self._kernel_name]

return compiler.run_frontend(kernel_fn)
return locls[self._kernel_name]

@property
def kernel_ir(self):
"""Returns the Numba IR generated for a RangeKernelTemplate.
Returns: The Numba functionIR object for the compiled kernel_txt string.
def py_func(self):
"""Returns the python function generated for a
TreeReduceIntermediateKernelTemplate.
Returns: The python function object for the compiled kernel_txt string.
"""
return self._kernel_ir
return self._py_func

@property
def kernel_string(self):
Expand All @@ -134,7 +131,3 @@ def dump_kernel_string(self):
"""Helper to print the kernel function string."""
print(self._kernel_txt)
sys.stdout.flush()

def dump_kernel_ir(self):
"""Helper to dump the Numba IR for the RangeKernelTemplate."""
self._kernel_ir.dump()
Loading

0 comments on commit 97ad569

Please sign in to comment.