-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
150 lines (127 loc) · 4.7 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import torch
import os
import pandas as pd
def init_weights(m):
if isinstance(m, torch.nn.Linear):
torch.nn.init.xavier_uniform_(m.weight, gain=torch.nn.init.calculate_gain('relu'))
m.bias.data.fill_(0.01)
def init_ortho(m):
""" Initializes weight layers with orthonormal matrix
Use with network.apply(init_ortho)
"""
if isinstance(m, torch.nn.Linear):
torch.nn.init.orthogonal_(m.weight)
def get_n_params(model):
""" Returns the number of learnable parameters of a network
"""
pp=0
for p in list(model.parameters()):
nn=1
for s in list(p.size()):
nn = nn*s
pp += nn
return pp
def add_noise(inputs):
""" Adds noise to a tensor
makes sure it stays between 0 and 1
"""
noise = torch.clip(torch.randn_like(inputs)*0.01, min=0, max=1)
return inputs + noise
def accuracy(y_pred, y_true):
""" Computes the accuracy between 2 class tensors
"""
y_pred = torch.round(y_pred)
y_true = torch.round(y_true)
right = (y_pred == y_true)
return (torch.sum(right) / len(right)).item()
def minmax_scale(v, new_min, new_max):
""" Scales tensor between new_min and new_max
Args:
v : tensor to scale
new_min : minimum value in the tensor
new_max : maximum value in the tensor
"""
with torch.no_grad():
v_min, v_max = v.min(), v.max()
v = (v - v_min)/(v_max - v_min)*(new_max - new_min) + new_min
return v
def apply_weight_decay(*modules, weight_decay_factor=0., wo_bn=True):
'''
https://discuss.pytorch.org/t/weight-decay-in-the-optimizers-is-a-bad-idea-especially-with-batchnorm/16994/5
Apply weight decay to pytorch model without BN;
In pytorch:
if group['weight_decay'] != 0:
grad = grad.add(p, alpha=group['weight_decay'])
p is the param;
:param modules:
:param weight_decay_factor:
:return:
'''
for module in modules:
for m in module.modules():
if hasattr(m, 'weight'):
if wo_bn and isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
continue
if isinstance(m, torch.nn.Linear) or isinstance(m, torch.nn.Conv2d):
m.weight.grad += m.weight * weight_decay_factor
def write_params(p, folder='saved_models', verbose=0):
""" Writes the params in a separate parameter file
"""
filename = p['filename']
string = f"""Name : {p['filename']}
Last epoch : {p['epoch']}
########### GLOBAL ###########
DS: {p['ds']}
Run test : {p['run_test']}
Batch size : {p['bs']}
Crop_size : {p['crop_size']}\n\n"""
string += f"""########### ARCHI ###########
Input dim : {p['z_dim']}
{p['archi_info']}\n\n"""
string += f"""########### TRAINING PARAMS ###########\n
Epochs : {p['n_epoch']}
Save freq : {p['save_frequency']}
Discriminator learning factor (k) : {p['k']}\n\n"""
string += f"""########### MODEL PARAMS ###########
lrG : {p['lrG']}
lrD : {p['lrD']}
beta : {p['beta1']}
Weight decay Discriminator : {p['weight_decayD']}
Weight decay Generator : {p['weight_decayG']}
label_reals : {p['label_reals']}
label_fakes :{p['label_fakes']}"""
if verbose:
print(string)
print()
print("#######################\n")
filename += "-PARAMS"
with open(os.path.join(folder, filename), 'w+') as file :
file.write(string)
file.close()
def get_epoch_from_log(param_dict, folder='saved_models', verbose=1):
""" Reads the PARAMS file to fill the param dict with
the correct parameters
"""
with open(os.path.join(folder, param_dict["filename"] + "-PARAMS"), "r") as f:
lines = pd.Series(f.readlines()).str.strip('\n')
#### Verif paramètres égaux TODO ###
search = {'filename': 'Name', 'archi_info': 'upsample type',
'lrG':'lrG', 'lrD': 'lrD', 'beta1': 'beta',
'weight_decayD': 'Weight decay Discriminator',
'weight_decayG': 'Weight decay Generator',
'k': "Discriminator learning factor (k) : 2",
'z_dim': 'Input dim', 'n_epoch': 'Epochs',
'save_frequency': 'Save freq', 'label_fakes': 'label_fakes',
'label_reals': 'label_reals', 'ds': 'DS', 'run_test': "Run test",
'bs': "Batch size", 'crop_size': "Crop_size", 'epoch':"Last epoch"}
for param in param_dict.keys():
try:
line = lines[lines.str.startswith(search[param])]
pp = line.values[0].split(":")[1].strip()
if param not in ['filename', 'archi_info', 'ds', 'run_test']:
pp = float(pp)
if int(pp) == pp:
pp = int(pp)
param_dict[param] = pp
except Exception as e:
print(f"Could not retrieve {param} from log : {e}")