Skip to content

Commit

Permalink
Dmstgcn (LibCity#379)
Browse files Browse the repository at this point in the history
* feat: DMSTGCN

* fix: data size

* feat: add comments

* fix: add comment

---------

Co-authored-by: wangyongyao <wangyongyao@kuaishou.com>
  • Loading branch information
Kazeya27 and wangyongyao committed Dec 6, 2023
1 parent 6614b5b commit 05fba49
Show file tree
Hide file tree
Showing 7 changed files with 415 additions and 2 deletions.
16 changes: 16 additions & 0 deletions libcity/config/data/DMSTGCNDataset.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"batch_size": 64,
"cache_dataset": true,
"num_workers": 0,
"pad_with_last_sample": true,
"train_rate": 0.6,
"eval_rate": 0.2,
"scaler": "standard",
"load_external": false,
"normal_external": true,
"ext_scaler": "standard",
"input_window": 12,
"output_window": 12,
"add_time_in_day": false,
"add_day_in_week": false
}
29 changes: 29 additions & 0 deletions libcity/config/model/traffic_state_pred/DMSTGCN.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"max_epoch": 200,

"learner": "adam",
"learning_rate": 0.001,
"lr_epsilon": 1e-8,
"weight_decay": 0.0001,

"lr_patience": 10,
"lr_decay_ratio": 0.3,
"lr_threshold": 1e-3,
"lr_scheduler": "reducelronplateau",

"clip_grad_norm": true,
"max_grad_norm": 5,
"use_early_stop": true,
"patience": 20,

"num_layers": 2,
"dropout": 0.3,
"residual_channels": 32,
"dilation_channels": 32,
"end_channels": 512,
"kernel_size": 2,
"num_blocks": 4,
"normalization": "batch",
"embedding_dims": 40,
"order": 2
}
7 changes: 6 additions & 1 deletion libcity/config/task_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
"STResNetCommon", "ACFMCommon", "ASTGCNCommon", "MSTGCNCommon", "ToGCN", "CONVGCN", "STG2Seq",
"DMVSTNet", "ATDM", "GMAN", "GTS", "STDN", "HGCN", "STSGCN", "STAGGCN", "STNN", "ResLSTM", "DGCN",
"MultiSTGCnet", "STMGAT", "CRANN", "STTN", "CONVGCNCommon", "DSAN", "DKFN", "CCRNN", "MultiSTGCnetCommon",
"GEML", "FNN", "GSNet", "CSTN", "D2STGNN", "STID","STGODE", "STNorm"
"GEML", "FNN", "GSNet", "CSTN", "D2STGNN", "STID","STGODE", "STNorm", "DMSTGCN"
],
"allowed_dataset": [
"METR_LA", "PEMS_BAY", "PEMSD3", "PEMSD4", "PEMSD7", "PEMSD8", "PEMSD7(M)",
Expand All @@ -100,6 +100,11 @@
"NYCTAXI_OD", "NYCTAXI_GRID", "T_DRIVE_SMALL", "NYCBIKE", "AUSTINRIDE",
"BIKEDC", "BIKECHI", "NYC_RISK", "CHICAGO_RISK"
],
"DMSTGCN": {
"dataset_class": "DMSTGCNDataset",
"executor": "TrafficStateExecutor",
"evaluator": "TrafficStateEvaluator"
},
"STGODE": {
"dataset_class": "STGODEDataset",
"executor": "TrafficStateExecutor",
Expand Down
2 changes: 2 additions & 0 deletions libcity/data/dataset/dataset_subclass/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from libcity.data.dataset.dataset_subclass.gsnet_dataset import GSNetDataset
from libcity.data.dataset.dataset_subclass.line_dataset import LINEDataset
from libcity.data.dataset.dataset_subclass.stgode_dataset import STGODEDataset
from libcity.data.dataset.dataset_subclass.dmstgcn_dataset import DMSTGCNDataset

__all__ = [
"ACFMDataset",
Expand All @@ -46,4 +47,5 @@
"GSNetDataset",
"LINEDataset",
"STGODEDataset"
"DMSTGCNDataset"
]
71 changes: 71 additions & 0 deletions libcity/data/dataset/dataset_subclass/dmstgcn_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import os

import numpy as np

from libcity.data.dataset import TrafficStateDataset
from libcity.data.utils import generate_dataloader


class DMSTGCNDataset(TrafficStateDataset):
def __init__(self, config):
super().__init__(config)
self.feature_name = {'X': 'float', 'y': 'float', 'idx': 'int'} # idx: 该数据时间段序号
self.timeslots = 24 * 60 * 60 // self.time_intervals

def _load_dyna(self, filename):
return super()._load_dyna_3d(filename)

def _add_external_information(self, df, ext_data=None):
return super()._add_external_information_3d(df, ext_data)

def get_data(self):
# 加载数据集
x_train, y_train, x_val, y_val, x_test, y_test = [], [], [], [], [], []
if self.data is None:
self.data = {}
if self.cache_dataset and os.path.exists(self.cache_file_name):
x_train, y_train, x_val, y_val, x_test, y_test = self._load_cache_train_val_test()
else:
x_train, y_train, x_val, y_val, x_test, y_test = self._generate_train_val_test()
# 数据归一化
self.feature_dim = x_train.shape[-1]
self.ext_dim = self.feature_dim - self.output_dim
self.scaler = self._get_scalar(self.scaler_type,
x_train[..., :self.output_dim], y_train[..., :self.output_dim])
self.ext_scaler = self._get_scalar(self.ext_scaler_type,
x_train[..., self.output_dim:], y_train[..., self.output_dim:])
x_train[..., :self.output_dim] = self.scaler.transform(x_train[..., :self.output_dim])
y_train[..., :self.output_dim] = self.scaler.transform(y_train[..., :self.output_dim])
idx_train = np.arange(0, x_train.shape[0]) % self.timeslots
x_val[..., :self.output_dim] = self.scaler.transform(x_val[..., :self.output_dim])
y_val[..., :self.output_dim] = self.scaler.transform(y_val[..., :self.output_dim])
idx_val = np.arange(x_train.shape[0], x_train.shape[0] + x_val.shape[0]) % self.timeslots
x_test[..., :self.output_dim] = self.scaler.transform(x_test[..., :self.output_dim])
y_test[..., :self.output_dim] = self.scaler.transform(y_test[..., :self.output_dim])
idx_test = np.arange(x_train.shape[0] + x_val.shape[0],
x_train.shape[0] + x_val.shape[0] + x_test.shape[0]) % self.timeslots
if self.normal_external:
x_train[..., self.output_dim:] = self.ext_scaler.transform(x_train[..., self.output_dim:])
y_train[..., self.output_dim:] = self.ext_scaler.transform(y_train[..., self.output_dim:])
x_val[..., self.output_dim:] = self.ext_scaler.transform(x_val[..., self.output_dim:])
y_val[..., self.output_dim:] = self.ext_scaler.transform(y_val[..., self.output_dim:])
x_test[..., self.output_dim:] = self.ext_scaler.transform(x_test[..., self.output_dim:])
y_test[..., self.output_dim:] = self.ext_scaler.transform(y_test[..., self.output_dim:])
# 把训练集的X和y聚合在一起成为list,测试集验证集同理
# x_train/y_train: (num_samples, input_length, ..., feature_dim)
# train_data(list): train_data[i]是一个元组,由x_train[i]和y_train[i]组成
train_data = list(zip(x_train, y_train, idx_train))
eval_data = list(zip(x_val, y_val, idx_val))
test_data = list(zip(x_test, y_test, idx_test))
# 转Dataloader
self.train_dataloader, self.eval_dataloader, self.test_dataloader = \
generate_dataloader(train_data, eval_data, test_data, self.feature_name,
self.batch_size, self.num_workers, pad_with_last_sample=self.pad_with_last_sample)
self.num_batches = len(self.train_dataloader)
return self.train_dataloader, self.eval_dataloader, self.test_dataloader

def get_data_feature(self):
return {"scaler": self.scaler, "adj_mx": self.adj_mx, "ext_dim": self.ext_dim,
"num_nodes": self.num_nodes, "feature_dim": self.feature_dim,
"output_dim": self.output_dim, "num_batches": self.num_batches,
"time_slots": self.timeslots}
Loading

0 comments on commit 05fba49

Please sign in to comment.