-
Notifications
You must be signed in to change notification settings - Fork 0
/
pytorch_ViT_DINO.py
173 lines (161 loc) · 7.22 KB
/
pytorch_ViT_DINO.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import torch
from skimage import io
from tqdm import tqdm
import argparse
import os
import csv
from vit_pytorch import Dino
import torch.optim as optim
from pathlib import Path
from torch.utils.data import DataLoader
from utils.data_utils import *
from utils.model_utils import *
def train(
train_loader: DataLoader,
learner: Dino,
optimizer: optim,
device: torch.device,
project_directory: Path,
model_save_name: str,
num_epochs: int
) -> None:
# if loss log file exists, remove to not include previous training runs
res_file_path = Path(project_directory).joinpath('results_and_models').joinpath(f'{model_save_name[:-8]}_dino_loss_values.csv')
if os.path.exists(res_file_path):
os.remove(res_file_path)
for epoch in range(num_epochs):
losses = []
for data in tqdm(train_loader):
data = data[0].to(device=device) # [0]: tensors [1]: labels
loss = learner(data)
optimizer.zero_grad() # clear gradient information
loss.backward() # calculate gradient
optimizer.step()
learner.update_moving_average()
losses.append(loss.item())
with torch.no_grad():
total_loss = sum(losses) / len(losses)
print('epoch', epoch, 'loss: ', total_loss)
with open(res_file_path, mode="a", newline="") as data:
csv_writer = csv.writer(data)
csv_writer.writerow((epoch, total_loss))
checkpoint = {'state_dict': learner.state_dict(), 'optimizer': optimizer.state_dict()}
save_checkpoint(state=checkpoint, filepath=project_directory.joinpath("models").joinpath(model_save_name))
def create_argparser() -> argparse.Namespace:
parser = argparse.ArgumentParser()
# directory where all relevant folders are located
parser.add_argument("-project_directory", type=Path)
# directory where the PCAM data are located
parser.add_argument("-data_root", type=Path)
# number of epochs to train for
parser.add_argument("-num_epochs", type=int, default=18)
# number of classes to predict between
parser.add_argument("-num_classes", type=int, default=2)
# proportion to weight parameter update by
parser.add_argument("-learning_rate", type=float, default=3e-4)
# number of epochs trained past when the loss decreases to a minimum
parser.add_argument("-patience", type=int, default=5)
# number of inputs before gradient is calculated
parser.add_argument("-batch_size", type=int, default=120)
# filename of the model, including .pth.tar
parser.add_argument("-model_save_name", type=str)
# channels first
parser.add_argument("-img_shape", default=(3, 96, 96), type=tuple, nargs="+")
# size of image patch, 8, 16 and 32 are good values
parser.add_argument("-patch_size", type=int, default=16)
# last dimension of output tensor after linear transformation
parser.add_argument("-dim", type=int, default=1024)
# number of transformer blocks
parser.add_argument("-depth", type=int, default=6)
# number of heads in multi-head attention layer
parser.add_argument("-heads", type=int, default=8)
# dimension of multilayer perceptron layer
parser.add_argument("-mlp_dim", type=int, default=2048)
# hidden layer name or index, from which to extract the embedding
parser.add_argument("-hidden_layer", type=str, default='to_latent')
# projector network hidden dimension
parser.add_argument("-projection_hidden_size", type=int, default=512)
# number of layers in projection network
parser.add_argument("-projection_layers", type=int, default=4)
# output logits dimensions (referenced as K in paper)
parser.add_argument("-num_classes_K", type=int, default=65336)
# student temperature
parser.add_argument("-student_temp", type=float, default=0.9)
# teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
parser.add_argument("-teacher_temp", type=float, default=0.04)
# upper bound for local crop - 0.4 was recommended in the paper
parser.add_argument("-local_upper_crop_scale", type=float, default=0.4)
# lower bound for global crop - 0.5 was recommended in the paper
parser.add_argument("-global_lower_crop_scale", type=float, default=0.5)
# moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok
parser.add_argument("-moving_average_decay", type=float, default=0.9)
# moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok
parser.add_argument("-center_moving_average_decay", type=float, default=0.9)
return parser.parse_args()
def main():
args = create_argparser()
device = define_device()
model = define_model(
img_shape=args.img_shape,
patch_size=args.patch_size,
num_classes=args.num_classes,
dim=args.dim,
depth=args.depth,
mlp_dim=args.mlp_dim,
heads=args.heads
)
learner = define_learner(
model=model,
img_shape=args. img_shape,
hidden_layer=args.hidden_layer,
projection_hidden_size=args.projection_hidden_size,
projection_layers=args.projection_layers,
num_classes_K=args.num_classes_K,
student_temp=args.student_temp,
teacher_temp=args.teacher_temp,
local_upper_crop_scale=args.local_upper_crop_scale,
global_lower_crop_scale=args.global_lower_crop_scale,
moving_average_decay=args.moving_average_decay,
center_moving_average_decay=args.center_moving_average_decay
)
optimizer = define_optimizer(learner=learner, learning_rate=args.learning_rate)
learner.to(device)
print_model_summary(model=learner)
train_dataset = create_dino_datasets(dataset_root=args.data_root)
train_dataloader = create_dino_dataloaders(
train_dataset=train_dataset,
batch_size=args.batch_size
)
train(
train_loader=train_dataloader,
learner=learner,
optimizer=optimizer,
device=device,
project_directory=args.project_directory,
model_save_name=args.model_save_name,
num_epochs=args.num_epochs
)
save_dino_results(
optimizer=optimizer,
batch_size=args.batch_size,
model_save_name=args.model_save_name,
patience=args.patience,
patch_size=args.patch_size,
dim=args.dim,
depth=args.depth,
heads=args.heads,
mlp_dim=args.mlp_dim,
hidden_layer=args.hidden_layer,
projection_hidden_size=args.projection_hidden_size,
projection_layers=args.projection_layers,
num_classes_K=args.num_classes_K,
student_temp=args.student_temp,
teacher_temp=args.teacher_temp,
local_upper_crop_scale=args.local_upper_crop_scale,
global_lower_crop_scale=args.global_lower_crop_scale,
moving_average_decay=args.moving_average_decay,
center_moving_average_decay=args.center_moving_average_decay,
project_directory=args.project_directory
)
if __name__ == '__main__':
main()