-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
42 lines (34 loc) · 2.29 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
from dig.threedgraph.dataset import MD17
from dig.threedgraph.evaluation import ThreeDEvaluator
from utils.run import run
from methods.q_schnet.q_schnet import Q_SchNet
from methods.q_dimenetpp.q_dimenetpp import Q_DimeNetPP
from methods.q_spherenet.q_spherenet import Q_SphereNet
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device("cpu")
dataset_md17 = MD17(root='dataset/', name='md17_benzene2017')
split_idx_md17 = dataset_md17.get_idx_split(len(dataset_md17.data.y), train_size=1000, valid_size=1000, seed=42)
train_dataset_md17, valid_dataset_md17, test_dataset_md17 = dataset_md17[split_idx_md17['train']], dataset_md17[split_idx_md17['valid']], dataset_md17[split_idx_md17['test']]
print('train, validaion, test:', len(train_dataset_md17), len(valid_dataset_md17), len(test_dataset_md17))
model_md17 = Q_SchNet(energy_and_force=True, cutoff=10.0, num_layers=6, hidden_channels=128, out_channels=1, num_filters=128, num_gaussians=50)
# model_md17 = Q_DimeNetPP(energy_and_force=True, cutoff=5.0, num_layers=4,
# hidden_channels=128, out_channels=1, int_emb_size=64, basis_emb_size=8, out_emb_channels=256,
# num_spherical=3, num_radial=6, envelope_exponent=5,
# num_before_skip=1, num_after_skip=2, num_output_layers=3, num=-1, n=30, sort_by_angle=False
# )
# model_md17 = Q_SphereNet(energy_and_force=True, cutoff=5.0, num_layers=4,
# hidden_channels=128, out_channels=1, int_emb_size=64,
# basis_emb_size_dist=8, basis_emb_size_angle=8, basis_emb_size_torsion=8, out_emb_channels=256,
# num_spherical=3, num_radial=6, envelope_exponent=5,
# num_before_skip=1, num_after_skip=2, num_output_layers=3, n=30, sort_by_angle=True, num=-1
# )
loss_func_md17 = torch.nn.L1Loss()
evaluation_md17 = ThreeDEvaluator()
checkpoint = None
PATH = '' # the path of checkpoint.pt
checkpoint = torch.load(PATH)
model_md17.load_state_dict(checkpoint['model_state_dict'])
run3d_md17 = run()
run3d_md17 = run()
run3d_md17.run(device, train_dataset_md17, valid_dataset_md17, test_dataset_md17, model_md17, loss_func_md17, evaluation_md17, epochs=500, batch_size=1, vt_batch_size=32, lr=0.0005,
lr_decay_factor=0.5, lr_decay_step_size=100, energy_and_force=True, save_dir='', log_dir='', only_test=True, checkpoint=checkpoint)