From 7ae9f534d1bf7917eaf3b8358bb2fe8d09eb8087 Mon Sep 17 00:00:00 2001 From: luyuheng Date: Tue, 9 Jul 2024 16:06:00 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=BA=E6=89=80=E6=9C=89masked=E7=9A=84state?= =?UTF-8?q?=E9=A2=84=E6=B5=8B=E8=AF=84=E4=BC=B0=E5=87=BD=E6=95=B0=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0mask=5Fval=EF=BC=8C=E4=BF=AE=E5=A4=8D=E9=81=97?= =?UTF-8?q?=E7=95=99=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../config/model/traffic_state_pred/STSSL.json | 2 +- libcity/evaluator/traffic_state_evaluator.py | 16 ++++++++++------ libcity/model/loss.py | 12 +++++++----- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/libcity/config/model/traffic_state_pred/STSSL.json b/libcity/config/model/traffic_state_pred/STSSL.json index 7ce6268b..77d205cc 100644 --- a/libcity/config/model/traffic_state_pred/STSSL.json +++ b/libcity/config/model/traffic_state_pred/STSSL.json @@ -34,5 +34,5 @@ "pad_back_trend": 0, "interval_period": 1, "interval_trend": 7, - "eval_mask_val": 5.001 + "mask_val": 5.001 } diff --git a/libcity/evaluator/traffic_state_evaluator.py b/libcity/evaluator/traffic_state_evaluator.py index 77968cc2..6255eba3 100644 --- a/libcity/evaluator/traffic_state_evaluator.py +++ b/libcity/evaluator/traffic_state_evaluator.py @@ -13,10 +13,10 @@ class TrafficStateEvaluator(AbstractEvaluator): def __init__(self, config): self.metrics = config.get('metrics', ['MAE']) # 评估指标, 是一个 list self.allowed_metrics = ['MAE', 'MSE', 'RMSE', 'MAPE', 'masked_MAE', 'masked_MSE', 'masked_RMSE', 'masked_MAPE', - 'R2', 'EVAR', "IN_masked_MAE", "IN_masked_MAPE", "OUT_masked_MAE", "OUT_masked_MAPE"] + 'R2', 'EVAR'] self.save_modes = config.get('save_mode', ['csv', 'json']) self.mode = config.get('evaluator_mode', 'single') # or average - self.mask_val = config.get('eval_mask_val', None) + self.mask_val = config.get('mask_val', None) self.config = config self.len_timeslots = 0 self.result = {} # 每一种指标的结果 @@ -60,10 +60,12 @@ def collect(self, batch): mask_val=self.mask_val).item()) elif metric == 'masked_MSE': self.intermediate_result[metric + '@' + str(i)].append( - loss.masked_mse_torch(y_pred[:, :i], y_true[:, :i], 0).item()) + loss.masked_mse_torch(y_pred[:, :i], y_true[:, :i], 0, + mask_val=self.mask_val).item()) elif metric == 'masked_RMSE': self.intermediate_result[metric + '@' + str(i)].append( - loss.masked_rmse_torch(y_pred[:, :i], y_true[:, :i], 0).item()) + loss.masked_rmse_torch(y_pred[:, :i], y_true[:, :i], 0, + mask_val=self.mask_val).item()) elif metric == 'masked_MAPE': self.intermediate_result[metric + '@' + str(i)].append( loss.masked_mape_torch(y_pred[:, :i], y_true[:, :i], 0, @@ -95,10 +97,12 @@ def collect(self, batch): mask_val=self.mask_val).item()) elif metric == 'masked_MSE': self.intermediate_result[metric + '@' + str(i)].append( - loss.masked_mse_torch(y_pred[:, i - 1], y_true[:, i - 1], 0).item()) + loss.masked_mse_torch(y_pred[:, i - 1], y_true[:, i - 1], 0, + mask_val=self.mask_val).item()) elif metric == 'masked_RMSE': self.intermediate_result[metric + '@' + str(i)].append( - loss.masked_rmse_torch(y_pred[:, i - 1], y_true[:, i - 1], 0).item()) + loss.masked_rmse_torch(y_pred[:, i - 1], y_true[:, i - 1], 0, + mask_val=self.mask_val).item()) elif metric == 'masked_MAPE': self.intermediate_result[metric + '@' + str(i)].append( loss.masked_mape_torch(y_pred[:, i - 1], y_true[:, i - 1], 0, diff --git a/libcity/model/loss.py b/libcity/model/loss.py index 1862ea89..7a3c8e40 100644 --- a/libcity/model/loss.py +++ b/libcity/model/loss.py @@ -56,7 +56,7 @@ def quantile_loss(preds, labels, delta=0.25): return torch.mean(torch.where(condition, large_res, small_res)) -def masked_mape_torch(preds, labels, null_val=np.nan, eps=0, mask_val=np.nan): +def masked_mape_torch(preds, labels, null_val=np.nan, eps=0, mask_val=None): labels[torch.abs(labels) < 1e-4] = 0 if np.isnan(null_val) and eps != 0: loss = torch.abs((preds - labels) / (labels + eps)) @@ -65,7 +65,7 @@ def masked_mape_torch(preds, labels, null_val=np.nan, eps=0, mask_val=np.nan): mask = ~torch.isnan(labels) else: mask = labels.ne(null_val) - if not np.isnan(mask_val): + if mask_val: mask &= labels.ge(mask_val) mask = mask.float() mask /= torch.mean(mask) @@ -76,12 +76,14 @@ def masked_mape_torch(preds, labels, null_val=np.nan, eps=0, mask_val=np.nan): return torch.mean(loss) -def masked_mse_torch(preds, labels, null_val=np.nan): +def masked_mse_torch(preds, labels, null_val=np.nan, mask_val=None): labels[torch.abs(labels) < 1e-4] = 0 if np.isnan(null_val): mask = ~torch.isnan(labels) else: mask = labels.ne(null_val) + if mask_val: + mask &= labels.ge(mask_val) mask = mask.float() mask /= torch.mean(mask) mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask) @@ -91,10 +93,10 @@ def masked_mse_torch(preds, labels, null_val=np.nan): return torch.mean(loss) -def masked_rmse_torch(preds, labels, null_val=np.nan): +def masked_rmse_torch(preds, labels, null_val=np.nan, mask_val=None): labels[torch.abs(labels) < 1e-4] = 0 return torch.sqrt(masked_mse_torch(preds=preds, labels=labels, - null_val=null_val)) + null_val=null_val, mask_val=mask_val)) def r2_score_torch(preds, labels):