XASNet is a graph neural network (GNN) model to predict X-ray absorption spectra (XAS) of small molecules while maintaing the explainibility of the predicted spectra. It can be chosen based on different GNN architectures, i.e. GraphNet, graph convolutional neural network (GCN), multi-head graph attention network (GATv2). XASNet can be trained on datasets of 3d molecules with variable sizes composed of the first- and second row of main group elements H, C, N, O, and F. Here, we trained the GNN models on custom-generated carbon K-edge XAS dataset of 65k small organic molecules (subset of original QM9), denoted as QM9-XAS.
To explain the predictions, feature attributions are employed to determine the respective contributions of various atoms in the molecules to the peaks observed in the XAS spectrum. Here, we also developed a method which assigns the ground-truth contributions of various atoms in a molecule to a peak in the TDDFT spectrum. The developed data pipeline produces atoms labels denoting whether a particular atom conztibute to an XAS peak.
Quickstart installation and usage example is given below. Training, prediction and explainability of XAS spectra are given in the example.
To install XASNet-XAI
, download this repository and use pip.
git clone https://github.com/Amirktb1994/XASNet-XAI
conda create -n xasnet-xai numpy
conda activate xasnet-xai
pip install ./XASNet-XAI
The raw and processed QM9-XAS dataset can be downloaded form
The labels of QM9-XAS dataset was used for training, validation and test. The labels of graphs in QM9-XAS are the correponding XAS spectra for QM9 structures. Python environment can be used according to the following to prepare the QM9-XAS dataset,
from XASNet.data import QM9_XAS
# load or create QM9-XAS graph dataset
root = 'path-to-save/load-QM9-XAS-dataset'
qm9_spec = QM9_XAS(
root=root,
raw_dir='./', # path to save/load the raw data necessary to build the graph dataset
spectra=xas_spectra # XAS spectra of all structures in QM9-XAS
)
# save the dataset if it doesn't exists
if not osp.exists(root):
torch.save(qm9_spec, root)
GNNTrainer
can be used to train and validate the GNN models. It can also be used for performing XAS spectra predictions with the trained models.
from XASNet.models import XASNet_GNN
from XASNet.trainer import GNNTrainer
# load the GNN model
trainer = GNNTrainer(model=gnn_model,
model_name="model-name",
device=device,
metric_path="./metrics")
trainer.train_val(
train_loader, # train data loader
val_loader, # val data loader
optimizer, # optimizer, i.e. AdamW
loss_fn, # loss function
scheduler, # learning rate scheduler
num_epochs, # number of epochs
write_every=1, # frequency to write train/val outcome
train_graphnet=True # whether the trained model is GraphNet
)
The explainibility of XAS predictions works with orca raw output files. The OrcaAnlyser
class analyse the orca raw output to obtain excited states and orbital contribution of atoms. The Contributions
class performs final calculations to obtain ground truth atomic contribution to a peak in XAS spectrum.
from XASNet.utils import GraphDataProducer
from XASNet.utils import (
GroundTruthGenerator,
OrcaAnlyser,
Contributions
)
# loading test dataset
root = 'path-to-qm9xas-dataset'
test_qm9xas = QM9_XAS(root=root,
raw_dir='./raw/')
# picking a graph from test dataset
mol_idx = 100
graph_picker = GraphDataProducer(
model=xasnet_gnn,
gnn_type="gatv2",
test_data=test_qm9xas,
idx_to_pick=mol_idx
)
# make prediction with the loaded model
graph = graph_picker.picked_graph
y_true = graph.spectrum
x_pred, y_pred = graph_picker.predictions()
# explainability based on input graph
path_orca_output = 'path-to-orca-raw-output-file'
path_orca_spectrum = 'path-to-xas-spectrum-output-file'
orca_analyzer = OrcaAnlyser(path_orca_output,
path_orca_spectrum)
excitations = orca_analyzer.give_excitations()
contributions = Contributions(
excitations,
all_cam_data,
281,
atom_labels
)
# obtaining core/virtual contribution of atoms to the peak, in this case 281 eV
acc, don = contributions.don_acc_contrs()
# obtaining corresponding cam contributions
cam_contr = contributions.cam_contrs()