Skip to content

Commit

Permalink
Node2vec (LibCity#235)
Browse files Browse the repository at this point in the history
* node2vec

* node2vec
  • Loading branch information
aptx1231 committed Dec 18, 2021
1 parent 30573a6 commit f141038
Show file tree
Hide file tree
Showing 12 changed files with 336 additions and 16 deletions.
2 changes: 2 additions & 0 deletions contribution_list.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ For a list of all models reproduced in LibCity, see [Docs](https://bigscity-libc
||HMMM|[XBR-1111](https://github.com/XBR-1111)|
|Road Network Representation Learning|ChebConv|[aptx1231](https://github.com/aptx1231)|
||LINE|[l782993610](https://github.com/l782993610)|
||Node2Vec|[aptx1231](https://github.com/aptx1231)|
||GAT|[Qwtdgh](https://github.com/Qwtdgh)|



3 changes: 3 additions & 0 deletions libcity/config/data/RoadNetWorkDataset.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{

}
3 changes: 3 additions & 0 deletions libcity/config/executor/GensimExecutor.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{

}
11 changes: 11 additions & 0 deletions libcity/config/model/road_representation/Node2Vec.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"output_dim": 128,
"is_directed": true,
"p": 2,
"q": 1,
"num_walks": 10,
"walk_length": 80,
"window_size": 10,
"num_workers": 10,
"max_epoch": 1000
}
7 changes: 6 additions & 1 deletion libcity/config/task_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@
},
"road_representation": {
"allowed_model": [
"ChebConv", "LINE", "GAT"
"ChebConv", "LINE", "GAT", "Node2Vec"
],
"allowed_dataset": [
"BJ_roadmap"
Expand All @@ -399,6 +399,11 @@
"dataset_class": "ChebConvDataset",
"executor": "ChebConvExecutor",
"evaluator": "RoadRepresentationEvaluator"
},
"Node2Vec": {
"dataset_class": "RoadNetWorkDataset",
"executor": "GensimExecutor",
"evaluator": "RoadRepresentationEvaluator"
}
}
}
4 changes: 3 additions & 1 deletion libcity/data/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from libcity.data.dataset.gsnet_dataset import GSNetDataset
from libcity.data.dataset.line_dataset import LINEDataset
from libcity.data.dataset.cstn_dataset import CSTNDataset
from libcity.data.dataset.roadnetwork_dataset import RoadNetWorkDataset

__all__ = [
"AbstractDataset",
Expand Down Expand Up @@ -67,5 +68,6 @@
'ChebConvDataset',
"GSNetDataset",
"LINEDataset",
"CSTNDataset"
"CSTNDataset",
"RoadNetWorkDataset"
]
29 changes: 29 additions & 0 deletions libcity/data/dataset/roadnetwork_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os
from libcity.data.dataset import TrafficStateDataset


class RoadNetWorkDataset(TrafficStateDataset):
def __init__(self, config):
self.config = config
self.dataset = self.config.get('dataset', '')
self.data_path = './raw_data/' + self.dataset + '/'
self.geo_file = self.config.get('geo_file', self.dataset)
self.rel_file = self.config.get('rel_file', self.dataset)
assert os.path.exists(self.data_path + self.geo_file + '.geo')
assert os.path.exists(self.data_path + self.rel_file + '.rel')
super().__init__(config)

def get_data(self):
"""
返回数据的DataLoader,此类只负责返回路网结构adj_mx,而adj_mx在data_feature中,这里什么都不返回
"""
return None, None, None

def get_data_feature(self):
"""
返回一个 dict,包含数据集的相关特征
Returns:
dict: 包含数据集的相关特征的字典
"""
return {"adj_mx": self.adj_mx, "num_nodes": self.num_nodes}
33 changes: 21 additions & 12 deletions libcity/evaluator/road_representation_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def evaluate(self):
k_means.fit(node_emb)
y_predict = k_means.predict(node_emb)

rid_pos = self._load_geo()
rid_file = self._load_geo()
# 记录每个类别都有哪些geo实体
result_token = dict()
for i in range(len(y_predict)):
Expand All @@ -60,18 +60,27 @@ def evaluate(self):
self._logger.info('Kmeans category is saved at {}'.format(result_path))

# QGIS可视化
rid_pos = rid_pos['coordinates']
rid_type = rid_file['type'][0]
rid_pos = rid_file['coordinates']
rid2wkt = dict()
for i in range(rid_pos.shape[0]):
rid_list = eval(rid_pos[i])
wkt_str = 'LINESTRING('
for j in range(len(rid_list)):
rid = rid_list[j]
wkt_str += (str(rid[0]) + ' ' + str(rid[1]))
if j != len(rid_list) - 1:
wkt_str += ','
wkt_str += ')'
rid2wkt[i] = wkt_str
if rid_type == 'LineString':
for i in range(rid_pos.shape[0]):
rid_list = eval(rid_pos[i]) # [(lat1, lon1), (lat2, lon2)...]
wkt_str = 'LINESTRING('
for j in range(len(rid_list)):
rid = rid_list[j]
wkt_str += (str(rid[0]) + ' ' + str(rid[1]))
if j != len(rid_list) - 1:
wkt_str += ','
wkt_str += ')'
rid2wkt[i] = wkt_str
elif rid_type == 'Point':
for i in range(rid_pos.shape[0]):
rid_list = eval(rid_pos[i]) # [lat1, lon1]
wkt_str = 'Point({} {})'.format(rid_list[0], rid_list[1])
rid2wkt[i] = wkt_str
else:
raise ValueError('Error geo type!')

df = []
for i in range(len(y_predict)):
Expand Down
4 changes: 3 additions & 1 deletion libcity/executor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from libcity.executor.abstract_tradition_executor import AbstractTraditionExecutor
from libcity.executor.chebconv_executor import ChebConvExecutor
from libcity.executor.eta_executor import ETAExecutor
from libcity.executor.gensim_executor import GensimExecutor

__all__ = [
"TrajLocPredExecutor",
Expand All @@ -23,5 +24,6 @@
"AbstractTraditionExecutor",
"ChebConvExecutor",
"LINEExecutor",
"ETAExecutor"
"ETAExecutor",
"GensimExecutor"
]
33 changes: 33 additions & 0 deletions libcity/executor/gensim_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from libcity.utils import get_evaluator, ensure_dir
from libcity.executor.abstract_executor import AbstractExecutor


class GensimExecutor(AbstractExecutor):
def __init__(self, config, model):
self.evaluator = get_evaluator(config)
self.config = config
self.model = model
self.exp_id = config.get('exp_id', None)

self.cache_dir = './libcity/cache/{}/model_cache'.format(self.exp_id)
self.evaluate_res_dir = './libcity/cache/{}/evaluate_cache'.format(self.exp_id)
ensure_dir(self.cache_dir)
ensure_dir(self.evaluate_res_dir)

def evaluate(self, test_dataloader):
"""
use model to test data
"""
self.evaluator.evaluate()

def train(self, train_dataloader, eval_dataloader):
"""
use data to train model with config
"""
self.model.run()

def load_model(self, cache_name):
pass

def save_model(self, cache_name):
pass
Loading

0 comments on commit f141038

Please sign in to comment.