Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Curated updates II #51

Merged
merged 12 commits into from
Mar 12, 2024
16 changes: 15 additions & 1 deletion docs/source/submodules/curated_spike_analysis.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,24 @@ To make generation of this easier there is a bit in :code:`get_responsive_neuron
st.get_responsive_neurons(z_parameters=my_parameters) # created a parameters dict
st.save_responsive_neurons()

curation = sa.read_responsive_neurons()
curation = sa.read_responsive_neurons(file_path)

curated_st = sa.CuratedSpikeAnalysis(curation=curation)

Loading the Data
----------------

Because :code:`CuratedSpikeAnalysis` is a :code:`SpikeAnalysis` object it requires :code:`StimulusData` and a
:code:`SpikeData`. These can be loaded with the normal methods :code:`set_stimulus_data()` and
:code:`set_spike_data()`, but there is also a convenience function to load all necessary data from the :code:`SpikeAnalysis`:
:code:`set_spike_analysis()`.

.. code-block:: python

# This will collect stim and spike data all at once.
curated_st.set_spike_analysis(st)


Curating the Data
-----------------

Expand Down
36 changes: 17 additions & 19 deletions docs/source/submodules/merged_spike_analysis.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,34 +19,32 @@ similar fashion to other classes
# or we can use lists
merged_data.add_analysis(analysis=[st3,st4], name=['animal3', 'animal4'])

Once the data to merge is ready to be merged one can use the :code:`merge()` function. This takes
in the value :code:`psth`, which can either be set to :code:`True` to mean to load a balanced
:code:`psths` values or can be a value in a list of potential merge values, e.g. :code:`zscore` or
for example :code:`fr`.
Once the data to merge is ready to be merged one can use the :code:`merge_data()` function.

.. code-block:: python

# will attempt to merge the psths of each dataset
merged_data.merge(psth=True)
merged_data.merge_data()

# will attempt to merge z scores
merged_data.merge(psth=['zscore'])

Note, that the datasets to be merged must be balanced. For example a dataset with 5 neurons,
10 trials, and 200 timepoints can only be merged to another dataset with :code:`x` neurons, 10
trials, and 200 timepoints. The concatenation occurs at the level of the neuron axis (:code:`axis 0`)
so everything else must have the same dimensionality.
After merging the datasets the standard :code:`SpikeAnalysis` functions can be run. Under the hood each dataset
will be run with the exact same conditions to ensure the time bins are balanced. At a fundamental level the data
is set up as a series of matrices with :code:`(n_neurons, n_trialgroups, n_time_bins)`.

Since different animals each have different numbers of trial groups the functions after :code:`get_raw_psth()` are
run with the :code:`fill` which will take animals missing a trial group and fill with :code:`fill`. The default for this
is :code:`np.nan`.

.. code-block:: python

merged_data.get_raw_psth(window=[-1, 2], time_bin_ms=1)
merged_data.zscore_data(time_bin_ms=10, bsl_window=[-1,-.1], z_window=[-1,2], fill=np.nan)

Finally, the merged data set can be return for use in the :code:`SpikePlotter` class.

.. code-block:: python

msa = merged_data.get_merged_data()
plotter = sa.SpikePlotter()
plotter.set_analysis(msa)
plotter.set_analysis(merged_data)

This works because the :code:`MSA` returned is a :code:`SpikeAnalysis` object that has specific
guardrails around methods which can no longer be accessed. For example, if the data was merged with
:code:`psth=True`, then z scores can be regenerated across the data with a different :code:`time_bin_ms`,
but if :code:`psth=['zscore']` was used then new z scores can be generated and the :code:`MSA` will
return a :code:`NotImplementedError`
This works because the :code:`merged_data` is a :code:`SpikeAnalysis` object that has specific
guardrails around methods which can no longer be accessed. Plotting can occur as would normally occur.
15 changes: 15 additions & 0 deletions docs/source/submodules/spike_analysis.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ Setting Stimulus and Spike Data
-------------------------------

:code:`SpikeAnalysis` requires both :code:`StimulusData` and :code:`SpikeData` to perform analyses. It has a setting method for each of these datasets.
To leverage the power of the SpikeInterface project there is a separate function: :code:`set_spike_data_si()`, which takes
any :code:`spikeinterface.BaseSorting`.


.. code-block:: python

Expand All @@ -38,6 +41,18 @@ Setting Stimulus and Spike Data
spiketrain.set_stimulus_data(event_times = stim)
spiketrain.set_spike_data(sp = spikes)

or

.. code-block:: python

# sorting = spikeinterface.BaseSorting

import spikeanalysis as sa
spiketrain = sa.SpikeAnalysis()
spiketrain.set_stimulus_data(event_times=stim)
spiketrain.set_spike_data_si(sp=sorting)


Calculating Peristimulus Histogram (PSTH)
-----------------------------------------

Expand Down
67 changes: 62 additions & 5 deletions src/spikeanalysis/curated_spike_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np

from .spike_analysis import SpikeAnalysis
from .spike_data import SpikeData


def read_responsive_neurons(folder_path) -> dict:
Expand Down Expand Up @@ -44,7 +45,10 @@ class CuratedSpikeAnalysis(SpikeAnalysis):
"""Class for analyzing curated spiketrain data
based on a curation dictionary"""

def __init__(self, curation: dict | None = None):
def __init__(
self, curation: dict | None = None, st: SpikeAnalysis | None = None, save_parameters=False, verbose=False
):

"""
Parameters
----------
Expand All @@ -53,22 +57,73 @@ def __init__(self, curation: dict | None = None):

"""

self.curation = curation
super().__init__()
self.curation = curation or {}
if st is not None:
self.set_spike_analysis(st=st)
super().__init__(save_parameters=save_parameters, verbose=verbose)

def set_curation(self, curation: dict):
def set_curation(
self,
curation: dict,
):
"""
Function for seting the curation dictionary
Parameters
----------
curation: dict
The curation dict for curating
"""
if not isinstance(curation, dict):
raise TypeError(f"curation must be dict not a {type(curation)}")
self.curation = curation

def set_spike_data(self, sp: "SpikeData"):
def set_spike_data(self, sp: SpikeData):
"""
Function for setting a SpikeData object

Parameters
----------
sp: SpikeData
A spikeanalysis.SpikeData object to be curated
"""
if not isinstance(sp, SpikeData):
raise TypeError("Set with spike data")
from copy import deepcopy

super().set_spike_data(sp=sp)
self._original_cluster_ids = deepcopy(self.cluster_ids)


def set_spike_data_si(self, sp: "Sorting"):
"""
Function for setting a spikeinterface sorting

Parameters
----------
sp: spikeinterface.BaseSorting
The spikeinterface Sorting object to load
"""

from copy import deepcopy

super().set_spike_data_si(sp=sp)
self._original_cluster_ids = deepcopy(self.cluster_ids)

def set_spike_analysis(self, st: SpikeAnalysis):
"""
Function for setting a SpikeAnalysis
st: spikanalysis.SpikeAnalysis
The SpikeAnalysis (containing Stim and Spike Data to load)"""
from copy import deepcopy

self.events = st.events
self._sampling_rate = st._sampling_rate
self._original_cluster_ids = deepcopy(st.cluster_ids)
self.raw_spike_times = st.spike_times
self.spike_clusters = st.spike_clusters
self._cids = st._cids
self.cluster_ids = st.cluster_ids

def curate(
self,
criteria: str | dict,
Expand All @@ -93,6 +148,8 @@ def curate(

"""
curation = self.curation
if len(curation) == 0:
raise RuntimeError("Must set curation first. Run `set_curation`")

if by_stim and by_response:
assert isinstance(criteria, dict), "must give both stim and response as a dict to run"
Expand Down
33 changes: 29 additions & 4 deletions src/spikeanalysis/merged_spike_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,19 @@ class MergedSpikeAnalysis(SpikeAnalysis):

def __init__(self, spikeanalysis_list=None, name_list=None, save_parameters=False, verbose=False):

if spikeanalysis_list is not None:
if not isinstance(spikeanalysis_list, list) and not isinstance(
spikeanalysis_list, (SpikeAnalysis, CuratedSpikeAnalysis)
):
raise TypeError("spikeanalysis must be a list or an individual spikeanalysis")
if isinstance(spikeanalysis_list, (SpikeAnalysis, CuratedSpikeAnalysis)):
spikeanalysis_list = [spikeanalysis_list]
if name_list is not None:
if not isinstance(name_list, list) and not isinstance(name_list, str):
raise TypeError("name list must be a list or a str")
if isinstance(name_list, str):
name_list = [name_list]

self.spikeanalysis_list = spikeanalysis_list or []
self.name_list = name_list or []
super().__init__(save_parameters=save_parameters, verbose=verbose)
Expand All @@ -19,16 +32,24 @@ def add_analysis(self, spikeanalysis, name):
if len(spikeanalysis) != len(name):
raise RuntimeError(f"{len(spikeanalysis)=} != {len(name)=}")
for idx, sa in enumerate(spikeanalysis):
self._verify_obj(sa)
self.spikeanalysis_list.append(sa)
if name[idx] in self.name_list:
raise RuntimeError("The same name can not be used for multiple datasets")
self.name_list.append(name[idx])
else:
if not isinstance(spikeanalysis, (SpikeAnalysis | CuratedSpikeAnalysis)):
if not isinstance(spikeanalysis, (SpikeAnalysis, CuratedSpikeAnalysis)):
raise TypeError(f"Spikeanalysis must be a list or a spikeanalysis not a type {type(spikeanalysis)}")
if not isinstance(name, str):
raise TypeError("if spikeanalysis is type SpikeAnalysis, then name must be a string")
self.spikeanalysis_list.append(spikeanalysis)
self.name_list.append(name)

def _verify_obj(self, obj):

if any([id(sa) == id(obj) for sa in self.spikeanalysis_list]):
raise RuntimeError("Cannot merge the same data twice")

def merge_data(self):

self.events = self.spikeanalysis_list[0].events
Expand Down Expand Up @@ -101,7 +122,7 @@ def _fill_merged_data(
if not np.isnan(fill) and not isinstance(fill, (int, float)):
raise TypeError(f"fill should be nan or ideally 0; it is {fill}")

tg_list = [np.unique(x[self._get_stim_key(x, stim)]['trial_groups']) for x in self.events_list]
tg_list = [np.unique(x[self._get_stim_key(x, stim)]["trial_groups"]) for x in self.events_list]
flat_tg_list = list(set([y for x in tg_list for y in x]))
for psth_idx, fr in enumerate(data_list):
current_tg = tg_list[psth_idx]
Expand Down Expand Up @@ -144,6 +165,8 @@ def get_raw_firing_rate(

self.mean_firing_rate = merged_firing_rates
self.fr_bins = self.spikeanalysis_list[0].fr_bins
self.fr_windows = self.spikeanalysis_list[0].fr_windows


def z_score_data(self, time_bin_ms, bsl_window, z_window, eps=0, fill=np.nan):

Expand All @@ -163,6 +186,8 @@ def z_score_data(self, time_bin_ms, bsl_window, z_window, eps=0, fill=np.nan):

self.z_scores = merged_z_scores
self.z_bins = self.spikeanalysis_list[0].z_bins
self.z_windows = self.spikeanalysis_list[0].z_windows


def latencies(self):
print("To do")
Expand All @@ -177,7 +202,7 @@ def trial_correlation(
self,
window: list | list[list],
time_bin_ms: float | None = None,
dataset = "psth",
method = "pearson",
dataset="psth",
method="pearson",
):
raise NotImplementedError("Should run in the base SpikeAnalysis")
16 changes: 9 additions & 7 deletions test/test_merged_spike_analysis.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import numpy as np
from pathlib import Path
from copy import deepcopy


from spikeanalysis.merged_spike_analysis import MergedSpikeAnalysis
Expand Down Expand Up @@ -32,9 +33,9 @@ def test_ma_init(sa):


def test_add_analysis(sa):

sa2 = deepcopy(sa)
test_msa = MergedSpikeAnalysis()
test_msa.add_analysis([sa, sa], name=["test1", "test2"])
test_msa.add_analysis([sa, sa2], name=["test1", "test2"])

assert len(test_msa.spikeanalysis_list) == len(test_msa.name_list)

Expand All @@ -54,8 +55,8 @@ def sa_mocked(sa):


def test_merge(sa_mocked):

test_msa = MergedSpikeAnalysis([sa_mocked, sa_mocked], ["test1", "test2"])
sa_mocked2 = deepcopy(sa_mocked)
test_msa = MergedSpikeAnalysis([sa_mocked, sa_mocked2], ["test1", "test2"])
test_msa.merge_data()

assert len(test_msa.raw_spike_times) == 2 * len(sa_mocked.raw_spike_times) # same data twice
Expand All @@ -65,7 +66,8 @@ def test_merge(sa_mocked):

def test_fr_z_psth(sa_mocked):

test_msa = MergedSpikeAnalysis([sa_mocked, sa_mocked], ["test1", "test2"])
sa_mocked2 = deepcopy(sa_mocked)
test_msa = MergedSpikeAnalysis([sa_mocked, sa_mocked2], ["test1", "test2"])
test_msa.merge_data()
test_msa.get_raw_psth(
window=[0, 300],
Expand All @@ -85,7 +87,6 @@ def test_fr_z_psth(sa_mocked):


def test_fr_z_psth_different_trials(sa_mocked):
from copy import deepcopy

sa_mocked1 = deepcopy(sa_mocked)
sa_mocked1.events = {
Expand Down Expand Up @@ -117,7 +118,8 @@ def test_fr_z_psth_different_trials(sa_mocked):

def test_interspike_interval(sa_mocked):

test_msa = MergedSpikeAnalysis([sa_mocked, sa_mocked], ["test1", "test2"])
sa_mocked2 = deepcopy(sa_mocked)
test_msa = MergedSpikeAnalysis([sa_mocked, sa_mocked2], ["test1", "test2"])
test_msa.merge_data()
test_msa.get_interspike_intervals()

Expand Down
Loading