Skip to content

Commit

Permalink
Merge pull request #1311 from apdavison/container-add--check-types
Browse files Browse the repository at this point in the history
Check the types of objects added to a container with the new `add()` method
  • Loading branch information
JuliaSprenger authored Jul 27, 2023
2 parents 182a7e9 + 74b5aa4 commit 972901f
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 5 deletions.
3 changes: 3 additions & 0 deletions doc/source/authors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ and may not be the current affiliation of a contributor.
* Elodie Legouée [21]
* Heberto Mayorquin [24]
* Thomas Perret [25]
* Kyle Johnsen [26, 27]

1. Centre de Recherche en Neuroscience de Lyon, CNRS UMR5292 - INSERM U1028 - Universite Claude Bernard Lyon 1
2. Unité de Neuroscience, Information et Complexité, CNRS UPR 3293, Gif-sur-Yvette, France
Expand All @@ -88,6 +89,8 @@ and may not be the current affiliation of a contributor.
23. Bio Engineering Laboratory, DBSSE, ETH, Basel, Switzerland
24. CatalystNeuro
25. Institut des Sciences Cognitives Marc Jeannerod, CNRS UMR5229, Lyon, France
26. Georgia Institute of Technology
27. Emory University

If we've somehow missed you off the list we're very sorry - please let us know.

Expand Down
18 changes: 16 additions & 2 deletions neo/core/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,22 @@ def _get_container(self, cls):
def add(self, *objects):
"""Add a new Neo object to the Container"""
for obj in objects:
container = self._get_container(obj.__class__)
container.append(obj)
if (
obj.__class__.__name__ in self._child_objects
or (
hasattr(obj, "proxy_for")
and obj.proxy_for.__name__ in self._child_objects
)
):
container = self._get_container(obj.__class__)
container.append(obj)
else:
raise TypeError(
f"Cannot add object of type {obj.__class__.__name__} "
f"to a {self.__class__.__name__}, can only add objects of the "
f"following types: {self._child_objects}"
)



def filter(self, targdict=None, data=True, container=False, recursive=True,
Expand Down
6 changes: 4 additions & 2 deletions neo/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def __init__(self, objects=None, name=None, description=None, file_origin=None,
self.allowed_types = None
else:
self.allowed_types = tuple(allowed_types)
for type_ in self.allowed_types:
if type_.__name__ not in self._child_objects:
raise TypeError(f"Groups can not contain objects of type {type_.__name__}")

if objects:
self.add(*objects)
Expand Down Expand Up @@ -140,8 +143,7 @@ def add(self, *objects):
if self.allowed_types and not isinstance(obj, self.allowed_types):
raise TypeError("This Group can only contain {}, but not {}"
"".format(self.allowed_types, type(obj)))
container = self._get_container(obj.__class__)
container.append(obj)
super().add(*objects)

def walk(self):
"""
Expand Down
6 changes: 5 additions & 1 deletion neo/test/coretest/test_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from neo.core import SpikeTrain, AnalogSignal, Event
from neo.test.tools import (assert_neo_object_is_compliant,
assert_same_sub_schema)
from neo.test.generate_datasets import random_block, simple_block
from neo.test.generate_datasets import random_block, simple_block, random_signal


N_EXAMPLES = 5
Expand Down Expand Up @@ -493,6 +493,10 @@ def test_add(self):
new_blk.add(*blk.segments)
assert len(new_blk.segments) == n_segs_start + len(blk.segments)

def test_add_invalid_type_raises_Exception(self):
new_blk = Block()
self.assertRaises(TypeError, new_blk.add, random_signal())


if __name__ == "__main__":
unittest.main()
7 changes: 7 additions & 0 deletions neo/test/coretest/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from neo.core.segment import Segment
from neo.core.view import ChannelView
from neo.core.group import Group
from neo.core.block import Block


class TestGroup(unittest.TestCase):
Expand Down Expand Up @@ -91,3 +92,9 @@ def test_walk(self):
target.extend([children[1], children[2], *grandchildren[2]])
self.assertEqual(flattened,
target)

def test_add_invalid_type_raises_Exception(self):
group = Group()
self.assertRaises(TypeError, group.add, Block())

self.assertRaises(TypeError, Group, allowed_types=[Block])
4 changes: 4 additions & 0 deletions neo/test/coretest/test_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,10 @@ def test_add(self):
seg.add(proxy_epoch)
assert len(seg.epochs) == 1

def test_add_invalid_type_raises_Exception(self):
seg = Segment()
self.assertRaises(TypeError, seg.add, Block())


if __name__ == "__main__":
unittest.main()

0 comments on commit 972901f

Please sign in to comment.