forked from LLNL/lbann
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add RoBERTa base model to applications (LLNL#1999)
* initial commit of new functions * improvments and fixes based on validation sbert model * added proper gelu calculation * formatted with black * formatted with black * fixes to imports * added name options * added RoBERTa base model from YubNub work * auto-formatted with black * incorporated feedback, added laoding of pretrained weights * bug fixes and corrections after validating against pytorch model * replaced custom functions with newly added PFE functions * added code to run roberta model * added script to download and process pretrained weights * minor changes * added README and small changes to clean up cod * Update python/lbann/modules/__init__.py Co-authored-by: Tim Moon <moon13@llnl.gov> * Update applications/nlp/RoBERTa/roberta.py Co-authored-by: Tim Moon <moon13@llnl.gov> * synthetic dataset is now 1024 samples * minor revision based on feedback * added missing position_ids init function * Update applications/nlp/RoBERTa/roberta.py Co-authored-by: Tim Moon <moon13@llnl.gov> * added missing reshape for tesselate layer Co-authored-by: Tim Moon <moon13@llnl.gov>
- Loading branch information
Showing
5 changed files
with
993 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# RoBERTa | ||
|
||
This directory contains and LBANN implementation of an optimized version of the | ||
BERT model, RoBERTa. This implementation is based on and validated against the | ||
[HuggingFace PyTorch RoBERTa | ||
implementation](https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/roberta/modeling_roberta.py#L695). | ||
|
||
## Dependencies | ||
|
||
- PyTorch | ||
|
||
## Usage | ||
|
||
You will need to run the `get_model_config.py` script to download the model | ||
configuration file and pretrained weights from the HuggingFace repository. By | ||
default, the RoBERTa model will load pretrained weights provided in the | ||
HuggingFace implementation. If you want to train the model from scratch, | ||
without loading pretrained weights, then run `get_model_config.py | ||
--no-weights`: | ||
|
||
```bash | ||
# Download config and pretrained weights | ||
python3 get_model_config.py | ||
|
||
# Download just config | ||
python3 get_model_config.py --no-weights | ||
``` | ||
|
||
The directory should now contain a `config.json` file and optionally, | ||
`pytorch_model.bin` and `pretrained_weights/`. Modifying values in | ||
`config.json` will change the parameters used to build the RoBERTa model. | ||
|
||
An example of how to run the model is provided in `main.py` and a synthetic | ||
dataset is provided in `dataset.py`. Run the example with: | ||
|
||
```bash | ||
python3 main.py --nodes 1 --procs-per-node 2 --time-limit 60 --partition pbatch --epochs 5 --mini-batch-size 10 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import numpy as np | ||
|
||
data = np.random.randint(50265, size=(1024,513)) | ||
data[:,0] = data[:,0] % 10 | ||
|
||
def get_sample(i): | ||
vals = data[i] | ||
return vals.flatten().astype(np.float32) | ||
|
||
def num_samples(): | ||
return 1024 | ||
|
||
def sample_dims(): | ||
return (513,) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import sys | ||
import os | ||
import warnings | ||
import itertools | ||
import time | ||
import glob | ||
import urllib.request | ||
import argparse | ||
|
||
import numpy as np | ||
import torch | ||
|
||
files = { | ||
"config.json": "https://huggingface.co/sentence-transformers/paraphrase-distilroberta-base-v1/resolve/main/config.json", | ||
"pytorch_model.bin": "https://huggingface.co/sentence-transformers/paraphrase-distilroberta-base-v1/resolve/main/pytorch_model.bin", | ||
} | ||
weights_dir = "pretrained_weights" | ||
|
||
|
||
def download_file(url, fn): | ||
def report_hook(count, block_size, total_size): | ||
duration = int(time.time() - start_time) | ||
progress_size = int(count * block_size / (1024 ** 2)) | ||
percent = min(int(count * block_size * 100 / total_size), 100) | ||
prog_bar = "|" + "#" * int(percent / 2) + "-" * (50 - int(percent / 2)) + "|" | ||
sys.stdout.write( | ||
f"\r{prog_bar} {percent}%, {progress_size} MB, {duration}s elapsed" | ||
) | ||
sys.stdout.flush() | ||
|
||
if os.path.exists(fn): | ||
warnings.warn(f"File '{fn}' already exists, skipping download") | ||
else: | ||
print(f"\n\nDownloading {fn} from {url}\n") | ||
start_time = time.time() | ||
urllib.request.urlretrieve(url, fn, report_hook) | ||
|
||
|
||
def extract_weights(model, weights_dir): | ||
for name, weights in model.items(): | ||
weights = np.array(weights).astype(np.float32) | ||
np.save(f"./{weights_dir}/{name}.npy", weights) | ||
|
||
|
||
def process_weights(weights_dir): | ||
# Combine layernorm weights and bias to single file | ||
layernorm_files = glob.glob(f"./{weights_dir}/*LayerNorm*.npy") | ||
layernorm_groups = {} | ||
for fn in layernorm_files: | ||
base_fn = fn.split(".LayerNorm")[0] | ||
if base_fn in layernorm_groups: | ||
layernorm_groups[base_fn].append(fn) | ||
else: | ||
layernorm_groups[base_fn] = [fn] | ||
|
||
for base_fn, fns in layernorm_groups.items(): | ||
weight_fn = [fn for fn in fns if "weight.npy" in fn][0] | ||
bias_fn = [fn for fn in fns if "bias.npy" in fn][0] | ||
|
||
weight_bias_vals = np.stack([np.load(weight_fn), np.load(bias_fn)]).T.copy() | ||
np.save(f"{base_fn}.layernorm.weightbias.npy", weight_bias_vals) | ||
|
||
# Transpose embedding layer weights | ||
embed_files = [ | ||
glob.glob(f"{weights_dir}/{e}.npy") | ||
for e in ( | ||
"*position_embeddings*", | ||
"*token_type_embeddings*", | ||
"*word_embeddings*", | ||
) | ||
] | ||
embed_files = itertools.chain(*embed_files) | ||
for fn in embed_files: | ||
np.save(fn, np.load(fn).T.copy()) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--no-weights', action='store_true', help='avoids downloading model weights') | ||
args = parser.parse_args() | ||
|
||
if args.no_weights: | ||
del files['pytorch_model.bin'] | ||
|
||
"""Download model from huggingface""" | ||
for fn, url in files.items(): | ||
download_file(url, fn) | ||
|
||
if not args.no_weights: | ||
""" Extract weights """ | ||
if not os.path.exists(weights_dir): | ||
os.makedirs(weights_dir) | ||
model = torch.load("pytorch_model.bin", map_location="cpu") | ||
extract_weights(model, weights_dir) | ||
|
||
""" Process weights for loading into LBANN """ | ||
process_weights(weights_dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,226 @@ | ||
from types import SimpleNamespace | ||
import argparse | ||
import os | ||
import sys | ||
import json | ||
import numpy as np | ||
|
||
import lbann | ||
from lbann.util import str_list | ||
import lbann.contrib.args | ||
import lbann.contrib.launcher | ||
|
||
from roberta import RobertaModel | ||
|
||
# ---------------------------------------------- | ||
# Options | ||
# ---------------------------------------------- | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--epochs", | ||
default=10, | ||
type=int, | ||
help="number of epochs to train", | ||
) | ||
parser.add_argument( | ||
"--mini-batch-size", | ||
default=32, | ||
type=int, | ||
help="size of minibatches for training", | ||
) | ||
parser.add_argument( | ||
"--job-name", | ||
action="store", | ||
default="lbann_RoBERTa", | ||
type=str, | ||
help="scheduler job name", | ||
metavar="NAME", | ||
) | ||
parser.add_argument( | ||
"--work-dir", | ||
action="store", | ||
default=None, | ||
type=str, | ||
help="working directory", | ||
metavar="DIR", | ||
) | ||
parser.add_argument("--batch-job", action="store_true", help="submit as batch job") | ||
parser.add_argument( | ||
"--checkpoint", action="store_true", help="checkpoint trainer after every epoch" | ||
) | ||
lbann.contrib.args.add_scheduler_arguments(parser) | ||
lbann_params = parser.parse_args() | ||
|
||
# ---------------------------------------------- | ||
# Data Reader | ||
# ---------------------------------------------- | ||
def make_data_reader(): | ||
reader = lbann.reader_pb2.DataReader() | ||
|
||
# Train data reader | ||
_reader = reader.reader.add() | ||
_reader.name = "python" | ||
_reader.role = "train" | ||
_reader.shuffle = True | ||
_reader.percent_of_data_to_use = 1.0 | ||
_reader.python.module = "dataset" | ||
_reader.python.module_dir = os.path.dirname(os.path.realpath(__file__)) | ||
_reader.python.sample_function = "get_sample" | ||
_reader.python.num_samples_function = "num_samples" | ||
_reader.python.sample_dims_function = "sample_dims" | ||
|
||
# Validation data reader | ||
_reader = reader.reader.add() | ||
_reader.name = "python" | ||
_reader.role = "validate" | ||
_reader.shuffle = False | ||
_reader.percent_of_data_to_use = 1.0 | ||
_reader.python.module = "dataset" | ||
_reader.python.module_dir = os.path.dirname(os.path.realpath(__file__)) | ||
_reader.python.sample_function = "get_sample" | ||
_reader.python.num_samples_function = "num_samples" | ||
_reader.python.sample_dims_function = "sample_dims" | ||
|
||
# Test data reader | ||
_reader = reader.reader.add() | ||
_reader.name = "python" | ||
_reader.role = "test" | ||
_reader.shuffle = False | ||
_reader.percent_of_data_to_use = 1.0 | ||
_reader.python.module = "dataset" | ||
_reader.python.module_dir = os.path.dirname(os.path.realpath(__file__)) | ||
_reader.python.sample_function = "get_sample" | ||
_reader.python.num_samples_function = "num_samples" | ||
_reader.python.sample_dims_function = "sample_dims" | ||
|
||
return reader | ||
|
||
|
||
# ---------------------------------------------- | ||
# Loss | ||
# ---------------------------------------------- | ||
class CrossEntropyLoss(lbann.modules.Module): | ||
"""Cross-entropy loss for classification. | ||
Given an input vector x, weight matrix W, and label y: | ||
L = -log( softmax(W*x) * onehot(y) ) | ||
Args: | ||
num_classes (int): Number of class. | ||
weights (lbann.Weights): Matrix with dimensions of | ||
num_classes x input_size. Each row is an embedding vector | ||
for the corresponding class. | ||
data_layout (str): Data layout of fully-connected layer. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
num_classes, | ||
weights=[], | ||
data_layout="data_parallel", | ||
): | ||
self.num_classes = num_classes | ||
self.data_layout = data_layout | ||
self.fc = lbann.modules.FullyConnectedModule( | ||
self.num_classes, | ||
weights=weights, | ||
bias=False, | ||
activation=lbann.LogSoftmax, | ||
name="class_fc", | ||
data_layout=self.data_layout, | ||
) | ||
|
||
def forward(self, x, label): | ||
"""Compute cross-entropy loss. | ||
Args: | ||
x (lbann.Layer): Input vector. | ||
label (lbann.Layer): Label. Should have one entry, which | ||
will be cast to an integer. | ||
Returns: | ||
lbann.Layer: Loss function value. | ||
""" | ||
log_probs = self.fc(x) | ||
label_onehot = lbann.OneHot( | ||
label, | ||
size=self.num_classes, | ||
data_layout=self.data_layout, | ||
) | ||
loss = lbann.Multiply( | ||
log_probs, | ||
label_onehot, | ||
data_layout=self.data_layout, | ||
) | ||
loss = lbann.Reduction( | ||
loss, | ||
mode="sum", | ||
data_layout=self.data_layout, | ||
) | ||
loss = lbann.Negative(loss, data_layout=self.data_layout) | ||
return loss | ||
|
||
|
||
# ---------------------------------------------- | ||
# Build and Run Model | ||
# ---------------------------------------------- | ||
with open("./config.json") as f: | ||
config = json.load(f, object_hook=lambda d: SimpleNamespace(**d)) | ||
config.input_shape = (16, 32) | ||
config.load_weights = os.path.exists('./pretrained_weights') | ||
|
||
# Construct the model | ||
input_ = lbann.Slice( | ||
lbann.Input(data_field="samples"), | ||
slice_points=str_list([0, 1, 1 + np.prod(config.input_shape)]), | ||
) | ||
labels = lbann.Identity(input_) | ||
sample = lbann.Reshape(input_, dims=str_list(config.input_shape)) | ||
roberta = RobertaModel(config, load_weights=config.load_weights) | ||
out = roberta(sample) | ||
out = lbann.ChannelwiseFullyConnected(out, output_channel_dims=[1000]) | ||
loss = CrossEntropyLoss(10, data_layout="model_parallel") | ||
obj = loss(out, labels) | ||
metrics = [lbann.Metric(obj, name="loss")] | ||
|
||
model = lbann.Model( | ||
lbann_params.epochs, | ||
layers=lbann.traverse_layer_graph(input_), | ||
objective_function=obj, | ||
metrics=metrics, | ||
callbacks=[ | ||
lbann.CallbackPrint(), | ||
lbann.CallbackTimer(), | ||
], | ||
) | ||
|
||
# Setup trainer, optimizer, data_reader | ||
trainer = lbann.Trainer( | ||
mini_batch_size=lbann_params.mini_batch_size, | ||
num_parallel_readers=1, | ||
) | ||
optimizer = lbann.Adam( | ||
learn_rate=0.01, | ||
beta1=0.9, | ||
beta2=0.99, | ||
eps=1e-8, | ||
) | ||
data_reader = make_data_reader() | ||
|
||
# Launch LBANN | ||
kwargs = lbann.contrib.args.get_scheduler_kwargs(lbann_params) | ||
kwargs["environment"] = {} | ||
lbann.contrib.launcher.run( | ||
trainer, | ||
model, | ||
data_reader, | ||
optimizer, | ||
work_dir=lbann_params.work_dir, | ||
job_name=lbann_params.job_name, | ||
lbann_args=["--num_io_threads=1"], | ||
batch_job=lbann_params.batch_job, | ||
**kwargs, | ||
) |
Oops, something went wrong.