Skip to content

Commit

Permalink
Successfully run training
Browse files Browse the repository at this point in the history
  • Loading branch information
you-n-g committed Jul 10, 2024
1 parent a9fc343 commit 4c057f6
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 8 deletions.
51 changes: 44 additions & 7 deletions qlib/contrib/model/pytorch_general_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,12 +488,40 @@ def metric_fn(self, pred, label):

raise ValueError("unknown metric `%s`" % self.metric)


def _get_fl(self, data: torch.Tensor):
"""
get feature and label from data
- Handle the different data shape of time series and tabular data
Parameters
----------
data : torch.Tensor
input data which maybe 3 dimension or 2 dimension
- 3dim: [batch_size, time_step, feature_dim]
- 2dim: [batch_size, feature_dim]
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
"""
if data.dim() == 3:
# it is a time series dataset
feature = data[:, :, 0:-1].to(self.device)
label = data[:, -1, -1].to(self.device)
elif data.dim() == 2:
# it is a tabular dataset
feature = data[:, 0:-1].to(self.device)
label = data[:, -1].to(self.device)
else:
raise ValueError("Unsupported data shape.")
return feature, label

def train_epoch(self, data_loader):
self.dnn_model.train()

for data, weight in data_loader:
feature = data[:, :, 0:-1].to(self.device)
label = data[:, -1, -1].to(self.device)
feature , label = self._get_fl(data)

pred = self.dnn_model(feature.float())
loss = self.loss_fn(pred, label, weight.to(self.device))
Expand Down Expand Up @@ -526,19 +554,18 @@ def test_epoch(self, data_loader):

def fit(
self,
dataset,
dataset: Union[DatasetH, TSDatasetH],
evals_result=dict(),
save_path=None,
reweighter=None,
):
ists = isinstance(dataset, TSDatasetH) # is this time series dataset

dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
if dl_train.empty or dl_valid.empty:
raise ValueError("Empty data from dataset, please check your dataset config.")

dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader

if reweighter is None:
wl_train = np.ones(len(dl_train))
wl_valid = np.ones(len(dl_valid))
Expand All @@ -548,6 +575,15 @@ def fit(
else:
raise ValueError("Unsupported reweighter type.")

# Preprocess for data. To align to Dataset Interface for DataLoader
if ists:
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
else:
# If it is a tabular, we convert the dataframe to numpy to be indexable by DataLoader
dl_train = dl_train.values
dl_valid = dl_valid.values

train_loader = DataLoader(
ConcatDataset(dl_train, wl_train),
batch_size=self.batch_size,
Expand All @@ -562,6 +598,7 @@ def fit(
num_workers=self.n_jobs,
drop_last=True,
)
del dl_train, dl_valid, wl_train, wl_valid

save_path = get_or_create_path(save_path)

Expand Down Expand Up @@ -605,7 +642,7 @@ def fit(
if self.use_gpu:
torch.cuda.empty_cache()

def predict(self, dataset):
def predict(self, dataset: Union[DatasetH, TSDatasetH]):
if not self.fitted:
raise ValueError("model is not fitted yet!")

Expand Down
9 changes: 8 additions & 1 deletion tests/model/test_general_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,16 @@ def test_both_dataset(self):
"dropout":0.,
},
),
GeneralPTNN(
n_epochs=2,
pt_model_uri="qlib.contrib.model.pytorch_nn.Net", # it is a MLP
pt_model_kwargs={
"input_dim":3,
},
),
]

for ds, model in zip((tsds, tbds), model_l):
for ds, model in reversed(list(zip((tsds, tbds), model_l))):
model.fit(ds) # It works
model.predict(ds) # It works
break
Expand Down

0 comments on commit 4c057f6

Please sign in to comment.