From 8c367da1156fd6276cc0832ebecf52b1081bde47 Mon Sep 17 00:00:00 2001 From: powercheng <43227582+hczs@users.noreply.github.com> Date: Sat, 13 Jul 2024 00:12:07 +0800 Subject: [PATCH] fix: TESTAM adapt to the multiple feature case (#433) * fix: TESTAM remove input_dim and add multi-feature * fix: TESTAM adapt to the multiple feature case --- .../model/traffic_state_pred/TESTAM.json | 1 - libcity/executor/testam_executor.py | 11 +++-- .../model/traffic_speed_prediction/TESTAM.py | 46 ++++++++++++------- 3 files changed, 37 insertions(+), 21 deletions(-) diff --git a/libcity/config/model/traffic_state_pred/TESTAM.json b/libcity/config/model/traffic_state_pred/TESTAM.json index bdacb93b..7706542d 100644 --- a/libcity/config/model/traffic_state_pred/TESTAM.json +++ b/libcity/config/model/traffic_state_pred/TESTAM.json @@ -20,7 +20,6 @@ "dropout": 0.0, "prob_mul": false, - "input_dim": 2, "hidden_size": 32, "layers": 3, "is_quantile": true, diff --git a/libcity/executor/testam_executor.py b/libcity/executor/testam_executor.py index a0bb0c0d..e0c0214b 100644 --- a/libcity/executor/testam_executor.py +++ b/libcity/executor/testam_executor.py @@ -100,15 +100,20 @@ def get_label(ind_loss, gate, real): real = batch['y'] out, gate, res = self.model.predict(batch) - predict = self._scaler.inverse_transform(out[..., :self.output_dim]) + out = out.transpose(1, 3) + predict = self._scaler.inverse_transform(out[..., :self.output_dim]).transpose(1, 3) real = self._scaler.inverse_transform(real) # BTNF -> BFNT real = real[..., :self.output_dim].transpose(1, 3) - ind_loss = loss.masked_mae_torch(self._scaler.inverse_transform(res), real.permute(0, 2, 3, 1), 0.0, + tmp_real = real.permute(0, 2, 3, 1) + ind_loss_real = torch.cat([tmp_real, tmp_real, tmp_real], dim=-1) if self.output_dim > 1 else tmp_real + ind_loss = loss.masked_mae_torch(self._scaler.inverse_transform(res), ind_loss_real, 0.0, reduce=False) if self.is_quantile: gated_loss = loss.masked_mae_torch(predict, real, reduce=False).permute(0, 2, 3, 1) - l_worst_avoidance, l_best_choice = get_quantile_label(gated_loss, gate, real) + avg_loss = gated_loss.mean(dim=-1, keepdim=True) # 平均损失 + tmp_real = real.mean(dim=1, keepdim=True) + l_worst_avoidance, l_best_choice = get_quantile_label(avg_loss, gate, tmp_real) else: l_worst_avoidance, l_best_choice = get_label(ind_loss, gate, real) worst_avoidance = -.5 * l_worst_avoidance * torch.log(gate) diff --git a/libcity/model/traffic_speed_prediction/TESTAM.py b/libcity/model/traffic_speed_prediction/TESTAM.py index b2f82438..91dd2e3c 100644 --- a/libcity/model/traffic_speed_prediction/TESTAM.py +++ b/libcity/model/traffic_speed_prediction/TESTAM.py @@ -251,11 +251,13 @@ class TemporalModel(nn.Module): - Notes: in the trivial traffic forecasting problem, we have total 288 = 24 * 60 / 5 (5 min interval) """ - def __init__(self, hidden_size, num_nodes, layers, dropout, in_dim=1, vocab_size=288, activation=nn.ReLU()): + def __init__(self, hidden_size, num_nodes, layers, dropout, in_dim=1, vocab_size=288, activation=nn.ReLU(), + output_dim=1): super(TemporalModel, self).__init__() self.vocab_size = vocab_size self.act = activation self.embedding = TemporalInformationEmbedding(hidden_size, vocab_size=vocab_size) + self.in_dim = in_dim self.spd_proj = nn.Linear(in_dim, hidden_size) self.spd_cat = nn.Linear(hidden_size * 2, hidden_size) # Cat speed information and TIM information @@ -271,16 +273,20 @@ def __init__(self, hidden_size, num_nodes, layers, dropout, in_dim=1, vocab_size self.attn_layers.append(SkipConnection(cp(module), cp(norm))) self.ff.append(SkipConnection(cp(ff), cp(norm))) - self.proj = nn.Linear(hidden_size, 1) + self.proj = nn.Linear(hidden_size, output_dim) def forward(self, x, speed=None): + # x cur_time_index or next_time_index B 1 T TIM = self.embedding(x) # For the traffic forecasting, we introduce learnable node features # The user may modify this node feature into meta-learning based representation, which enables the ability to adopt the model into different dataset - x_nemb = torch.einsum('btc, nc -> bntc', TIM, self.node_features) + x_nemb = torch.einsum('btc, nc -> bntc', TIM, self.node_features) # BTN 32 if speed is None: - speed = torch.zeros_like(x_nemb[..., 0]) - x_spd = self.spd_proj(speed.unsqueeze(dim=-1)) + speed = torch.zeros_like(x_nemb[..., :self.in_dim]) + else: + speed = speed.permute(0, 2, 3, 1) + # x_spd = self.spd_proj(speed.unsqueeze(dim=-1)) + x_spd = self.spd_proj(speed) x_nemb = self.spd_cat(torch.cat([x_spd, x_nemb], dim=-1)) attns = [] @@ -352,7 +358,8 @@ def forward(self, x, prev_hidden, supports): out = self.proj(self.act(x)) - return x_start - out[..., :-1], out[..., [-1]], hiddens + final_dim = -1 if self.out_dim <= 1 else -(self.out_dim - 1) + return x_start - out[..., :final_dim], out, hiddens class AttentionModel(nn.Module): @@ -560,7 +567,9 @@ def __init__(self, config, data_feature): # model dropout = config.get("dropout", 0.3) prob_mul = config.get("prob_mul", False) - in_dim = config.get("input_dim", 2) + self.feature_dim = self.data_feature.get('feature_dim') + self.ext_dim = self.data_feature.get("ext_dim") + self.input_dim = self.feature_dim - self.ext_dim self.output_dim = config.get("output_dim", 1) hidden_size = config.get("hidden_size", 32) layers = config.get("layers", 3) @@ -572,13 +581,14 @@ def __init__(self, config, data_feature): self.prob_mul = prob_mul self.supports_len = 2 - self.identity_expert = TemporalModel(hidden_size, num_nodes, in_dim=in_dim - 1, layers=layers, dropout=dropout, - vocab_size=self.vocab_size) - self.adaptive_expert = STModel(hidden_size, self.supports_len, num_nodes, in_dim=in_dim, layers=layers, - dropout=dropout) - self.attention_expert = AttentionModel(hidden_size, in_dim=in_dim, layers=layers, dropout=dropout) + self.identity_expert = TemporalModel(hidden_size, num_nodes, in_dim=self.input_dim, layers=layers, + dropout=dropout, vocab_size=self.vocab_size, output_dim=self.output_dim) + self.adaptive_expert = STModel(hidden_size, self.supports_len, num_nodes, in_dim=self.feature_dim, + layers=layers, dropout=dropout, out_dim=self.output_dim) + self.attention_expert = AttentionModel(hidden_size, in_dim=self.feature_dim, layers=layers, dropout=dropout, + out_dim=self.output_dim) - self.gate_network = MemoryGate(hidden_size, num_nodes) + self.gate_network = MemoryGate(hidden_size, num_nodes, input_dim=self.feature_dim) for model in [self.identity_expert, self.adaptive_expert, self.attention_expert]: for n, p in model.named_parameters(): @@ -597,10 +607,11 @@ def forward(self, input, gate_out=False): g2 = torch.softmax(torch.relu(torch.mm(n2, n1.T)), dim=-1) new_supports = [g1, g2] + # time_index = input[:, -self.ext_dim:, 0] # B, T time_index = input[:, -1, 0] # B, T cur_time_index = (time_index * self.vocab_size).long() next_time_index = ((time_index * self.vocab_size + 12) % self.vocab_size).long() - o_identity, h_identity = self.identity_expert(cur_time_index, input[:, 0]) + o_identity, h_identity = self.identity_expert(cur_time_index, input[:, :self.input_dim]) _, h_future = self.identity_expert(next_time_index) _, o_adaptive, h_adaptive = self.adaptive_expert(input, h_future, new_supports) @@ -612,7 +623,7 @@ def forward(self, input, gate_out=False): B, N, T, _ = o_identity.size() gate_in = [h_identity[-1], h_adaptive[-1], h_attention] gate = torch.softmax(self.gate_network(input.permute(0, 2, 3, 1), gate_in), dim=-1) - out = torch.zeros_like(o_identity).view(-1, 1) + out = torch.zeros_like(o_identity).view(-1, self.input_dim) outs = [o_identity, o_adaptive, o_attention] counts = [] @@ -622,14 +633,14 @@ def forward(self, input, gate_out=False): routes = routes.view(-1) for i in range(len(outs)): - cur_out = outs[i].view(-1, 1) + cur_out = outs[i].view(-1, self.input_dim) indices = torch.eq(routes, i).nonzero(as_tuple=True)[0] out[indices] = cur_out[indices] counts.append(len(indices)) if self.prob_mul: out = out * (route_prob_max).unsqueeze(dim=-1) - out = out.view(B, N, T, 1) + out = out.view(B, N, T, self.input_dim) out = out.permute(0, 3, 1, 2) # out: BFNT @@ -637,6 +648,7 @@ def forward(self, input, gate_out=False): # ind_out: BNTF if self.training or gate_out: return out, gate, ind_out + else: # BFNT -> BTNF return out.transpose(1, 3)