From 766c164546b1bb718113ab1baae5db1ed7c3ee04 Mon Sep 17 00:00:00 2001 From: houcheng Date: Mon, 24 Jun 2024 16:13:00 +0800 Subject: [PATCH] fix: use the time interval to calculate the TESTAM vocab szie parameter --- libcity/config/model/traffic_state_pred/TESTAM.json | 3 +-- libcity/model/traffic_speed_prediction/TESTAM.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/libcity/config/model/traffic_state_pred/TESTAM.json b/libcity/config/model/traffic_state_pred/TESTAM.json index 9a823e9d..bdacb93b 100644 --- a/libcity/config/model/traffic_state_pred/TESTAM.json +++ b/libcity/config/model/traffic_state_pred/TESTAM.json @@ -24,6 +24,5 @@ "hidden_size": 32, "layers": 3, "is_quantile": true, - "quantile": 0.7, - "vocab_size": 288 + "quantile": 0.7 } \ No newline at end of file diff --git a/libcity/model/traffic_speed_prediction/TESTAM.py b/libcity/model/traffic_speed_prediction/TESTAM.py index f8be2242..b2f82438 100644 --- a/libcity/model/traffic_speed_prediction/TESTAM.py +++ b/libcity/model/traffic_speed_prediction/TESTAM.py @@ -566,7 +566,7 @@ def __init__(self, config, data_feature): layers = config.get("layers", 3) self.is_quantile = config.get("is_quantile", False) self.quantile = config.get("quantile", 0.7) - self.vocab_size = config.get("vocab_size", 288) + self.vocab_size = 24 * 60 * 60 // config.get("time_intervals", 300) self.dropout = dropout self.prob_mul = prob_mul