-
Notifications
You must be signed in to change notification settings - Fork 0
/
agent.py
124 lines (103 loc) · 4 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
This class will take control of the whole process of training or testing Segmentation models
"""
import tensorflow as tf
from utils.misc import timeit
import pickle
from utils.misc import calculate_flops
# import os
# import pdb
from models import *
from test import *
from train import *
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
class Agent:
"""
Agent will run the program
Choose the type of operation
Create a model
reset the graph and create a session
Create a trainer or tester
Then run it and handle it
"""
def __init__(self, args):
self.args = args
self.task = args.task
# tmp = globals()
# Get the class from globals by selecting it by arguments
self.model = globals()[args.model]
# trainer or tester
self.operator = globals()[args.operator]
self.sess = None
self.train_model = None
@timeit
def build_model(self):
if self.operator.name == 'Train':
with tf.variable_scope('network') as scope:
self.model = self.model(self.args)
self.model.build()
# print('Building Train Network')
# with tf.variable_scope('network') as scope:
# self.train_model = self.model(self.args, phase=0)
# self.train_model.build()
#
# print('Building Test Network')
# with tf.variable_scope('network') as scope:
# scope.reuse_variables()
# self.test_model = self.model(self.args, phase=1)
# self.test_model.build()
else: # inference phase
print('Building Test Network')
with tf.variable_scope('network') as scope:
# self.train_model = None
self.model = self.model(self.args)
self.model.build()
calculate_flops()
@timeit
def run(self):
"""
Initiate the Graph, sess, model, operator
:return:
"""
print("Agent is running now...\n\n")
# Reset the graph
tf.reset_default_graph()
# Create the sess
# gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.8)
gpu_options = tf.GPUOptions(allow_growth=True)
self.sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True))
# Create Model class and build it
with self.sess.as_default():
self.build_model()
# Create the operator
self.operator = self.operator(self.args, self.sess, self.model)
self.operator.run()
self.operator.finalize()
self.sess.close()
print("\nAgent is exited...\n")
@staticmethod
def load_pretrained_weights(sess, pretrained_path):
print('############### START Loading from PKL ##################')
with open(pretrained_path, 'rb') as ff:
pretrained_weights = pickle.load(ff, encoding='latin1')
print("Loading pretrained weights of resnet18")
# all_vars = tf.trainable_variables()
# all_vars += tf.get_collection('mu_sigma_bn')
all_vars = tf.all_variables()
for v in all_vars:
if v.op.name in pretrained_weights.keys():
if str(v.shape) != str(pretrained_weights[v.op.name].shape):
print(v.shape)
print(pretrained_weights[v.op.name].shape)
exit(0)
assign_op = v.assign(pretrained_weights[v.op.name])
sess.run(assign_op)
print(v.op.name + " - loaded successfully, size ", pretrained_weights[v.op.name].shape)
print("All pretrained weights of resnet18 is loaded")
def debug(self):
self.load_pretrained_weights(self.sess, 'pretrained_weights/linknet_weights.pkl')
try:
self.operator.debug_layers()
except KeyboardInterrupt:
pass