Skip to content

Commit

Permalink
init update
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangjiatao committed Nov 5, 2020
0 parents commit 5bea7aa
Show file tree
Hide file tree
Showing 11 changed files with 1,593 additions and 0 deletions.
10 changes: 10 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
*.ckpt
*.log
*.pyc
*.tsv
*.csv
*.DS_Store
*.zip
.vscode/
*.txt
data/
25 changes: 25 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Few-shot Uncertain Relation Learning

## Install
Make sure your local environment has the following installed:
```
Python3
Pytorch >= 1.5.0
visdom
```

## Usage
start visdom
```
python -m visdom.server
```
run experiment
```
python run_exp.py
```
check result in
```
http://localhost:8097/
```

Data is available at : [[download]](https://drive.google.com/file/d/1_B4pvegXsjiRX3BTv4m6IoSglYVw9NrO/view?usp=sharing)
34 changes: 34 additions & 0 deletions run_exp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from src.model_run import Model_Run
from src.args import read_args
from src.input_analysis import checkflow
from src.res_analysis import workflow
import os

if __name__ == '__main__':
args = read_args()


args.experiment_name = 'test_GMUC' # Experiment ID
args.set_aggregator = 'GMUC' # select model ['FSRL', 'FSUKGE', 'TEST', 'ST', 'QT', 'WWW', 'GMUC']
args.datapath = './data/MAGA-PLUS-NL27K/NL27K-N3'
args.eval_every = 2000
args.max_batches = 60000
args.rank_weight = 0.0
args.ae_weight = 0.0

# make Experiment dir
exp_path = './Experiments/' + args.experiment_name
if(os.path.exists(exp_path) == False):
os.makedirs(exp_path)
if(os.path.exists(exp_path + '/checkpoints') == False):
os.makedirs(exp_path + '/checkpoints')

# input dataset analysis
# checkflow(args)

# model execution
model_run = Model_Run(args)
model_run.train()

# result analysis
workflow(args)
47 changes: 47 additions & 0 deletions src/args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import argparse

def read_args():
parser = argparse.ArgumentParser()
parser.add_argument("--set_aggregator", default="FSRL", type=str) # model ['FSUKGE', 'FSRL']
parser.add_argument("--datapath", default="../data/NL27K-FSUKGE", type=str)
parser.add_argument("--random_seed", default=1, type=int)
parser.add_argument("--random_embed", default=1, type=int)
parser.add_argument("--few", default=3, type=int)
parser.add_argument("--test", default=0, type=int) # 0 train, 1 test
parser.add_argument("--embed_model", default='ComplEx', type=str)
parser.add_argument("--batch_size", default=128, type=int)
parser.add_argument("--embed_dim", default=100, type=int)
parser.add_argument("--dropout", default=0.5, type=float)
parser.add_argument("--fine_tune", default=0, type=int)
# parser.add_argument("--aggregate", default='max', type=str)
parser.add_argument("--process_steps", default=2, type=int) # Queryencoder
# parser.add_argument("--aggregator", default='max', type=str)
parser.add_argument("--lr", default=0.01, type=float)
parser.add_argument("--weight_decay", default=0, type=float)
parser.add_argument("--max_neighbor", default=30, type=int)
parser.add_argument("--train_few", default=1, type=int)
parser.add_argument("--margin", default=5.0, type=float)
parser.add_argument("--eval_every", default=5000, type=int)
parser.add_argument("--max_batches", default=40000, type=int)
# parser.add_argument("--prefix", default='intial', type=str)
parser.add_argument("--rank_weight", default=1.0, type=float)
parser.add_argument("--ae_weight", default=0.00001, type=float)
parser.add_argument("--mae_weight", default=1.0, type=float)
parser.add_argument("--if_GPU", default=1, type=int)
parser.add_argument("--type_constrain", default=1, type=int) # if type_constrain
parser.add_argument("--neg_nums", default = 1, type=int) # neg num for each train query
parser.add_argument("--sim", default = 'KL', type=str) # [EL, KL] KG2E similair function
parser.add_argument("--experiment_name", default = 'default', type=str)

args = parser.parse_args()
# args.save_path = 'models/' + args.prefix

#print (args.embed_dim)

# print("------arguments/parameters-------")
# for k, v in vars(args).items():
# print(k + ': ' + str(v))
# print("---------------------------------")

return args

96 changes: 96 additions & 0 deletions src/data_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import json
import random

def train_generate(datapath, neg_nums , batch_size, few, symbol2id, ent2id, e1rel_e2, type_constrain):
'''
create train data
'''
# # 5_FSRL_T3-2_query
# train_tasks_n0 = json.load(open(datapath + '/train_tasks_n0.json'))

train_tasks = json.load(open(datapath + '/train_tasks.json'))
rel2candidates = json.load(open(datapath + '/rel2candidates_all.json'))
task_pool = list(train_tasks.keys())
#print (task_pool[0])

num_tasks = len(task_pool)

# for query_ in train_tasks.keys():
# print len(train_tasks[query_])
# if len(train_tasks[query_]) < 4:
# print len(train_tasks[query_])

print ("train data generation")

rel_idx = 0

while True:
if rel_idx % num_tasks == 0:
random.shuffle(task_pool)
query = task_pool[rel_idx % num_tasks]
#print (query)
rel_idx += 1

#query_rand = random.randint(0, (num_tasks - 1))
#query = task_pool[query_rand]

candidates = rel2candidates[query]
#print rel_idx

if len(candidates) <= 20:
continue

train_and_test = train_tasks[query]
random.shuffle(train_and_test)

support_triples = train_and_test[:few]
support_pairs = [[symbol2id[triple[0]], symbol2id[triple[2]], float(triple[3])] for triple in support_triples] # (h, t, s)
support_left = [ent2id[triple[0]] for triple in support_triples]
support_right = [ent2id[triple[2]] for triple in support_triples]

all_test_triples = train_and_test[few:]

# #start 5_FSRL_T3-2_query
# all_test_triples_n0 = []
# train_and_test_n0 = train_tasks_n0[query]
# random.shuffle(train_and_test_n0)
# for triple in train_and_test_n0:
# if triple not in support_triples:
# all_test_triples_n0.append(triple)
# all_test_triples = all_test_triples_n0
# #end 5_FSRL_T3-2_query

if len(all_test_triples) == 0:
continue

if len(all_test_triples) < batch_size:
query_triples = [random.choice(all_test_triples) for _ in range(batch_size)]
else:
query_triples = random.sample(all_test_triples, batch_size)

query_pairs = [[symbol2id[triple[0]], symbol2id[triple[2]], float(triple[3])] for triple in query_triples] # (h, t, s)
query_left = [ent2id[triple[0]] for triple in query_triples]
query_right = [ent2id[triple[2]] for triple in query_triples]
query_confidence = [float(triple[3]) for triple in query_triples]

false_pairs = []
false_left = []
false_right = []
if not type_constrain:
candidates = list(ent2id.keys()) # add all entity
for triple in query_triples:
for i in range(neg_nums):
e_h = triple[0]
rel = triple[1]
e_t = triple[2]
while True:
noise = random.choice(candidates)
if (noise not in e1rel_e2[e_h+rel]) and noise != e_t:
break
false_pairs.append([symbol2id[e_h], symbol2id[noise], 0.0]) # (h, t, s)
false_left.append(ent2id[e_h])
false_right.append(ent2id[noise])

yield support_pairs, query_pairs, false_pairs, support_left, support_right, query_left, query_right, query_confidence, false_left, false_right


Loading

0 comments on commit 5bea7aa

Please sign in to comment.