-
Notifications
You must be signed in to change notification settings - Fork 0
/
npz_dataset.py
53 lines (39 loc) · 1.51 KB
/
npz_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
"""
npz_dataset.py
Dataset class (for PyTorch DataLoader) for data saved in *.npz or *.npy format.
If using a *.npz file, it must contain an array 'x' that stores all the data and
can contain an optional array 'params' of known parameters for comparison.
"""
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
class PDEDataset(Dataset):
"""PDE dataset with inputs x and targets also x."""
def __init__(self, data_file=None, transform=None, data_size=None):
"""
Args:
data_file (numpy save): file with all data
transform (callable, optional): Optional transform to be applied
on a sample.
"""
data = np.load(data_file)
if type(data) is np.ndarray:
self.data_x = data
self.params = None
elif 'u_set' in data.files:
self.data_x = data['u_set']
self.params = data['para_set'] if 'para_set' in data.files else None
else:
raise ValueError("Dataset import failed. NPZ files must include 'u_set' array containing data.")
self.transform = transform
def __len__(self):
return len(self.data_x)
def __getitem__(self, idx):
x = torch.from_numpy(self.data_x[idx])
if self.params is None:
sample = [x, x, torch.tensor(float('nan'))]
else:
sample = [x, x, torch.from_numpy(self.params[idx])]
if self.transform:
sample = self.transform(sample)
return sample