Skip to content

Commit

Permalink
fix: TESTAM adapt to the multiple feature case (LibCity#433)
Browse files Browse the repository at this point in the history
* fix: TESTAM remove input_dim and add multi-feature

* fix: TESTAM adapt to the multiple feature case
  • Loading branch information
hczs committed Jul 12, 2024
1 parent 6265868 commit 8c367da
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 21 deletions.
1 change: 0 additions & 1 deletion libcity/config/model/traffic_state_pred/TESTAM.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

"dropout": 0.0,
"prob_mul": false,
"input_dim": 2,
"hidden_size": 32,
"layers": 3,
"is_quantile": true,
Expand Down
11 changes: 8 additions & 3 deletions libcity/executor/testam_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,20 @@ def get_label(ind_loss, gate, real):

real = batch['y']
out, gate, res = self.model.predict(batch)
predict = self._scaler.inverse_transform(out[..., :self.output_dim])
out = out.transpose(1, 3)
predict = self._scaler.inverse_transform(out[..., :self.output_dim]).transpose(1, 3)
real = self._scaler.inverse_transform(real)
# BTNF -> BFNT
real = real[..., :self.output_dim].transpose(1, 3)
ind_loss = loss.masked_mae_torch(self._scaler.inverse_transform(res), real.permute(0, 2, 3, 1), 0.0,
tmp_real = real.permute(0, 2, 3, 1)
ind_loss_real = torch.cat([tmp_real, tmp_real, tmp_real], dim=-1) if self.output_dim > 1 else tmp_real
ind_loss = loss.masked_mae_torch(self._scaler.inverse_transform(res), ind_loss_real, 0.0,
reduce=False)
if self.is_quantile:
gated_loss = loss.masked_mae_torch(predict, real, reduce=False).permute(0, 2, 3, 1)
l_worst_avoidance, l_best_choice = get_quantile_label(gated_loss, gate, real)
avg_loss = gated_loss.mean(dim=-1, keepdim=True) # 平均损失
tmp_real = real.mean(dim=1, keepdim=True)
l_worst_avoidance, l_best_choice = get_quantile_label(avg_loss, gate, tmp_real)
else:
l_worst_avoidance, l_best_choice = get_label(ind_loss, gate, real)
worst_avoidance = -.5 * l_worst_avoidance * torch.log(gate)
Expand Down
46 changes: 29 additions & 17 deletions libcity/model/traffic_speed_prediction/TESTAM.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,13 @@ class TemporalModel(nn.Module):
- Notes: in the trivial traffic forecasting problem, we have total 288 = 24 * 60 / 5 (5 min interval)
"""

def __init__(self, hidden_size, num_nodes, layers, dropout, in_dim=1, vocab_size=288, activation=nn.ReLU()):
def __init__(self, hidden_size, num_nodes, layers, dropout, in_dim=1, vocab_size=288, activation=nn.ReLU(),
output_dim=1):
super(TemporalModel, self).__init__()
self.vocab_size = vocab_size
self.act = activation
self.embedding = TemporalInformationEmbedding(hidden_size, vocab_size=vocab_size)
self.in_dim = in_dim
self.spd_proj = nn.Linear(in_dim, hidden_size)
self.spd_cat = nn.Linear(hidden_size * 2, hidden_size) # Cat speed information and TIM information

Expand All @@ -271,16 +273,20 @@ def __init__(self, hidden_size, num_nodes, layers, dropout, in_dim=1, vocab_size
self.attn_layers.append(SkipConnection(cp(module), cp(norm)))
self.ff.append(SkipConnection(cp(ff), cp(norm)))

self.proj = nn.Linear(hidden_size, 1)
self.proj = nn.Linear(hidden_size, output_dim)

def forward(self, x, speed=None):
# x cur_time_index or next_time_index B 1 T
TIM = self.embedding(x)
# For the traffic forecasting, we introduce learnable node features
# The user may modify this node feature into meta-learning based representation, which enables the ability to adopt the model into different dataset
x_nemb = torch.einsum('btc, nc -> bntc', TIM, self.node_features)
x_nemb = torch.einsum('btc, nc -> bntc', TIM, self.node_features) # BTN 32
if speed is None:
speed = torch.zeros_like(x_nemb[..., 0])
x_spd = self.spd_proj(speed.unsqueeze(dim=-1))
speed = torch.zeros_like(x_nemb[..., :self.in_dim])
else:
speed = speed.permute(0, 2, 3, 1)
# x_spd = self.spd_proj(speed.unsqueeze(dim=-1))
x_spd = self.spd_proj(speed)
x_nemb = self.spd_cat(torch.cat([x_spd, x_nemb], dim=-1))

attns = []
Expand Down Expand Up @@ -352,7 +358,8 @@ def forward(self, x, prev_hidden, supports):

out = self.proj(self.act(x))

return x_start - out[..., :-1], out[..., [-1]], hiddens
final_dim = -1 if self.out_dim <= 1 else -(self.out_dim - 1)
return x_start - out[..., :final_dim], out, hiddens


class AttentionModel(nn.Module):
Expand Down Expand Up @@ -560,7 +567,9 @@ def __init__(self, config, data_feature):
# model
dropout = config.get("dropout", 0.3)
prob_mul = config.get("prob_mul", False)
in_dim = config.get("input_dim", 2)
self.feature_dim = self.data_feature.get('feature_dim')
self.ext_dim = self.data_feature.get("ext_dim")
self.input_dim = self.feature_dim - self.ext_dim
self.output_dim = config.get("output_dim", 1)
hidden_size = config.get("hidden_size", 32)
layers = config.get("layers", 3)
Expand All @@ -572,13 +581,14 @@ def __init__(self, config, data_feature):
self.prob_mul = prob_mul
self.supports_len = 2

self.identity_expert = TemporalModel(hidden_size, num_nodes, in_dim=in_dim - 1, layers=layers, dropout=dropout,
vocab_size=self.vocab_size)
self.adaptive_expert = STModel(hidden_size, self.supports_len, num_nodes, in_dim=in_dim, layers=layers,
dropout=dropout)
self.attention_expert = AttentionModel(hidden_size, in_dim=in_dim, layers=layers, dropout=dropout)
self.identity_expert = TemporalModel(hidden_size, num_nodes, in_dim=self.input_dim, layers=layers,
dropout=dropout, vocab_size=self.vocab_size, output_dim=self.output_dim)
self.adaptive_expert = STModel(hidden_size, self.supports_len, num_nodes, in_dim=self.feature_dim,
layers=layers, dropout=dropout, out_dim=self.output_dim)
self.attention_expert = AttentionModel(hidden_size, in_dim=self.feature_dim, layers=layers, dropout=dropout,
out_dim=self.output_dim)

self.gate_network = MemoryGate(hidden_size, num_nodes)
self.gate_network = MemoryGate(hidden_size, num_nodes, input_dim=self.feature_dim)

for model in [self.identity_expert, self.adaptive_expert, self.attention_expert]:
for n, p in model.named_parameters():
Expand All @@ -597,10 +607,11 @@ def forward(self, input, gate_out=False):
g2 = torch.softmax(torch.relu(torch.mm(n2, n1.T)), dim=-1)
new_supports = [g1, g2]

# time_index = input[:, -self.ext_dim:, 0] # B, T
time_index = input[:, -1, 0] # B, T
cur_time_index = (time_index * self.vocab_size).long()
next_time_index = ((time_index * self.vocab_size + 12) % self.vocab_size).long()
o_identity, h_identity = self.identity_expert(cur_time_index, input[:, 0])
o_identity, h_identity = self.identity_expert(cur_time_index, input[:, :self.input_dim])
_, h_future = self.identity_expert(next_time_index)

_, o_adaptive, h_adaptive = self.adaptive_expert(input, h_future, new_supports)
Expand All @@ -612,7 +623,7 @@ def forward(self, input, gate_out=False):
B, N, T, _ = o_identity.size()
gate_in = [h_identity[-1], h_adaptive[-1], h_attention]
gate = torch.softmax(self.gate_network(input.permute(0, 2, 3, 1), gate_in), dim=-1)
out = torch.zeros_like(o_identity).view(-1, 1)
out = torch.zeros_like(o_identity).view(-1, self.input_dim)

outs = [o_identity, o_adaptive, o_attention]
counts = []
Expand All @@ -622,21 +633,22 @@ def forward(self, input, gate_out=False):
routes = routes.view(-1)

for i in range(len(outs)):
cur_out = outs[i].view(-1, 1)
cur_out = outs[i].view(-1, self.input_dim)
indices = torch.eq(routes, i).nonzero(as_tuple=True)[0]
out[indices] = cur_out[indices]
counts.append(len(indices))
if self.prob_mul:
out = out * (route_prob_max).unsqueeze(dim=-1)

out = out.view(B, N, T, 1)
out = out.view(B, N, T, self.input_dim)

out = out.permute(0, 3, 1, 2)
# out: BFNT
# gate: BNTF
# ind_out: BNTF
if self.training or gate_out:
return out, gate, ind_out

else:
# BFNT -> BTNF
return out.transpose(1, 3)
Expand Down

0 comments on commit 8c367da

Please sign in to comment.