Skip to content

Commit

Permalink
GEML Implement (LibCity#159)
Browse files Browse the repository at this point in the history
* Implement GEML model, and add the TrafficStateOdDataset.

* Optimize directory structure for OD matrix prediction

* Revise typo and gpu-device error

* Debug for
Customized loss function calculating.
And the fix some issues.

* Optimize directory structure for OD matrix prediction

* Optimize directory structure for OD matrix prediction

* Debug: The semantic matrix generate was wrong.

* Debug: fix import error
  • Loading branch information
Apolsus committed Sep 4, 2021
1 parent c104851 commit 7eca580
Show file tree
Hide file tree
Showing 14 changed files with 601 additions and 9 deletions.
16 changes: 16 additions & 0 deletions libtraffic/config/data/TrafficStateOdDataset.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"batch_size": 64,
"cache_dataset": true,
"num_workers": 0,
"pad_with_last_sample": true,
"train_rate": 0.7,
"eval_rate": 0.1,
"scaler": "none",
"load_external": false,
"normal_external": false,
"ext_scaler": "none",
"input_window": 12,
"output_window": 12,
"add_time_in_day": false,
"add_day_in_week": false
}
36 changes: 36 additions & 0 deletions libtraffic/config/executor/GEMLExecutor.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
{
"gpu": true,
"gpu_id": 0,
"max_epoch": 100,
"train_loss": "none",
"epoch": 0,
"learner": "adam",
"learning_rate": 0.01,
"weight_decay": 0,
"lr_epsilon": 1e-8,
"lr_beta1": 0.9,
"lr_beta2": 0.999,
"lr_alpha": 0.99,
"lr_momentum": 0,
"lr_decay": false,
"lr_scheduler": "multisteplr",
"lr_decay_ratio": 0.1,
"steps": [5, 20, 40, 70],
"step_size": 10,
"lr_T_max": 30,
"lr_eta_min": 0,
"lr_patience": 10,
"lr_threshold": 1e-4,
"clip_grad_norm": false,
"max_grad_norm": 1.0,
"use_early_stop": false,
"patience": 50,
"log_level": "INFO",
"log_every": 1,
"saved_model": true,
"load_best_epoch": true,
"hyper_tune": false,
"loss_p0": 0.5,
"loss_p1": 0.25,
"loss_p2": 0.25
}
26 changes: 26 additions & 0 deletions libtraffic/config/model/traffic_state_pred/GEML.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"scaler": "minmax01",
"load_external": false,
"normal_external": false,
"ext_scaler": "none",
"add_time_in_day": true,
"add_day_in_week": false,
"use_row_column": true,
"max_epoch": 100,
"learner": "adam",
"learning_rate": 0.001,
"weight_decay": 1e-6,
"lr_beta1": 0.9,
"lr_beta2": 0.999,
"lr_decay": false,
"clip_grad_norm": false,
"use_early_stop": true,
"patience": 50,

"embed_dim": 32,
"p_interval": 1,

"loss_p0": 0.5,
"loss_p1": 0.25,
"loss_p2": 0.25
}
8 changes: 7 additions & 1 deletion libtraffic/config/task_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@
"ASTGCN", "MSTGCN", "MTGNN", "ACFM", "STResNet", "RNN", "LSTM", "GRU", "AutoEncoder", "Seq2Seq",
"STResNetCommon", "ACFMCommon", "ASTGCNCommon", "MSTGCNCommon","ToGCN", "CONVGCN", "STG2Seq",
"DMVSTNet", "ATDM", "GMAN", "GTS", "STDN", "HGCN", "STSGCN", "STAGGCN", "STNN", "ResLSTM", "DGCN",
"MultiSTGCnet", "STMGAT", "CRANN", "STTN", "CONVGCNCommon", "DSAN", "DKFN", "CCRNN", "MultiSTGCnetCommon"],
"MultiSTGCnet", "STMGAT", "CRANN", "STTN", "CONVGCNCommon", "DSAN", "DKFN", "CCRNN", "MultiSTGCnetCommon",
"GEML"],
"allowed_dataset": ["METR_LA", "PEMS_BAY", "PEMSD3", "PEMSD4", "PEMSD7", "PEMSD8", "PEMSD7(M)",
"LOOP_SEATTLE", "LOS_LOOP", "LOS_LOOP_SMALL", "Q_TRAFFIC", "SZ_TAXI",
"NYCBike20140409", "NYCBike20160708", "NYCBike20160809", "NYCTaxi20140112",
Expand Down Expand Up @@ -315,6 +316,11 @@
"dataset_class": "TrafficStatePointDataset",
"executor": "TrafficStateExecutor",
"evaluator": "TrafficStateEvaluator"
},
"GEML": {
"dataset_class": "TrafficStateOdDataset",
"executor": "GEMLExecutor",
"evaluator": "TrafficStateEvaluator"
}
},
"map_matching": {
Expand Down
2 changes: 2 additions & 0 deletions libtraffic/data/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
TrafficStateGridDataset
from libtraffic.data.dataset.traffic_state_grid_od_dataset import \
TrafficStateGridOdDataset
from libtraffic.data.dataset.traffic_state_od_dataset import TrafficStateOdDataset
from libtraffic.data.dataset.acfm_dataset import ACFMDataset
from libtraffic.data.dataset.tgclstm_dataset import TGCLSTMDataset
from libtraffic.data.dataset.astgcn_dataset import ASTGCNDataset
Expand Down Expand Up @@ -36,6 +37,7 @@
"TrafficStateCPTDataset",
"TrafficStatePointDataset",
"TrafficStateGridDataset",
"TrafficStateOdDataset",
"TrafficStateGridOdDataset",
"ACFMDataset",
"TGCLSTMDataset",
Expand Down
48 changes: 48 additions & 0 deletions libtraffic/data/dataset/traffic_state_datatset.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,54 @@ def _load_grid_4d(self, filename):
self._logger.info("Loaded file " + filename + '.grid' + ', shape=' + str(data.shape))
return data

def _load_od_4d(self, filename):
"""
加载.od文件,格式[dyna_id, type, time, origin_id, destination_id properties(若干列)],
.geo文件中的id顺序应该跟.dyna中一致,
其中全局参数`data_col`用于指定需要加载的数据的列,不设置则默认全部加载
Args:
filename(str): 数据文件名,不包含后缀
Returns:
np.ndarray: 数据数组, 4d-array: (len_time, len_row, len_column, feature_dim)
"""
self._logger.info("Loading file " + filename + '.od')
odfile = pd.read_csv(self.data_path + filename + '.od')
if self.data_col != '': # 根据指定的列加载数据集
if isinstance(self.data_col, list):
data_col = self.data_col.copy()
else: # str
data_col = [self.data_col].copy()
data_col.insert(0, 'time')
data_col.insert(1, 'origin_id')
data_col.insert(2, 'destination_id')
odfile = odfile[data_col]
else: # 不指定则加载所有列
odfile = odfile[odfile.columns[2:]] # 从time列开始所有列
# 求时间序列
self.timesolts = list(odfile['time'][:int(odfile.shape[0] / self.num_nodes / self.num_nodes)])
self.idx_of_timesolts = dict()
if not odfile['time'].isna().any(): # 时间没有空值
self.timesolts = list(map(lambda x: x.replace('T', ' ').replace('Z', ''), self.timesolts))
self.timesolts = np.array(self.timesolts, dtype='datetime64[ns]')
for idx, _ts in enumerate(self.timesolts):
self.idx_of_timesolts[_ts] = idx

feature_dim = len(odfile.columns) - 3
df = odfile[odfile.columns[-feature_dim:]]
len_time = len(self.timesolts)
data = np.zeros((self.num_nodes, self.num_nodes, len_time, feature_dim))
for i in range(self.num_nodes):
origin_index = i * len_time * self.num_nodes # 每个起点占据len_t*n行
for j in range(self.num_nodes):
destination_index = j * len_time # 每个终点占据len_t行
index = origin_index + destination_index
data[i][j] = df[index:index + len_time].values
data = data.transpose((2, 0, 1, 3)) # (len_time, num_nodes, num_nodes, feature_dim)
self._logger.info("Loaded file " + filename + '.od' + ', shape=' + str(data.shape))
return data

def _load_grid_od_4d(self, filename):
"""
加载.gridod文件,格式[dyna_id, type, time, origin_row_id, origin_column_id,
Expand Down
55 changes: 55 additions & 0 deletions libtraffic/data/dataset/traffic_state_od_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os

import numpy as np

from libtraffic.data.dataset import TrafficStateDataset


class TrafficStateOdDataset(TrafficStateDataset):
def __init__(self, config):
super().__init__(config)
self.cache_file_name = os.path.join('./libtraffic/cache/dataset_cache/',
'od_based_{}.npz'.format(self.parameters_str))
self._load_rel() # don't care whether there is a .rel file

def _load_dyna(self, filename):
return super(TrafficStateOdDataset, self)._load_od_4d(filename)

def _load_geo(self):
"""
加载.geo文件,格式[geo_id, type, coordinates, properties(若干列)]
"""
super()._load_geo()

def _load_rel(self):
"""
加载.rel文件,格式[rel_id, type, origin_id, destination_id, properties(若干列)]
Returns:
np.ndarray: self.adj_mx, N*N的邻接矩阵
"""
super()._load_rel()

def _add_external_information(self, df, ext_data=None):
"""
增加外部信息(一周中的星期几/day of week,一天中的某个时刻/time of day,外部数据),
Args:
df(np.ndarray): 交通状态数据多维数组, (len_time, ..., feature_dim)
ext_data(np.ndarray): 外部数据
Returns:
np.ndarray: 融合后的外部数据和交通状态数据, (len_time, ..., feature_dim_plus)
"""
return super()._add_external_information_4d(df, ext_data)

def get_data_feature(self):
"""
返回数据集特征,scaler是归一化方法,adj_mx是邻接矩阵,num_nodes是网格的个数,
feature_dim是输入数据的维度,output_dim是模型输出的维度
Returns:
dict: 包含数据集的相关特征的字典
"""
return {"scaler": self.scaler, "adj_mx": self.adj_mx,
"num_nodes": self.num_nodes, "feature_dim": self.feature_dim, "ext_dim": self.ext_dim,
"output_dim": self.output_dim, "num_batches": self.num_batches}
12 changes: 7 additions & 5 deletions libtraffic/executor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from libtraffic.executor.traj_loc_pred_executor import TrajLocPredExecutor
from libtraffic.executor.traffic_state_executor import TrafficStateExecutor
from libtraffic.executor.dcrnn_executor import DCRNNExecutor
from libtraffic.executor.mtgnn_executor import MTGNNExecutor
from libtraffic.executor.hyper_tuning import HyperTuning
from libtraffic.executor.geml_executor import GEMLExecutor
from libtraffic.executor.geosan_executor import GeoSANExecutor
from libtraffic.executor.hyper_tuning import HyperTuning
from libtraffic.executor.map_matching_executor import MapMatchingExecutor
from libtraffic.executor.mtgnn_executor import MTGNNExecutor
from libtraffic.executor.traffic_state_executor import TrafficStateExecutor
from libtraffic.executor.traj_loc_pred_executor import TrajLocPredExecutor

__all__ = [
"TrajLocPredExecutor",
Expand All @@ -13,5 +14,6 @@
"MTGNNExecutor",
"HyperTuning",
"GeoSANExecutor",
"MapMatchingExecutor"
"MapMatchingExecutor",
"GEMLExecutor"
]
109 changes: 109 additions & 0 deletions libtraffic/executor/geml_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import time
from functools import partial

import numpy as np
import torch
import os

from libtraffic.executor.traffic_state_executor import TrafficStateExecutor
from libtraffic.model import loss


class GEMLExecutor(TrafficStateExecutor):
def __init__(self, config, model):
TrafficStateExecutor.__init__(self, config, model)
self.loss_p0 = config.get("loss_p0", 0.5)
self.loss_p1 = config.get("loss_p1", 0.25)
self.loss_p2 = config.get("loss_p2", 0.25)

# 只是重载了 predict 的输出读取
def evaluate(self, test_dataloader):
"""
use model to test data
Args:
test_dataloader(torch.Dataloader): Dataloader
"""
self._logger.info('Start evaluating ...')
with torch.no_grad():
self.model.eval()
# self.evaluator.clear()
y_truths = []
y_preds = []
for batch in test_dataloader:
batch.to_tensor(self.device)
output, _, _ = self.model.predict(batch)
y_true = self._scaler.inverse_transform(batch['y'][..., :self.output_dim])
y_pred = self._scaler.inverse_transform(output[..., :self.output_dim])
y_truths.append(y_true.cpu().numpy())
y_preds.append(y_pred.cpu().numpy())
# evaluate_input = {'y_true': y_true, 'y_pred': y_pred}
# self.evaluator.collect(evaluate_input)
# self.evaluator.save_result(self.evaluate_res_dir)
y_preds = np.concatenate(y_preds, axis=0)
y_truths = np.concatenate(y_truths, axis=0) # concatenate on batch
outputs = {'prediction': y_preds, 'truth': y_truths}
filename = \
time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime(time.time())) + '_' \
+ self.config['model'] + '_' + self.config['dataset'] + '_predictions.npz'
np.savez_compressed(os.path.join(self.evaluate_res_dir, filename), **outputs)
self.evaluator.clear()
self.evaluator.collect({'y_true': torch.tensor(y_truths), 'y_pred': torch.tensor(y_preds)})
test_result = self.evaluator.save_result(self.evaluate_res_dir)
return test_result

# 只是重载了 predict 的输出读取
def _build_train_loss(self):
"""
根据全局参数`train_loss`选择训练过程的loss函数
如果该参数为none,则需要使用模型自定义的loss函数
注意,loss函数应该接收`Batch`对象作为输入,返回对应的loss(torch.tensor)
"""
if self.train_loss.lower() == 'none':
self._logger.warning('Received none train loss func and will use the loss func defined in the model.')
return None
if self.train_loss.lower() not in ['mae', 'mse', 'rmse', 'mape', 'logcosh', 'huber', 'quantile', 'masked_mae',
'masked_mse', 'masked_rmse', 'masked_mape', 'r2', 'evar']:
self._logger.warning('Received unrecognized train loss function, set default mae loss func.')
else:
self._logger.info('You select `{}` as train loss function.'.format(self.train_loss.lower()))

def func(batch):
y_true = batch['y']
y_in_true = torch.sum(y_true, dim=-2, keepdim=True) # (B, TO, N, 1)
y_out_true = torch.sum(y_true.permute(0, 1, 3, 2, 4), dim=-2, keepdim=True) # (B, TO, N, 1)
y_predicted, y_in, y_out = self.model.predict(batch)
y_true = self._scaler.inverse_transform(y_true[..., :self.output_dim])
y_predicted = self._scaler.inverse_transform(y_predicted[..., :self.output_dim])
if self.train_loss.lower() == 'mae':
lf = loss.masked_mae_torch
elif self.train_loss.lower() == 'mse':
lf = loss.masked_mse_torch
elif self.train_loss.lower() == 'rmse':
lf = loss.masked_rmse_torch
elif self.train_loss.lower() == 'mape':
lf = loss.masked_mape_torch
elif self.train_loss.lower() == 'logcosh':
lf = loss.log_cosh_loss
elif self.train_loss.lower() == 'huber':
lf = loss.huber_loss
elif self.train_loss.lower() == 'quantile':
lf = loss.quantile_loss
elif self.train_loss.lower() == 'masked_mae':
lf = partial(loss.masked_mae_torch, null_val=0)
elif self.train_loss.lower() == 'masked_mse':
lf = partial(loss.masked_mse_torch, null_val=0)
elif self.train_loss.lower() == 'masked_rmse':
lf = partial(loss.masked_rmse_torch, null_val=0)
elif self.train_loss.lower() == 'masked_mape':
lf = partial(loss.masked_mape_torch, null_val=0)
elif self.train_loss.lower() == 'r2':
lf = loss.r2_score_torch
elif self.train_loss.lower() == 'evar':
lf = loss.explained_variance_score_torch
else:
lf = loss.masked_mae_torch
return self.loss_p0 * lf(y_predicted, y_true) + self.loss_p1 * lf(y_in, y_in_true) + self.loss_p2 * lf(y_out,
y_out_true)

return func
4 changes: 2 additions & 2 deletions libtraffic/model/traffic_demand_prediction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from libtraffic.model.traffic_demand_prediction.STG2Seq import STG2Seq
from libtraffic.model.traffic_demand_prediction.DMVSTNet import DMVSTNet
from libtraffic.model.traffic_demand_prediction.CCRNN import CCRNN
from libtraffic.model.traffic_demand_prediction.DMVSTNet import DMVSTNet
from libtraffic.model.traffic_demand_prediction.STG2Seq import STG2Seq

__all__ = [
"STG2Seq",
Expand Down
Loading

0 comments on commit 7eca580

Please sign in to comment.