-
Notifications
You must be signed in to change notification settings - Fork 2
/
train_test_helper.py
260 lines (184 loc) · 11 KB
/
train_test_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
import torch
import torch.nn.functional as F
import numpy as np
from torch.nn.utils import clip_grad_norm_
from pirl_loss import loss_pirl, get_img_pair_probs
def get_count_correct_preds(network_output, target):
score, predicted = torch.max(network_output, 1) # Returns max score and the index where max score was recorded
count_correct = (target == predicted).sum().float() # So that when accuracy is computed, it is not rounded to int
return count_correct
def get_count_correct_preds_pretext(img_pair_probs_arr, img_mem_rep_probs_arr):
"""
Get count of correct predictions for pre-text task
:param img_pair_probs_arr: Prob vector of batch of images I and I_t to belong to same data distribution.
:param img_mem_rep_probs_arr: Prob vector of batch of I and mem_bank_rep of I to belong to same data distribution
"""
avg_probs_arr = (1/2) * (img_pair_probs_arr + img_mem_rep_probs_arr)
count_correct = (avg_probs_arr >= 0.5).sum().float() # So that when accuracy is computed, it is not rounded to int
return count_correct.item()
class PIRLModelTrainTest():
def __init__(self, network, device, model_file_path, all_images_mem, train_image_indices,
val_image_indices, count_negatives, temp_parameter, beta, only_train=False, threshold=1e-4):
super(PIRLModelTrainTest, self).__init__()
self.network = network
self.device = device
self.model_file_path = model_file_path
self.threshold = threshold
self.train_loss = 1e9
self.val_loss = 1e9
self.all_images_mem = torch.tensor(all_images_mem, dtype=torch.float).to(device)
self.train_image_indices = train_image_indices.copy()
self.val_image_indices = val_image_indices.copy()
self.count_negatives = count_negatives
self.temp_parameter = temp_parameter
self.beta = beta
self.only_train = only_train
def train(self, optimizer, epoch, params_max_norm, train_data_loader, val_data_loader,
no_train_samples, no_val_samples):
self.network.train()
train_loss, correct, cnt_batches = 0, 0, 0
for batch_idx, (data_batch, batch_img_indices) in enumerate(train_data_loader):
# Separate input image I batch and transformed image I_t batch (jigsaw patches) from data_batch
i_batch, i_t_patches_batch = data_batch[0], data_batch[1]
# Set device for i_batch, i_t_patches_batch and batch_img_indices
i_batch, i_t_patches_batch = i_batch.to(self.device), i_t_patches_batch.to(self.device)
batch_img_indices = batch_img_indices.to(self.device)
# Forward pass through the network
optimizer.zero_grad()
vi_batch, vi_t_batch = self.network(i_batch, i_t_patches_batch)
# Prepare memory bank of negatives for current batch
np.random.shuffle(self.train_image_indices)
mn_indices_all = np.array(list(set(self.train_image_indices) - set(batch_img_indices)))
np.random.shuffle(mn_indices_all)
mn_indices = mn_indices_all[:self.count_negatives]
mn_arr = self.all_images_mem[mn_indices]
# Get memory bank representation for current batch images
mem_rep_of_batch_imgs = self.all_images_mem[batch_img_indices]
# Get prob for I, I_t to belong to same data distribution.
img_pair_probs_arr = get_img_pair_probs(vi_batch, vi_t_batch, mn_arr, self.temp_parameter)
# Get prob for I and mem_bank_rep of I to belong to same data distribution
img_mem_rep_probs_arr = get_img_pair_probs(vi_batch, mem_rep_of_batch_imgs, mn_arr, self.temp_parameter)
# Compute loss => back-prop gradients => Update weights
loss = loss_pirl(img_pair_probs_arr, img_mem_rep_probs_arr)
loss.backward()
clip_grad_norm_(self.network.parameters(), params_max_norm)
optimizer.step()
# Update running loss and no of pseudo correct predictions for epoch
correct += get_count_correct_preds_pretext(img_pair_probs_arr, img_mem_rep_probs_arr)
train_loss += loss.item()
cnt_batches += 1
# Update memory bank representation for images from current batch
all_images_mem_new = self.all_images_mem.clone().detach()
all_images_mem_new[batch_img_indices] = (self.beta * all_images_mem_new[batch_img_indices]) + \
((1 - self.beta) * vi_batch)
self.all_images_mem = all_images_mem_new.clone().detach()
del i_batch, i_t_patches_batch, vi_batch, vi_t_batch, mn_arr, mem_rep_of_batch_imgs
del img_mem_rep_probs_arr, img_pair_probs_arr
train_loss /= cnt_batches
if epoch % 10 == 0:
torch.save(self.network.state_dict(), self.model_file_path + '_epoch_{}'.format(epoch))
if self.only_train is False:
val_loss, val_acc = self.test(epoch, val_data_loader, no_val_samples)
if val_loss < self.val_loss - self.threshold:
self.val_loss = val_loss
torch.save(self.network.state_dict(), self.model_file_path)
else:
val_loss, val_acc = 0.0, 0.0
train_acc = correct / no_train_samples
print('\nAfter epoch {} - Train set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
epoch, train_loss, correct, no_train_samples, 100. * correct / no_train_samples))
return train_loss, train_acc, val_loss, val_acc
def test(self, epoch, test_data_loader, no_test_samples):
self.network.eval()
test_loss, correct, cnt_batches = 0, 0, 0
for batch_idx, (data_batch, batch_img_indices) in enumerate(test_data_loader):
# Separate input image I batch and transformed image I_t batch (jigsaw patches) from data_batch
i_batch, i_t_patches_batch = data_batch[0], data_batch[1]
# Set device for i_batch, i_t_patches_batch and batch_img_indices
i_batch, i_t_patches_batch = i_batch.to(self.device), i_t_patches_batch.to(self.device)
batch_img_indices = batch_img_indices.to(self.device)
# Forward pass through the network
vi_batch, vi_t_batch = self.network(i_batch, i_t_patches_batch)
# Prepare memory bank of negatives for current batch
np.random.shuffle(self.val_image_indices)
mn_indices_all = np.array(list(set(self.val_image_indices) - set(batch_img_indices)))
np.random.shuffle(mn_indices_all)
mn_indices = mn_indices_all[:self.count_negatives]
mn_arr = self.all_images_mem[mn_indices]
# Get memory bank representation for current batch images
mem_rep_of_batch_imgs = self.all_images_mem[batch_img_indices]
# Get prob for I, I_t to belong to same data distribution.
img_pair_probs_arr = get_img_pair_probs(vi_batch, vi_t_batch, mn_arr, self.temp_parameter)
# Get prob for I and mem_bank_rep of I to belong to same data distribution
img_mem_rep_probs_arr = get_img_pair_probs(vi_batch, mem_rep_of_batch_imgs, mn_arr, self.temp_parameter)
# Compute loss
loss = loss_pirl(img_pair_probs_arr, img_mem_rep_probs_arr)
# Update running loss and no of pseudo correct predictions for epoch
correct += get_count_correct_preds_pretext(img_pair_probs_arr, img_mem_rep_probs_arr)
test_loss += loss.item()
cnt_batches += 1
# Update memory bank representation for images from current batch
all_images_mem_new = self.all_images_mem.clone().detach()
all_images_mem_new[batch_img_indices] = (self.beta * all_images_mem_new[batch_img_indices]) + \
((1 - self.beta) * vi_batch)
self.all_images_mem = all_images_mem_new.clone().detach()
del i_batch, i_t_patches_batch, vi_batch, vi_t_batch, mn_arr, mem_rep_of_batch_imgs
del img_mem_rep_probs_arr, img_pair_probs_arr
test_loss /= cnt_batches
test_acc = correct / no_test_samples
print('\nAfter epoch {} - Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
epoch, test_loss, correct, no_test_samples, 100. * correct / no_test_samples))
return test_loss, test_acc
class ModelTrainTest():
def __init__(self, network, device, model_file_path, threshold=1e-4):
super(ModelTrainTest, self).__init__()
self.network = network
self.device = device
self.model_file_path = model_file_path
self.threshold = threshold
self.train_loss = 1e9
self.val_loss = 1e9
def train(self, optimizer, epoch, params_max_norm, train_data_loader, val_data_loader,
no_train_samples, no_val_samples):
self.network.train()
train_loss, correct, cnt_batches = 0, 0, 0
for batch_idx, (data, target) in enumerate(train_data_loader):
data, target = data.to(self.device), target.to(self.device)
optimizer.zero_grad()
output = self.network(data)
loss = F.nll_loss(output, target)
loss.backward()
clip_grad_norm_(self.network.parameters(), params_max_norm)
optimizer.step()
correct += get_count_correct_preds(output, target)
train_loss += loss.item()
cnt_batches += 1
del data, target, output
train_loss /= cnt_batches
val_loss, val_acc = self.test(epoch, val_data_loader, no_val_samples)
if val_loss < self.val_loss - self.threshold:
self.val_loss = val_loss
torch.save(self.network.state_dict(), self.model_file_path)
train_acc = correct / no_train_samples
print('\nAfter epoch {} - Train set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
epoch, train_loss, correct, no_train_samples, 100. * correct / no_train_samples))
return train_loss, train_acc, val_loss, val_acc
def test(self, epoch, test_data_loader, no_test_samples):
self.network.eval()
test_loss = 0
correct = 0
for batch_idx, (data, target) in enumerate(test_data_loader):
data, target = data.to(self.device), target.to(self.device)
output = self.network(data)
test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss
correct += get_count_correct_preds(output, target)
del data, target, output
test_loss /= no_test_samples
test_acc = correct / no_test_samples
print('\nAfter epoch {} - Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
epoch, test_loss, correct, no_test_samples, 100. * correct / no_test_samples))
return test_loss, test_acc
if __name__ == '__main__':
img_pair_probs_arr = torch.randn((256,))
img_mem_rep_probs_arr = torch.randn((256,))
print (get_count_correct_preds_pretext(img_pair_probs_arr, img_mem_rep_probs_arr))