-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 5bea7aa
Showing
11 changed files
with
1,593 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
*.ckpt | ||
*.log | ||
*.pyc | ||
*.tsv | ||
*.csv | ||
*.DS_Store | ||
*.zip | ||
.vscode/ | ||
*.txt | ||
data/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# Few-shot Uncertain Relation Learning | ||
|
||
## Install | ||
Make sure your local environment has the following installed: | ||
``` | ||
Python3 | ||
Pytorch >= 1.5.0 | ||
visdom | ||
``` | ||
|
||
## Usage | ||
start visdom | ||
``` | ||
python -m visdom.server | ||
``` | ||
run experiment | ||
``` | ||
python run_exp.py | ||
``` | ||
check result in | ||
``` | ||
http://localhost:8097/ | ||
``` | ||
|
||
Data is available at : [[download]](https://drive.google.com/file/d/1_B4pvegXsjiRX3BTv4m6IoSglYVw9NrO/view?usp=sharing) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from src.model_run import Model_Run | ||
from src.args import read_args | ||
from src.input_analysis import checkflow | ||
from src.res_analysis import workflow | ||
import os | ||
|
||
if __name__ == '__main__': | ||
args = read_args() | ||
|
||
|
||
args.experiment_name = 'test_GMUC' # Experiment ID | ||
args.set_aggregator = 'GMUC' # select model ['FSRL', 'FSUKGE', 'TEST', 'ST', 'QT', 'WWW', 'GMUC'] | ||
args.datapath = './data/MAGA-PLUS-NL27K/NL27K-N3' | ||
args.eval_every = 2000 | ||
args.max_batches = 60000 | ||
args.rank_weight = 0.0 | ||
args.ae_weight = 0.0 | ||
|
||
# make Experiment dir | ||
exp_path = './Experiments/' + args.experiment_name | ||
if(os.path.exists(exp_path) == False): | ||
os.makedirs(exp_path) | ||
if(os.path.exists(exp_path + '/checkpoints') == False): | ||
os.makedirs(exp_path + '/checkpoints') | ||
|
||
# input dataset analysis | ||
# checkflow(args) | ||
|
||
# model execution | ||
model_run = Model_Run(args) | ||
model_run.train() | ||
|
||
# result analysis | ||
workflow(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import argparse | ||
|
||
def read_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--set_aggregator", default="FSRL", type=str) # model ['FSUKGE', 'FSRL'] | ||
parser.add_argument("--datapath", default="../data/NL27K-FSUKGE", type=str) | ||
parser.add_argument("--random_seed", default=1, type=int) | ||
parser.add_argument("--random_embed", default=1, type=int) | ||
parser.add_argument("--few", default=3, type=int) | ||
parser.add_argument("--test", default=0, type=int) # 0 train, 1 test | ||
parser.add_argument("--embed_model", default='ComplEx', type=str) | ||
parser.add_argument("--batch_size", default=128, type=int) | ||
parser.add_argument("--embed_dim", default=100, type=int) | ||
parser.add_argument("--dropout", default=0.5, type=float) | ||
parser.add_argument("--fine_tune", default=0, type=int) | ||
# parser.add_argument("--aggregate", default='max', type=str) | ||
parser.add_argument("--process_steps", default=2, type=int) # Queryencoder | ||
# parser.add_argument("--aggregator", default='max', type=str) | ||
parser.add_argument("--lr", default=0.01, type=float) | ||
parser.add_argument("--weight_decay", default=0, type=float) | ||
parser.add_argument("--max_neighbor", default=30, type=int) | ||
parser.add_argument("--train_few", default=1, type=int) | ||
parser.add_argument("--margin", default=5.0, type=float) | ||
parser.add_argument("--eval_every", default=5000, type=int) | ||
parser.add_argument("--max_batches", default=40000, type=int) | ||
# parser.add_argument("--prefix", default='intial', type=str) | ||
parser.add_argument("--rank_weight", default=1.0, type=float) | ||
parser.add_argument("--ae_weight", default=0.00001, type=float) | ||
parser.add_argument("--mae_weight", default=1.0, type=float) | ||
parser.add_argument("--if_GPU", default=1, type=int) | ||
parser.add_argument("--type_constrain", default=1, type=int) # if type_constrain | ||
parser.add_argument("--neg_nums", default = 1, type=int) # neg num for each train query | ||
parser.add_argument("--sim", default = 'KL', type=str) # [EL, KL] KG2E similair function | ||
parser.add_argument("--experiment_name", default = 'default', type=str) | ||
|
||
args = parser.parse_args() | ||
# args.save_path = 'models/' + args.prefix | ||
|
||
#print (args.embed_dim) | ||
|
||
# print("------arguments/parameters-------") | ||
# for k, v in vars(args).items(): | ||
# print(k + ': ' + str(v)) | ||
# print("---------------------------------") | ||
|
||
return args | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import json | ||
import random | ||
|
||
def train_generate(datapath, neg_nums , batch_size, few, symbol2id, ent2id, e1rel_e2, type_constrain): | ||
''' | ||
create train data | ||
''' | ||
# # 5_FSRL_T3-2_query | ||
# train_tasks_n0 = json.load(open(datapath + '/train_tasks_n0.json')) | ||
|
||
train_tasks = json.load(open(datapath + '/train_tasks.json')) | ||
rel2candidates = json.load(open(datapath + '/rel2candidates_all.json')) | ||
task_pool = list(train_tasks.keys()) | ||
#print (task_pool[0]) | ||
|
||
num_tasks = len(task_pool) | ||
|
||
# for query_ in train_tasks.keys(): | ||
# print len(train_tasks[query_]) | ||
# if len(train_tasks[query_]) < 4: | ||
# print len(train_tasks[query_]) | ||
|
||
print ("train data generation") | ||
|
||
rel_idx = 0 | ||
|
||
while True: | ||
if rel_idx % num_tasks == 0: | ||
random.shuffle(task_pool) | ||
query = task_pool[rel_idx % num_tasks] | ||
#print (query) | ||
rel_idx += 1 | ||
|
||
#query_rand = random.randint(0, (num_tasks - 1)) | ||
#query = task_pool[query_rand] | ||
|
||
candidates = rel2candidates[query] | ||
#print rel_idx | ||
|
||
if len(candidates) <= 20: | ||
continue | ||
|
||
train_and_test = train_tasks[query] | ||
random.shuffle(train_and_test) | ||
|
||
support_triples = train_and_test[:few] | ||
support_pairs = [[symbol2id[triple[0]], symbol2id[triple[2]], float(triple[3])] for triple in support_triples] # (h, t, s) | ||
support_left = [ent2id[triple[0]] for triple in support_triples] | ||
support_right = [ent2id[triple[2]] for triple in support_triples] | ||
|
||
all_test_triples = train_and_test[few:] | ||
|
||
# #start 5_FSRL_T3-2_query | ||
# all_test_triples_n0 = [] | ||
# train_and_test_n0 = train_tasks_n0[query] | ||
# random.shuffle(train_and_test_n0) | ||
# for triple in train_and_test_n0: | ||
# if triple not in support_triples: | ||
# all_test_triples_n0.append(triple) | ||
# all_test_triples = all_test_triples_n0 | ||
# #end 5_FSRL_T3-2_query | ||
|
||
if len(all_test_triples) == 0: | ||
continue | ||
|
||
if len(all_test_triples) < batch_size: | ||
query_triples = [random.choice(all_test_triples) for _ in range(batch_size)] | ||
else: | ||
query_triples = random.sample(all_test_triples, batch_size) | ||
|
||
query_pairs = [[symbol2id[triple[0]], symbol2id[triple[2]], float(triple[3])] for triple in query_triples] # (h, t, s) | ||
query_left = [ent2id[triple[0]] for triple in query_triples] | ||
query_right = [ent2id[triple[2]] for triple in query_triples] | ||
query_confidence = [float(triple[3]) for triple in query_triples] | ||
|
||
false_pairs = [] | ||
false_left = [] | ||
false_right = [] | ||
if not type_constrain: | ||
candidates = list(ent2id.keys()) # add all entity | ||
for triple in query_triples: | ||
for i in range(neg_nums): | ||
e_h = triple[0] | ||
rel = triple[1] | ||
e_t = triple[2] | ||
while True: | ||
noise = random.choice(candidates) | ||
if (noise not in e1rel_e2[e_h+rel]) and noise != e_t: | ||
break | ||
false_pairs.append([symbol2id[e_h], symbol2id[noise], 0.0]) # (h, t, s) | ||
false_left.append(ent2id[e_h]) | ||
false_right.append(ent2id[noise]) | ||
|
||
yield support_pairs, query_pairs, false_pairs, support_left, support_right, query_left, query_right, query_confidence, false_left, false_right | ||
|
||
|
Oops, something went wrong.