Skip to content

Commit

Permalink
Fix HDF5EventSource.__len__ if allowed_tels is not None
Browse files Browse the repository at this point in the history
  • Loading branch information
maxnoe committed Nov 23, 2023
1 parent a2b6072 commit 19517b4
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 2 deletions.
8 changes: 7 additions & 1 deletion ctapipe/instrument/subarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,13 @@ def tel_ids_to_indices(self, tel_ids):
np.array:
array of corresponding tel indices
"""
tel_ids = np.array(tel_ids, dtype=int, copy=False).ravel()
if isinstance(tel_ids, (int, np.integer)):
pass
elif not isinstance(tel_ids, np.ndarray):
tel_ids = np.fromiter(tel_ids, dtype=int, count=len(tel_ids))
else:
tel_ids = np.array(tel_ids, dtype=int, copy=False).ravel()

return self.tel_index_array[tel_ids]

def tel_ids_to_mask(self, tel_ids):
Expand Down
8 changes: 8 additions & 0 deletions ctapipe/instrument/tests/test_subarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,16 @@ def test_tel_indexing(example_subarray):
assert sub.tel_index_array[tel_id] == sub.tel_indices[tel_id]

assert sub.tel_ids_to_indices(1) == 0
assert sub.tel_ids_to_indices(np.uint16(2)) == 1
assert np.all(sub.tel_ids_to_indices([1, 2, 3]) == np.array([0, 1, 2]))

# test it also works with sets
assert np.all(sub.tel_ids_to_indices({1, 2, 3}) == np.array([0, 1, 2]))

# and dict-keys
tel_data = {1: "foo", 2: "bar", 3: "baz"}
assert np.all(sub.tel_ids_to_indices(tel_data.keys()) == np.array([0, 1, 2]))


def test_tel_ids_to_mask(prod5_lst, reference_location):
subarray = SubarrayDescription(
Expand Down
13 changes: 12 additions & 1 deletion ctapipe/io/hdf5eventsource.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,8 +333,19 @@ def simulation_config(self) -> Dict[int, SimulationConfigContainer]:
"""
return self._simulation_configs

@lazyproperty
def _n_events_with_allowed_tels(self):
if self.allowed_tels is None:
return len(self.file_.root.dl1.event.subarray.trigger)

triggered_tels = self.file_.root.dl1.event.subarray.trigger.col(
"tels_with_trigger"
)
tel_idx = self.subarray.tel_ids_to_indices(self.allowed_tels)
return np.count_nonzero(np.any(triggered_tels[:, tel_idx], axis=1))

def __len__(self):
n_events = len(self.file_.root.dl1.event.subarray.trigger)
n_events = self._n_events_with_allowed_tels
if self.max_events is not None:
return min(n_events, self.max_events)
return n_events
Expand Down

0 comments on commit 19517b4

Please sign in to comment.