Skip to content

Commit

Permalink
fix HA to match result in DCRNN (LibCity#231)
Browse files Browse the repository at this point in the history
  • Loading branch information
XBR-1111 committed Dec 10, 2021
1 parent 93af58f commit cac90a0
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions test/test_HA.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@

config = {
'model': 'HA',
'lag': [7*24*12, 1],
'weight': [0.2, 0.8],
'n_sample': [4, 4],
'lag': [24 * 7 * 12],
'weight': [1],
'dataset': 'METR_LA',
'train_rate': 0.7,
'eval_rate': 0.1,
'input_window': 12,
'output_windows': 3,
'null_value': 0,
'metrics': ['MAE', 'MAPE', 'MSE', 'RMSE', 'masked_MAE',
'masked_MAPE', 'masked_MSE', 'masked_RMSE', 'R2', 'EVAR']
}
Expand Down Expand Up @@ -78,15 +78,13 @@ def historical_average(data):
output_window = config.get('output_window', 3)
lag = config.get('lag', 7 * 24 * 12)
weight = config.get('weight', 1.0)
n_sample = config.get('n_sample', 4)
null_value = config.get('null_value', 0)

if isinstance(lag, int):
lag = [lag]
if isinstance(weight, int) or isinstance(weight, float):
weight = [weight]
if isinstance(n_sample, int):
n_sample = [n_sample]
assert sum(weight) == 1
assert int(t * (train_rate + eval_rate)) > max(np.array(n_sample) * np.array(lag))

y_true = []
y_pred = []
Expand All @@ -96,8 +94,12 @@ def historical_average(data):
# y_pred
y_pred_i = 0
for j in range(len(lag)):
# 隔lag[j]时间步采样n_sample[j]次, 得到(n_sample[j], N, F)取平均值得到(N, F), 最后用weight[j]加权
y_pred_i += weight[j] * np.mean(data[i - n_sample[j] * lag[j]:i:lag[j], :, :], axis=0)
# 隔lag[j]时间步在整个训练集采样, 得到(n_sample, N, F)取平均值得到(N, F), 最后用weight[j]加权
inds = [j for j in range(i % lag[j], int(t * (train_rate + eval_rate)), lag[j])]
history = data[inds, :, :]
# 对得到的history数据去除空值后求平均
y_pred_i += weight[j] * np.mean(data[inds, :, :], axis=0, where=history != null_value)
y_pred_i[np.isnan(y_pred_i)] = 0
y_pred.append(y_pred_i) # (N, F)

y_pred = np.array(y_pred) # (test_size, N, F)
Expand All @@ -113,8 +115,10 @@ def main():
print(config)
data = get_data(config.get('dataset', ''))
y_pred, y_true = historical_average(data)
# y_pred = y_pred[:, :, :, 0]
# y_true = y_true[:, :, :, 0]
evaluate_model(y_pred=y_pred, y_true=y_true, metrics=config['metrics'],
path=config['model']+'_'+config['dataset']+'_metrics.csv')
path=config['model'] + '_' + config['dataset'] + '_metrics.csv')


if __name__ == '__main__':
Expand Down

0 comments on commit cac90a0

Please sign in to comment.