-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_loader.py
129 lines (101 loc) · 5.08 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""
Copyright©[2023] Fraunhofer-Gesellschaft zur Foerderung der angewandten Forschung e.V. acting on behalf of its Fraunhofer-Institut für Kognitive Systeme IKS. All rights reserved.
This software is subject to the terms and conditions of the GNU GPLv2 (https://www.gnu.de/documents/gpl-2.0.de.html).
Contact: nicola.franco@iks.fraunhofer.de
Data loader for the experiments
This file contains the data loader for the experiments. It is based on the
data loader from the pytorch-ood repository. The data loader is used to
load the datasets and to create the dataloaders for the experiments.
"""
import torch
from torchvision.datasets import CIFAR10, CIFAR100, SVHN, LSUN
from pytorch_ood.dataset.img import (
LSUNCrop,
LSUNResize,
Textures,
TinyImageNetCrop,
TinyImageNetResize,
GaussianNoise,
UniformNoise,
)
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
def prepare_data(dataset_path, batch_size, dataset):
"""
Prepare the data for the experiments
Parameters
----------
dataset_path : str
Path to the datasets.
batch_size : int
Batch size for the dataloaders.
dataset : str
Name of the dataset.
Returns
-------
id_loader : torch.utils.data.DataLoader
Dataloader for the in-distribution dataset.
loaders : dict
Dictionary containing the dataloaders for the out-of-distribution
datasets.
"""
trn = transforms.ToTensor()
crop = transforms.Compose([transforms.CenterCrop(32), trn])
resize = transforms.Compose([transforms.Resize((32, 32)), trn])
cifar10 = True if dataset == 'cifar10' else False
if cifar10:
id_dataset = CIFAR10(dataset_path, train=False, transform=trn, download=False)
else:
id_dataset = CIFAR100(dataset_path, train=False, transform=trn, download=False)
id_loader = DataLoader(id_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
# create all OOD datasets
np.random.seed(seed=1)
# ood_datasets = [CIFAR100, SVHN]#, LSUNCrop, TinyImageNetCrop, LSUNResize] #TinyImageNetResize, Textures
ood_datasets = {
'CIFAR100' : CIFAR100('/storage/project_data/robustness_certificates/data/images_classic/', train=False, transform=trn, download=True)
if cifar10 else CIFAR10('/storage/datasets/torchvision_cache/', train=False, transform=trn, download=False),
'SVHN' : SVHN('/storage/datasets/torchvision_cache/svhn/', split='test', transform=trn, download=False),
'LSUNCrop' : LSUNCrop('/storage/datasets/processed/', transform=crop, download=False),
'GaussianNoise' : GaussianNoise(length=1000, size=(32, 32, 3), transform=trn, seed=1),
'LSUNResize' : LSUNResize('/storage/datasets/processed/', transform=resize, download=False),
'TinyImageNetCrop' : TinyImageNetCrop('/storage/datasets/processed/', transform=crop, download=False),
# 'TinyImageNetResize' : TinyImageNetResize('/storage/datasets/processed/', transform=resize, download=False),
'Textures' : Textures('/storage/datasets/processed/', transform=resize, download=False),
'UniformNoise' : UniformNoise(length=1000, size=(32, 32, 3), transform=trn, seed=1),
}
target_size = 1000 #400 if args.dset_in_name=='RImgNet' else
loaders = {}
for name, dataset in ood_datasets.items():
subset = np.random.choice(len(dataset), size=min(len(dataset), target_size), replace=False)
sampler = torch.utils.data.SubsetRandomSampler(subset)
# dataset_out_test = ood_dataset('/storage/datasets/processed/', transform=transform, train=False, download=False)
loaders[name] = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2, sampler=sampler)
return id_loader, loaders
def prepare_loader_for_robustness(dataset_path, batch_size):
""" Prepare the loader for the robustness evaluation
Args:
dataset_path (str): path to the dataset
batch_size (int): batch size for the loader
Returns:
loader (torch.utils.data.DataLoader): loader for the robustness evaluation
"""
np.random.seed(seed=1)
trn = transforms.ToTensor()
id_dataset = CIFAR10(dataset_path, train=False, transform=trn, download=False)
subset = np.random.choice(
len(id_dataset), size=min(len(id_dataset), 1000), replace=False)
sampler = torch.utils.data.SubsetRandomSampler(subset)
return DataLoader(id_dataset, batch_size=batch_size, shuffle=False, num_workers=2, sampler=sampler)
def prepare_training_data(dataset_path, batch_size):
"""
Prepare the loader for the training of the model
Args:
dataset_path (str): path to the dataset
batch_size (int): batch size for the loader
Returns:
loader (torch.utils.data.DataLoader): loader for the training of the model
"""
dataset = CIFAR10(dataset_path, train=True, transform=transforms.ToTensor(), download=False)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
return loader