forked from cjx0525/BGCN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
33 lines (30 loc) · 966 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from time import time
import os
def test(model, epoch, loader, device, CONFIG, metrics):
'''
test for dot-based model
'''
model.eval()
for metric in metrics:
metric.start()
start = time()
with torch.no_grad():
rs = model.propagate()
for users, ground_truth_u_b, train_mask_u_b in loader:
pred_b = model.evaluate(rs, users.to(device))
pred_b -= 1e8*train_mask_u_b.to(device)
for metric in metrics:
metric(pred_b, ground_truth_u_b.to(device))
print('Test: time={:d}s'.format(int(time()-start)))
for metric in metrics:
metric.stop()
print('{}:{}'.format(metric.get_title(), metric.metric), end='\t')
print('')
return metrics