diff --git a/doc/source/authors.rst b/doc/source/authors.rst index cf4cc5dd4..95821bc11 100644 --- a/doc/source/authors.rst +++ b/doc/source/authors.rst @@ -60,6 +60,7 @@ and may not be the current affiliation of a contributor. * Etienne Combrisson [6] * Ben Dichter [24] * Elodie Legouée [21] +* Oliver Kloss [13] * Heberto Mayorquin [24] * Thomas Perret [25] * Kyle Johnsen [26, 27] diff --git a/neo/core/__init__.py b/neo/core/__init__.py index 17c1517af..300fecf49 100644 --- a/neo/core/__init__.py +++ b/neo/core/__init__.py @@ -10,7 +10,9 @@ Classes: .. autoclass:: Block +.. automethod:: Block.filter .. autoclass:: Segment +.. automethod:: Segment.filter .. autoclass:: Group .. autoclass:: AnalogSignal @@ -35,6 +37,9 @@ from neo.core.analogsignal import AnalogSignal from neo.core.irregularlysampledsignal import IrregularlySampledSignal +# Import FilterClasses +from neo.core import filters + from neo.core.event import Event from neo.core.epoch import Epoch diff --git a/neo/core/container.py b/neo/core/container.py index 89521c28e..6a36ef33b 100644 --- a/neo/core/container.py +++ b/neo/core/container.py @@ -6,6 +6,8 @@ """ from copy import deepcopy + +from neo.core import filters from neo.core.baseneo import BaseNeo, _reference_name, _container_name from neo.core.objectlist import ObjectList from neo.core.spiketrain import SpikeTrain @@ -21,24 +23,25 @@ def unique_objs(objs): return [obj for obj in objs if id(obj) not in seen and not seen.add(id(obj))] - def filterdata(data, targdict=None, objects=None, **kwargs): """ Return a list of the objects in data matching *any* of the search terms in either their attributes or annotations. Search terms can be provided as keyword arguments or a dictionary, either as a positional - argument after data or to the argument targdict. targdict can also - be a list of dictionaries, in which case the filters are applied - sequentially. If targdict and kwargs are both supplied, the - targdict filters are applied first, followed by the kwarg filters. - A targdict of None or {} and objects = None corresponds to no filters - applied, therefore returning all child objects. - Default targdict and objects is None. + argument after data or to the argument targdict. + A key of a provided dictionary is the name of the requested annotation + and the value is a FilterCondition object. + E.g.: Equal(x), LessThan(x), InRange(x, y). + targdict can also + be a list of dictionaries, in which case the filters are applied + sequentially. - objects (optional) should be the name of a Neo object type, - a neo object class, or a list of one or both of these. If specified, - only these objects will be returned. + A list of dictionaries is handled as follows: [ { or } and { or } ] + If targdict and kwargs are both supplied, the + targdict filters are applied first, followed by the kwarg filters. + A targdict of None or {} corresponds to no filters applied, therefore + returning all child objects. Default targdict is None. """ # if objects are specified, get the classes @@ -72,20 +75,26 @@ def filterdata(data, targdict=None, objects=None, **kwargs): else: # do the actual filtering results = [] - for key, value in sorted(targdict.items()): - for obj in data: - if (hasattr(obj, key) and getattr(obj, key) == value and - all([obj is not res for res in results])): + for obj in data: + for key, value in sorted(targdict.items()): + if hasattr(obj, key) and getattr(obj, key) == value: results.append(obj) - elif (key in obj.annotations and obj.annotations[key] == value and - all([obj is not res for res in results])): + break + if isinstance(value, filters.FilterCondition) and key in obj.annotations: + if value.evaluate(obj.annotations[key]): + results.append(obj) + break + if key in obj.annotations and obj.annotations[key] == value: results.append(obj) + break + + # remove duplicates from results + results = list({ id(res): res for res in results }.values()) # keep only objects of the correct classes if objects: results = [result for result in results if - result.__class__ in objects or - result.__class__.__name__ in objects] + result.__class__ in objects or result.__class__.__name__ in objects] if results and all(isinstance(obj, SpikeTrain) for obj in results): return SpikeTrainList(results) @@ -366,9 +375,17 @@ def filter(self, targdict=None, data=True, container=False, recursive=True, Return a list of child objects matching *any* of the search terms in either their attributes or annotations. Search terms can be provided as keyword arguments or a dictionary, either as a positional - argument after data or to the argument targdict. targdict can also + argument after data or to the argument targdict. + A key of a provided dictionary is the name of the requested annotation + and the value is a FilterCondition object. + E.g.: equal(x), less_than(x), InRange(x, y). + + targdict can also be a list of dictionaries, in which case the filters are applied - sequentially. If targdict and kwargs are both supplied, the + sequentially. + + A list of dictionaries is handled as follows: [ { or } and { or } ] + If targdict and kwargs are both supplied, the targdict filters are applied first, followed by the kwarg filters. A targdict of None or {} corresponds to no filters applied, therefore returning all child objects. Default targdict is None. @@ -391,6 +408,8 @@ def filter(self, targdict=None, data=True, container=False, recursive=True, >>> obj.filter(name="Vm") >>> obj.filter(objects=neo.SpikeTrain) >>> obj.filter(targdict={'myannotation':3}) + >>> obj.filter(name=neo.core.filters.Equal(5)) + >>> obj.filter({'name': neo.core.filters.LessThan(5)}) """ if isinstance(targdict, str): diff --git a/neo/core/filters.py b/neo/core/filters.py new file mode 100644 index 000000000..500cc146f --- /dev/null +++ b/neo/core/filters.py @@ -0,0 +1,173 @@ +""" +This module implements :class:`FilterCondition`, which enables use of different filter conditions +for neo.core.container.filter. + +Classes: + - :class:`FilterCondition`: Abstract base class for defining filter conditions. + - :class:`Equals`: Filter condition to check if a value is equal to the control value. + - :class:`IsNot`: Filter condition to check if a value is not equal to the control value. + - :class:`LessThanOrEquals`: Filter condition to check if a value is less than or equal to the + control value. + - :class:`GreaterThanOrEquals`: Filter condition to check if a value is greater than or equal to + the control value. + - :class:`LessThan`: Filter condition to check if a value is less than the control value. + - :class:`GreaterThan`: Filter condition to check if a value is greater than the control value. + - :class:`IsIn`: Filter condition to check if a value is in a list or equal to the control + value. + - :class:`InRange`: Filter condition to check if a value is in a specified range. + +The provided classes allow users to select filter conditions and use them with +:func:`neo.core.container.filter()` to perform specific filtering operations on data. +""" +from abc import ABC, abstractmethod +from numbers import Number +from typing import Union, Any + + +class FilterCondition(ABC): + """ + FilterCondition object is given as parameter to container.filter(): + + Usage: + segment.filter(my_annotation=) or + segment=filter({'my_annotation': }) + """ + @abstractmethod + def __init__(self, control: Any) -> None: + """ + Initialize new FilterCondition object. + + Parameters: + control: Any - The control value to be used for filtering. + + This is an abstract base class and should not be instantiated directly. + """ + + @abstractmethod + def evaluate(self, compare: Any) -> bool: + """ + Evaluate the filter condition for given value. + + Parameters: + compare: Any - The value to be compared with the control value. + + Returns: + bool: True if the condition is satisfied, False otherwise. + + This method should be implemented in subclasses. + """ + + +class Equals(FilterCondition): + """ + Filter condition to check if target value is equal to the control value. + """ + def __init__(self, control: Any) -> None: + self.control = control + + def evaluate(self, compare: Any) -> bool: + return compare == self.control + + +class IsNot(FilterCondition): + """ + Filter condition to check if target value is not equal to the control value. + """ + def __init__(self, control: Any) -> None: + self.control = control + + def evaluate(self, compare: Any) -> bool: + return compare != self.control + + +class LessThanOrEquals(FilterCondition): + """ + Filter condition to check if target value is less than or equal to the control value. + """ + def __init__(self, control: Number) -> None: + self.control = control + + def evaluate(self, compare: Number) -> bool: + return compare <= self.control + + +class GreaterThanOrEquals(FilterCondition): + """ + Filter condition to check if target value is greater than or equal to the control value. + """ + def __init__(self, control: Number) -> None: + self.control = control + + def evaluate(self, compare: Number) -> bool: + return compare >= self.control + + +class LessThan(FilterCondition): + """ + Filter condition to check if target value is less than the control value. + """ + def __init__(self, control: Number) -> None: + self.control = control + + def evaluate(self, compare: Number) -> bool: + return compare < self.control + + +class GreaterThan(FilterCondition): + """ + Filter condition to check if target value is greater than the control value. + """ + def __init__(self, control: Number) -> None: + self.control = control + + def evaluate(self, compare: Number) -> bool: + return compare > self.control + + +class IsIn(FilterCondition): + """ + Filter condition to check if target is in control. + """ + def __init__(self, control: Union[list, tuple, set, int]) -> None: + self.control = control + + def evaluate(self, compare: Any) -> bool: + if isinstance(self.control, (list, tuple, set)): + return compare in self.control + if isinstance(self.control, int): + return compare == self.control + + raise SyntaxError('parameter not of type list, tuple, set or int') + + +class InRange(FilterCondition): + """ + Filter condition to check if a value is in a specified range. + + Usage: + InRange(upper_bound, upper_bound, left_closed=False, right_closed=False) + + Parameters: + lower_bound: int - The lower bound of the range. + upper_bound: int - The upper bound of the range. + left_closed: bool - If True, the range includes the lower bound (lower_bound <= compare). + right_closed: bool - If True, the range includes the upper bound (compare <= upper_bound). + """ + def __init__(self, lower_bound: Number, upper_bound: Number, + left_closed: bool=False, right_closed: bool=False) -> None: + if not isinstance(lower_bound, Number) or not isinstance(upper_bound, Number): + raise ValueError("parameter is not a number") + + self.lower_bound = lower_bound + self.upper_bound = upper_bound + self.left_closed = left_closed + self.right_closed = right_closed + + def evaluate(self, compare: Number) -> bool: + if not self.left_closed and not self.right_closed: + return self.lower_bound <= compare <= self.upper_bound + if not self.left_closed and self.right_closed: + return self.lower_bound <= compare < self.upper_bound + if self.left_closed and not self.right_closed: + return self.lower_bound < compare <= self.upper_bound + return self.lower_bound < compare < self.upper_bound diff --git a/neo/test/coretest/test_container.py b/neo/test/coretest/test_container.py index b21f86d75..94e285698 100644 --- a/neo/test/coretest/test_container.py +++ b/neo/test/coretest/test_container.py @@ -4,8 +4,13 @@ import unittest +import quantities as pq + import numpy as np +import neo.core +from neo.core import filters + try: from IPython.lib.pretty import pretty except ImportError as err: @@ -38,6 +43,26 @@ class TestContainerNeo(unittest.TestCase): TestCase to make sure basic initialization and methods work ''' + @classmethod + def setUpClass(cls): + seg = neo.core.Segment() + st1 = neo.core.SpikeTrain([1, 2] * pq.ms, t_stop=10) + st1.annotate(test=5) + st2 = neo.core.SpikeTrain([3, 4] * pq.ms, t_stop=10) + st2.annotate(test=6) + st2.annotate(name='st_num_1') + st2.annotate(filt=6) + st3 = neo.core.SpikeTrain([5, 6] * pq.ms, t_stop=10) + st3.annotate(list=[1, 2]) + st3.annotate(dict={'key': 5}) + seg.spiketrains.append(st1) + seg.spiketrains.append(st2) + seg.spiketrains.append(st3) + + cls.seg = seg + cls.st1 = st1 + cls.st2 = st2 + def test_init(self): '''test to make sure initialization works properly''' container = Container(name='a container', description='this is a test') @@ -95,6 +120,115 @@ def test_filter(self): container = Container() self.assertRaises(TypeError, container.filter, "foo") + def test_filter_results(self): + ''' + Tests FilterConditions correct results + ''' + self.assertEqual(self.st1.annotations, + self.seg.filter(test=filters.Equals(5))[0].annotations) + self.assertEqual(self.st1.annotations, + self.seg.filter(test=filters.LessThan(6))[0].annotations) + self.assertEqual(self.st1.annotations, + self.seg.filter(test=filters.GreaterThan(4))[0].annotations) + self.assertEqual(self.st1.annotations, + self.seg.filter(test=filters.IsNot(1))[0].annotations) + self.assertEqual(self.st1.annotations, + self.seg.filter(test=filters.IsIn([1, 2, 5]))[0].annotations) + self.assertEqual(self.st1.annotations, + self.seg.filter(test=filters.InRange(1, 5))[0].annotations) + self.assertEqual(self.st1.annotations, + self.seg.filter(test=filters.GreaterThanOrEquals(5))[0].annotations) + self.assertEqual(self.st1.annotations, + self.seg.filter(test=filters.LessThanOrEquals(5))[0].annotations) + + def test_filter_equal(self): + ''' + Tests FilterCondition object "Equals". + ''' + self.assertEqual(1, len(self.seg.filter(test=filters.Equals(5)))) + self.assertEqual(0, len(self.seg.filter(test=filters.Equals(1)))) + self.assertEqual(1, len(self.seg.filter({'list': filters.Equals([1, 2])}))) + self.assertEqual(1, len(self.seg.filter(dict=filters.Equals({'key': 5})))) + + def test_filter_is_not(self): + ''' + Tests FilterCondition object "IsNot". + ''' + self.assertEqual(2, len(self.seg.filter(test=filters.IsNot(1)))) + self.assertEqual(1, len(self.seg.filter(test=filters.IsNot(5)))) + self.assertEqual(0, len(self.seg.filter([{"test": filters.IsNot(5)}, + {"test": filters.IsNot(6)}]))) + + def test_filter_less_than(self): + ''' + Tests FilterCondition object "LessThan". + ''' + self.assertEqual(0, len(self.seg.filter(test=filters.LessThan(5)))) + self.assertEqual(1, len(self.seg.filter(test=filters.LessThan(6)))) + self.assertEqual(2, len(self.seg.filter(test=filters.LessThan(7)))) + + def test_filter_less_than_equal(self): + ''' + Tests FilterCondition object "LessThanEquals". + ''' + self.assertEqual(0, len(self.seg.filter(test=filters.LessThanOrEquals(4)))) + self.assertEqual(1, len(self.seg.filter(test=filters.LessThanOrEquals(5)))) + self.assertEqual(2, len(self.seg.filter(test=filters.LessThanOrEquals(6)))) + + def test_filter_greater_than(self): + ''' + Tests FilterCondition object "GreaterThan". + ''' + self.assertEqual(0, len(self.seg.filter(test=filters.GreaterThan(6)))) + self.assertEqual(1, len(self.seg.filter(test=filters.GreaterThan(5)))) + self.assertEqual(2, len(self.seg.filter(test=filters.GreaterThan(4)))) + + def test_filter_greater_than_equal(self): + ''' + Tests FilterCondition object "GreaterThanEquals". + ''' + self.assertEqual(0, len(self.seg.filter(test=filters.GreaterThanOrEquals(7)))) + self.assertEqual(1, len(self.seg.filter(test=filters.GreaterThanOrEquals(6)))) + self.assertEqual(2, len(self.seg.filter(test=filters.GreaterThanOrEquals(5)))) + + def test_filter_is_in(self): + ''' + Tests FilterCondition object "IsIn". + ''' + # list + self.assertEqual(0, len(self.seg.filter(test=filters.IsIn([4, 7, 10])))) + self.assertEqual(1, len(self.seg.filter(test=filters.IsIn([5, 7, 10])))) + self.assertEqual(2, len(self.seg.filter(test=filters.IsIn([5, 6, 10])))) + # tuple + self.assertEqual(0, len(self.seg.filter(test=filters.IsIn((4, 7, 10))))) + self.assertEqual(1, len(self.seg.filter(test=filters.IsIn((5, 7, 10))))) + self.assertEqual(2, len(self.seg.filter(test=filters.IsIn((5, 6, 10))))) + # set + self.assertEqual(0, len(self.seg.filter(test=filters.IsIn({4, 7, 10})))) + self.assertEqual(1, len(self.seg.filter(test=filters.IsIn({5, 7, 10})))) + self.assertEqual(2, len(self.seg.filter(test=filters.IsIn({5, 6, 10})))) + + def test_filter_in_range(self): + ''' + Tests FilterCondition object "InRange". + ''' + with self.assertRaises(ValueError): + self.seg.filter(test=filters.InRange("wrong", 6, False, False)) + self.assertEqual(2, len(self.seg.filter(test=filters.InRange(5, 6, False, False)))) + self.assertEqual(1, len(self.seg.filter(test=filters.InRange(5, 6, True, False)))) + self.assertEqual(1, len(self.seg.filter(test=filters.InRange(5, 6, False, True)))) + self.assertEqual(0, len(self.seg.filter(test=filters.InRange(5, 6, True, True)))) + + def test_filter_filter_consistency(self): + ''' + Tests old functionality with new filter method. + ''' + self.assertEqual(2, len(self.seg.filter({'test': filters.Equals(5), + 'filt': filters.Equals(6)}))) + self.assertEqual(0, len(self.seg.filter([{'test': filters.Equals(5)}, + {'filt': filters.Equals(6)}]))) + self.assertEqual(1, len(self.seg.filter(name='st_num_1'))) + class Test_Container_merge(unittest.TestCase): '''