Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Fix multiple issues in distributed multi-GPU GraphSAGE example #3870

Merged
merged 8 commits into from
Mar 25, 2022

Conversation

BarclayII
Copy link
Collaborator

@BarclayII BarclayII commented Mar 22, 2022

  • Previously running the GraphSAGE example with multi-GPU multi-machine will throw device mismatch error.
  • Removed pad_data and try torch.distributed.algorithms.join.Join to deal with uneven training set sizes.

@BarclayII BarclayII requested review from classicsong, zheng-da and Rhett-Ying and removed request for classicsong March 22, 2022 07:31
@BarclayII BarclayII changed the title [Bug] Fix device mismatch in distributed multi-GPU GraphSAGE example [Bug] Fix multiple issues in distributed multi-GPU GraphSAGE example Mar 23, 2022
@BarclayII BarclayII linked an issue Mar 23, 2022 that may be closed by this pull request
@BarclayII
Copy link
Collaborator Author

@bioannidis @tonyjie Could you try this solution out? PyTorch 1.10 now recommends using Join to handle uneven training set across workers. I tried by my own and it seemed to work across different batch sizes.

@jermainewang
Copy link
Member

I suggest also mention this in the related tutorials/user guides.

@tonyjie
Copy link

tonyjie commented Mar 23, 2022

Actually I'm fine with GraphSAGE distributed training code. I'm trying to use DistEdgeDataLoader for other applications, and found that the code would work under some parameters settings but would fail (stuck in deadlocks) under others. I just tried from torch.distributed.algorithms.join import Join but it didn't seem to work for my case.

But in general, in fewer machines settings, it tends to be easier to run successfully, e.g. 2 partitions compared to 4 partitions. So I think it's still a synchronization problem.

@BarclayII
Copy link
Collaborator Author

BarclayII commented Mar 23, 2022

I just tried from torch.distributed.algorithms.join import Join but it didn't seem to work for my case.

That's bad, because the purpose of Join is to deal with potential deadlocks caused by uneven training sets. Could you tell us the following:

  • Number of nodes (or machines)
  • Size of your partitions (number of nodes) for each node
  • Number of trainers per node
  • Number of samplers per node
  • Number of vertices/edges iterated in your DistDataLoader in each trainer process (that is, the shape of the dataset argument in your DistDataLoader, which can vary across trainer processes)
  • Batch size
  • Shuffle and drop_last flag (which I assume is True and False resp.)

Also, were you able to notice where your code is hanging (e.g. something like torch.distributed.barrier, or within DGL's RPC calls)?

If you could provide a reproducible example it would be even better, although I understand it may be hard to do so.

@tonyjie
Copy link

tonyjie commented Mar 23, 2022

One setting that fails (stuck in deadlock) is as follows:

  • number of nodes = 2
  • number of trainers per node = 2
  • number of samplers per node = 0
  • number of servers per node = 1
  • For the entire graph: num_edges = 144000, num_nodes = 2625
  • number of nodes in each partition (2 partitions in total) are: nodetype_0: [0, 850], nodetype_1: [850, 1332]; nodetype_0: [1332, 2164], nodetype_1: [2164, 2625], which are 1332 nodes and 1293 nodes.
  • I'm running on distributed CPU setting, so Net(..., dev_id=None) which is the same as torch.device('cpu')
  • I'm using DistEdgeDataLoader, so its arguments are different from DistDataLoader (though DistEdgeDataloader is a combination of EdgeCollator and DistDataLoader). I set those arguments like: g as DistGraph object, eids as the dict that contains the total number of edges of DistGraph calling DistGraph.number_of_edges(...).
sampler = dgl.dataloading.MultiLayerNeighborSampler([None], return_eids=True)
dataloader = DistEdgeDataLoader(
    g, # g = dgl.distributed.DistGraph
    eids = {to_etype_name(k): th.arange(
        g.number_of_edges(etype=to_etype_name(k)))
     for k in [1,2,3,4,5]}, 
    graph_sampler = sampler,
    # use_ddp=True,
    batch_size=args.minibatch_size,
    shuffle=True,
    drop_last=False)
  • Here I set minibatch_size as 5000 (edges for each iteration), so there are 144000 / 2 / 5000 = 14 steps per epoch.
  • Shuffle and drop_last flag: Yes it's set as True and False resp.

But when the num_trainers is set to 1 and every other parameters remain the same, the same code snippet would work!

Potential Questions & Bugs from my side?

Number of vertices/edges iterated in your DistDataLoader in each trainer process (that is, the shape of the dataset argument in your DistDataLoader, which can vary across trainer processes)

I found that I'm not doing anything like dgl.distributed.node_split() in my example. As I'm working on DistEdgeDataLoader, shall I use dgl.distributed.edge_split() and then set the return eid as the argument eid of DistEdgeDataLoader?

Thanks for your detailed question. Looking forward to some suggestions.

@tonyjie
Copy link

tonyjie commented Mar 23, 2022

The code would just hang on a "random" step: Part x | Epoch n | Step m | ... |. I didn't look into which line of code it would stuck now, but I think it's some synchronization error. Also, I didn't use g.barrier

@jermainewang jermainewang merged commit 7d41608 into dmlc:master Mar 25, 2022
@jermainewang
Copy link
Member

@tonyjie I think your case is different and may be related to this patch #3867. Please download the latest nightly build and try again. If it doesn't work out still, please open a new issue on that (you could copy-paste the info here). I will merge this PR to resolve the original issue.

@BarclayII BarclayII deleted the fix-dist-example branch March 25, 2022 16:24
@tonyjie
Copy link

tonyjie commented Mar 28, 2022

@tonyjie I think your case is different and may be related to this patch #3867. Please download the latest nightly build and try again. If it doesn't work out still, please open a new issue on that (you could copy-paste the info here). I will merge this PR to resolve the original issue.

I downloaded the latest nightly build pip install --pre dgl-cu111 dglgo -f https://data.dgl.ai/wheels-test/repo.html, but it still doesn't work. It makes sense because my distributed program would hang on the middle of training instead of DistGraph.init() mentioned in thie patch #3867

I would open another issue later. Thanks

BarclayII added a commit that referenced this pull request Mar 31, 2022
…3870)

* fix distributed multi-GPU example device

* try Join

* update version requirement in README

* use model.join

* fix docs

Co-authored-by: Jinjing Zhou <VoVAllen@users.noreply.github.com>
@zheng-da
Copy link
Collaborator

zheng-da commented Apr 5, 2022

i tried the idea here and it doesn't work for me either. here is the error i got:

Part 0 | Epoch 00000 | Batch 000 | Train Acc: 0.0000 | Train Loss (ALL|GNN): 7.6276|7.6276 | Time: 4.3357
terminate called after throwing an instance of 'gloo::EnforceNotMet'
  what():  [enforce fail at ../third_party/gloo/gloo/transport/tcp/pair.cc:510] op.preamble.length <= op.nbytes. 24640 vs 4
Traceback (most recent call last):
  File "/home/dzzhen/.local/lib/python3.9/site-packages/torch/distributed/algorithms/join.py", line 274, in __exit__
  File "/home/dzzhen/m5-gnn/python/m5gnn/model/rgcn_node_base.py", line 179, in fit
    loss.backward()
  File "/home/dzzhen/.local/lib/python3.9/site-packages/torch/_tensor.py", line 307, in backward
    join_hook.main_hook()
  File "/home/dzzhen/.local/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 193, in main_hook
    ddp._match_all_reduce_for_bwd_pass()
  File "/home/dzzhen/.local/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1070, in _match_all_reduce_for_bwd_pass
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/dzzhen/.local/lib/python3.9/site-packages/torch/autograd/__init__.py", line 154, in backward
    work.wait()
RuntimeError: [../third_party/gloo/gloo/transport/tcp/pair.cc:589] Read error [10.2.23.254]:1439: Connection reset by peer
    Variable._execution_engine.run_backward(
RuntimeError: [../third_party/gloo/gloo/transport/tcp/pair.cc:589] Read error [10.2.23.254]:49550: Connection reset by peer

@Rhett-Ying
Copy link
Collaborator

@tonyjie could you post the details(command, error) into this ticket: #3881 ?

@tonyjie
Copy link

tonyjie commented Apr 6, 2022

@tonyjie could you post the details(command, error) into this ticket: #3881 ?

Hi, I think my problem is kind of different and I already (basically) solved it. I'm trying to write my own distributed training code for link prediction task, therefore I use DistEdgeDataLoader, but I didn't do dgl.distributed.edge_split() before. The code could run under some conditions, but there's large possibilities that it would stuck and hanging in the middle of training (instead of hangs at very beginning when distributed train mentioned in #3881 ).

Now I use edge_split(), and seems the distributed training is fine now. Maybe there would cause some data race and communication error before without edge_split()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Bug in DistDGL when the batch size is small
7 participants