Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Merge idist into master #1045

Merged
merged 54 commits into from
May 31, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
177fb6f
Improved parallel utils (#1023)
vfdev-5 May 11, 2020
91d8875
[WIP] create from context for XLA
vfdev-5 May 11, 2020
3cfccd4
autopep8 fix
May 11, 2020
f71043f
Tests for _sync_model for XLA
vfdev-5 May 11, 2020
093ddb1
autopep8 fix
May 11, 2020
7ad7fcf
More tests and updates
vfdev-5 May 11, 2020
d57b3c9
autopep8 fix
May 11, 2020
7fcadca
[WIP] create from context for Native Torch Dist
vfdev-5 May 12, 2020
5a6e052
autopep8 fix
May 12, 2020
1c362fe
Added tests for idist.* created from context for native dist settings
vfdev-5 May 12, 2020
12512cf
[WIP] Fix tests
vfdev-5 May 13, 2020
228fd89
Fixed metric related tests
vfdev-5 May 13, 2020
b09ea05
autopep8 fix
May 13, 2020
a23da8e
Merge branch 'master' of https://github.com/pytorch/ignite into idist
vfdev-5 May 13, 2020
da72b15
[WIP] idist - Docs & code updates (#1034)
vfdev-5 May 15, 2020
0352bc6
Merge branch 'master' into origin-idist
vfdev-5 May 15, 2020
16256cf
Merge branch 'master' of https://github.com/pytorch/ignite into origi…
vfdev-5 May 16, 2020
914bba9
Tpu metrics (#1042)
vfdev-5 May 16, 2020
feb79b4
Merge branch 'master' into idist
vfdev-5 May 16, 2020
25d38d1
Increased err tol for mse and rmse tests on single TPU
vfdev-5 May 16, 2020
8886948
Fixes #991 (#1047)
vfdev-5 May 16, 2020
add8a4d
Merge branch 'master' into idist
vfdev-5 May 16, 2020
bdae449
add TPU checkpointing to CPU. (#1005)
erip May 16, 2020
d1cc29d
Updated tests on checkpoint and TPU
vfdev-5 May 16, 2020
977ac8c
Merge branch 'master' into idist
vfdev-5 May 17, 2020
15072ae
Added barrier op in idist (#1050)
vfdev-5 May 17, 2020
ac86d46
Merge branch 'master' into idist
vfdev-5 May 18, 2020
037e7f7
Fixed bug with torch.cuda.set_device
vfdev-5 May 19, 2020
2a01cc3
Fixed cuda device index, added warning if cuda device index != local …
vfdev-5 May 19, 2020
1f54ab5
autopep8 fix
May 19, 2020
199224a
Merge branch 'master' into idist
vfdev-5 May 22, 2020
888a654
Issue 1011 (#1053)
vfdev-5 May 22, 2020
ae1bdf5
Improved device() method (#1062)
vfdev-5 May 23, 2020
0fa8c61
Merge branch 'master' into idist
sdesrozis May 23, 2020
537dbd0
Idist kwargs dict (#1064)
vfdev-5 May 23, 2020
727f038
removed badly merged _need_to_sync
vfdev-5 May 23, 2020
530c422
Improved device and setup_common_training_handlers (#1066)
vfdev-5 May 24, 2020
74ddacb
Idist improve2 (#1075)
vfdev-5 May 28, 2020
6735dc0
Merge branch 'master' into idist
vfdev-5 May 28, 2020
b1b5d56
Merge branch 'master' into idist
vfdev-5 May 28, 2020
1e5d7d3
Added support for str input for all gather (#1081)
vfdev-5 May 29, 2020
89e1358
Fix #1055 (#1068)
sdesrozis May 29, 2020
1c34eda
Merge branch 'master' into idist
vfdev-5 May 29, 2020
d277a25
Fix failing tests on multi-gpus
vfdev-5 May 29, 2020
d9a80c6
Fix failing XLA tests
vfdev-5 May 30, 2020
f617787
Merge branch 'master' into idist
vfdev-5 May 30, 2020
a8f03e8
Merge branch 'master' into idist
vfdev-5 May 31, 2020
b41cf6d
Fixes failing tests on multi-GPUs
vfdev-5 May 31, 2020
222cb60
autopep8 fix
May 31, 2020
b3b9aff
Remove useless barriers (#1085)
sdesrozis May 31, 2020
44f4c63
Fixes failing TPU with fork mp
vfdev-5 May 31, 2020
8989e5e
Merge branch 'master' into idist
vfdev-5 May 31, 2020
f4ee4f9
Applied review suggestions
vfdev-5 May 31, 2020
669ef8a
autopep8 fix
May 31, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added barrier op in idist (#1050)
* Added barrier op in idist

* Fixed test and updated one_rank_only to use idist

* Moved one_rank_only to idist, adapted tests

* autopep8 fix

* Removed redundant imports

* Another test fix of setup_logger

Co-authored-by: AutoPEP8 <>
  • Loading branch information
vfdev-5 committed May 17, 2020
commit 15072ae39807223cd10af3a1e60b6688bb36cf7f
7 changes: 7 additions & 0 deletions ignite/distributed/comp_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "sum") -> torch.Tensor:
def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
pass

@abstractmethod
def barrier(self):
pass


class _SerialModel(ComputationModel):
"""Private class defines non-distributed computation model for code compatibility with other distributed models.
Expand Down Expand Up @@ -174,3 +178,6 @@ def _do_all_reduce(self, tensor: torch.Tensor, op: str = "sum") -> torch.Tensor:

def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
pass

def barrier(self):
pass
3 changes: 3 additions & 0 deletions ignite/distributed/comp_models/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,6 @@ def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
output = [torch.zeros_like(tensor) for _ in range(self.get_world_size())]
dist.all_gather(output, tensor)
return torch.cat(output, dim=0)

def barrier(self):
dist.barrier()
3 changes: 3 additions & 0 deletions ignite/distributed/comp_models/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,6 @@ def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
output[self.get_rank() % group_size] = tensor
xm.all_reduce("sum", [output,])
return output.reshape(-1, *output.shape[2:])

def barrier(self):
xm.rendezvous("barrier")
46 changes: 46 additions & 0 deletions ignite/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@
"set_local_rank",
"all_reduce",
"all_gather",
"barrier",
"hostname",
"has_xla_support",
"sync",
"registered_computation_models",
"one_rank_only",
]

_model = _SerialModel()
Expand Down Expand Up @@ -297,6 +299,13 @@ def all_gather(tensor: Union[torch.Tensor, Number]) -> torch.Tensor:
return _model.all_gather(tensor)


@_sync_model_wrapper
def barrier():
"""Helper method to synchronize all processes.
"""
_model.barrier()


def set_local_rank(index: int):
"""Method to hint the local rank in case if torch native distributed context is created by user
without using :meth:`~ignite.distributed.utils.initialize` or :meth:`~ignite.distributed.utils.spawn`.
Expand Down Expand Up @@ -417,3 +426,40 @@ def show_config():
logger.info("num tasks per_node: {}".format(get_ntasks_per_node()))
logger.info("num nodes: {}".format(get_num_nodes()))
logger.info("node rank: {}".format(get_node_rank()))


def one_rank_only(rank: int = 0, with_barrier: bool = False):
"""Decorator to filter handlers wrt a rank number

Args:
rank (int): rank number of the handler (default: 0).
with_barrier (bool): synchronisation with a barrier (default: False).

.. code-block:: python

engine = ...

@engine.on(...)
@one_rank_only() # means @one_rank_only(rank=0)
def some_handler(_):
...

@engine.on(...)
@one_rank_only(rank=1)
def some_handler(_):
...
"""

def _one_rank_only(func):
@wraps(func)
def wrapper(*args, **kwargs):
ret = None
if get_rank() == rank:
ret = func(*args, **kwargs)
if with_barrier:
barrier()
return ret

return wrapper

return _one_rank_only
48 changes: 4 additions & 44 deletions ignite/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import collections.abc as collections
import logging
import random
from functools import wraps
from typing import Any, Callable, Optional, Tuple, Type, Union

import torch
import torch.distributed as dist

__all__ = ["convert_tensor", "apply_to_tensor", "apply_to_type", "to_onehot", "setup_logger", "one_rank_only"]
__all__ = ["convert_tensor", "apply_to_tensor", "apply_to_type", "to_onehot", "setup_logger"]


def convert_tensor(
Expand Down Expand Up @@ -119,10 +117,9 @@ def setup_logger(
formatter = logging.Formatter(format)

if distributed_rank is None:
if dist.is_available() and dist.is_initialized():
distributed_rank = dist.get_rank()
else:
distributed_rank = 0
import ignite.distributed as idist

distributed_rank = idist.get_rank()

if distributed_rank > 0:
logger.addHandler(logging.NullHandler())
Expand Down Expand Up @@ -158,40 +155,3 @@ def manual_seed(seed: int) -> None:
np.random.seed(seed)
except ImportError:
pass


def one_rank_only(rank: int = 0, barrier: bool = False):
"""Decorator to filter handlers wrt a rank number

Args:
rank (int): rank number of the handler (default: 0).
barrier (bool): synchronisation with a barrier (default: False).

.. code-block:: python

engine = ...

@engine.on(...)
@one_rank_only() # means @one_rank_only(rank=0)
def some_handler(_):
...

@engine.on(...)
@one_rank_only(rank=1)
def some_handler(_):
...
"""

def _one_rank_only(func):
@wraps(func)
def wrapper(*args, **kwargs):
ret = None
if dist.get_rank() == rank:
ret = func(*args, **kwargs)
if barrier:
dist.barrier()
return ret

return wrapper

return _one_rank_only
140 changes: 140 additions & 0 deletions tests/ignite/distributed/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import ignite.distributed as idist
from ignite.distributed.utils import has_xla_support, sync
from ignite.engine import Engine, Events


def _sanity_check():
Expand Down Expand Up @@ -430,6 +431,56 @@ def _test_fn(index):
xmp_executor(_test_fn, args=(), nprocs=n)


def _test_distrib_barrier(device):

t = torch.tensor([idist.get_rank()], device=device, dtype=torch.float)
true_res = sum([i for i in range(idist.get_world_size())])

if idist.get_rank() == 0:
t += 10.0
idist.barrier()

tt = idist.all_reduce(t)
assert tt.item() == true_res + 10.0


@pytest.mark.distributed
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test_idist_barrier_nccl(distributed_context_single_node_nccl):

device = "cuda:{}".format(distributed_context_single_node_nccl["local_rank"])
_test_distrib_barrier(device)


@pytest.mark.distributed
def test_idist_barrier_gloo(distributed_context_single_node_gloo):

device = "cpu"
_test_distrib_barrier(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars")
@pytest.mark.skipif(not has_xla_support, reason="Skip if no PyTorch XLA package")
def test_idist_barrier_xla():

device = idist.device()
_test_distrib_barrier(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars")
@pytest.mark.skipif(not has_xla_support, reason="Skip if no PyTorch XLA package")
def test_idist_barrier_xla_in_child_proc(xmp_executor):
n = int(os.environ["NUM_TPU_WORKERS"])

def _test_fn(index):
device = idist.device()
_test_distrib_barrier(device)

xmp_executor(_test_fn, args=(), nprocs=n)


@pytest.mark.distributed
def test_idist_methods_overhead_gloo(distributed_context_single_node_gloo):
import time
Expand Down Expand Up @@ -473,3 +524,92 @@ def test_idist_methods_overhead_nccl(distributed_context_single_node_nccl):
t2 = elapsed / n

assert t2 * 3 > t1, "{} * 3 vs {}".format(t2, t1)


def _test_distrib_one_rank_only(device):
def _test(barrier):
# last rank
rank = idist.get_world_size() - 1

value = torch.tensor(0).to(device)

@idist.one_rank_only(rank=rank, with_barrier=barrier)
def initialize():
value.data = torch.tensor(100).to(device)

initialize()

value_list = idist.all_gather(tensor=value)

for r in range(idist.get_world_size()):
if r == rank:
assert value_list[r].item() == 100
else:
assert value_list[r].item() == 0

_test(barrier=True)
_test(barrier=False)


def _test_distrib_one_rank_only_with_engine(device):
def _test(barrier):
engine = Engine(lambda e, b: b)

batch_sum = torch.tensor(0).to(device)

@engine.on(Events.ITERATION_COMPLETED)
@idist.one_rank_only(with_barrier=barrier) # ie rank == 0
def _(_):
batch_sum.data += torch.tensor(engine.state.batch).to(device)

engine.run([1, 2, 3], max_epochs=2)

value_list = idist.all_gather(tensor=batch_sum)

for r in range(idist.get_world_size()):
if r == 0:
assert value_list[r].item() == 12
else:
assert value_list[r].item() == 0

_test(barrier=True)
_test(barrier=False)


@pytest.mark.distributed
def test_idist_one_rank_only_gloo(distributed_context_single_node_gloo):
device = "cpu"
_test_distrib_one_rank_only(device=device)
_test_distrib_one_rank_only_with_engine(device=device)


@pytest.mark.distributed
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test_idist_one_rank_only_nccl(local_rank, distributed_context_single_node_nccl):
device = "cuda:{}".format(local_rank)
_test_distrib_one_rank_only(device=device)
_test_distrib_one_rank_only_with_engine(device=device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars")
@pytest.mark.skipif(not has_xla_support, reason="Skip if no PyTorch XLA package")
def test_idist_one_rank_only_xla():

device = idist.device()
_test_distrib_one_rank_only(device=device)
_test_distrib_one_rank_only_with_engine(device=device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars")
@pytest.mark.skipif(not has_xla_support, reason="Skip if no PyTorch XLA package")
def test_idist_one_rank_only_xla_nprocs(xmp_executor):
n = int(os.environ["NUM_TPU_WORKERS"])

def _test_fn(index):
device = idist.device()
_test_distrib_one_rank_only(device=device)
_test_distrib_one_rank_only_with_engine(device=device)

xmp_executor(_test_fn, args=(), nprocs=n)
Loading