-
Notifications
You must be signed in to change notification settings - Fork 2
/
test.py
executable file
·59 lines (47 loc) · 1.34 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import random
import tinn
class Data:
def __init__(self, path, nips, nops):
self.read_data(path, nips, nops)
def __repr__(self):
return f'{len(self)} rows with {len(self.in_[0])} inputs and {len(self.tg[0])} outputs.'
def read_data(self, path, nips, nops):
self.in_, self.tg = [], []
with open(path) as data_file:
for line in data_file:
row = list(map(float, line.split()))
self.in_.append(row[:nips])
self.tg.append(row[nips:])
def shuffle(self):
indexes = list(range(len(self.in_)))
random.shuffle(indexes)
self.in_ = [self.in_[i] for i in indexes]
self.tg = [self.tg[i] for i in indexes]
def __len__(self):
return len(self.in_)
def main():
nips = 256
nhid = 28
nops = 10
rate = 1.0
anneal = 0.99
data = Data('semeion.data', nips, nops)
t = tinn.Tinn(nips, nhid, nops)
for _ in range(3):
data.shuffle()
error = 0
for in_, tg in zip(data.in_, data.tg):
error += tinn.xttrain(t, in_, tg, rate)
print(f'error {error/len(data)} :: learning rate {rate}')
rate *= anneal
t.save('saved.tinn')
loaded = tinn.xtload('saved.tinn')
in_ = data.in_[0]
tg = data.tg[0]
pd = tinn.xtpredict(loaded, in_)
print(' '.join(map(str, tg)))
print(' '.join(map(str, pd)))
if __name__ == '__main__':
main()