-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_video_classification.py
39 lines (28 loc) · 2.21 KB
/
train_video_classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import json
from data.classification.video.video_classification_dataset_factories import create_video_classification_datasets, \
create_video_classification_datamodule
from models.classification.classification_model_factories import create_classification_module
from models.classification.video.video_classification_model_configuration import VideoClassificationModelConfiguration
from models.classification.video.video_classification_model_description import VideoClassificationModelDescription
from models.classification.video.video_classification_model_factories import register_all_video_classification_models
from training.classification.classification_training_factories import run_classification_training
from utilities.configuration.configuration_reader import ConfigurationReader
from utilities.json import write_json_file
from utilities.random import initialize_random_numbers
def run_experiment(configuration_reader: ConfigurationReader) -> None:
print(f"Starting experiment: {configuration_reader.experiment}, environment: {configuration_reader.environment}")
datasets, weights, mapping = create_video_classification_datasets(configuration_reader)
datamodule = create_video_classification_datamodule(configuration_reader, datasets, mapping)
model = create_classification_module(configuration_reader, VideoClassificationModelConfiguration,
mapping.num_classes, weights)
metrics, confusion_matrix = run_classification_training(configuration_reader, VideoClassificationModelConfiguration,
datamodule,
model, mapping)
description = VideoClassificationModelDescription.create(configuration_reader, mapping, metrics, confusion_matrix)
write_json_file(configuration_reader.get_artifact_path() / "model.json", description)
if __name__ == "__main__":
initialize_random_numbers()
register_all_video_classification_models()
configuration_readers = ConfigurationReader.create_from_cmdline(task="video_classification_training")
for experiment_configuration_reader in configuration_readers:
run_experiment(experiment_configuration_reader)