From 4c057f645eda74ac1a7ec0e0d40cfc686552259f Mon Sep 17 00:00:00 2001 From: Young Date: Wed, 10 Jul 2024 06:25:30 +0000 Subject: [PATCH] Successfully run training --- qlib/contrib/model/pytorch_general_nn.py | 51 ++++++++++++++++++++---- tests/model/test_general_nn.py | 9 ++++- 2 files changed, 52 insertions(+), 8 deletions(-) diff --git a/qlib/contrib/model/pytorch_general_nn.py b/qlib/contrib/model/pytorch_general_nn.py index 94f4397c52..e62e4a734d 100644 --- a/qlib/contrib/model/pytorch_general_nn.py +++ b/qlib/contrib/model/pytorch_general_nn.py @@ -488,12 +488,40 @@ def metric_fn(self, pred, label): raise ValueError("unknown metric `%s`" % self.metric) + + def _get_fl(self, data: torch.Tensor): + """ + get feature and label from data + - Handle the different data shape of time series and tabular data + + Parameters + ---------- + data : torch.Tensor + input data which maybe 3 dimension or 2 dimension + - 3dim: [batch_size, time_step, feature_dim] + - 2dim: [batch_size, feature_dim] + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + """ + if data.dim() == 3: + # it is a time series dataset + feature = data[:, :, 0:-1].to(self.device) + label = data[:, -1, -1].to(self.device) + elif data.dim() == 2: + # it is a tabular dataset + feature = data[:, 0:-1].to(self.device) + label = data[:, -1].to(self.device) + else: + raise ValueError("Unsupported data shape.") + return feature, label + def train_epoch(self, data_loader): self.dnn_model.train() for data, weight in data_loader: - feature = data[:, :, 0:-1].to(self.device) - label = data[:, -1, -1].to(self.device) + feature , label = self._get_fl(data) pred = self.dnn_model(feature.float()) loss = self.loss_fn(pred, label, weight.to(self.device)) @@ -526,19 +554,18 @@ def test_epoch(self, data_loader): def fit( self, - dataset, + dataset: Union[DatasetH, TSDatasetH], evals_result=dict(), save_path=None, reweighter=None, ): + ists = isinstance(dataset, TSDatasetH) # is this time series dataset + dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) if dl_train.empty or dl_valid.empty: raise ValueError("Empty data from dataset, please check your dataset config.") - dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader - dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader - if reweighter is None: wl_train = np.ones(len(dl_train)) wl_valid = np.ones(len(dl_valid)) @@ -548,6 +575,15 @@ def fit( else: raise ValueError("Unsupported reweighter type.") + # Preprocess for data. To align to Dataset Interface for DataLoader + if ists: + dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader + dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader + else: + # If it is a tabular, we convert the dataframe to numpy to be indexable by DataLoader + dl_train = dl_train.values + dl_valid = dl_valid.values + train_loader = DataLoader( ConcatDataset(dl_train, wl_train), batch_size=self.batch_size, @@ -562,6 +598,7 @@ def fit( num_workers=self.n_jobs, drop_last=True, ) + del dl_train, dl_valid, wl_train, wl_valid save_path = get_or_create_path(save_path) @@ -605,7 +642,7 @@ def fit( if self.use_gpu: torch.cuda.empty_cache() - def predict(self, dataset): + def predict(self, dataset: Union[DatasetH, TSDatasetH]): if not self.fitted: raise ValueError("model is not fitted yet!") diff --git a/tests/model/test_general_nn.py b/tests/model/test_general_nn.py index 2fa485fdad..0ebb9daa4b 100644 --- a/tests/model/test_general_nn.py +++ b/tests/model/test_general_nn.py @@ -67,9 +67,16 @@ def test_both_dataset(self): "dropout":0., }, ), + GeneralPTNN( + n_epochs=2, + pt_model_uri="qlib.contrib.model.pytorch_nn.Net", # it is a MLP + pt_model_kwargs={ + "input_dim":3, + }, + ), ] - for ds, model in zip((tsds, tbds), model_l): + for ds, model in reversed(list(zip((tsds, tbds), model_l))): model.fit(ds) # It works model.predict(ds) # It works break