-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_mnist.py
129 lines (102 loc) · 4.84 KB
/
train_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# -*- coding: utf-8 -*-
import matplotlib
matplotlib.use('Agg')
import argparse
import chainer
from chainer import training
from chainer.training import extensions
from model import LeNets
from util.logger import setup_logger
from util.extensions import VisualizeDeepFeature
def main():
parser = argparse.ArgumentParser(description='Chainer example: MNIST')
parser.add_argument('--batchsize', '-b', type=int, default=32,
help='Number of images in each mini-batch')
parser.add_argument('--epoch', '-e', type=int, default=30,
help='Number of sweeps over the dataset to train')
parser.add_argument('--centerloss', '-c', action='store_true',
default=False, help='Use center loss')
parser.add_argument('--alpha_ratio', '-a', type=float, default=0.5,
help='alpha ratio')
parser.add_argument('--lambda_ratio', '-l', type=float, default=0.1,
help='lambda ratio')
parser.add_argument('--frequency', '-f', type=int, default=-1,
help='Frequency of taking a snapshot')
parser.add_argument('--gpu', '-g', type=int, default=-1,
help='GPU ID (negative value indicates CPU)')
parser.add_argument('--out', '-o', default='result',
help='Directory to output the result')
parser.add_argument('--resume', '-r', default='',
help='Resume the training from snapshot')
args = parser.parse_args()
logger = setup_logger(__name__)
logger.info("GPU: {}".format(args.gpu))
logger.info("# Minibatch-size: {}".format(args.batchsize))
logger.info("# epoch: {}".format(args.epoch))
logger.info("Calculate center loss: {}".format(args.centerloss))
if args.centerloss:
logger.info('# alpha: {}'.format(args.alpha_ratio))
logger.info('# lambda: {}'.format(args.lambda_ratio))
NUM_CLASSES = 10
model = LeNets(
out_dim=NUM_CLASSES,
alpha_ratio=args.alpha_ratio,
lambda_ratio=args.lambda_ratio,
is_center_loss=args.centerloss,
)
if args.gpu >= 0:
# Make a specified GPU current
chainer.cuda.get_device_from_id(args.gpu).use()
model.to_gpu() # Copy the model to the GPU
# Setup an optimizer
optimizer = chainer.optimizers.Adam()
optimizer.setup(model)
# Load the MNIST dataset
train, test = chainer.datasets.get_mnist(ndim=3)
train_iter = chainer.iterators.MultiprocessIterator(train, args.batchsize, n_processes=4)
test_iter = chainer.iterators.MultiprocessIterator(test, args.batchsize, n_processes=4,
repeat=False, shuffle=False)
# Set up a trainer
updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)
# Evaluate the model with the test dataset for each epoch
trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))
# Dump a computational graph from 'loss' variable at the first iteration
# The "main" refers to the target link of the "main" optimizer.
trainer.extend(extensions.dump_graph('main/loss'))
# Take a snapshot for each specified epoch
frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))
# Write a log of evaluation statistics for each epoch
trainer.extend(extensions.LogReport())
# Save two plot images to the result dir
if extensions.PlotReport.available():
trainer.extend(
extensions.PlotReport(['main/loss', 'validation/main/loss'],
'epoch', file_name='loss.png'))
trainer.extend(
extensions.PlotReport(
['main/accuracy', 'validation/main/accuracy'],
'epoch', file_name='accuracy.png'))
# Print selected entries of the log to stdout
# Here "main" refers to the target link of the "main" optimizer again, and
# "validation" refers to the default name of the Evaluator extension.
# Entries other than 'epoch' are reported by the Classifier link, called by
# either the updater or the evaluator.
trainer.extend(extensions.PrintReport(
['epoch', 'iteration', 'main/loss', 'validation/main/loss',
'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))
# Visualize Deep Features
trainer.extend(VisualizeDeepFeature(
train[:10000],
NUM_CLASSES,
args.centerloss), trigger=(1, 'epoch'))
# Print a progress bar to stdout
trainer.extend(extensions.ProgressBar())
if args.resume:
# Resume from a snapshot
chainer.serializers.load_npz(args.resume, trainer)
# Run the training
trainer.run()
if __name__ == '__main__':
main()