Skip to content

Commit

Permalink
为所有masked的state预测评估函数添加mask_val,修复遗留问题
Browse files Browse the repository at this point in the history
  • Loading branch information
Lucas-lyh committed Jul 9, 2024
1 parent 4a6286f commit 7ae9f53
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 12 deletions.
2 changes: 1 addition & 1 deletion libcity/config/model/traffic_state_pred/STSSL.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@
"pad_back_trend": 0,
"interval_period": 1,
"interval_trend": 7,
"eval_mask_val": 5.001
"mask_val": 5.001
}
16 changes: 10 additions & 6 deletions libcity/evaluator/traffic_state_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ class TrafficStateEvaluator(AbstractEvaluator):
def __init__(self, config):
self.metrics = config.get('metrics', ['MAE']) # 评估指标, 是一个 list
self.allowed_metrics = ['MAE', 'MSE', 'RMSE', 'MAPE', 'masked_MAE', 'masked_MSE', 'masked_RMSE', 'masked_MAPE',
'R2', 'EVAR', "IN_masked_MAE", "IN_masked_MAPE", "OUT_masked_MAE", "OUT_masked_MAPE"]
'R2', 'EVAR']
self.save_modes = config.get('save_mode', ['csv', 'json'])
self.mode = config.get('evaluator_mode', 'single') # or average
self.mask_val = config.get('eval_mask_val', None)
self.mask_val = config.get('mask_val', None)
self.config = config
self.len_timeslots = 0
self.result = {} # 每一种指标的结果
Expand Down Expand Up @@ -60,10 +60,12 @@ def collect(self, batch):
mask_val=self.mask_val).item())
elif metric == 'masked_MSE':
self.intermediate_result[metric + '@' + str(i)].append(
loss.masked_mse_torch(y_pred[:, :i], y_true[:, :i], 0).item())
loss.masked_mse_torch(y_pred[:, :i], y_true[:, :i], 0,
mask_val=self.mask_val).item())
elif metric == 'masked_RMSE':
self.intermediate_result[metric + '@' + str(i)].append(
loss.masked_rmse_torch(y_pred[:, :i], y_true[:, :i], 0).item())
loss.masked_rmse_torch(y_pred[:, :i], y_true[:, :i], 0,
mask_val=self.mask_val).item())
elif metric == 'masked_MAPE':
self.intermediate_result[metric + '@' + str(i)].append(
loss.masked_mape_torch(y_pred[:, :i], y_true[:, :i], 0,
Expand Down Expand Up @@ -95,10 +97,12 @@ def collect(self, batch):
mask_val=self.mask_val).item())
elif metric == 'masked_MSE':
self.intermediate_result[metric + '@' + str(i)].append(
loss.masked_mse_torch(y_pred[:, i - 1], y_true[:, i - 1], 0).item())
loss.masked_mse_torch(y_pred[:, i - 1], y_true[:, i - 1], 0,
mask_val=self.mask_val).item())
elif metric == 'masked_RMSE':
self.intermediate_result[metric + '@' + str(i)].append(
loss.masked_rmse_torch(y_pred[:, i - 1], y_true[:, i - 1], 0).item())
loss.masked_rmse_torch(y_pred[:, i - 1], y_true[:, i - 1], 0,
mask_val=self.mask_val).item())
elif metric == 'masked_MAPE':
self.intermediate_result[metric + '@' + str(i)].append(
loss.masked_mape_torch(y_pred[:, i - 1], y_true[:, i - 1], 0,
Expand Down
12 changes: 7 additions & 5 deletions libcity/model/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def quantile_loss(preds, labels, delta=0.25):
return torch.mean(torch.where(condition, large_res, small_res))


def masked_mape_torch(preds, labels, null_val=np.nan, eps=0, mask_val=np.nan):
def masked_mape_torch(preds, labels, null_val=np.nan, eps=0, mask_val=None):
labels[torch.abs(labels) < 1e-4] = 0
if np.isnan(null_val) and eps != 0:
loss = torch.abs((preds - labels) / (labels + eps))
Expand All @@ -65,7 +65,7 @@ def masked_mape_torch(preds, labels, null_val=np.nan, eps=0, mask_val=np.nan):
mask = ~torch.isnan(labels)
else:
mask = labels.ne(null_val)
if not np.isnan(mask_val):
if mask_val:
mask &= labels.ge(mask_val)
mask = mask.float()
mask /= torch.mean(mask)
Expand All @@ -76,12 +76,14 @@ def masked_mape_torch(preds, labels, null_val=np.nan, eps=0, mask_val=np.nan):
return torch.mean(loss)


def masked_mse_torch(preds, labels, null_val=np.nan):
def masked_mse_torch(preds, labels, null_val=np.nan, mask_val=None):
labels[torch.abs(labels) < 1e-4] = 0
if np.isnan(null_val):
mask = ~torch.isnan(labels)
else:
mask = labels.ne(null_val)
if mask_val:
mask &= labels.ge(mask_val)
mask = mask.float()
mask /= torch.mean(mask)
mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask)
Expand All @@ -91,10 +93,10 @@ def masked_mse_torch(preds, labels, null_val=np.nan):
return torch.mean(loss)


def masked_rmse_torch(preds, labels, null_val=np.nan):
def masked_rmse_torch(preds, labels, null_val=np.nan, mask_val=None):
labels[torch.abs(labels) < 1e-4] = 0
return torch.sqrt(masked_mse_torch(preds=preds, labels=labels,
null_val=null_val))
null_val=null_val, mask_val=mask_val))


def r2_score_torch(preds, labels):
Expand Down

0 comments on commit 7ae9f53

Please sign in to comment.