Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Fix multiple issues in distributed multi-GPU GraphSAGE example #3870

Merged
merged 8 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/pytorch/graphsage/dist/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ This is an example of training GraphSage in a distributed fashion. Before traini
sudo pip3 install ogb
```

**Requires PyTorch 1.10.0+ to work.**

To train GraphSage, it has five steps:

### Step 0: Setup a Distributed File System
Expand Down
120 changes: 49 additions & 71 deletions examples/pytorch/graphsage/dist/train_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import torch.optim as optim
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
from torch.distributed.algorithms.join import Join
import socket

def load_subtensor(g, seeds, input_nodes, device, load_feat=True):
"""
Expand Down Expand Up @@ -155,41 +157,11 @@ def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device):
model.train()
return compute_acc(pred[val_nid], labels[val_nid]), compute_acc(pred[test_nid], labels[test_nid])

def pad_data(nids, device):
"""
In distributed traning scenario, we need to make sure that each worker has same number of
batches. Otherwise the synchronization(barrier) is called diffirent times, which results in
the worker with more batches hangs up.

This function pads the nids to the same size for all workers, by repeating the head ids till
the maximum size among all workers.
"""
import torch.distributed as dist
# NCCL backend only supports GPU tensors, thus here we need to allocate it to gpu
num_nodes = th.tensor(nids.numel()).to(device)
dist.all_reduce(num_nodes, dist.ReduceOp.MAX)
max_num_nodes = int(num_nodes)
nids_length = nids.shape[0]
if max_num_nodes > nids_length:
pad_size = max_num_nodes % nids_length
repeat_size = max_num_nodes // nids_length
new_nids = th.cat([nids for _ in range(repeat_size)] + [nids[:pad_size]], axis=0)
print("Pad nids from {} to {}".format(nids_length, max_num_nodes))
else:
new_nids = nids
assert new_nids.shape[0] == max_num_nodes
return new_nids


def run(args, device, data):
# Unpack data
train_nid, val_nid, test_nid, in_feats, n_classes, g = data
shuffle = True
if args.pad_data:
train_nid = pad_data(train_nid, device)
# Current pipeline doesn't support duplicate node id within the same batch
# Therefore turn off shuffling to avoid potential duplicate node id within the same batch
shuffle = False
# Create sampler
sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')],
dgl.distributed.sample_neighbors, device)
Expand All @@ -209,8 +181,7 @@ def run(args, device, data):
if args.num_gpus == -1:
model = th.nn.parallel.DistributedDataParallel(model)
else:
dev_id = g.rank() % args.num_gpus
model = th.nn.parallel.DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id)
model = th.nn.parallel.DistributedDataParallel(model, device_ids=[device], output_device=device)
loss_fcn = nn.CrossEntropyLoss()
loss_fcn = loss_fcn.to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
Expand All @@ -233,43 +204,46 @@ def run(args, device, data):
# Loop over the dataloader to sample the computation dependency graph as a list of
# blocks.
step_time = []
for step, blocks in enumerate(dataloader):
tic_step = time.time()
sample_time += tic_step - start

# The nodes for input lies at the LHS side of the first block.
# The nodes for output lies at the RHS side of the last block.
batch_inputs = blocks[0].srcdata['features']
batch_labels = blocks[-1].dstdata['labels']
batch_labels = batch_labels.long()

num_seeds += len(blocks[-1].dstdata[dgl.NID])
num_inputs += len(blocks[0].srcdata[dgl.NID])
blocks = [block.to(device) for block in blocks]
batch_labels = batch_labels.to(device)
# Compute loss and prediction
start = time.time()
batch_pred = model(blocks, batch_inputs)
loss = loss_fcn(batch_pred, batch_labels)
forward_end = time.time()
optimizer.zero_grad()
loss.backward()
compute_end = time.time()
forward_time += forward_end - start
backward_time += compute_end - forward_end

optimizer.step()
update_time += time.time() - compute_end

step_t = time.time() - tic_step
step_time.append(step_t)
iter_tput.append(len(blocks[-1].dstdata[dgl.NID]) / step_t)
if step % args.log_every == 0:
acc = compute_acc(batch_pred, batch_labels)
gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0
print('Part {} | Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MB | time {:.3f} s'.format(
g.rank(), epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]), gpu_mem_alloc, np.sum(step_time[-args.log_every:])))
start = time.time()

with Join([model]):
BarclayII marked this conversation as resolved.
Show resolved Hide resolved
for step, blocks in enumerate(dataloader):
tic_step = time.time()
sample_time += tic_step - start

# The nodes for input lies at the LHS side of the first block.
# The nodes for output lies at the RHS side of the last block.
batch_inputs = blocks[0].srcdata['features']
batch_labels = blocks[-1].dstdata['labels']
batch_labels = batch_labels.long()

num_seeds += len(blocks[-1].dstdata[dgl.NID])
num_inputs += len(blocks[0].srcdata[dgl.NID])
blocks = [block.to(device) for block in blocks]
batch_labels = batch_labels.to(device)
# Compute loss and prediction
start = time.time()
#print(g.rank(), blocks[0].device, model.module.layers[0].fc_neigh.weight.device, dev_id)
batch_pred = model(blocks, batch_inputs)
loss = loss_fcn(batch_pred, batch_labels)
forward_end = time.time()
optimizer.zero_grad()
loss.backward()
compute_end = time.time()
forward_time += forward_end - start
backward_time += compute_end - forward_end

optimizer.step()
update_time += time.time() - compute_end

step_t = time.time() - tic_step
step_time.append(step_t)
iter_tput.append(len(blocks[-1].dstdata[dgl.NID]) / step_t)
if step % args.log_every == 0:
acc = compute_acc(batch_pred, batch_labels)
gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0
print('Part {} | Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MB | time {:.3f} s'.format(
g.rank(), epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]), gpu_mem_alloc, np.sum(step_time[-args.log_every:])))
start = time.time()

toc = time.time()
print('Part {}, Epoch Time(s): {:.4f}, sample+data_copy: {:.4f}, forward: {:.4f}, backward: {:.4f}, update: {:.4f}, #seeds: {}, #inputs: {}'.format(
Expand All @@ -285,11 +259,14 @@ def run(args, device, data):
time.time() - start))

def main(args):
print(socket.gethostname(), 'Initializing DGL dist')
dgl.distributed.initialize(args.ip_config)
if not args.standalone:
print(socket.gethostname(), 'Initializing DGL process group')
th.distributed.init_process_group(backend=args.backend)
print(socket.gethostname(), 'Initializing DistGraph')
g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config)
print('rank:', g.rank())
print(socket.gethostname(), 'rank:', g.rank())

pb = g.get_partition_book()
if 'trainer_id' in g.ndata:
Expand All @@ -311,7 +288,8 @@ def main(args):
if args.num_gpus == -1:
device = th.device('cpu')
else:
device = th.device('cuda:'+str(args.local_rank))
dev_id = g.rank() % args.num_gpus
device = th.device('cuda:'+str(dev_id))
labels = g.ndata['labels'][np.arange(g.number_of_nodes())]
n_classes = len(th.unique(labels[th.logical_not(th.isnan(labels))]))
print('#labels:', n_classes)
Expand Down