Skip to content

Commit

Permalink
Merge pull request #2217 from LukasBeiske/config_ml_prefix
Browse files Browse the repository at this point in the history
Add option to configure ml prefixes
  • Loading branch information
maxnoe authored Jan 16, 2023
2 parents 3cebb42 + ea8c66a commit 8a7a367
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 12 deletions.
34 changes: 23 additions & 11 deletions ctapipe/reco/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ class SKLearnReconstructor(Reconstructor):
#: Name of the target column in training table
target: str = None

prefix = Unicode(
default_value=None,
allow_none=True,
help="Prefix for the output of this model. If None, ``model_cls`` is used.",
).tag(config=True)
features = List(Unicode(), help="Features to use for this model").tag(config=True)
model_config = Dict({}, help="kwargs for the sklearn model").tag(config=True)
model_cls = Enum(SUPPORTED_MODELS.keys(), default_value=None, allow_none=True).tag(
Expand Down Expand Up @@ -117,6 +122,10 @@ def __init__(self, subarray=None, models=None, **kwargs):
"__init__() missing 1 required positional argument: 'subarray'"
)

if self.prefix is None:
# Default prefix is model_cls
self.prefix = self.model_cls

super().__init__(subarray, **kwargs)
self.subarray = subarray
self.feature_generator = FeatureGenerator(parent=self)
Expand All @@ -129,7 +138,7 @@ def __init__(self, subarray=None, models=None, **kwargs):
self.unit = None
self.stereo_combiner = StereoCombiner.from_name(
self.stereo_combiner_cls,
prefix=self.model_cls,
prefix=self.prefix,
property=self.property,
parent=self,
)
Expand All @@ -145,6 +154,9 @@ def __init__(self, subarray=None, models=None, **kwargs):
self.__dict__.update(loaded.__dict__)
self.subarray = subarray

if self.prefix is None:
self.prefix = self.model_cls

@abstractmethod
def __call__(self, event: ArrayEventContainer) -> None:
"""Event-wise prediction for the EventSource-Loop.
Expand Down Expand Up @@ -413,8 +425,8 @@ def __call__(self, event: ArrayEventContainer) -> None:
is_valid=False,
)

container.prefix = f"{self.model_cls}_tel"
event.dl2.tel[tel_id].energy[self.model_cls] = container
container.prefix = f"{self.prefix}_tel"
event.dl2.tel[tel_id].energy[self.prefix] = container

self.stereo_combiner(event)

Expand All @@ -431,14 +443,14 @@ def predict_table(self, key, table: Table) -> Table:

result = Table(
{
f"{self.model_cls}_tel_energy": energy,
f"{self.model_cls}_tel_is_valid": is_valid,
f"{self.prefix}_tel_energy": energy,
f"{self.prefix}_tel_is_valid": is_valid,
}
)
add_defaults_and_meta(
result,
ReconstructedEnergyContainer,
prefix=self.model_cls,
prefix=self.prefix,
stereo=False,
)
return result
Expand Down Expand Up @@ -479,8 +491,8 @@ def __call__(self, event: ArrayEventContainer) -> None:
prediction=np.nan, is_valid=False
)

container.prefix = f"{self.model_cls}_tel"
event.dl2.tel[tel_id].classification[self.model_cls] = container
container.prefix = f"{self.prefix}_tel"
event.dl2.tel[tel_id].classification[self.prefix] = container

self.stereo_combiner(event)

Expand All @@ -497,12 +509,12 @@ def predict_table(self, key, table: Table) -> Table:

result = Table(
{
f"{self.model_cls}_tel_prediction": score,
f"{self.model_cls}_tel_is_valid": is_valid,
f"{self.prefix}_tel_prediction": score,
f"{self.prefix}_tel_is_valid": is_valid,
}
)
add_defaults_and_meta(
result, ParticleClassificationContainer, prefix=self.model_cls, stereo=False
result, ParticleClassificationContainer, prefix=self.prefix, stereo=False
)
return result

Expand Down
2 changes: 1 addition & 1 deletion ctapipe/tools/apply_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def start(self):
self.loader.h5file = self.h5file

def _apply(self, reconstructor):
prefix = reconstructor.model_cls
prefix = reconstructor.prefix
property = reconstructor.property

desc = f"Applying {reconstructor.__class__.__name__}"
Expand Down
1 change: 1 addition & 0 deletions docs/changes/2217.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add option to configure prefixes of ml models keeping the model class as default prefix.

0 comments on commit 8a7a367

Please sign in to comment.