-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
106 lines (89 loc) · 4.92 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from args import *
from models import *
from data import *
from modules import *
from generator import *
import pandas
from allennlp.data.tokenizers.token import Token
from unitrigger_utils import get_embedding_weight
args = training_params()
torch.manual_seed(args.seed)
np.random.seed(args.seed)
pandas.set_option('display.max_columns', None)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.set_printoptions(precision=3)
cols = ['target_label', 'test_clean_acc_without', 'test_clean_f1_without', 'test_target_acc_with', 'test_target_detect_auc', 'init_trigger', 'trigger']
labels = ['target_label', 'ACC', 'F1', 'TARGET_ACC', 'DETECT AUC', 'INIT', 'TRIGGER']
def main(args):
train_data, train_dev_data, dev_data, test_data, vocab = load_dataset(args)
model, encoder, word_embeddings = initialize_model(args, vocab)
trapdoors = [[None] for _ in range(len(args.target_labels))]
if args.trapdoor_method == "DARCY":
model, trapdoors = prepare_trapdoors_DARCY(args, train_data, train_dev_data, vocab, args.target_labels,
trapdoor_num=args.trapdoor_num,
num_random=args.trapdoor_num_random,
inter_tree_neighbor=args.trapdoor_inter_neighbor,
intra_tree_neighbor=args.trapdoor_intra_neighbor)
model.training = True
else:
trapdoors = prepare_trapdoors_random(vocab, args.num_class, args.trapdoor_num)
additional_data, trapdoor_train_data, trapdoor_dev_data = populate_data(args,
train_data, args.target_labels, trapdoors,
dev_ratio=args.trapdoor_dev_ratio,
trapdoor_ratio=args.trapdoor_ratio,
detector_ratio=args.detector_ratio)
train_data = train_data + additional_data
random.shuffle(train_data)
print("additional_data", len(additional_data))
print("trapdoor_train_data", len(trapdoor_train_data))
print("TRAPDOORS", trapdoors)
# MAIN MODEL
print("Training Model")
train_model(args, model, vocab, train_data, train_dev_data)
vocab.save_to_files(args.vocab_path)
model.training = False
detector_model = model
# DETECTOR
print("Training Detector")
detector = load_detector(args, detector_model, encoder, vocab)
trapdoor_train = trapdoor_train_data
trapdoor_dev = trapdoor_dev_data
train_detector(args, detector, vocab, trapdoor_train, trapdoor_dev)
model.train() # for enabling gradients to go through the network
rt_test_clean = evaluate(model, vocab, test_data, trigger_token_ids=None, detector=None, prefix="test_clean")
rt_dev_clean = evaluate(model, vocab, dev_data, trigger_token_ids=None, detector=None, prefix="dev_clean")
tree = None
embedding_weight = get_embedding_weight(model)
tree = KDTree(embedding_weight.numpy())
results = []
for target_label, trapdoor in zip(args.target_labels, trapdoors):
all_triggers = []
previous_inits = []
previous_triggers = []
targeted_test_data = filter_data(test_data, target_label)
targeted_dev_data = filter_data(dev_data, target_label, sample_ratio=0.5)
if not args.oracle:
trigger_token_ids, init_trigger = attack_unitrigger(args, model, vocab, target_label, targeted_dev_data,
init_trigger=args.init_trigger, previous_inits=previous_inits, previous_triggers=previous_triggers, tree=tree)
else: #in oracle attacks the attackers have access to the detector
trigger_token_ids, init_trigger = attack_unitrigger(args, model, vocab, target_label, targeted_dev_data,
init_trigger=args.init_trigger, previous_inits=previous_inits,
exempt_triggers=previous_triggers, tree=tree, surrogate=detector)
all_triggers.append(trigger_token_ids)
rt_before = evaluate(model, vocab, targeted_test_data, trigger_token_ids=None, detector=None, prefix="test_target")
rt_after = evaluate(model, vocab, targeted_test_data, trigger_token_ids, detector, prefix="test_target")
if not detector:
rt_after = {**rt_after, **{'test_target_detect_auc':'N/A', 'test_target_detect_fpr':'N/A', 'test_target_detect_tpr': 'N/A'}}
rt = {**rt_dev_clean, **rt_test_clean, **rt_before, **rt_after}
rt['target_label'] = target_label
rt['trapdoor'] = ",".join([token.text for token in trapdoor])
rt['trigger'] = translate_triggers(trigger_token_ids, vocab)
rt['init_trigger'] = init_trigger
results.append(rt)
df = pd.DataFrame.from_dict(results)
df = df[cols]
df.columns = labels
print(df)
if __name__ == '__main__':
main(args)