From ddc2b58cc45820a3ad8c70246d16046890411911 Mon Sep 17 00:00:00 2001 From: Shing Chan Date: Thu, 10 Oct 2024 15:30:22 +0100 Subject: [PATCH] feat: add functionality to plot time series of steps --- setup.py | 3 +- src/stepcount/stepcount.py | 73 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 964b57c..d64eb71 100644 --- a/setup.py +++ b/setup.py @@ -62,7 +62,8 @@ def get_string(string, rel_path="src/stepcount/__init__.py"): "torch==1.13.*", "torchvision==0.14.*", "transforms3d==0.4.*", - "numba==0.58.*" + "numba==0.58.*", + "matplotlib==3.7.*", ], extras_require={ "dev": [ diff --git a/src/stepcount/stepcount.py b/src/stepcount/stepcount.py index 2a61383..c4a36fd 100644 --- a/src/stepcount/stepcount.py +++ b/src/stepcount/stepcount.py @@ -11,6 +11,8 @@ import numpy as np import pandas as pd import joblib +import matplotlib.pyplot as plt +import matplotlib.dates as mdates from numba import njit from stepcount import utils @@ -388,6 +390,10 @@ def main(): print(daily_adj.set_index('Date').drop(columns='Filename')) print("\nOutput files saved in:", outdir) + print("\nPlotting...") + fig = plot(Y, title=basename) + fig.savefig(f"{outdir}/{basename}-Steps.png", bbox_inches='tight', pad_inches=0) + after = time.time() print(f"Done! ({round(after - before,2)}s)") @@ -1085,6 +1091,73 @@ def numba_detect_bouts( return bouts +def plot(Y, title=None): + """ + Plot time series of steps per minute for each day. + + Parameters: + - Y: pandas Series or DataFrame with a 'Steps' column. Must have a DatetimeIndex. + + Returns: + - fig: matplotlib figure object + """ + + MAX_STEPS_PER_MINUTE = 180 + + if isinstance(Y, pd.DataFrame): + Y = Y['Steps'] + + assert isinstance(Y, pd.Series), "Y must be a pandas Series, or a DataFrame with a 'Steps' column" + + # Resample to 1 minute intervals + # Note: .sum() returns 0 when all values are NaN, so we need to use a custom function + def _sum(x): + if x.isna().all(): + return np.nan + return x.sum() + + Y = Y.resample('1T').agg(_sum) + + dates_index = Y.index.normalize() + unique_dates = dates_index.unique() + + # Set the plot figure and size + fig = plt.figure(figsize=(10, len(unique_dates) * 2)) + + # Group by each day + for i, (day, y) in enumerate(Y.groupby(dates_index)): + ax = fig.add_subplot(len(unique_dates), 1, i + 1) + + # Plot steps + ax.plot(y.index, y, label='steps/min') + + # Grey shading where NA + ax.fill_between(y.index, 0, MAX_STEPS_PER_MINUTE, where=y.isna(), color='grey', alpha=0.3, interpolate=True, label='missing') + + # Formatting the x-axis to show hours and minutes + ax.xaxis.set_major_locator(mdates.HourLocator(interval=1)) + ax.xaxis.set_minor_locator(mdates.MinuteLocator(interval=15)) + ax.xaxis.set_major_formatter(mdates.DateFormatter("%H:%M")) + + # Set x-axis limits to start at 00:00 and end at 24:00 + ax.set_xlim(day, day + pd.DateOffset(days=1)) + # Set y-axis limits + ax.set_ylim(-10, MAX_STEPS_PER_MINUTE) + + ax.tick_params(axis='x', rotation=45) + ax.set_ylabel('steps/min') + ax.set_title(day.strftime('%Y-%m-%d')) + ax.grid(True) + ax.legend(loc='upper left') + + if title: + fig.suptitle(title) + + fig.tight_layout() + + return fig + + if __name__ == '__main__': main()