diff --git a/doc/source/rawio.rst b/doc/source/rawio.rst index 5f0d41b45..52127700a 100644 --- a/doc/source/rawio.rst +++ b/doc/source/rawio.rst @@ -281,6 +281,32 @@ Read event timestamps and times In [42]: print(ev_times) [ 0.0317] +Signal streams and signal buffers +--------------------------------- + +For reading analog signals **neo.rawio** has 2 important concepts: + + 1. The **signal_stream** : it is a group of channels that can be read together using :func:`get_analog_signal_chunk()`. + This group of channels is guaranteed to have the same sampling rate, and the same duration per segment. + Most of the time, this group of channel is a "logical" group of channels. In short they are from the same headstage + or from the same auxiliary board. + Optionally, depending on the format, a **signal_stream** can be a slice of or an entire **signal_buffer**. + + 2. The **signal_buffer** : it is group of channels that share the same data layout in a file. The most simple example + is channel that can be read by a simple :func:`signals = np.memmap(file, shape=..., dtype=... , offset=...)`. + A **signal_buffer** can contain one or several **signal_stream**'s (very often it is only one). + There are two kind of formats that handle this concept: + + * Formats which use :func:`np.memmap()` internally + * Formats based on hdf5 + + There are many formats that do not handle this concept: + + * the ones that use an external python package for reading data (edf, ced, plexon2, ...) + * the ones with a complicated data layout (e.g. those where the data blocks are split without structure) + + To check if a format makes use of the buffer api you can check the class attribute flag `has_buffer_description_api` of the + rawio class. diff --git a/neo/rawio/axonrawio.py b/neo/rawio/axonrawio.py index 8caf8554c..6980bf17a 100644 --- a/neo/rawio/axonrawio.py +++ b/neo/rawio/axonrawio.py @@ -53,7 +53,7 @@ import numpy as np from .baserawio import ( - BaseRawIO, + BaseRawWithBufferApiIO, _signal_channel_dtype, _signal_stream_dtype, _signal_buffer_dtype, @@ -63,7 +63,7 @@ from neo.core import NeoReadWriteError -class AxonRawIO(BaseRawIO): +class AxonRawIO(BaseRawWithBufferApiIO): """ Class for Class for reading data from pCLAMP and AxoScope files (.abf version 1 and 2) @@ -92,7 +92,7 @@ class AxonRawIO(BaseRawIO): rawmode = "one-file" def __init__(self, filename=""): - BaseRawIO.__init__(self) + BaseRawWithBufferApiIO.__init__(self) self.filename = filename def _parse_header(self): @@ -115,8 +115,6 @@ def _parse_header(self): head_offset = info["sections"]["DataSection"]["uBlockIndex"] * BLOCKSIZE totalsize = info["sections"]["DataSection"]["llNumEntries"] - self._raw_data = np.memmap(self.filename, dtype=sig_dtype, mode="r", shape=(totalsize,), offset=head_offset) - # 3 possible modes if version < 2.0: mode = info["nOperationMode"] @@ -142,7 +140,7 @@ def _parse_header(self): ) else: episode_array = np.empty(1, [("offset", "i4"), ("len", "i4")]) - episode_array[0]["len"] = self._raw_data.size + episode_array[0]["len"] = totalsize episode_array[0]["offset"] = 0 # sampling_rate @@ -154,9 +152,14 @@ def _parse_header(self): # one sweep = one segment nb_segment = episode_array.size + stream_id = "0" + buffer_id = "0" + # Get raw data by segment - self._raw_signals = {} + # self._raw_signals = {} self._t_starts = {} + self._buffer_descriptions = {0 :{}} + self._stream_buffer_slice = {stream_id : None} pos = 0 for seg_index in range(nb_segment): length = episode_array[seg_index]["len"] @@ -169,7 +172,15 @@ def _parse_header(self): if (fSynchTimeUnit != 0) and (mode == 1): length /= fSynchTimeUnit - self._raw_signals[seg_index] = self._raw_data[pos : pos + length].reshape(-1, nbchannel) + self._buffer_descriptions[0][seg_index] = {} + self._buffer_descriptions[0][seg_index][buffer_id] = { + "type" : "raw", + "file_path" : str(self.filename), + "dtype" : str(sig_dtype), + "order": "C", + "file_offset" : head_offset + pos * sig_dtype.itemsize, + "shape" : (int(length // nbchannel), int(nbchannel)), + } pos += length t_start = float(episode_array[seg_index]["offset"]) @@ -227,17 +238,14 @@ def _parse_header(self): offset -= info["listADCInfo"][chan_id]["fSignalOffset"] else: gain, offset = 1.0, 0.0 - stream_id = "0" - buffer_id = "0" - signal_channels.append( - (name, str(chan_id), self._sampling_rate, sig_dtype, units, gain, offset, stream_id, buffer_id) - ) + + signal_channels.append((name, str(chan_id), self._sampling_rate, sig_dtype, units, gain, offset, stream_id, buffer_id)) signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype) # one unique signal stream and buffer - signal_buffers = np.array([("Signals", "0")], dtype=_signal_buffer_dtype) - signal_streams = np.array([("Signals", "0", "0")], dtype=_signal_stream_dtype) + signal_buffers = np.array([("Signals", buffer_id)], dtype=_signal_buffer_dtype) + signal_streams = np.array([("Signals", stream_id, buffer_id)], dtype=_signal_stream_dtype) # only one events channel : tag # In ABF timstamps are not attached too any particular segment @@ -295,21 +303,26 @@ def _segment_t_start(self, block_index, seg_index): return self._t_starts[seg_index] def _segment_t_stop(self, block_index, seg_index): - t_stop = self._t_starts[seg_index] + self._raw_signals[seg_index].shape[0] / self._sampling_rate + sig_size = self.get_signal_size(block_index, seg_index, 0) + t_stop = self._t_starts[seg_index] + sig_size / self._sampling_rate return t_stop - def _get_signal_size(self, block_index, seg_index, stream_index): - shape = self._raw_signals[seg_index].shape - return shape[0] + # def _get_signal_size(self, block_index, seg_index, stream_index): + # shape = self._raw_signals[seg_index].shape + # return shape[0] def _get_signal_t_start(self, block_index, seg_index, stream_index): return self._t_starts[seg_index] - def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes): - if channel_indexes is None: - channel_indexes = slice(None) - raw_signals = self._raw_signals[seg_index][slice(i_start, i_stop), channel_indexes] - return raw_signals + # def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes): + # if channel_indexes is None: + # channel_indexes = slice(None) + # raw_signals = self._raw_signals[seg_index][slice(i_start, i_stop), channel_indexes] + # return raw_signals + + def _get_analogsignal_buffer_description(self, block_index, seg_index, buffer_id): + return self._buffer_descriptions[block_index][seg_index][buffer_id] + def _event_count(self, block_index, seg_index, event_channel_index): return self._raw_ev_timestamps.size diff --git a/neo/rawio/baserawio.py b/neo/rawio/baserawio.py index e875d7aa2..34c3e85a3 100644 --- a/neo/rawio/baserawio.py +++ b/neo/rawio/baserawio.py @@ -77,6 +77,8 @@ from neo import logging_handler +from .utils import get_memmap_chunk_from_opened_file + possible_raw_modes = [ "one-file", @@ -182,6 +184,15 @@ def __init__(self, use_cache: bool = False, cache_path: str = "same_as_resource" self.header = None self.is_header_parsed = False + self._has_buffer_description_api = False + + def has_buffer_description_api(self) -> bool: + """ + Return if the reader handle the buffer API. + If True then the reader support internally `get_analogsignal_buffer_description()` + """ + return self._has_buffer_description_api + def parse_header(self): """ Parses the header of the file(s) to allow for faster computations @@ -191,6 +202,7 @@ def parse_header(self): # this must create # self.header['nb_block'] # self.header['nb_segment'] + # self.header['signal_buffers'] # self.header['signal_streams'] # self.header['signal_channels'] # self.header['spike_channels'] @@ -663,6 +675,7 @@ def get_signal_size(self, block_index: int, seg_index: int, stream_index: int | """ stream_index = self._get_stream_index_from_arg(stream_index) + return self._get_signal_size(block_index, seg_index, stream_index) def get_signal_t_start(self, block_index: int, seg_index: int, stream_index: int | None = None): @@ -1311,7 +1324,6 @@ def _get_analogsignal_chunk( ------- array of samples, with each requested channel in a column """ - raise (NotImplementedError) ### @@ -1350,6 +1362,150 @@ def _rescale_event_timestamp(self, event_timestamps: np.ndarray, dtype: np.dtype def _rescale_epoch_duration(self, raw_duration: np.ndarray, dtype: np.dtype): raise (NotImplementedError) + ### + # buffer api zone + # must be implemented if has_buffer_description_api=True + def get_analogsignal_buffer_description(self, block_index: int = 0, seg_index: int = 0, buffer_id: str = None): + if not self.has_buffer_description_api: + raise ValueError("This reader do not support buffer_description API") + descr = self._get_analogsignal_buffer_description(block_index, seg_index, buffer_id) + return descr + + def _get_analogsignal_buffer_description(self, block_index, seg_index, buffer_id): + raise (NotImplementedError) + + + +class BaseRawWithBufferApiIO(BaseRawIO): + """ + Generic class for reader that support "buffer api". + + In short reader that are internally based on: + + * np.memmap + * hdf5 + + In theses cases _get_signal_size and _get_analogsignal_chunk are totaly generic and do not need to be implemented in the class. + + For this class sub classes must implements theses two dict: + * self._buffer_descriptions[block_index][seg_index] = buffer_description + * self._stream_buffer_slice[buffer_id] = None or slicer o indices + + """ + + def __init__(self, *arg, **kwargs): + super().__init__(*arg, **kwargs) + self._has_buffer_description_api = True + + def _get_signal_size(self, block_index, seg_index, stream_index): + buffer_id = self.header["signal_streams"][stream_index]["buffer_id"] + buffer_desc = self.get_analogsignal_buffer_description(block_index, seg_index, buffer_id) + # some hdf5 revert teh buffer + time_axis = buffer_desc.get("time_axis", 0) + return buffer_desc['shape'][time_axis] + + def _get_analogsignal_chunk( + self, + block_index: int, + seg_index: int, + i_start: int | None, + i_stop: int | None, + stream_index: int, + channel_indexes: list[int] | None, + ): + + stream_id = self.header["signal_streams"][stream_index]["id"] + buffer_id = self.header["signal_streams"][stream_index]["buffer_id"] + + buffer_slice = self._stream_buffer_slice[stream_id] + + + buffer_desc = self.get_analogsignal_buffer_description(block_index, seg_index, buffer_id) + + i_start = i_start or 0 + i_stop = i_stop or buffer_desc['shape'][0] + + if buffer_desc['type'] == "raw": + + # open files on demand and keep reference to opened file + if not hasattr(self, '_memmap_analogsignal_buffers'): + self._memmap_analogsignal_buffers = {} + if block_index not in self._memmap_analogsignal_buffers: + self._memmap_analogsignal_buffers[block_index] = {} + if seg_index not in self._memmap_analogsignal_buffers[block_index]: + self._memmap_analogsignal_buffers[block_index][seg_index] = {} + if buffer_id not in self._memmap_analogsignal_buffers[block_index][seg_index]: + fid = open(buffer_desc['file_path'], mode='rb') + self._memmap_analogsignal_buffers[block_index][seg_index][buffer_id] = fid + else: + fid = self._memmap_analogsignal_buffers[block_index][seg_index][buffer_id] + + num_channels = buffer_desc['shape'][1] + + raw_sigs = get_memmap_chunk_from_opened_file(fid, num_channels, i_start, i_stop, np.dtype(buffer_desc['dtype']), file_offset=buffer_desc['file_offset']) + + + elif buffer_desc['type'] == 'hdf5': + + # open files on demand and keep reference to opened file + if not hasattr(self, '_hdf5_analogsignal_buffers'): + self._hdf5_analogsignal_buffers = {} + if block_index not in self._hdf5_analogsignal_buffers: + self._hdf5_analogsignal_buffers[block_index] = {} + if seg_index not in self._hdf5_analogsignal_buffers[block_index]: + self._hdf5_analogsignal_buffers[block_index][seg_index] = {} + if buffer_id not in self._hdf5_analogsignal_buffers[block_index][seg_index]: + import h5py + h5file = h5py.File(buffer_desc['file_path'], mode="r") + self._hdf5_analogsignal_buffers[block_index][seg_index][buffer_id] = h5file + else: + h5file = self._hdf5_analogsignal_buffers[block_index][seg_index][buffer_id] + + hdf5_path = buffer_desc["hdf5_path"] + full_raw_sigs = h5file[hdf5_path] + + time_axis = buffer_desc.get("time_axis", 0) + if time_axis == 0: + raw_sigs = full_raw_sigs[i_start:i_stop, :] + elif time_axis == 1: + raw_sigs = full_raw_sigs[:, i_start:i_stop].T + else: + raise RuntimeError("Should never happen") + + if buffer_slice is not None: + raw_sigs = raw_sigs[:, buffer_slice] + + + + else: + raise NotImplementedError() + + # this is a pre slicing when the stream do not contain all channels (for instance spikeglx when load_sync_channel=False) + if buffer_slice is not None: + raw_sigs = raw_sigs[:, buffer_slice] + + # channel slice requested + if channel_indexes is not None: + raw_sigs = raw_sigs[:, channel_indexes] + + + return raw_sigs + + def __del__(self): + if hasattr(self, '_memmap_analogsignal_buffers'): + for block_index in self._memmap_analogsignal_buffers.keys(): + for seg_index in self._memmap_analogsignal_buffers[block_index].keys(): + for buffer_id, fid in self._memmap_analogsignal_buffers[block_index][seg_index].items(): + fid.close() + del self._memmap_analogsignal_buffers + + if hasattr(self, '_hdf5_analogsignal_buffers'): + for block_index in self._hdf5_analogsignal_buffers.keys(): + for seg_index in self._hdf5_analogsignal_buffers[block_index].keys(): + for buffer_id, h5_file in self._hdf5_analogsignal_buffers[block_index][seg_index].items(): + h5_file.close() + del self._hdf5_analogsignal_buffers + def pprint_vector(vector, lim: int = 8): vector = np.asarray(vector) diff --git a/neo/rawio/bci2000rawio.py b/neo/rawio/bci2000rawio.py index 96fac9183..d7e7cf003 100644 --- a/neo/rawio/bci2000rawio.py +++ b/neo/rawio/bci2000rawio.py @@ -1,6 +1,9 @@ """ BCI2000RawIO is a class to read BCI2000 .dat files. https://www.bci2000.org/mediawiki/index.php/Technical_Reference:BCI2000_File_Format + +Note : BCI2000RawIO cannot implemented using has_buffer_description_api because the buffer +is not compact. The buffer of signals is not compact (has some interleaved state uint in between channels) """ import numpy as np @@ -50,9 +53,11 @@ def _parse_header(self): self.header["nb_block"] = 1 self.header["nb_segment"] = [1] - # one unique stream and buffer - signal_buffers = np.array(("Signals", "0"), dtype=_signal_buffer_dtype) - signal_streams = np.array([("Signals", "0", "0")], dtype=_signal_stream_dtype) + # one unique stream but no buffer because channels are not compact + stream_id = "0" + buffer_id = "" + signal_buffers = np.array([], dtype=_signal_buffer_dtype) + signal_streams = np.array([("Signals", stream_id, buffer_id)], dtype=_signal_stream_dtype) self.header["signal_buffers"] = signal_buffers self.header["signal_streams"] = signal_streams @@ -80,8 +85,6 @@ def _parse_header(self): if isinstance(offset, str): offset = float(offset) - stream_id = "0" - buffer_id = "0" sig_channels.append((ch_name, chan_id, sr, dtype, units, gain, offset, stream_id, buffer_id)) self.header["signal_channels"] = np.array(sig_channels, dtype=_signal_channel_dtype) diff --git a/neo/rawio/brainvisionrawio.py b/neo/rawio/brainvisionrawio.py index 9cfcf615b..6ac597695 100644 --- a/neo/rawio/brainvisionrawio.py +++ b/neo/rawio/brainvisionrawio.py @@ -13,7 +13,7 @@ import numpy as np from .baserawio import ( - BaseRawIO, + BaseRawWithBufferApiIO, _signal_channel_dtype, _signal_stream_dtype, _signal_buffer_dtype, @@ -21,10 +21,12 @@ _event_channel_dtype, ) +from .utils import get_memmap_shape + from neo.core import NeoReadWriteError -class BrainVisionRawIO(BaseRawIO): +class BrainVisionRawIO(BaseRawWithBufferApiIO): """Class for reading BrainVision files Parameters @@ -42,8 +44,8 @@ class BrainVisionRawIO(BaseRawIO): rawmode = "one-file" def __init__(self, filename=""): - BaseRawIO.__init__(self) - self.filename = filename + BaseRawWithBufferApiIO.__init__(self) + self.filename = str(filename) def _parse_header(self): # Read header file (vhdr) @@ -78,13 +80,23 @@ def _parse_header(self): sig_dtype = fmts[fmt] - # raw signals memmap - sigs = np.memmap(binary_filename, dtype=sig_dtype, mode="r", offset=0) - if sigs.size % nb_channel != 0: - sigs = sigs[: -sigs.size % nb_channel] - self._raw_signals = sigs.reshape(-1, nb_channel) + + stream_id = "0" + buffer_id = "0" + self._buffer_descriptions = {0 :{0 : {}}} + self._stream_buffer_slice = {} + shape = get_memmap_shape(binary_filename, sig_dtype, num_channels=nb_channel, offset=0) + self._buffer_descriptions[0][0][buffer_id] = { + "type" : "raw", + "file_path" : binary_filename, + "dtype" : str(sig_dtype), + "order": "C", + "file_offset" : 0, + "shape" : shape, + } + self._stream_buffer_slice[stream_id] = None - signal_buffers = np.array(("Signals", "0"), dtype=_signal_buffer_dtype) + signal_buffers = np.array([("Signals", "0")], dtype=_signal_buffer_dtype) signal_streams = np.array([("Signals", "0", "0")], dtype=_signal_stream_dtype) sig_channels = [] @@ -181,24 +193,14 @@ def _segment_t_start(self, block_index, seg_index): return 0.0 def _segment_t_stop(self, block_index, seg_index): - t_stop = self._raw_signals.shape[0] / self._sampling_rate + sig_size = self.get_signal_size(block_index, seg_index, 0) + t_stop = sig_size / self._sampling_rate return t_stop ### - def _get_signal_size(self, block_index, seg_index, stream_index): - if stream_index != 0: - raise ValueError("`stream_index` must be 0") - return self._raw_signals.shape[0] - def _get_signal_t_start(self, block_index, seg_index, stream_index): return 0.0 - def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes): - if channel_indexes is None: - channel_indexes = slice(None) - raw_signals = self._raw_signals[slice(i_start, i_stop), channel_indexes] - return raw_signals - ### def _spike_count(self, block_index, seg_index, unit_index): return 0 @@ -232,6 +234,9 @@ def _rescale_event_timestamp(self, event_timestamps, dtype, event_channel_index) event_times = event_timestamps.astype(dtype) / self._sampling_rate return event_times + def _get_analogsignal_buffer_description(self, block_index, seg_index, buffer_id): + return self._buffer_descriptions[block_index][seg_index][buffer_id] + def read_brainvsion_soup(filename): with open(filename, "r", encoding="utf8") as f: diff --git a/neo/rawio/elanrawio.py b/neo/rawio/elanrawio.py index 18de2472d..06c928145 100644 --- a/neo/rawio/elanrawio.py +++ b/neo/rawio/elanrawio.py @@ -23,7 +23,7 @@ import numpy as np from .baserawio import ( - BaseRawIO, + BaseRawWithBufferApiIO, _signal_channel_dtype, _signal_stream_dtype, _signal_buffer_dtype, @@ -31,10 +31,12 @@ _event_channel_dtype, ) +from .utils import get_memmap_shape + from neo.core import NeoReadWriteError -class ElanRawIO(BaseRawIO): +class ElanRawIO(BaseRawWithBufferApiIO): """ Class for reading time-frequency EEG data maps from the Elan software @@ -59,7 +61,7 @@ class ElanRawIO(BaseRawIO): rawmode = "one-file" def __init__(self, filename=None, entfile=None, posfile=None): - BaseRawIO.__init__(self) + BaseRawWithBufferApiIO.__init__(self) self.filename = pathlib.Path(filename) # check whether ent and pos files are defined @@ -156,20 +158,28 @@ def _parse_header(self): sig_dtype = np.dtype(">i" + str(n)) # unique buffer and stream - signal_buffers = np.array([("Signals", "0")], dtype=_signal_buffer_dtype) - signal_streams = np.array([("Signals", "0", "0")], dtype=_signal_stream_dtype) + stream_id = "0" + buffer_id = "0" + signal_buffers = np.array([("Signals", buffer_id)], dtype=_signal_buffer_dtype) + signal_streams = np.array([("Signals", stream_id, buffer_id)], dtype=_signal_stream_dtype) + sig_channels = [] - for c, chan_info in enumerate(channel_infos[:-2]): + for c, chan_info in enumerate(channel_infos): chan_name = chan_info["label"] chan_id = str(c) + if c < len(channel_infos) - 2: + # standard channels + stream_id = "0" + else: + # last 2 channels are not included in the stream + stream_id = "" + gain = (chan_info["max_physic"] - chan_info["min_physic"]) / ( chan_info["max_logic"] - chan_info["min_logic"] ) offset = -chan_info["min_logic"] * gain + chan_info["min_physic"] - stream_id = "0" - buffer_id = "0" sig_channels.append( ( chan_name, @@ -187,8 +197,18 @@ def _parse_header(self): sig_channels = np.array(sig_channels, dtype=_signal_channel_dtype) # raw data - self._raw_signals = np.memmap(self.filename, dtype=sig_dtype, mode="r", offset=0).reshape(-1, nb_channel + 2) - self._raw_signals = self._raw_signals[:, :-2] + self._buffer_descriptions = {0 :{0 : {}}} + self._stream_buffer_slice = {} + shape = get_memmap_shape(self.filename, sig_dtype, num_channels=nb_channel + 2, offset=0) + self._buffer_descriptions[0][0][buffer_id] = { + "type" : "raw", + "file_path" : self.filename, + "dtype" : sig_dtype, + "order": "C", + "file_offset" : 0, + "shape" : shape, + } + self._stream_buffer_slice["0"] = slice(0, -2) # triggers with open(self.posfile, mode="rt", encoding="ascii", newline=None) as f: @@ -246,25 +266,15 @@ def _segment_t_start(self, block_index, seg_index): return 0.0 def _segment_t_stop(self, block_index, seg_index): - t_stop = self._raw_signals.shape[0] / self._sampling_rate + sig_size = self.get_signal_size(block_index, seg_index, 0) + t_stop = sig_size / self._sampling_rate return t_stop - def _get_signal_size(self, block_index, seg_index, stream_index): - if stream_index != 0: - raise ValueError("`stream_index` must be 0") - return self._raw_signals.shape[0] - def _get_signal_t_start(self, block_index, seg_index, stream_index): if stream_index != 0: raise ValueError("`stream_index` must be 0") return 0.0 - def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes): - if channel_indexes is None: - channel_indexes = slice(None) - raw_signals = self._raw_signals[slice(i_start, i_stop), channel_indexes] - return raw_signals - def _spike_count(self, block_index, seg_index, unit_index): return 0 @@ -291,3 +301,6 @@ def _get_event_timestamps(self, block_index, seg_index, event_channel_index, t_s def _rescale_event_timestamp(self, event_timestamps, dtype, event_channel_index): event_times = event_timestamps.astype(dtype) / self._sampling_rate return event_times + + def _get_analogsignal_buffer_description(self, block_index, seg_index, buffer_id): + return self._buffer_descriptions[block_index][seg_index][buffer_id] diff --git a/neo/rawio/maxwellrawio.py b/neo/rawio/maxwellrawio.py index 7224f27af..6cb8ced6f 100644 --- a/neo/rawio/maxwellrawio.py +++ b/neo/rawio/maxwellrawio.py @@ -27,7 +27,7 @@ import numpy as np from .baserawio import ( - BaseRawIO, + BaseRawWithBufferApiIO, _signal_channel_dtype, _signal_stream_dtype, _signal_buffer_dtype, @@ -38,7 +38,7 @@ from neo.core import NeoReadWriteError -class MaxwellRawIO(BaseRawIO): +class MaxwellRawIO(BaseRawWithBufferApiIO): """ Class for reading MaxOne or MaxTwo files. @@ -59,7 +59,7 @@ class MaxwellRawIO(BaseRawIO): rawmode = "one-file" def __init__(self, filename="", rec_name=None): - BaseRawIO.__init__(self) + BaseRawWithBufferApiIO.__init__(self) self.filename = filename self.rec_name = rec_name @@ -119,10 +119,12 @@ def _parse_header(self): # create signal channels max_sig_length = 0 - self._signals = {} + self._buffer_descriptions = {0 :{0 :{}}} + self._stream_buffer_slice = {} sig_channels = [] well_indices_to_remove = [] for stream_index, stream_id in enumerate(signal_streams["id"]): + if int(version) == 20160704: sr = 20000.0 settings = h5file["settings"] @@ -135,11 +137,11 @@ def _parse_header(self): else: gain = settings["gain"][0] gain_uV = 3.3 / (1024 * gain) * 1e6 - sigs = h5file["sig"] + hdf5_path = "sig" mapping = h5file["mapping"] ids = np.array(mapping["channel"]) ids = ids[ids >= 0] - self._channel_slice = ids + self._stream_buffer_slice[stream_id] = ids elif int(version) > 20160704: settings = h5file["wells"][stream_id][self.rec_name]["settings"] sr = settings["sampling"][0] @@ -154,12 +156,25 @@ def _parse_header(self): gain_uV = 3.3 / (1024 * gain) * 1e6 mapping = settings["mapping"] if "routed" in h5file["wells"][stream_id][self.rec_name]["groups"]: - sigs = h5file["wells"][stream_id][self.rec_name]["groups"]["routed"]["raw"] + hdf5_path = f"/wells/{stream_id}/{self.rec_name}/groups/routed/raw" else: warnings.warn(f"No 'routed' group found for well {stream_id}") well_indices_to_remove.append(stream_index) continue + self._stream_buffer_slice[stream_id] = None + + buffer_id = stream_id + shape = h5file[hdf5_path].shape + self._buffer_descriptions[0][0][buffer_id] = { + "type" : "hdf5", + "file_path" : str(self.filename), + "hdf5_path" : hdf5_path, + "shape" : shape, + "time_axis": 1, + } + self._stream_buffer_slice[stream_id] = slice(None) + channel_ids = np.array(mapping["channel"]) electrode_ids = np.array(mapping["electrode"]) mask = channel_ids >= 0 @@ -175,8 +190,7 @@ def _parse_header(self): (ch_name, str(chan_id), sr, "uint16", "uV", gain_uV, offset_uV, stream_id, buffer_id) ) - self._signals[stream_id] = sigs - max_sig_length = max(max_sig_length, sigs.shape[1]) + max_sig_length = max(max_sig_length, shape[1]) self._t_stop = max_sig_length / sr @@ -210,57 +224,60 @@ def _segment_t_start(self, block_index, seg_index): def _segment_t_stop(self, block_index, seg_index): return self._t_stop - def _get_signal_size(self, block_index, seg_index, stream_index): - stream_id = self.header["signal_streams"][stream_index]["id"] - sigs = self._signals[stream_id] - return sigs.shape[1] + def _get_analogsignal_buffer_description(self, block_index, seg_index, buffer_id): + return self._buffer_descriptions[block_index][seg_index][buffer_id] + + # def _get_signal_size(self, block_index, seg_index, stream_index): + # stream_id = self.header["signal_streams"][stream_index]["id"] + # sigs = self._signals[stream_id] + # return sigs.shape[1] def _get_signal_t_start(self, block_index, seg_index, stream_index): return 0.0 - def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes): - stream_id = self.header["signal_streams"][stream_index]["id"] - sigs = self._signals[stream_id] - - if i_start is None: - i_start = 0 - if i_stop is None: - i_stop = sigs.shape[1] - - resorted_indexes = None - if channel_indexes is None: - channel_indexes = slice(None) - else: - if np.array(channel_indexes).size > 1 and np.any(np.diff(channel_indexes) < 0): - # get around h5py constraint that it does not allow datasets - # to be indexed out of order - order_f = np.argsort(channel_indexes) - sorted_channel_indexes = channel_indexes[order_f] - # use argsort again on order_f to obtain resorted_indexes - resorted_indexes = np.argsort(order_f) - - try: - if resorted_indexes is None: - if self._old_format: - sigs = sigs[self._channel_slice, i_start:i_stop] - sigs = sigs[channel_indexes] - else: - sigs = sigs[channel_indexes, i_start:i_stop] - else: - if self._old_format: - sigs = sigs[self._channel_slice, i_start:i_stop] - sigs = sigs[sorted_channel_indexes] - else: - sigs = sigs[sorted_channel_indexes, i_start:i_stop] - sigs = sigs[resorted_indexes] - except OSError as e: - print("*" * 10) - print(_hdf_maxwell_error) - print("*" * 10) - raise (e) - sigs = sigs.T - - return sigs + # def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes): + # stream_id = self.header["signal_streams"][stream_index]["id"] + # sigs = self._signals[stream_id] + + # if i_start is None: + # i_start = 0 + # if i_stop is None: + # i_stop = sigs.shape[1] + + # resorted_indexes = None + # if channel_indexes is None: + # channel_indexes = slice(None) + # else: + # if np.array(channel_indexes).size > 1 and np.any(np.diff(channel_indexes) < 0): + # # get around h5py constraint that it does not allow datasets + # # to be indexed out of order + # order_f = np.argsort(channel_indexes) + # sorted_channel_indexes = channel_indexes[order_f] + # # use argsort again on order_f to obtain resorted_indexes + # resorted_indexes = np.argsort(order_f) + + # try: + # if resorted_indexes is None: + # if self._old_format: + # sigs = sigs[self._channel_slice, i_start:i_stop] + # sigs = sigs[channel_indexes] + # else: + # sigs = sigs[channel_indexes, i_start:i_stop] + # else: + # if self._old_format: + # sigs = sigs[self._channel_slice, i_start:i_stop] + # sigs = sigs[sorted_channel_indexes] + # else: + # sigs = sigs[sorted_channel_indexes, i_start:i_stop] + # sigs = sigs[resorted_indexes] + # except OSError as e: + # print("*" * 10) + # print(_hdf_maxwell_error) + # print("*" * 10) + # raise (e) + # sigs = sigs.T + + # return sigs _hdf_maxwell_error = """Maxwell file format is based on HDF5. diff --git a/neo/rawio/micromedrawio.py b/neo/rawio/micromedrawio.py index c92e3ea72..7e30a3829 100644 --- a/neo/rawio/micromedrawio.py +++ b/neo/rawio/micromedrawio.py @@ -14,13 +14,14 @@ import numpy as np from .baserawio import ( - BaseRawIO, + BaseRawWithBufferApiIO, _signal_channel_dtype, _signal_stream_dtype, _signal_buffer_dtype, _spike_channel_dtype, _event_channel_dtype, ) +from .utils import get_memmap_shape from neo.core import NeoReadWriteError @@ -32,7 +33,7 @@ def read_f(self, fmt, offset=None): return struct.unpack(fmt, self.read(struct.calcsize(fmt))) -class MicromedRawIO(BaseRawIO): +class MicromedRawIO(BaseRawWithBufferApiIO): """ Class for reading data from micromed (.trc). @@ -45,11 +46,15 @@ class MicromedRawIO(BaseRawIO): extensions = ["trc", "TRC"] rawmode = "one-file" + def __init__(self, filename=""): - BaseRawIO.__init__(self) + BaseRawWithBufferApiIO.__init__(self) self.filename = filename def _parse_header(self): + + self._buffer_descriptions = {0 :{ 0: {}}} + with open(self.filename, "rb") as fid: f = StructFile(fid) @@ -97,9 +102,22 @@ def _parse_header(self): # raw signals memmap sig_dtype = "u" + str(Bytes) - self._raw_signals = np.memmap(self.filename, dtype=sig_dtype, mode="r", offset=Data_Start_Offset).reshape( - -1, Num_Chan - ) + # self._raw_signals = np.memmap(self.filename, dtype=sig_dtype, mode="r", offset=Data_Start_Offset).reshape( + # -1, Num_Chan + # ) + signal_shape = get_memmap_shape(self.filename, sig_dtype, num_channels=Num_Chan, offset=Data_Start_Offset) + buffer_id = "0" + stream_id = "0" + self._buffer_descriptions[0][0][buffer_id] = { + "type" : "raw", + "file_path" : str(self.filename), + "dtype" : sig_dtype, + "order": "C", + "file_offset" : 0, + "shape" : signal_shape, + } + + # Reading Code Info zname2, pos, length = zones["ORDER"] @@ -128,16 +146,16 @@ def _parse_header(self): (sampling_rate,) = f.read_f("H") sampling_rate *= Rate_Min chan_id = str(c) - stream_id = "0" - buffer_id = "0" - signal_channels.append( - (chan_name, chan_id, sampling_rate, sig_dtype, units, gain, offset, stream_id, buffer_id) - ) - signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype) + + signal_channels.append((chan_name, chan_id, sampling_rate, sig_dtype, units, gain, offset, stream_id, buffer_id)) - signal_buffers = np.array([("Signals", "0")], dtype=_signal_buffer_dtype) - signal_streams = np.array([("Signals", "0", "0")], dtype=_signal_stream_dtype) + + signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype) + + self._stream_buffer_slice = {"0": slice(None)} + signal_buffers = np.array([("Signals", buffer_id)], dtype=_signal_buffer_dtype) + signal_streams = np.array([("Signals", stream_id, buffer_id)], dtype=_signal_stream_dtype) if np.unique(signal_channels["sampling_rate"]).size != 1: raise NeoReadWriteError("The sampling rates must be the same across signal channels") @@ -166,7 +184,7 @@ def _parse_header(self): keep = ( (rawevent["start"] >= rawevent["start"][0]) - & (rawevent["start"] < self._raw_signals.shape[0]) + & (rawevent["start"] < signal_shape[0]) & (rawevent["start"] != 0) ) rawevent = rawevent[keep] @@ -207,25 +225,15 @@ def _segment_t_start(self, block_index, seg_index): return 0.0 def _segment_t_stop(self, block_index, seg_index): - t_stop = self._raw_signals.shape[0] / self._sampling_rate + sig_size = self.get_signal_size(block_index, seg_index, 0) + t_stop = sig_size / self._sampling_rate return t_stop - def _get_signal_size(self, block_index, seg_index, stream_index): - if stream_index != 0: - raise ValueError("`stream_index` must be 0") - return self._raw_signals.shape[0] - def _get_signal_t_start(self, block_index, seg_index, stream_index): if stream_index != 0: raise ValueError("`stream_index` must be 0") return 0.0 - def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes): - if channel_indexes is None: - channel_indexes = slice(channel_indexes) - raw_signals = self._raw_signals[slice(i_start, i_stop), channel_indexes] - return raw_signals - def _spike_count(self, block_index, seg_index, unit_index): return 0 @@ -266,3 +274,6 @@ def _rescale_event_timestamp(self, event_timestamps, dtype, event_channel_index) def _rescale_epoch_duration(self, raw_duration, dtype, event_channel_index): durations = raw_duration.astype(dtype) / self._sampling_rate return durations + + def _get_analogsignal_buffer_description(self, block_index, seg_index, buffer_id): + return self._buffer_descriptions[block_index][seg_index][buffer_id] diff --git a/neo/rawio/neuronexusrawio.py b/neo/rawio/neuronexusrawio.py index 7d7f6e970..1858b2409 100644 --- a/neo/rawio/neuronexusrawio.py +++ b/neo/rawio/neuronexusrawio.py @@ -43,7 +43,7 @@ import numpy as np from .baserawio import ( - BaseRawIO, + BaseRawWithBufferApiIO, _signal_channel_dtype, _signal_stream_dtype, _signal_buffer_dtype, @@ -53,7 +53,7 @@ from neo.core import NeoReadWriteError -class NeuroNexusRawIO(BaseRawIO): +class NeuroNexusRawIO(BaseRawWithBufferApiIO): extensions = ["xdat", "json"] rawmode = "one-file" @@ -94,7 +94,7 @@ def __init__(self, filename: str | Path = ""): """ - BaseRawIO.__init__(self) + BaseRawWithBufferApiIO.__init__(self) if not Path(filename).is_file(): raise FileNotFoundError(f"The metadata file {filename} was not found") @@ -134,14 +134,18 @@ def _parse_header(self): binary_file = self.binary_file timestamp_file = self.timestamp_file - # Make the two memory maps - self._raw_data = np.memmap( - binary_file, - dtype=BINARY_DTYPE, - mode="r", - shape=(self._n_samples, self._n_channels), - offset=0, # headerless binary file - ) + # the will cretae a memory map with teh generic mechanism + buffer_id = "0" + self._buffer_descriptions = {0 :{0 :{}}} + self._buffer_descriptions[0][0][buffer_id] = { + "type" : "raw", + "file_path" : str(binary_file), + "dtype" : BINARY_DTYPE, + "order": "C", + "file_offset" : 0, + "shape" : (self._n_samples, self._n_channels), + } + # Make the memory map for timestamp self._timestamps = np.memmap( timestamp_file, dtype=np.int64, # this is from the allego sample reader timestamps are np.int64 @@ -205,10 +209,12 @@ def _parse_header(self): signal_streams["id"] = [str(stream_id) for stream_id in stream_ids] # One unique buffer signal_streams["buffer_id"] = buffer_id - + self._stream_buffer_slice = {} for stream_index, stream_id in enumerate(stream_ids): name = stream_id_to_stream_name.get(int(stream_id), "") signal_streams["name"][stream_index] = name + chan_inds = np.flatnonzero(signal_channels["stream_id"] == stream_id) + self._stream_buffer_slice[stream_id] = chan_inds # No events event_channels = [] @@ -245,26 +251,6 @@ def _parse_header(self): for d in (bl_annotations, seg_annotations): d["rec_datetime"] = rec_datetime - def _get_signal_size(self, block_index, seg_index, stream_index): - - # All streams have the same size so just return the raw_data (num_samples, num_chans) - return self._raw_data.shape[0] - - def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes): - - if i_start is None: - i_start = 0 - if i_stop is None: - i_stop = self._get_signal_size(block_index, seg_index, stream_index) - - raw_data = self._raw_data[i_start:i_stop, :] - - if channel_indexes is None: - channel_indexes = slice(None) - - raw_data = raw_data[:, channel_indexes] - return raw_data - def _segment_t_stop(self, block_index, seg_index): t_stop = self.metadata["status"]["t_range"][1] @@ -303,6 +289,9 @@ def read_metadata(self, fname_metadata): return metadata + def _get_analogsignal_buffer_description(self, block_index, seg_index, buffer_id): + return self._buffer_descriptions[block_index][seg_index][buffer_id] + # this is pretty useless right now, but I think after a # refactor with sub streams we could adapt this for the sub-streams diff --git a/neo/rawio/neuroscoperawio.py b/neo/rawio/neuroscoperawio.py index 78f12b4ae..0090687df 100644 --- a/neo/rawio/neuroscoperawio.py +++ b/neo/rawio/neuroscoperawio.py @@ -22,7 +22,7 @@ from xml.etree import ElementTree from .baserawio import ( - BaseRawIO, + BaseRawWithBufferApiIO, _signal_channel_dtype, _signal_stream_dtype, _signal_buffer_dtype, @@ -30,8 +30,10 @@ _event_channel_dtype, ) +from .utils import get_memmap_shape -class NeuroScopeRawIO(BaseRawIO): + +class NeuroScopeRawIO(BaseRawWithBufferApiIO): extensions = ["xml", "dat", "lfp", "eeg"] rawmode = "one-file" @@ -65,7 +67,7 @@ def __init__(self, filename, binary_file=None): filename provided with a supported data extension (.dat, .lfp, .eeg): - It assumes that the XML file has the same name and a .xml extension. """ - BaseRawIO.__init__(self) + BaseRawWithBufferApiIO.__init__(self) self.filename = filename self.binary_file = binary_file @@ -106,11 +108,24 @@ def _parse_header(self): raise (NotImplementedError) # Extract signal from the data file - self._raw_signals = np.memmap(self.data_file_path, dtype=sig_dtype, mode="r", offset=0).reshape(-1, nb_channel) + # one unique stream and buffer + shape = get_memmap_shape(self.data_file_path, sig_dtype, num_channels=nb_channel, offset=0) + buffer_id = "0" + stream_id = "0" + self._buffer_descriptions = {0: {0:{}}} + self._buffer_descriptions[0][0][buffer_id] = { + "type" : "raw", + "file_path" : str(self.data_file_path), + "dtype" : sig_dtype, + "order": "C", + "file_offset" : 0, + "shape" : shape, + } + self._stream_buffer_slice = {stream_id : None} # one unique stream and buffer - signal_buffers = np.array([("Signals", "0")], dtype=_signal_buffer_dtype) - signal_streams = np.array([("Signals", "0", "0")], dtype=_signal_stream_dtype) + signal_buffers = np.array([("Signals", buffer_id)], dtype=_signal_buffer_dtype) + signal_streams = np.array([("Signals", stream_id, buffer_id)], dtype=_signal_stream_dtype) # signals sig_channels = [] @@ -150,25 +165,15 @@ def _segment_t_start(self, block_index, seg_index): return 0.0 def _segment_t_stop(self, block_index, seg_index): - t_stop = self._raw_signals.shape[0] / self._sampling_rate + sig_size = self.get_signal_size(block_index, seg_index, 0) + t_stop = sig_size / self._sampling_rate return t_stop - def _get_signal_size(self, block_index, seg_index, stream_index): - if stream_index != 0: - raise ValueError("`stream_index` must be 0") - return self._raw_signals.shape[0] - def _get_signal_t_start(self, block_index, seg_index, stream_index): if stream_index != 0: raise ValueError("`stream_index` must be 0") return 0.0 - def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes): - if channel_indexes is None: - channel_indexes = slice(None) - raw_signals = self._raw_signals[slice(i_start, i_stop), channel_indexes] - return raw_signals - def _resolve_xml_and_data_paths(self): """ Resolves XML and data paths from the provided filename and binary_file attributes. @@ -212,3 +217,6 @@ def _resolve_xml_and_data_paths(self): self.xml_file_path = xml_file_path self.data_file_path = data_file_path + + def _get_analogsignal_buffer_description(self, block_index, seg_index, buffer_id): + return self._buffer_descriptions[block_index][seg_index][buffer_id] diff --git a/neo/rawio/openephysbinaryrawio.py b/neo/rawio/openephysbinaryrawio.py index b03d129ee..565771ae0 100644 --- a/neo/rawio/openephysbinaryrawio.py +++ b/neo/rawio/openephysbinaryrawio.py @@ -16,7 +16,7 @@ import numpy as np from .baserawio import ( - BaseRawIO, + BaseRawWithBufferApiIO, _signal_channel_dtype, _signal_stream_dtype, _signal_buffer_dtype, @@ -24,8 +24,10 @@ _event_channel_dtype, ) +from .utils import get_memmap_shape -class OpenEphysBinaryRawIO(BaseRawIO): + +class OpenEphysBinaryRawIO(BaseRawWithBufferApiIO): """ Handle several Blocks and several Segments. @@ -62,7 +64,7 @@ class OpenEphysBinaryRawIO(BaseRawIO): rawmode = "one-dir" def __init__(self, dirname="", load_sync_channel=False, experiment_names=None): - BaseRawIO.__init__(self) + BaseRawWithBufferApiIO.__init__(self) self.dirname = dirname if experiment_names is not None: if isinstance(experiment_names, str): @@ -127,7 +129,8 @@ def _parse_header(self): for chan_info in info["channels"]: chan_id = chan_info["channel_name"] if "SYNC" in chan_id and not self.load_sync_channel: - continue + # the channel is removed from stream but not the buffer + stream_id = "" if chan_info["units"] == "": # in some cases for some OE version the unit is "", but the gain is to "uV" units = "uV" @@ -160,13 +163,30 @@ def _parse_header(self): signal_buffers = np.array(signal_buffers, dtype=_signal_buffer_dtype) # create memmap for signals + self._buffer_descriptions = {} + self._stream_buffer_slice = {} for block_index in range(nb_block): + self._buffer_descriptions[block_index] = {} for seg_index in range(nb_segment_per_block[block_index]): + self._buffer_descriptions[block_index][seg_index] = {} for stream_index, info in self._sig_streams[block_index][seg_index].items(): num_channels = len(info["channels"]) - memmap_sigs = np.memmap(info["raw_filename"], info["dtype"], order="C", mode="r").reshape( - -1, num_channels - ) + # memmap_sigs = np.memmap(info["raw_filename"], info["dtype"], order="C", mode="r").reshape( + # -1, num_channels + # ) + stream_id = str(stream_index) + buffer_id = str(stream_index) + shape = get_memmap_shape(info["raw_filename"], info["dtype"], num_channels=num_channels, + offset=0) + self._buffer_descriptions[block_index][seg_index][buffer_id] = { + "type" : "raw", + "file_path" : str(info["raw_filename"]), + "dtype" : info["dtype"], + "order": "C", + "file_offset" : 0, + "shape" : shape, + } + has_sync_trace = self._sig_streams[block_index][seg_index][stream_index]["has_sync_trace"] # check sync channel validity (only for AP and LF) @@ -174,7 +194,13 @@ def _parse_header(self): raise ValueError( "SYNC channel is not present in the recording. " "Set load_sync_channel to False" ) - info["memmap"] = memmap_sigs + + if has_sync_trace and not self.load_sync_channel: + self._stream_buffer_slice[stream_id] = slice(None, -1) + else: + self._stream_buffer_slice[stream_id] = None + + # info["memmap"] = memmap_sigs # events zone # channel map: one channel one stream @@ -275,7 +301,11 @@ def _parse_header(self): # loop over signals for stream_index, info in self._sig_streams[block_index][seg_index].items(): t_start = info["t_start"] - dur = info["memmap"].shape[0] / float(info["sample_rate"]) + stream_id = str(stream_index) + buffer_id = str(stream_index) + sig_size = self._buffer_descriptions[block_index][seg_index][buffer_id]["shape"][0] + # dur = info["memmap"].shape[0] / float(info["sample_rate"]) + dur = sig_size / float(info["sample_rate"]) t_stop = t_start + dur if global_t_start is None or global_t_start > t_start: global_t_start = t_start @@ -373,25 +403,25 @@ def _channels_to_group_id(self, channel_indexes): group_id = group_ids[0] return group_id - def _get_signal_size(self, block_index, seg_index, stream_index): - sigs = self._sig_streams[block_index][seg_index][stream_index]["memmap"] - return sigs.shape[0] + # def _get_signal_size(self, block_index, seg_index, stream_index): + # sigs = self._sig_streams[block_index][seg_index][stream_index]["memmap"] + # return sigs.shape[0] def _get_signal_t_start(self, block_index, seg_index, stream_index): t_start = self._sig_streams[block_index][seg_index][stream_index]["t_start"] return t_start - def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes): - sigs = self._sig_streams[block_index][seg_index][stream_index]["memmap"] - has_sync_trace = self._sig_streams[block_index][seg_index][stream_index]["has_sync_trace"] + # def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes): + # sigs = self._sig_streams[block_index][seg_index][stream_index]["memmap"] + # has_sync_trace = self._sig_streams[block_index][seg_index][stream_index]["has_sync_trace"] - if not self.load_sync_channel and has_sync_trace: - sigs = sigs[:, :-1] + # if not self.load_sync_channel and has_sync_trace: + # sigs = sigs[:, :-1] - sigs = sigs[i_start:i_stop, :] - if channel_indexes is not None: - sigs = sigs[:, channel_indexes] - return sigs + # sigs = sigs[i_start:i_stop, :] + # if channel_indexes is not None: + # sigs = sigs[:, channel_indexes] + # return sigs def _spike_count(self, block_index, seg_index, unit_index): pass @@ -450,6 +480,10 @@ def _rescale_epoch_duration(self, raw_duration, dtype, event_channel_index): durations = raw_duration.astype(dtype) return durations + def _get_analogsignal_buffer_description(self, block_index, seg_index, buffer_id): + return self._buffer_descriptions[block_index][seg_index][buffer_id] + + _possible_event_stream_names = ( "timestamps", diff --git a/neo/rawio/plexonrawio.py b/neo/rawio/plexonrawio.py index 62b0a1a88..1318ab3be 100644 --- a/neo/rawio/plexonrawio.py +++ b/neo/rawio/plexonrawio.py @@ -309,6 +309,7 @@ def _parse_header(self): # In that case we use the channel prefix both as stream id and name buffer_id = "" stream_name = stream_id_to_stream_name.get(stream_id, stream_id) + buffer_id = "" signal_streams.append((stream_name, stream_id, buffer_id)) signal_streams = np.array(signal_streams, dtype=_signal_stream_dtype) diff --git a/neo/rawio/rawbinarysignalrawio.py b/neo/rawio/rawbinarysignalrawio.py index 1797772e9..feef2f98b 100644 --- a/neo/rawio/rawbinarysignalrawio.py +++ b/neo/rawio/rawbinarysignalrawio.py @@ -22,16 +22,17 @@ class RawBinarySignalIO import os from .baserawio import ( - BaseRawIO, + BaseRawWithBufferApiIO, _signal_channel_dtype, _signal_stream_dtype, _signal_buffer_dtype, _spike_channel_dtype, _event_channel_dtype, ) +from .utils import get_memmap_shape -class RawBinarySignalRawIO(BaseRawIO): +class RawBinarySignalRawIO(BaseRawWithBufferApiIO): """ Class for reading raw binary files with user specified values Parameters @@ -65,7 +66,7 @@ def __init__( signal_offset=0.0, bytesoffset=0, ): - BaseRawIO.__init__(self) + BaseRawWithBufferApiIO.__init__(self) self.filename = filename self.dtype = dtype self.sampling_rate = sampling_rate @@ -80,21 +81,32 @@ def _source_name(self): def _parse_header(self): if os.path.exists(self.filename): - self._raw_signals = np.memmap(self.filename, dtype=self.dtype, mode="r", offset=self.bytesoffset).reshape( - -1, self.nb_channel - ) + # on unique buffer and stream + buffer_id = "0" + stream_id = "0" + shape = get_memmap_shape(self.filename, self.dtype, num_channels=self.nb_channel, offset=self.bytesoffset) + self._buffer_descriptions = {0:{0:{}}} + self._buffer_descriptions[0][0][buffer_id] = { + "type" : "raw", + "file_path" : str(self.filename), + "dtype" : "uint16", + "order": "C", + "file_offset" : self.bytesoffset, + "shape" : shape, + } + self._stream_buffer_slice = {stream_id : None} + + else: # The the neo.io.RawBinarySignalIO is used for write_segment - self._raw_signals = None + self._buffer_descriptions = None signal_channels = [] - if self._raw_signals is not None: + if self._buffer_descriptions is not None: for c in range(self.nb_channel): name = f"ch{c}" chan_id = f"{c}" units = "" - stream_id = "0" - buffer_id = "0" signal_channels.append( ( name, @@ -144,21 +156,14 @@ def _segment_t_start(self, block_index, seg_index): return 0.0 def _segment_t_stop(self, block_index, seg_index): - t_stop = self._raw_signals.shape[0] / self.sampling_rate + sig_size = self.get_signal_size(block_index, seg_index, 0) + t_stop = sig_size / self.sampling_rate return t_stop - def _get_signal_size(self, block_index, seg_index, stream_index): - if stream_index != 0: - raise ValueError("stream_index must be 0") - return self._raw_signals.shape[0] - def _get_signal_t_start(self, block_index, seg_index, stream_index): if stream_index != 0: raise ValueError("stream_index must be 0") return 0.0 - def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes): - if channel_indexes is None: - channel_indexes = slice(None) - raw_signals = self._raw_signals[slice(i_start, i_stop), channel_indexes] - return raw_signals + def _get_analogsignal_buffer_description(self, block_index, seg_index, buffer_id): + return self._buffer_descriptions[block_index][seg_index][buffer_id] diff --git a/neo/rawio/rawmcsrawio.py b/neo/rawio/rawmcsrawio.py index 585e5fdd0..cba8baae6 100644 --- a/neo/rawio/rawmcsrawio.py +++ b/neo/rawio/rawmcsrawio.py @@ -17,16 +17,16 @@ import numpy as np from .baserawio import ( - BaseRawIO, + BaseRawWithBufferApiIO, _signal_channel_dtype, _signal_stream_dtype, _signal_buffer_dtype, _spike_channel_dtype, _event_channel_dtype, ) +from .utils import get_memmap_shape - -class RawMCSRawIO(BaseRawIO): +class RawMCSRawIO(BaseRawWithBufferApiIO): """ Class for reading an mcs file converted by the MC_DataToo binary converter @@ -41,7 +41,7 @@ class RawMCSRawIO(BaseRawIO): rawmode = "one-file" def __init__(self, filename=""): - BaseRawIO.__init__(self) + BaseRawWithBufferApiIO.__init__(self) self.filename = filename def _source_name(self): @@ -54,19 +54,31 @@ def _parse_header(self): self.sampling_rate = info["sampling_rate"] self.nb_channel = len(info["channel_names"]) - # one unique stream and buffer - signal_streams = np.array([("Signals", "0", "0")], dtype=_signal_stream_dtype) - signal_buffers = np.array([("Signals", "0")], dtype=_signal_buffer_dtype) - - self._raw_signals = np.memmap(self.filename, dtype=self.dtype, mode="r", offset=info["header_size"]).reshape( - -1, self.nb_channel - ) + # one unique stream and buffer with all channels + stream_id = "0" + buffer_id = "0" + signal_streams = np.array([("Signals", stream_id, buffer_id)], dtype=_signal_stream_dtype) + signal_buffers = np.array([("Signals", buffer_id)], dtype=_signal_buffer_dtype) + + # self._raw_signals = np.memmap(self.filename, dtype=self.dtype, mode="r", offset=info["header_size"]).reshape( + # -1, self.nb_channel + # ) + file_offset = info["header_size"] + shape = get_memmap_shape(self.filename, self.dtype, num_channels=self.nb_channel, offset=file_offset) + self._buffer_descriptions = {0:{0:{}}} + self._buffer_descriptions[0][0][buffer_id] = { + "type" : "raw", + "file_path" : str(self.filename), + "dtype" : "uint16", + "order": "C", + "file_offset" : file_offset, + "shape" : shape, + } + self._stream_buffer_slice = {stream_id : None} sig_channels = [] for c in range(self.nb_channel): chan_id = str(c) - stream_id = "0" - buffer_id = "0" sig_channels.append( ( info["channel_names"][c], @@ -107,20 +119,25 @@ def _segment_t_start(self, block_index, seg_index): return 0.0 def _segment_t_stop(self, block_index, seg_index): - t_stop = self._raw_signals.shape[0] / self.sampling_rate + # t_stop = self._raw_signals.shape[0] / self.sampling_rate + sig_size = self.get_signal_size(block_index, seg_index, 0) + t_stop = sig_size / self.sampling_rate return t_stop - def _get_signal_size(self, block_index, seg_index, stream_index): - return self._raw_signals.shape[0] + # def _get_signal_size(self, block_index, seg_index, stream_index): + # return self._raw_signals.shape[0] def _get_signal_t_start(self, block_index, seg_index, stream_index): return 0.0 - def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes): - if channel_indexes is None: - channel_indexes = slice(None) - raw_signals = self._raw_signals[slice(i_start, i_stop), channel_indexes] - return raw_signals + # def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes): + # if channel_indexes is None: + # channel_indexes = slice(None) + # raw_signals = self._raw_signals[slice(i_start, i_stop), channel_indexes] + # return raw_signals + + def _get_analogsignal_buffer_description(self, block_index, seg_index, buffer_id): + return self._buffer_descriptions[block_index][seg_index][buffer_id] def parse_mcs_raw_header(filename): diff --git a/neo/rawio/spikegadgetsrawio.py b/neo/rawio/spikegadgetsrawio.py index 3204f2725..c9fa909e6 100644 --- a/neo/rawio/spikegadgetsrawio.py +++ b/neo/rawio/spikegadgetsrawio.py @@ -273,6 +273,7 @@ def _parse_header(self): signal_streams = np.array(signal_streams, dtype=_signal_stream_dtype) signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype) + # no buffer concept here data are too fragmented signal_buffers = np.array([], dtype=_signal_buffer_dtype) diff --git a/neo/rawio/spikeglxrawio.py b/neo/rawio/spikeglxrawio.py index 9595dc04a..419528e93 100644 --- a/neo/rawio/spikeglxrawio.py +++ b/neo/rawio/spikeglxrawio.py @@ -57,16 +57,17 @@ import numpy as np from .baserawio import ( - BaseRawIO, + BaseRawWithBufferApiIO, _signal_channel_dtype, _signal_stream_dtype, _signal_buffer_dtype, _spike_channel_dtype, _event_channel_dtype, ) +from .utils import get_memmap_shape -class SpikeGLXRawIO(BaseRawIO): +class SpikeGLXRawIO(BaseRawWithBufferApiIO): """ Class for reading data from a SpikeGLX system @@ -108,7 +109,7 @@ class SpikeGLXRawIO(BaseRawIO): rawmode = "one-dir" def __init__(self, dirname="", load_sync_channel=False, load_channel_location=False): - BaseRawIO.__init__(self) + BaseRawWithBufferApiIO.__init__(self) self.dirname = dirname self.load_sync_channel = load_sync_channel self.load_channel_location = load_channel_location @@ -124,21 +125,42 @@ def _parse_header(self): stream_names = sorted(list(srates.keys()), key=lambda e: srates[e])[::-1] nb_segment = np.unique([info["seg_index"] for info in self.signals_info_list]).size - self._memmaps = {} + + # self._memmaps = {} self.signals_info_dict = {} + # one unique block + self._buffer_descriptions = {0 :{}} + self._stream_buffer_slice = {} for info in self.signals_info_list: - # key is (seg_index, stream_name) - key = (info["seg_index"], info["stream_name"]) + seg_index, stream_name = info["seg_index"], info["stream_name"] + key = (seg_index, stream_name) if key in self.signals_info_dict: raise KeyError(f"key {key} is already in the signals_info_dict") self.signals_info_dict[key] = info # create memmap - data = np.memmap(info["bin_file"], dtype="int16", mode="r", offset=0, order="C") + # data = np.memmap(info["bin_file"], dtype="int16", mode="r", offset=0, order="C") # this should be (info['sample_length'], info['num_chan']) # be some file are shorten - data = data.reshape(-1, info["num_chan"]) - self._memmaps[key] = data + # data = data.reshape(-1, info["num_chan"]) + # self._memmaps[key] = data + + buffer_id = stream_name + block_index = 0 + + if seg_index not in self._buffer_descriptions[0]: + self._buffer_descriptions[block_index][seg_index] = {} + + self._buffer_descriptions[block_index][seg_index][buffer_id] = { + "type" : "raw", + "file_path" : info["bin_file"], + "dtype" : "int16", + "order": "C", + "file_offset" : 0, + "shape" : get_memmap_shape(info["bin_file"], "int16", num_channels=info["num_chan"], offset=0), + } + + # create channel header signal_buffers = [] @@ -153,6 +175,7 @@ def _parse_header(self): signal_buffers.append((buffer_name, buffer_id)) stream_id = stream_name + stream_index = stream_names.index(info["stream_name"]) signal_streams.append((stream_name, stream_id, buffer_id)) @@ -173,10 +196,17 @@ def _parse_header(self): buffer_id, ) ) + + # all channel by dafult unless load_sync_channel=False + self._stream_buffer_slice[stream_id] = None # check sync channel validity if "nidq" not in stream_name: if not self.load_sync_channel and info["has_sync_trace"]: - signal_channels = signal_channels[:-1] + # the last channel is remove from the stream but not from the buffer + last_chan = signal_channels[-1] + last_chan = last_chan[:-2] + ("", buffer_id) + signal_channels = signal_channels[:-1] + [last_chan] + self._stream_buffer_slice[stream_id] = slice(0, -1) if self.load_sync_channel and not info["has_sync_trace"]: raise ValueError("SYNC channel is not present in the recording. " "Set load_sync_channel to False") @@ -261,42 +291,42 @@ def _segment_t_start(self, block_index, seg_index): def _segment_t_stop(self, block_index, seg_index): return self._t_stops[seg_index] - def _get_signal_size(self, block_index, seg_index, stream_index): - stream_id = self.header["signal_streams"][stream_index]["id"] - memmap = self._memmaps[seg_index, stream_id] - return int(memmap.shape[0]) + # def _get_signal_size(self, block_index, seg_index, stream_index): + # stream_id = self.header["signal_streams"][stream_index]["id"] + # memmap = self._memmaps[seg_index, stream_id] + # return int(memmap.shape[0]) def _get_signal_t_start(self, block_index, seg_index, stream_index): return 0.0 - def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes): - stream_id = self.header["signal_streams"][stream_index]["id"] - memmap = self._memmaps[seg_index, stream_id] - stream_name = self.header["signal_streams"]["name"][stream_index] - - # take care of sync channel - info = self.signals_info_dict[0, stream_name] - if not self.load_sync_channel and info["has_sync_trace"]: - memmap = memmap[:, :-1] - - # since we cut the memmap, we can simplify the channel selection - if channel_indexes is None: - channel_selection = slice(None) - elif isinstance(channel_indexes, slice): - channel_selection = channel_indexes - elif not isinstance(channel_indexes, slice): - if np.all(np.diff(channel_indexes) == 1): - # consecutive channel then slice this avoid a copy (because of ndarray.take(...) - # and so keep the underlying memmap - channel_selection = slice(channel_indexes[0], channel_indexes[0] + len(channel_indexes)) - else: - channel_selection = channel_indexes - else: - raise ValueError("get_analogsignal_chunk : channel_indexes" "must be slice or list or array of int") - - raw_signals = memmap[slice(i_start, i_stop), channel_selection] - - return raw_signals + # def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes): + # stream_id = self.header["signal_streams"][stream_index]["id"] + # memmap = self._memmaps[seg_index, stream_id] + # stream_name = self.header["signal_streams"]["name"][stream_index] + + # # take care of sync channel + # info = self.signals_info_dict[0, stream_name] + # if not self.load_sync_channel and info["has_sync_trace"]: + # memmap = memmap[:, :-1] + + # # since we cut the memmap, we can simplify the channel selection + # if channel_indexes is None: + # channel_selection = slice(None) + # elif isinstance(channel_indexes, slice): + # channel_selection = channel_indexes + # elif not isinstance(channel_indexes, slice): + # if np.all(np.diff(channel_indexes) == 1): + # # consecutive channel then slice this avoid a copy (because of ndarray.take(...) + # # and so keep the underlying memmap + # channel_selection = slice(channel_indexes[0], channel_indexes[0] + len(channel_indexes)) + # else: + # channel_selection = channel_indexes + # else: + # raise ValueError("get_analogsignal_chunk : channel_indexes" "must be slice or list or array of int") + + # raw_signals = memmap[slice(i_start, i_stop), channel_selection] + + # return raw_signals def _event_count(self, event_channel_idx, block_index=None, seg_index=None): timestamps, _, _ = self._get_event_timestamps(block_index, seg_index, event_channel_idx, None, None) @@ -336,6 +366,11 @@ def _rescale_event_timestamp(self, event_timestamps, dtype, event_channel_index) def _rescale_epoch_duration(self, raw_duration, dtype, event_channel_index): return None + def _get_analogsignal_buffer_description(self, block_index, seg_index, buffer_id): + return self._buffer_descriptions[block_index][seg_index][buffer_id] + + + def scan_files(dirname): """ diff --git a/neo/rawio/utils.py b/neo/rawio/utils.py new file mode 100644 index 000000000..0365da625 --- /dev/null +++ b/neo/rawio/utils.py @@ -0,0 +1,56 @@ +import mmap +import numpy as np + +def get_memmap_shape(filename, dtype, num_channels=None, offset=0): + dtype = np.dtype(dtype) + with open(filename, mode='rb') as f: + f.seek(0, 2) + flen = f.tell() + bytes = flen - offset + if bytes % dtype.itemsize != 0: + raise ValueError("Size of available data is not a multiple of the data-type size.") + size = bytes // dtype.itemsize + if num_channels is None: + shape = (size,) + else: + shape = (size // num_channels, num_channels) + return shape + + +def get_memmap_chunk_from_opened_file(fid, num_channels, start, stop, dtype, file_offset=0): + """ + Utility function to get a chunk as a memmap array directly from an opened file. + Using this instead memmap can avoid memmory consumption when multiprocessing. + + Similar mechanism is used in spikeinterface. + + """ + bytes_per_sample = num_channels * dtype.itemsize + + # Calculate byte offsets + start_byte = file_offset + start * bytes_per_sample + end_byte = file_offset + stop * bytes_per_sample + + # Calculate the length of the data chunk to load into memory + length = end_byte - start_byte + + # The mmap offset must be a multiple of mmap.ALLOCATIONGRANULARITY + memmap_offset, start_offset = divmod(start_byte, mmap.ALLOCATIONGRANULARITY) + memmap_offset *= mmap.ALLOCATIONGRANULARITY + + # Adjust the length so it includes the extra data from rounding down + # the memmap offset to a multiple of ALLOCATIONGRANULARITY + length += start_offset + + memmap_obj = mmap.mmap(fid.fileno(), length=length, access=mmap.ACCESS_READ, offset=memmap_offset) + + arr = np.ndarray( + shape=((stop - start), num_channels), + dtype=dtype, + buffer=memmap_obj, + offset=start_offset, + ) + + return arr + + diff --git a/neo/rawio/winedrrawio.py b/neo/rawio/winedrrawio.py index 58f78bb7c..78d0db67e 100644 --- a/neo/rawio/winedrrawio.py +++ b/neo/rawio/winedrrawio.py @@ -12,7 +12,7 @@ import numpy as np from .baserawio import ( - BaseRawIO, + BaseRawWithBufferApiIO, _signal_channel_dtype, _signal_stream_dtype, _signal_buffer_dtype, @@ -22,7 +22,7 @@ ) -class WinEdrRawIO(BaseRawIO): +class WinEdrRawIO(BaseRawWithBufferApiIO): extensions = ["EDR", "edr"] rawmode = "one-file" @@ -36,7 +36,7 @@ def __init__(self, filename=""): The *.edr file to be loaded """ - BaseRawIO.__init__(self) + BaseRawWithBufferApiIO.__init__(self) self.filename = filename def _source_name(self): @@ -62,16 +62,19 @@ def _parse_header(self): val = float(val) header[key] = val - self._raw_signals = np.memmap( - self.filename, - dtype="int16", - mode="r", - shape=( - header["NP"] // header["NC"], - header["NC"], - ), - offset=header["NBH"], - ) + # one unique block with one unique segment + # one unique buffer splited in several streams + buffer_id = "0" + self._buffer_descriptions = {0 :{0 :{}}} + self._buffer_descriptions[0][0][buffer_id] = { + "type" : "raw", + "file_path" : str(self.filename), + "dtype" : "int16", + "order": "C", + "file_offset" : int(header["NBH"]), + "shape" : (header["NP"] // header["NC"], header["NC"]), + } + DT = header["DT"] if "TU" in header: @@ -103,12 +106,14 @@ def _parse_header(self): characteristics = signal_channels[_common_sig_characteristics] unique_characteristics = np.unique(characteristics) signal_streams = [] + self._stream_buffer_slice = {} for i in range(unique_characteristics.size): mask = unique_characteristics[i] == characteristics signal_channels["stream_id"][mask] = str(i) # unique buffer for all streams buffer_id = "0" signal_streams.append((f"stream {i}", str(i), buffer_id)) + self._stream_buffer_slice[stream_id] = np.flatnonzero(mask) signal_streams = np.array(signal_streams, dtype=_signal_stream_dtype) # all stream are in the same unique buffer : memmap @@ -139,20 +144,12 @@ def _segment_t_start(self, block_index, seg_index): return 0.0 def _segment_t_stop(self, block_index, seg_index): - t_stop = self._raw_signals.shape[0] / self._sampling_rate + sig_size = self.get_signal_size(block_index, seg_index, 0) + t_stop = sig_size / self._sampling_rate return t_stop - def _get_signal_size(self, block_index, seg_index, stream_index): - return self._raw_signals.shape[0] - def _get_signal_t_start(self, block_index, seg_index, stream_index): return 0.0 - def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes): - stream_id = self.header["signal_streams"][stream_index]["id"] - (global_channel_indexes,) = np.nonzero(self.header["signal_channels"]["stream_id"] == stream_id) - if channel_indexes is None: - channel_indexes = slice(None) - global_channel_indexes = global_channel_indexes[channel_indexes] - raw_signals = self._raw_signals[slice(i_start, i_stop), global_channel_indexes] - return raw_signals + def _get_analogsignal_buffer_description(self, block_index, seg_index, buffer_id): + return self._buffer_descriptions[block_index][seg_index][buffer_id] diff --git a/neo/rawio/winwcprawio.py b/neo/rawio/winwcprawio.py index a760c8bf0..36049ae8b 100644 --- a/neo/rawio/winwcprawio.py +++ b/neo/rawio/winwcprawio.py @@ -13,7 +13,7 @@ import numpy as np from .baserawio import ( - BaseRawIO, + BaseRawWithBufferApiIO, _signal_channel_dtype, _signal_stream_dtype, _signal_buffer_dtype, @@ -23,7 +23,7 @@ ) -class WinWcpRawIO(BaseRawIO): +class WinWcpRawIO(BaseRawWithBufferApiIO): """ Class for reading WinWCP data @@ -38,7 +38,7 @@ class WinWcpRawIO(BaseRawIO): rawmode = "one-file" def __init__(self, filename=""): - BaseRawIO.__init__(self) + BaseRawWithBufferApiIO.__init__(self) self.filename = filename def _source_name(self): @@ -47,9 +47,9 @@ def _source_name(self): def _parse_header(self): SECTORSIZE = 512 - # only one memmap for all segment to avoid - # "error: [Errno 24] Too many open files" - self._memmap = np.memmap(self.filename, dtype="uint8", mode="r") + # one unique block with several segments + # one unique buffer splited in several streams + self._buffer_descriptions = {0 :{}} with open(self.filename, "rb") as fid: @@ -59,7 +59,6 @@ def _parse_header(self): for line in headertext.split("\r\n"): if "=" not in line: continue - # print '#' , line , '#' key, val = line.split("=") if key in [ "NC", @@ -81,7 +80,6 @@ def _parse_header(self): header[key] = val nb_segment = header["NR"] - self._raw_signals = {} all_sampling_interval = [] # loop for record number for seg_index in range(header["NR"]): @@ -96,9 +94,16 @@ def _parse_header(self): NP = NP // header["NC"] NC = header["NC"] ind0 = offset + header["NBA"] * SECTORSIZE - ind1 = ind0 + NP * NC * 2 - sigs = self._memmap[ind0:ind1].view("int16").reshape(NP, NC) - self._raw_signals[seg_index] = sigs + buffer_id = "0" + self._buffer_descriptions[0][seg_index] = {} + self._buffer_descriptions[0][seg_index][buffer_id] = { + "type" : "raw", + "file_path" : str(self.filename), + "dtype" : "int16", + "order": "C", + "file_offset" : ind0, + "shape" : (NP, NC), + } all_sampling_interval.append(analysisHeader["SamplingInterval"]) @@ -128,12 +133,15 @@ def _parse_header(self): characteristics = signal_channels[_common_sig_characteristics] unique_characteristics = np.unique(characteristics) signal_streams = [] + self._stream_buffer_slice = {} for i in range(unique_characteristics.size): mask = unique_characteristics[i] == characteristics signal_channels["stream_id"][mask] = str(i) # unique buffer for all streams buffer_id = "0" - signal_streams.append((f"stream {i}", str(i), buffer_id)) + stream_id = str(i) + signal_streams.append((f"stream {i}", stream_id, buffer_id)) + self._stream_buffer_slice[stream_id] = np.flatnonzero(mask) signal_streams = np.array(signal_streams, dtype=_signal_stream_dtype) # all stream are in the same unique buffer : memmap @@ -164,23 +172,15 @@ def _segment_t_start(self, block_index, seg_index): return 0.0 def _segment_t_stop(self, block_index, seg_index): - t_stop = self._raw_signals[seg_index].shape[0] / self._sampling_rate + sig_size = self.get_signal_size(block_index, seg_index, 0) + t_stop = sig_size / self._sampling_rate return t_stop - def _get_signal_size(self, block_index, seg_index, stream_index): - return self._raw_signals[seg_index].shape[0] - def _get_signal_t_start(self, block_index, seg_index, stream_index): return 0.0 - def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes): - stream_id = self.header["signal_streams"][stream_index]["id"] - (global_channel_indexes,) = np.nonzero(self.header["signal_channels"]["stream_id"] == stream_id) - if channel_indexes is None: - channel_indexes = slice(None) - inds = global_channel_indexes[channel_indexes] - raw_signals = self._raw_signals[seg_index][slice(i_start, i_stop), inds] - return raw_signals + def _get_analogsignal_buffer_description(self, block_index, seg_index, buffer_id): + return self._buffer_descriptions[block_index][seg_index][buffer_id] AnalysisDescription = [ diff --git a/neo/rawio/xarray_utils.py b/neo/rawio/xarray_utils.py new file mode 100644 index 000000000..8652e8614 --- /dev/null +++ b/neo/rawio/xarray_utils.py @@ -0,0 +1,185 @@ +""" +Experimental module to export a rawio reader that support the buffer_description API +to xarray dataset using zarr specification format v2. + +A block/segment/stream correspond to one xarray.DataSet + +A xarray.DataTree can also be expose to get all at block/segment/stream + +Note : + * only some IOs support this at the moment has_buffer_description_api()=True +""" +import json + +import numpy as np + +import base64 + + + + +def to_zarr_v2_reference(rawio_reader, block_index=0, seg_index=0, buffer_id=None): + """ + Transform the buffer_description_api into a dict ready for the xarray API 'reference://' + + + See https://fsspec.github.io/kerchunk/spec.html + See https://docs.xarray.dev/en/latest/user-guide/io.html#kerchunk + + See https://zarr-specs.readthedocs.io/en/latest/v2/v2.0.html + + + Usefull read also https://github.com/saalfeldlab/n5 + + """ + + # TODO later implement zarr v3 + + # rawio_reader. + signal_buffers = rawio_reader.header["signal_buffers"] + + buffer_index = list(signal_buffers["id"]).index(buffer_id) + + buffer_name = signal_buffers["name"][buffer_index] + + + rfs = dict() + rfs["version"] = 1 + rfs["refs"] = dict() + rfs["refs"][".zgroup"] = json.dumps(dict(zarr_format=2)) + zattrs = dict(name=buffer_name) + rfs["refs"][".zattrs"] = json.dumps(zattrs) + + + descr = rawio_reader.get_analogsignal_buffer_description(block_index=block_index, seg_index=seg_index, + buffer_id=buffer_id) + + if descr["type"] == "raw": + + + # channel : small enough can be internal with base64 + mask = rawio_reader.header["signal_channels"]["buffer_id"] == buffer_id + channels = rawio_reader.header["signal_channels"][mask] + channel_ids = channels["id"] + base64_encoded = base64.b64encode(channel_ids.tobytes()) + rfs["refs"]["channel/0"] = "base64:" + base64_encoded.decode() + zarray = dict( + chunks=channel_ids.shape, + compressor=None, + dtype=channel_ids.dtype.str, + fill_value=None, + filters=None, + order="C", + shape=channel_ids.shape, + zarr_format=2, + ) + zattrs = dict( + _ARRAY_DIMENSIONS=['channel'], + ) + rfs["refs"]["channel/.zattrs"] =json.dumps(zattrs) + rfs["refs"]["channel/.zarray"] =json.dumps(zarray) + + # traces buffer + dtype = np.dtype(descr["dtype"]) + zarray = dict( + chunks=descr["shape"], + compressor=None, + dtype=dtype.str, + fill_value=None, + filters=None, + order=descr["order"], + shape=descr["shape"], + zarr_format=2, + ) + zattrs = dict( + _ARRAY_DIMENSIONS=['time', 'channel'], + name=buffer_name, + ) + units = np.unique(channels['units']) + if units.size == 1: + zattrs['units'] = units[0] + gain = np.unique(channels['gain']) + offset = np.unique(channels['offset']) + if gain.size == 1 and offset.size: + zattrs['scale_factor'] = gain[0] + zattrs['add_offset'] = offset[0] + zattrs['sampling_rate'] = float(channels['sampling_rate'][0]) + + # unique big chunk + # TODO later : optional split in several small chunks + array_size = np.prod(descr["shape"], dtype='int64') + chunk_size = array_size * dtype.itemsize + rfs["refs"]["traces/0.0"] = [str(descr["file_path"]), descr["file_offset"], chunk_size] + rfs["refs"]["traces/.zarray"] =json.dumps(zarray) + rfs["refs"]["traces/.zattrs"] =json.dumps(zattrs) + + elif descr["type"] == "hdf5": + raise NotImplementedError + else: + raise ValueError(f"buffer_description type not handled {descr['type']}") + + # TODO later channel array_annotations + + return rfs + + + +def to_xarray_dataset(rawio_reader, block_index=0, seg_index=0, buffer_id=None): + """ + Utils fonction that transorm an instance a rawio into a xarray.Dataset + with lazy access. + This works only for rawio class that return True with has_buffer_description_api() and hinerits from + BaseRawWithBufferApiIO. + + + Note : the original idea of the function is from Ben Dichter in this page + https://gist.github.com/bendichter/30a9afb34b2178098c99f3b01fe72e75 + """ + import xarray as xr + + rfs = to_zarr_v2_reference(rawio_reader, block_index=block_index, seg_index=seg_index, buffer_id=buffer_id) + + ds = xr.open_dataset( + "reference://", + mask_and_scale=True, + engine="zarr", + backend_kwargs={ + "storage_options": dict( + fo=rfs, + remote_protocol="file", + ), + "consolidated": False, + }, + ) + return ds + +def to_xarray_datatree(rawio_reader): + """ + Expose a neo.rawio reader class to a xarray DataTree to lazily read signals. + """ + try: + # not released in xarray 2024.7.0, this will be released soon + from xarray import DataTree + except: + # this need the experimental DataTree in pypi xarray-datatree + try: + from datatree import DataTree + except: + raise ImportError("use xarray dev branch or pip install xarray-datatree") + + signal_buffers = rawio_reader.header['signal_buffers'] + buffer_ids = signal_buffers["id"] + + tree = DataTree(name="root") + + num_block = rawio_reader.header['nb_block'] + for block_index in range(num_block): + block = DataTree(name=f'block{block_index}', parent=tree) + num_seg = rawio_reader.header['nb_segment'][block_index] + for seg_index in range(num_seg): + segment = DataTree(name=f'segment{seg_index}', parent=block) + for buffer_id in buffer_ids: + ds = to_xarray_dataset(rawio_reader, block_index=block_index, seg_index=seg_index, buffer_id=buffer_id) + DataTree(data=ds, name=ds.attrs['name'], parent=segment) + + return tree diff --git a/neo/test/rawiotest/common_rawio_test.py b/neo/test/rawiotest/common_rawio_test.py index 9ad9853fb..488cb9fbf 100644 --- a/neo/test/rawiotest/common_rawio_test.py +++ b/neo/test/rawiotest/common_rawio_test.py @@ -125,3 +125,7 @@ def test_read_all(self): logging.getLogger().setLevel(logging.INFO) compliance.benchmark_speed_read_signals(reader) logging.getLogger().setLevel(level) + + # buffer api + if reader.has_buffer_description_api(): + compliance.check_buffer_api(reader) diff --git a/neo/test/rawiotest/rawio_compliance.py b/neo/test/rawiotest/rawio_compliance.py index a4d7131f8..263c22997 100644 --- a/neo/test/rawiotest/rawio_compliance.py +++ b/neo/test/rawiotest/rawio_compliance.py @@ -93,7 +93,8 @@ def check_signal_stream_buffer_hierachy(reader): assert stream["buffer_id"] in h["signal_buffers"]["id"] for channel in h["signal_channels"]: - assert channel["stream_id"] in h["signal_streams"]["id"] + if channel["stream_id"] != "": + assert channel["stream_id"] in h["signal_streams"]["id"] if channel["buffer_id"] != "": assert channel["buffer_id"] in h["signal_buffers"]["id"] @@ -152,7 +153,10 @@ def iter_over_sig_chunks(reader, stream_index, channel_indexes, chunksize=1024): for seg_index in range(nb_seg): sig_size = reader.get_signal_size(block_index, seg_index, stream_index) - nb = sig_size // chunksize + 1 + nb = int(np.floor(sig_size / chunksize)) + if sig_size % chunksize > 0: + nb += 1 + for i in range(nb): i_start = i * chunksize i_stop = min((i + 1) * chunksize, sig_size) @@ -477,3 +481,33 @@ def read_events(reader): def has_annotations(reader): assert hasattr(reader, "raw_annotations"), "raw_annotation are not set" + + +def check_buffer_api(reader): + buffer_ids = reader.header["signal_buffers"]["id"] + + nb_block = reader.block_count() + nb_event_channel = reader.event_channels_count() + + + for block_index in range(nb_block): + nb_seg = reader.segment_count(block_index) + for seg_index in range(nb_seg): + for buffer_id in buffer_ids: + descr = reader.get_analogsignal_buffer_description( + block_index=block_index, seg_index=seg_index, buffer_id=buffer_id + ) + assert descr["type"] in ("raw", "hdf5"), "buffer_description type uncorrect" + + try: + import xarray + HAVE_XARRAY = True + except ImportError: + HAVE_XARRAY = False + + if HAVE_XARRAY: + # this test quickly the experimental xaray_utils.py if xarray is present on the system + # this is not the case for the CI + from neo.rawio.xarray_utils import to_xarray_dataset + ds = to_xarray_dataset(reader, block_index=block_index, seg_index=seg_index, buffer_id=buffer_id) + assert isinstance(ds, xarray.Dataset)