Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate parfor to SPIRVKernelDispatcher #1435

Merged
merged 3 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions numba_dpex/core/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class DpexTargetOptions(CPUTargetOptions):
no_compile = _option_mapping("no_compile")
inline_threshold = _option_mapping("inline_threshold")
_compilation_mode = _option_mapping("_compilation_mode")
# TODO: create separate parfor kernel target
_parfor_body_args = _option_mapping("_parfor_body_args")

def finalize(self, flags, options):
super().finalize(flags, options)
Expand All @@ -63,6 +65,7 @@ def finalize(self, flags, options):
_inherit_if_not_set(
flags, options, "_compilation_mode", CompilationMode.KERNEL
)
_inherit_if_not_set(flags, options, "_parfor_body_args", None)


class DpexKernelTarget(TargetDescriptor):
Expand Down
248 changes: 25 additions & 223 deletions numba_dpex/core/parfors/kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,17 @@
from numba.parfors import parfor

from numba_dpex.core import config
from numba_dpex.core.decorators import kernel
from numba_dpex.core.parfors.parfor_sentinel_replace_pass import (
ParforBodyArguments,
)
from numba_dpex.core.types.kernel_api.index_space_ids import ItemType
from numba_dpex.core.utils.call_kernel_builder import SPIRVKernelModule
from numba_dpex.kernel_api_impl.spirv import spirv_generator
from numba_dpex.kernel_api_impl.spirv.dispatcher import (
SPIRVKernelDispatcher,
_SPIRVKernelCompileResult,
)

from ..descriptor import dpex_kernel_target
from ..types import DpnpNdArray
Expand All @@ -38,79 +47,19 @@
class ParforKernel:
def __init__(
self,
name,
kernel,
signature,
kernel_args,
kernel_arg_types,
queue: dpctl.SyclQueue,
local_accessors=None,
work_group_size=None,
kernel_module=None,
):
self.name = name
self.kernel = kernel
self.signature = signature
self.kernel_args = kernel_args
self.kernel_arg_types = kernel_arg_types
self.queue = queue
self.local_accessors = local_accessors
self.work_group_size = work_group_size


def _print_block(block):
for i, inst in enumerate(block.body):
print(" ", i, inst)


def _print_body(body_dict):
"""Pretty-print a set of IR blocks."""
for label, block in body_dict.items():
print("label: ", label)
_print_block(block)


def _compile_kernel_parfor(
sycl_queue, kernel_name, func_ir, argtypes, debug=False
):
with target_override(dpex_kernel_target.target_context.target_name):
cres = compile_numba_ir_with_dpex(
pyfunc=func_ir,
pyfunc_name=kernel_name,
args=argtypes,
return_type=None,
debug=debug,
is_kernel=True,
typing_context=dpex_kernel_target.typing_context,
target_context=dpex_kernel_target.target_context,
extra_compile_flags=None,
)
cres.library.inline_threshold = config.INLINE_THRESHOLD
cres.library._optimize_final_module()
func = cres.library.get_function(cres.fndesc.llvm_func_name)
kernel = dpex_kernel_target.target_context.prepare_spir_kernel(
func, cres.signature.args
)
spirv_module = spirv_generator.llvm_to_spirv(
dpex_kernel_target.target_context,
kernel.module.__str__(),
kernel.module.as_bitcode(),
)

dpctl_create_program_from_spirv_flags = []
if debug or config.DPEX_OPT == 0:
# if debug is ON we need to pass additional flags to igc.
dpctl_create_program_from_spirv_flags = ["-g", "-cl-opt-disable"]

# create a sycl::kernel_bundle
kernel_bundle = dpctl_prog.create_program_from_spirv(
sycl_queue,
spirv_module,
" ".join(dpctl_create_program_from_spirv_flags),
)
# create a sycl::kernel
sycl_kernel = kernel_bundle.get_sycl_kernel(kernel.name)

return sycl_kernel
self.kernel_module = kernel_module


def _legalize_names_with_typemap(names, typemap):
Expand Down Expand Up @@ -189,76 +138,11 @@ def _replace_var_with_array(vars, loop_body, typemap, calltypes):
typemap[v] = types.npytypes.Array(el_typ, 1, "C")


def _find_setitems_block(setitems, block, typemap):
for inst in block.body:
if isinstance(inst, ir.StaticSetItem) or isinstance(inst, ir.SetItem):
setitems.add(inst.target.name)
elif isinstance(inst, parfor.Parfor):
_find_setitems_block(setitems, inst.init_block, typemap)
_find_setitems_body(setitems, inst.loop_body, typemap)


def _find_setitems_body(setitems, loop_body, typemap):
"""
Find the arrays that are written into (goes into setitems)
"""
for label, block in loop_body.items():
_find_setitems_block(setitems, block, typemap)


def _replace_sentinel_with_parfor_body(kernel_ir, sentinel_name, loop_body):
# new label for splitting sentinel block
new_label = max(loop_body.keys()) + 1

# Search all the block in the kernel function for the sentinel assignment.
for label, block in kernel_ir.blocks.items():
for i, inst in enumerate(block.body):
if (
isinstance(inst, ir.Assign)
and inst.target.name == sentinel_name
):
# We found the sentinel assignment.
loc = inst.loc
scope = block.scope
# split block across __sentinel__
# A new block is allocated for the statements prior to the
# sentinel but the new block maintains the current block label.
prev_block = ir.Block(scope, loc)
prev_block.body = block.body[:i]

# The current block is used for statements after the sentinel.
block.body = block.body[i + 1 :] # noqa: E203
# But the current block gets a new label.
body_first_label = min(loop_body.keys())

# The previous block jumps to the minimum labelled block of the
# parfor body.
prev_block.append(ir.Jump(body_first_label, loc))
# Add all the parfor loop body blocks to the kernel function's
# IR.
for loop, b in loop_body.items():
kernel_ir.blocks[loop] = b
body_last_label = max(loop_body.keys())
kernel_ir.blocks[new_label] = block
kernel_ir.blocks[label] = prev_block
# Add a jump from the last parfor body block to the block
# containing statements after the sentinel.
kernel_ir.blocks[body_last_label].append(
ir.Jump(new_label, loc)
)
break
else:
continue
break


def create_kernel_for_parfor(
lowerer,
parfor_node,
typemap,
flags,
loop_ranges,
has_aliases,
races,
parfor_outputs,
) -> ParforKernel:
Expand Down Expand Up @@ -367,120 +251,38 @@ def create_kernel_for_parfor(
loop_ranges=loop_ranges,
param_dict=param_dict,
)
kernel_ir = kernel_template.kernel_ir

if config.DEBUG_ARRAY_OPT:
print("kernel_ir dump ", type(kernel_ir))
kernel_ir.dump()
print("loop_body dump ", type(loop_body))
_print_body(loop_body)

# rename all variables in kernel_ir afresh
var_table = get_name_var_table(kernel_ir.blocks)
new_var_dict = {}
reserved_names = (
[sentinel_name] + list(param_dict.values()) + legal_loop_indices
kernel_dispatcher: SPIRVKernelDispatcher = kernel(
kernel_template.py_func,
_parfor_body_args=ParforBodyArguments(
loop_body=loop_body,
param_dict=param_dict,
legal_loop_indices=legal_loop_indices,
),
)
for name, var in var_table.items():
if not (name in reserved_names):
new_var_dict[name] = mk_unique_var(name)
replace_var_names(kernel_ir.blocks, new_var_dict)
if config.DEBUG_ARRAY_OPT:
print("kernel_ir dump after renaming ")
kernel_ir.dump()

kernel_param_types = param_types

if config.DEBUG_ARRAY_OPT:
print(
"kernel_param_types = ",
type(kernel_param_types),
"\n",
kernel_param_types,
)

kernel_stub_last_label = max(kernel_ir.blocks.keys()) + 1

# Add kernel stub last label to each parfor.loop_body label to prevent
# label conflicts.
loop_body = add_offset_to_labels(loop_body, kernel_stub_last_label)

_replace_sentinel_with_parfor_body(kernel_ir, sentinel_name, loop_body)

if config.DEBUG_ARRAY_OPT:
print("kernel_ir last dump before renaming")
kernel_ir.dump()

kernel_ir.blocks = rename_labels(kernel_ir.blocks)
remove_dels(kernel_ir.blocks)

old_alias = flags.noalias
if not has_aliases:
if config.DEBUG_ARRAY_OPT:
print("No aliases found so adding noalias flag.")
flags.noalias = True

remove_dead(kernel_ir.blocks, kernel_ir.arg_names, kernel_ir, typemap)

if config.DEBUG_ARRAY_OPT:
print("kernel_ir after remove dead")
kernel_ir.dump()

# The first argument to a range kernel is a kernel_api.Item object. The
ZzEeKkAa marked this conversation as resolved.
Show resolved Hide resolved
# ``Item`` object is used by the kernel_api.spirv backend to generate the
# The first argument to a range kernel is a kernel_api.NdItem object. The
# ``NdItem`` object is used by the kernel_api.spirv backend to generate the
# correct SPIR-V indexing instructions. Since, the argument is not something
# available originally in the kernel_param_types, we add it at this point to
# make sure the kernel signature matches the actual generated code.
ty_item = ItemType(parfor_dim)
kernel_param_types = (ty_item, *kernel_param_types)
kernel_param_types = (ty_item, *param_types)
kernel_sig = signature(types.none, *kernel_param_types)

if config.DEBUG_ARRAY_OPT:
sys.stdout.flush()

if config.DEBUG_ARRAY_OPT:
print("after DUFunc inline".center(80, "-"))
kernel_ir.dump()

# The ParforLegalizeCFD pass has already ensured that the LHS and RHS
# arrays are on same device. We can take the queue from the first input
# array and use that to compile the kernel.

exec_queue: dpctl.SyclQueue = None

for arg in parfor_args:
obj = typemap[arg]
if isinstance(obj, DpnpNdArray):
filter_string = obj.queue.sycl_device
# FIXME: A better design is required so that we do not have to
# create a queue every time.
exec_queue = dpctl.get_device_cached_queue(filter_string)

if not exec_queue:
raise AssertionError(
"No execution found for parfor. No way to compile the kernel!"
)

sycl_kernel = _compile_kernel_parfor(
exec_queue,
kernel_name,
kernel_ir,
kernel_param_types,
debug=flags.debuginfo,
kcres: _SPIRVKernelCompileResult = kernel_dispatcher.get_compile_result(
types.void(*kernel_param_types) # kernel signature
)

flags.noalias = old_alias
kernel_module: SPIRVKernelModule = kcres.kernel_device_ir_module

if config.DEBUG_ARRAY_OPT:
print("kernel_sig = ", kernel_sig)

return ParforKernel(
name=kernel_name,
kernel=sycl_kernel,
signature=kernel_sig,
kernel_args=parfor_args,
kernel_arg_types=func_arg_types,
queue=exec_queue,
kernel_module=kernel_module,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,9 @@ def _generate_kernel_ir(self):
def dump_kernel_string(self):
raise NotImplementedError

@abc.abstractmethod
def dump_kernel_ir(self):
raise NotImplementedError

@property
@abc.abstractmethod
def kernel_ir(self):
def py_func(self):
raise NotImplementedError

@property
Expand Down
21 changes: 7 additions & 14 deletions numba_dpex/core/parfors/kernel_templates/range_kernel_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import sys

import dpnp
from numba.core import compiler

import numba_dpex as dpex

Expand Down Expand Up @@ -51,7 +50,7 @@ def __init__(
self._param_dict = param_dict

self._kernel_txt = self._generate_kernel_stub_as_string()
self._kernel_ir = self._generate_kernel_ir()
self._py_func = self._generate_kernel_ir()

def _generate_kernel_stub_as_string(self):
"""Generates a stub dpex kernel for the parfor as a string.
Expand Down Expand Up @@ -109,17 +108,15 @@ def _generate_kernel_ir(self):
globls = {"dpnp": dpnp, "dpex": dpex}
locls = {}
exec(self._kernel_txt, globls, locls)
kernel_fn = locls[self._kernel_name]

return compiler.run_frontend(kernel_fn)
return locls[self._kernel_name]

@property
def kernel_ir(self):
"""Returns the Numba IR generated for a RangeKernelTemplate.

Returns: The Numba functionIR object for the compiled kernel_txt string.
def py_func(self):
"""Returns the python function generated for a
TreeReduceIntermediateKernelTemplate.
Returns: The python function object for the compiled kernel_txt string.
"""
return self._kernel_ir
return self._py_func

@property
def kernel_string(self):
Expand All @@ -134,7 +131,3 @@ def dump_kernel_string(self):
"""Helper to print the kernel function string."""
print(self._kernel_txt)
sys.stdout.flush()

def dump_kernel_ir(self):
"""Helper to dump the Numba IR for the RangeKernelTemplate."""
self._kernel_ir.dump()
Loading
Loading