Skip to content

Commit

Permalink
Merge pull request #75 from martindurant/match_kwargs
Browse files Browse the repository at this point in the history
Allow match_kwargs for kernels that need records
  • Loading branch information
martindurant authored Aug 8, 2024
2 parents eb24912 + 781dfcb commit 6d46786
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 14 deletions.
37 changes: 24 additions & 13 deletions src/akimbo/apply_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,33 @@ class NoDtype:
kind = ""


def leaf(*layout):
def leaf(*layout, **_):
return layout[0].is_leaf


def run_with_transform(
arr: ak.Array, op, match=leaf, outtype=None, inmode="arrow", others=(), **kw
arr: ak.Array,
op,
match=leaf,
outtype=None,
inmode="arrow",
others=(),
match_kwargs=None,
**kw,
) -> ak.Array:
def func(layout, **kwargs):
if not isinstance(layout, tuple):
layout = (layout,)
if match(*layout):
if match(*layout, **(match_kwargs or {})):
if inmode == "arrow":
out = ak.str._apply_through_arrow(op, *layout, **kw)
out = ak.str._apply_through_arrow(
op, *layout, **kw, **(match_kwargs or {})
)
elif inmode == "numpy":
# works on numpy/cupy contents
out = op(*(lay.data for lay in layout), **kw)
out = op(*(lay.data for lay in layout), **kw, **(match_kwargs or {}))
else:
out = op(*layout, **kw)
out = op(*layout, **kw, **(match_kwargs or {}))
return outtype(out) if callable(outtype) else out

return ak.transform(func, arr, *others)
Expand All @@ -37,7 +46,7 @@ def dec(func, match=leaf, outtype=None, inmode="arrow"):
"""Make a nested/ragged version of an operation to apply throughout a tree"""

@functools.wraps(func)
def f(self, *args, where=None, **kwargs):
def f(self, *args, where=None, match_kwargs=None, **kwargs):
if not (
where is None
or isinstance(where, str)
Expand Down Expand Up @@ -71,6 +80,7 @@ def f(self, *args, where=None, **kwargs):
outtype=outtype,
inmode=inmode,
others=others,
match_kwargs=match_kwargs,
**kwargs,
)
final = ak.with_field(arr, out, where=where)
Expand All @@ -84,24 +94,25 @@ def f(self, *args, where=None, **kwargs):
outtype=outtype,
inmode=inmode,
others=others,
match_kwargs=match_kwargs,
**kwargs,
)
)

f.__doc__ = """Run vectorized functions on nested/ragged/complex array
f.__doc__ = f"""Run vectorized functions on nested/ragged/complex array
where: None | str | Sequence[str, ...]
if None, will attempt to apply the kernel throughout the nested structure,
wherever correct types are encountered. If where is given, only the selected
part of the structure will be considered, but the output will retain
the original shape. A fieldname or sequence of fieldnames to descend into
the tree are acceptable
match_kwargs: None | dict
any extra field identifiers for matching a record as OK to process
-Kernel documentation follows from the original function-
{'-Kernel documentation follows from the original function-' if f.__doc__ else ''}
===
""" + (
f.__doc__ or str(f)
)
{f.__doc__ or str(f)}
"""

return f
2 changes: 1 addition & 1 deletion src/akimbo/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def get_json_schema(
):
"""Get JSONSchema representation of the contents of a line-delimited JSON file
Currently requires dask_awkward to be installed, which in turn required dask
Currently, requires dask_awkward to be installed, which in turn required dask
Parameters
----------
Expand Down

0 comments on commit 6d46786

Please sign in to comment.