Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

astgnn update #441

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions libcity/config/executor/ASTGNNExecutor.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{
"gpu": true,
"gpu_id": 0,
"max_epoch": 100,
"train_loss": "none",
"epoch": 0,
"learner": "adam",
"learning_rate": 0.01,
"weight_decay": 0,
"lr_epsilon": 1e-8,
"lr_beta1": 0.9,
"lr_beta2": 0.999,
"lr_alpha": 0.99,
"lr_momentum": 0,
"lr_decay": false,
"lr_scheduler": "multisteplr",
"lr_decay_ratio": 0.1,
"steps": [5, 20, 40, 70],
"step_size": 10,
"lr_T_max": 30,
"lr_eta_min": 0,
"lr_patience": 10,
"lr_threshold": 1e-4,
"clip_grad_norm": false,
"max_grad_norm": 1.0,
"use_early_stop": false,
"patience": 50,
"log_level": "INFO",
"log_every": 1,
"saved_model": true,
"load_best_epoch": true,
"hyper_tune": false,
"fine_tune_lr": 0.0001,
"fine_tune_epochs": 50
}
35 changes: 35 additions & 0 deletions libcity/config/model/traffic_state_pred/ASTGNN.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{
"max_epoch": 100,
"fine_tune_epochs": 50,

"learner": "adam",
"learning_rate": 0.001,
"scaler": "minmax11",

"len_trend": 0,
"len_period": 0,
"len_closeness": 1,

"lr_decay": false,
"lr_scheduler": "steplr",
"step_size": 50,
"lr_decay_ratio": 0.5,
"batch_size": 16,

"clip_grad_norm": false,
"max_grad_norm": 10,
"use_early_stop": false,
"patience": 15,

"num_layers": 4,
"d_model": 64,
"nb_head": 8,
"dropout": 0.0,
"aware_temporal_context": 1,
"ScaledSAt": 1,
"SE": true,
"TE": true,
"kernel_size": 3,
"smooth_layer_num": 0,
"output_dim": 1
}
35 changes: 35 additions & 0 deletions libcity/config/model/traffic_state_pred/ASTGNNCommon.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{
"max_epoch": 100,
"fine_tune_epochs": 50,

"learner": "adam",
"learning_rate": 0.001,
"scaler": "minmax11",

"len_trend": 0,
"len_period": 0,
"len_closeness": 1,

"lr_decay": false,
"lr_scheduler": "steplr",
"step_size": 50,
"lr_decay_ratio": 0.5,
"batch_size": 16,

"clip_grad_norm": false,
"max_grad_norm": 10,
"use_early_stop": false,
"patience": 5,

"num_layers": 4,
"d_model": 64,
"nb_head": 8,
"dropout": 0.0,
"aware_temporal_context": 1,
"ScaledSAt": 1,
"SE": true,
"TE": true,
"kernel_size": 3,
"smooth_layer_num": 0,
"output_dim": 1
}
12 changes: 11 additions & 1 deletion libcity/config/task_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
"MultiSTGCnet", "STMGAT", "CRANN", "STTN", "CONVGCNCommon", "DSAN", "DKFN", "CCRNN", "MultiSTGCnetCommon",
"GEML", "FNN", "GSNet", "CSTN", "D2STGNN", "STID","STGODE", "STNorm", "DMSTGCN", "ESG", "SSTBAN", "STTSNet",
"FOGS", "RGSL", "DSTAGNN", "STPGCN", "HIEST", "STAEformer", "TESTAM", "MultiSPANS", "SimST", "TimeMixer",
"MegaCRN", "Trafformer", "STSSL", "STWave", "PDFormer", "STGNCDE"
"MegaCRN", "Trafformer", "STSSL", "STWave", "PDFormer", "STGNCDE", "ASTGNN", "ASTGNNCommon"
],
"allowed_dataset": [
"METR_LA", "PEMS_BAY", "PEMSD3", "PEMSD4", "PEMSD7", "PEMSD8", "PEMSD7(M)",
Expand All @@ -102,6 +102,16 @@
"NYCTAXI_OD", "NYCTAXI_GRID", "T_DRIVE_SMALL", "NYCBIKE", "AUSTINRIDE",
"BIKEDC", "BIKECHI", "NYC_RISK", "CHICAGO_RISK", "NYCTAXI20140112_FLOW"
],
"ASTGNNCommon": {
"dataset_class": "TrafficStatePointDataset",
"executor": "ASTGNNExecutor",
"evaluator": "TrafficStateEvaluator"
},
"ASTGNN": {
"dataset_class": "ASTGCNDataset",
"executor": "ASTGNNExecutor",
"evaluator": "TrafficStateEvaluator"
},
"STGNCDE": {
"dataset_class": "TrafficStatePointDataset",
"executor": "TrafficStateExecutor",
Expand Down
2 changes: 2 additions & 0 deletions libcity/executor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from libcity.executor.megacrn_executor import MegaCRNExecutor
from libcity.executor.trafformer_executor import TrafformerExecutor
from libcity.executor.pdformer_executor import PDFormerExecutor
from libcity.executor.astgnn_executor import ASTGNNExecutor

__all__ = [
"TrajLocPredExecutor",
Expand All @@ -44,4 +45,5 @@
"MegaCRNExecutor",
"TrafformerExecutor",
"PDFormerExecutor",
"ASTGNNExecutor",
]
80 changes: 80 additions & 0 deletions libcity/executor/astgnn_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import time
import numpy as np
import torch
import os
from libcity.executor.traffic_state_executor import TrafficStateExecutor


class ASTGNNExecutor(TrafficStateExecutor):
def __init__(self, config, model, data_feature):
TrafficStateExecutor.__init__(self, config, model, data_feature)
self.fine_tune_epochs = config.get("fine_tune_epochs", 0)
self.fine_tune_lr = config.get("fine_tune_lr", 0.001)
self.raw_epochs = self.epochs
self.epochs = self.epochs + self.fine_tune_epochs

def _train_epoch(self, train_dataloader, epoch_idx, loss_func=None):
"""
该Executor支持train和eval过程的loss计算方式不同
支持fine tune
"""
fine_tune = False
if epoch_idx >= self.raw_epochs:
# rebuild optimizer
self.learning_rate = self.fine_tune_lr
self.optimizer = self._build_optimizer()
fine_tune = True
self.model.train()
loss_func = self.model.calculate_train_loss if not fine_tune else self.model.calculate_val_loss
losses = []
for batch in train_dataloader:
self.optimizer.zero_grad()
batch.to_tensor(self.device)
loss = loss_func(batch)
self._logger.debug(loss.item())
losses.append(loss.item())
loss.backward()
if self.clip_grad_norm:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.optimizer.step()
return losses

def _valid_epoch(self, eval_dataloader, epoch_idx, loss_func=None):
with torch.no_grad():
self.model.eval()
loss_func = self.model.calculate_val_loss
losses = []
for batch in eval_dataloader:
batch.to_tensor(self.device)
loss = loss_func(batch)
self._logger.debug(loss.item())
losses.append(loss.item())
mean_loss = np.mean(losses)
self._writer.add_scalar('eval loss', mean_loss, epoch_idx)
return mean_loss

def evaluate(self, test_dataloader):
self._logger.info('Start evaluating ...')
with torch.no_grad():
self.model.eval()
y_truths = []
y_preds = []
for batch in test_dataloader:
batch.to_tensor(self.device)
output = self.model.predict(batch)
labels = self.model.get_label(batch)
y_true = self._scaler.inverse_transform(labels)
y_pred = self._scaler.inverse_transform(output)
y_truths.append(y_true.cpu().numpy())
y_preds.append(y_pred.cpu().numpy())
y_preds = np.concatenate(y_preds, axis=0)
y_truths = np.concatenate(y_truths, axis=0) # concatenate on batch
outputs = {'prediction': y_preds, 'truth': y_truths}
filename = \
time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime(time.time())) + '_' \
+ self.config['model'] + '_' + self.config['dataset'] + '_predictions.npz'
np.savez_compressed(os.path.join(self.evaluate_res_dir, filename), **outputs)
self.evaluator.clear()
self.evaluator.collect({'y_true': torch.tensor(y_truths), 'y_pred': torch.tensor(y_preds)})
test_result = self.evaluator.save_result(self.evaluate_res_dir)
return test_result
Loading