forked from LibCity/Bigscity-LibCity
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Implement GEML model, and add the TrafficStateOdDataset. * Optimize directory structure for OD matrix prediction * Revise typo and gpu-device error * Debug for Customized loss function calculating. And the fix some issues. * Optimize directory structure for OD matrix prediction * Optimize directory structure for OD matrix prediction * Debug: The semantic matrix generate was wrong. * Debug: fix import error
- Loading branch information
Showing
14 changed files
with
601 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
{ | ||
"batch_size": 64, | ||
"cache_dataset": true, | ||
"num_workers": 0, | ||
"pad_with_last_sample": true, | ||
"train_rate": 0.7, | ||
"eval_rate": 0.1, | ||
"scaler": "none", | ||
"load_external": false, | ||
"normal_external": false, | ||
"ext_scaler": "none", | ||
"input_window": 12, | ||
"output_window": 12, | ||
"add_time_in_day": false, | ||
"add_day_in_week": false | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
{ | ||
"gpu": true, | ||
"gpu_id": 0, | ||
"max_epoch": 100, | ||
"train_loss": "none", | ||
"epoch": 0, | ||
"learner": "adam", | ||
"learning_rate": 0.01, | ||
"weight_decay": 0, | ||
"lr_epsilon": 1e-8, | ||
"lr_beta1": 0.9, | ||
"lr_beta2": 0.999, | ||
"lr_alpha": 0.99, | ||
"lr_momentum": 0, | ||
"lr_decay": false, | ||
"lr_scheduler": "multisteplr", | ||
"lr_decay_ratio": 0.1, | ||
"steps": [5, 20, 40, 70], | ||
"step_size": 10, | ||
"lr_T_max": 30, | ||
"lr_eta_min": 0, | ||
"lr_patience": 10, | ||
"lr_threshold": 1e-4, | ||
"clip_grad_norm": false, | ||
"max_grad_norm": 1.0, | ||
"use_early_stop": false, | ||
"patience": 50, | ||
"log_level": "INFO", | ||
"log_every": 1, | ||
"saved_model": true, | ||
"load_best_epoch": true, | ||
"hyper_tune": false, | ||
"loss_p0": 0.5, | ||
"loss_p1": 0.25, | ||
"loss_p2": 0.25 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
{ | ||
"scaler": "minmax01", | ||
"load_external": false, | ||
"normal_external": false, | ||
"ext_scaler": "none", | ||
"add_time_in_day": true, | ||
"add_day_in_week": false, | ||
"use_row_column": true, | ||
"max_epoch": 100, | ||
"learner": "adam", | ||
"learning_rate": 0.001, | ||
"weight_decay": 1e-6, | ||
"lr_beta1": 0.9, | ||
"lr_beta2": 0.999, | ||
"lr_decay": false, | ||
"clip_grad_norm": false, | ||
"use_early_stop": true, | ||
"patience": 50, | ||
|
||
"embed_dim": 32, | ||
"p_interval": 1, | ||
|
||
"loss_p0": 0.5, | ||
"loss_p1": 0.25, | ||
"loss_p2": 0.25 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import os | ||
|
||
import numpy as np | ||
|
||
from libtraffic.data.dataset import TrafficStateDataset | ||
|
||
|
||
class TrafficStateOdDataset(TrafficStateDataset): | ||
def __init__(self, config): | ||
super().__init__(config) | ||
self.cache_file_name = os.path.join('./libtraffic/cache/dataset_cache/', | ||
'od_based_{}.npz'.format(self.parameters_str)) | ||
self._load_rel() # don't care whether there is a .rel file | ||
|
||
def _load_dyna(self, filename): | ||
return super(TrafficStateOdDataset, self)._load_od_4d(filename) | ||
|
||
def _load_geo(self): | ||
""" | ||
加载.geo文件,格式[geo_id, type, coordinates, properties(若干列)] | ||
""" | ||
super()._load_geo() | ||
|
||
def _load_rel(self): | ||
""" | ||
加载.rel文件,格式[rel_id, type, origin_id, destination_id, properties(若干列)] | ||
Returns: | ||
np.ndarray: self.adj_mx, N*N的邻接矩阵 | ||
""" | ||
super()._load_rel() | ||
|
||
def _add_external_information(self, df, ext_data=None): | ||
""" | ||
增加外部信息(一周中的星期几/day of week,一天中的某个时刻/time of day,外部数据), | ||
Args: | ||
df(np.ndarray): 交通状态数据多维数组, (len_time, ..., feature_dim) | ||
ext_data(np.ndarray): 外部数据 | ||
Returns: | ||
np.ndarray: 融合后的外部数据和交通状态数据, (len_time, ..., feature_dim_plus) | ||
""" | ||
return super()._add_external_information_4d(df, ext_data) | ||
|
||
def get_data_feature(self): | ||
""" | ||
返回数据集特征,scaler是归一化方法,adj_mx是邻接矩阵,num_nodes是网格的个数, | ||
feature_dim是输入数据的维度,output_dim是模型输出的维度 | ||
Returns: | ||
dict: 包含数据集的相关特征的字典 | ||
""" | ||
return {"scaler": self.scaler, "adj_mx": self.adj_mx, | ||
"num_nodes": self.num_nodes, "feature_dim": self.feature_dim, "ext_dim": self.ext_dim, | ||
"output_dim": self.output_dim, "num_batches": self.num_batches} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import time | ||
from functools import partial | ||
|
||
import numpy as np | ||
import torch | ||
import os | ||
|
||
from libtraffic.executor.traffic_state_executor import TrafficStateExecutor | ||
from libtraffic.model import loss | ||
|
||
|
||
class GEMLExecutor(TrafficStateExecutor): | ||
def __init__(self, config, model): | ||
TrafficStateExecutor.__init__(self, config, model) | ||
self.loss_p0 = config.get("loss_p0", 0.5) | ||
self.loss_p1 = config.get("loss_p1", 0.25) | ||
self.loss_p2 = config.get("loss_p2", 0.25) | ||
|
||
# 只是重载了 predict 的输出读取 | ||
def evaluate(self, test_dataloader): | ||
""" | ||
use model to test data | ||
Args: | ||
test_dataloader(torch.Dataloader): Dataloader | ||
""" | ||
self._logger.info('Start evaluating ...') | ||
with torch.no_grad(): | ||
self.model.eval() | ||
# self.evaluator.clear() | ||
y_truths = [] | ||
y_preds = [] | ||
for batch in test_dataloader: | ||
batch.to_tensor(self.device) | ||
output, _, _ = self.model.predict(batch) | ||
y_true = self._scaler.inverse_transform(batch['y'][..., :self.output_dim]) | ||
y_pred = self._scaler.inverse_transform(output[..., :self.output_dim]) | ||
y_truths.append(y_true.cpu().numpy()) | ||
y_preds.append(y_pred.cpu().numpy()) | ||
# evaluate_input = {'y_true': y_true, 'y_pred': y_pred} | ||
# self.evaluator.collect(evaluate_input) | ||
# self.evaluator.save_result(self.evaluate_res_dir) | ||
y_preds = np.concatenate(y_preds, axis=0) | ||
y_truths = np.concatenate(y_truths, axis=0) # concatenate on batch | ||
outputs = {'prediction': y_preds, 'truth': y_truths} | ||
filename = \ | ||
time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime(time.time())) + '_' \ | ||
+ self.config['model'] + '_' + self.config['dataset'] + '_predictions.npz' | ||
np.savez_compressed(os.path.join(self.evaluate_res_dir, filename), **outputs) | ||
self.evaluator.clear() | ||
self.evaluator.collect({'y_true': torch.tensor(y_truths), 'y_pred': torch.tensor(y_preds)}) | ||
test_result = self.evaluator.save_result(self.evaluate_res_dir) | ||
return test_result | ||
|
||
# 只是重载了 predict 的输出读取 | ||
def _build_train_loss(self): | ||
""" | ||
根据全局参数`train_loss`选择训练过程的loss函数 | ||
如果该参数为none,则需要使用模型自定义的loss函数 | ||
注意,loss函数应该接收`Batch`对象作为输入,返回对应的loss(torch.tensor) | ||
""" | ||
if self.train_loss.lower() == 'none': | ||
self._logger.warning('Received none train loss func and will use the loss func defined in the model.') | ||
return None | ||
if self.train_loss.lower() not in ['mae', 'mse', 'rmse', 'mape', 'logcosh', 'huber', 'quantile', 'masked_mae', | ||
'masked_mse', 'masked_rmse', 'masked_mape', 'r2', 'evar']: | ||
self._logger.warning('Received unrecognized train loss function, set default mae loss func.') | ||
else: | ||
self._logger.info('You select `{}` as train loss function.'.format(self.train_loss.lower())) | ||
|
||
def func(batch): | ||
y_true = batch['y'] | ||
y_in_true = torch.sum(y_true, dim=-2, keepdim=True) # (B, TO, N, 1) | ||
y_out_true = torch.sum(y_true.permute(0, 1, 3, 2, 4), dim=-2, keepdim=True) # (B, TO, N, 1) | ||
y_predicted, y_in, y_out = self.model.predict(batch) | ||
y_true = self._scaler.inverse_transform(y_true[..., :self.output_dim]) | ||
y_predicted = self._scaler.inverse_transform(y_predicted[..., :self.output_dim]) | ||
if self.train_loss.lower() == 'mae': | ||
lf = loss.masked_mae_torch | ||
elif self.train_loss.lower() == 'mse': | ||
lf = loss.masked_mse_torch | ||
elif self.train_loss.lower() == 'rmse': | ||
lf = loss.masked_rmse_torch | ||
elif self.train_loss.lower() == 'mape': | ||
lf = loss.masked_mape_torch | ||
elif self.train_loss.lower() == 'logcosh': | ||
lf = loss.log_cosh_loss | ||
elif self.train_loss.lower() == 'huber': | ||
lf = loss.huber_loss | ||
elif self.train_loss.lower() == 'quantile': | ||
lf = loss.quantile_loss | ||
elif self.train_loss.lower() == 'masked_mae': | ||
lf = partial(loss.masked_mae_torch, null_val=0) | ||
elif self.train_loss.lower() == 'masked_mse': | ||
lf = partial(loss.masked_mse_torch, null_val=0) | ||
elif self.train_loss.lower() == 'masked_rmse': | ||
lf = partial(loss.masked_rmse_torch, null_val=0) | ||
elif self.train_loss.lower() == 'masked_mape': | ||
lf = partial(loss.masked_mape_torch, null_val=0) | ||
elif self.train_loss.lower() == 'r2': | ||
lf = loss.r2_score_torch | ||
elif self.train_loss.lower() == 'evar': | ||
lf = loss.explained_variance_score_torch | ||
else: | ||
lf = loss.masked_mae_torch | ||
return self.loss_p0 * lf(y_predicted, y_true) + self.loss_p1 * lf(y_in, y_in_true) + self.loss_p2 * lf(y_out, | ||
y_out_true) | ||
|
||
return func |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.