-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
39 lines (31 loc) · 1.71 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os, comet_ml, cv2
from pytorch_lightning.loggers import CometLogger
from train import fit
import pytorch_lightning as pl
from funcs import save_model, load_model, model_path_er
from predict import predict
#from comet_ml import Experiment
from private.comet_key import key
def main (num_epochs=50, folder="Dataset", arch='inc'):
data_dir = os.path.join(os.getcwd(), folder)
pl.seed_everything(42, workers=True)
comet_logger=CometLogger(api_key = key(), experiment_name =arch, project_name = "deep-net")
#________________________________train_____________________________________
model=fit.train_fn(
model_arch=arch,
num_epochs=num_epochs,
logger=comet_logger,
data_dir = data_dir)
save_model(model=model, arch=arch, unique_id='best_val')
comet_logger.experiment.log_model(' '+arch+" Model", model_path_er(arch=arch, unique_id='best_val')[2])
#______________________________test________________________________________
image_folder=os.path.join(os.getcwd(), 'test-imgs')
for image_path in os.listdir(image_folder):
comet_logger.experiment.log_image(cv2.imread(os.path.join(image_folder,image_path)))
comet_logger.experiment.log_image(predict.main(input_img_path =os.path.join(image_folder,image_path),base_model=model))
#comet_logger.experiment.log_image(predict.main(input_img_path =os.path.join(image_folder,image_path),base_model=load_model(arch=arch, unique_id='full_train')))
if __name__ == "__main__":
#main(num_epochs=50, folder="Dataset", arch='alex')
#main(num_epochs=40, folder="Dataset", arch='vgg')
#main(num_epochs=50, folder="Dataset", arch='inc')
main(num_epochs=50, folder="Dataset", arch='res')