Skip to content

Commit

Permalink
fix: use time_intervals to calculate STPGCN points_per_hour parameter (
Browse files Browse the repository at this point in the history
…LibCity#425)

* fix: use time_intervals to calculate points_per_hour

* fix: remove time_intervals parameter
  • Loading branch information
hczs committed Jul 2, 2024
1 parent 04232be commit a0c81ca
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 4 deletions.
1 change: 0 additions & 1 deletion libcity/config/data/STPGCNDataset.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
"input_window": 12,
"output_window": 12,

"points_per_hour": 12,
"alpha": 4,
"beta": 2
}
4 changes: 2 additions & 2 deletions libcity/data/dataset/dataset_subclass/stpgcn_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, config):
'point_based_{}.npz'.format(self.parameters_str))

self.feature_name = {'X': 'float', 'y': 'float', 'pos_w': 'int', 'pos_d': 'int'}
self.points_per_hour = config.get('points_per_hour', 12)
self.points_per_hour = self.time_intervals // 60
self.alpha = config.get('alpha', 4)
self.beta = config.get('beta', 2)
self.t_size = self.beta + 1
Expand All @@ -60,7 +60,7 @@ def get_data_feature(self):
"""
return {"scaler": self.scaler, "ext_dim": self.ext_dim, "spatial_distance": self.spatial_distance,
"range_mask": self.range_mask, "num_nodes": self.num_nodes, "feature_dim": self.feature_dim,
"output_dim": self.output_dim, "num_batches": self.num_batches}
"output_dim": self.output_dim, "num_batches": self.num_batches, "points_per_hour": self.points_per_hour}

def _load_cache_train_val_test(self):
self._logger.info('Loading ' + self.cache_file_name)
Expand Down
2 changes: 1 addition & 1 deletion libcity/model/traffic_flow_prediction/STPGCN.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def __init__(self, config, data_feature):
self.beta = config.get("beta", 2)
self.t_size = self.beta + 1
self.week_len = 7
self.day_len = config.get("points_per_hour") * 24
self.day_len = self.data_feature.get("points_per_hour") * 24
self.range_mask = torch.Tensor(self.range_mask).to(self.device)

self.PAD = GeneratePad(self.device, self.C, self.V, self.d, self.beta)
Expand Down

0 comments on commit a0c81ca

Please sign in to comment.