Skip to content

Commit

Permalink
#204 Add utility to get phase label of 1D DataArray
Browse files Browse the repository at this point in the history
  • Loading branch information
astropenguin committed Aug 17, 2024
1 parent 409f4df commit 7f3b3b0
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
34 changes: 33 additions & 1 deletion decode/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__all__ = ["mad"]
__all__ = ["mad", "phaseof"]


# dependencies
Expand Down Expand Up @@ -38,3 +38,35 @@ def median(da: xr.DataArray) -> xr.DataArray:
)

return median(cast(xr.DataArray, np.abs(da - median(da))))


def phaseof(
da: xr.DataArray,
/,
*,
keep_attrs: bool = False,
keep_coords: bool = False,
) -> xr.DataArray:
"""Assign a phase to each value in a 1D DataArray.
The function assigns a unique phase (int64) to consecutive
identical values in the DataArray. The phase increases
sequentially whenever the value in the DataArray changes.
Args:
da: Input 1D DataArray.
keep_attrs: Whether to keep attributes of the input.
keep_coords: Whether to keep coordinates of the input.
Returns:
1D int64 DataArray of phases.
"""
if da.ndim != 1:
raise ValueError("Input DataArray must be 1D.")

is_transision = xr.zeros_like(da, bool)
is_transision.data[1:] = da.data[1:] != da.data[:-1]

phase = is_transision.cumsum(keep_attrs=keep_attrs)
return phase if keep_coords else phase.reset_coords(drop=True)
6 changes: 6 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,9 @@
def test_mad() -> None:
dems = MS.new(np.arange(25).reshape(5, 5))
assert (utils.mad(dems, "time") == 5.0).all()


def test_phaseof() -> None:
tester = xr.DataArray([0, 1, 1, 2, 2, 2, 1, 0])
expected = xr.DataArray([0, 1, 1, 2, 2, 2, 3, 4])
assert (utils.phaseof(tester) == expected).all()

0 comments on commit 7f3b3b0

Please sign in to comment.