Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Merge idist into master #1045

Merged
merged 54 commits into from
May 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
177fb6f
Improved parallel utils (#1023)
vfdev-5 May 11, 2020
91d8875
[WIP] create from context for XLA
vfdev-5 May 11, 2020
3cfccd4
autopep8 fix
May 11, 2020
f71043f
Tests for _sync_model for XLA
vfdev-5 May 11, 2020
093ddb1
autopep8 fix
May 11, 2020
7ad7fcf
More tests and updates
vfdev-5 May 11, 2020
d57b3c9
autopep8 fix
May 11, 2020
7fcadca
[WIP] create from context for Native Torch Dist
vfdev-5 May 12, 2020
5a6e052
autopep8 fix
May 12, 2020
1c362fe
Added tests for idist.* created from context for native dist settings
vfdev-5 May 12, 2020
12512cf
[WIP] Fix tests
vfdev-5 May 13, 2020
228fd89
Fixed metric related tests
vfdev-5 May 13, 2020
b09ea05
autopep8 fix
May 13, 2020
a23da8e
Merge branch 'master' of https://github.com/pytorch/ignite into idist
vfdev-5 May 13, 2020
da72b15
[WIP] idist - Docs & code updates (#1034)
vfdev-5 May 15, 2020
0352bc6
Merge branch 'master' into origin-idist
vfdev-5 May 15, 2020
16256cf
Merge branch 'master' of https://github.com/pytorch/ignite into origi…
vfdev-5 May 16, 2020
914bba9
Tpu metrics (#1042)
vfdev-5 May 16, 2020
feb79b4
Merge branch 'master' into idist
vfdev-5 May 16, 2020
25d38d1
Increased err tol for mse and rmse tests on single TPU
vfdev-5 May 16, 2020
8886948
Fixes #991 (#1047)
vfdev-5 May 16, 2020
add8a4d
Merge branch 'master' into idist
vfdev-5 May 16, 2020
bdae449
add TPU checkpointing to CPU. (#1005)
erip May 16, 2020
d1cc29d
Updated tests on checkpoint and TPU
vfdev-5 May 16, 2020
977ac8c
Merge branch 'master' into idist
vfdev-5 May 17, 2020
15072ae
Added barrier op in idist (#1050)
vfdev-5 May 17, 2020
ac86d46
Merge branch 'master' into idist
vfdev-5 May 18, 2020
037e7f7
Fixed bug with torch.cuda.set_device
vfdev-5 May 19, 2020
2a01cc3
Fixed cuda device index, added warning if cuda device index != local …
vfdev-5 May 19, 2020
1f54ab5
autopep8 fix
May 19, 2020
199224a
Merge branch 'master' into idist
vfdev-5 May 22, 2020
888a654
Issue 1011 (#1053)
vfdev-5 May 22, 2020
ae1bdf5
Improved device() method (#1062)
vfdev-5 May 23, 2020
0fa8c61
Merge branch 'master' into idist
sdesrozis May 23, 2020
537dbd0
Idist kwargs dict (#1064)
vfdev-5 May 23, 2020
727f038
removed badly merged _need_to_sync
vfdev-5 May 23, 2020
530c422
Improved device and setup_common_training_handlers (#1066)
vfdev-5 May 24, 2020
74ddacb
Idist improve2 (#1075)
vfdev-5 May 28, 2020
6735dc0
Merge branch 'master' into idist
vfdev-5 May 28, 2020
b1b5d56
Merge branch 'master' into idist
vfdev-5 May 28, 2020
1e5d7d3
Added support for str input for all gather (#1081)
vfdev-5 May 29, 2020
89e1358
Fix #1055 (#1068)
sdesrozis May 29, 2020
1c34eda
Merge branch 'master' into idist
vfdev-5 May 29, 2020
d277a25
Fix failing tests on multi-gpus
vfdev-5 May 29, 2020
d9a80c6
Fix failing XLA tests
vfdev-5 May 30, 2020
f617787
Merge branch 'master' into idist
vfdev-5 May 30, 2020
a8f03e8
Merge branch 'master' into idist
vfdev-5 May 31, 2020
b41cf6d
Fixes failing tests on multi-GPUs
vfdev-5 May 31, 2020
222cb60
autopep8 fix
May 31, 2020
b3b9aff
Remove useless barriers (#1085)
sdesrozis May 31, 2020
44f4c63
Fixes failing TPU with fork mp
vfdev-5 May 31, 2020
8989e5e
Merge branch 'master' into idist
vfdev-5 May 31, 2020
f4ee4f9
Applied review suggestions
vfdev-5 May 31, 2020
669ef8a
autopep8 fix
May 31, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions docs/source/distributed.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
ignite.distributed
==================

Helper module to use distributed settings for multiple backends:

- backends from native torch distributed configuration: "nccl", "gloo", "mpi"

- XLA on TPUs via `pytorch/xla <https://github.com/pytorch/xla>`_

This module wraps common methods to fetch information about distributed configuration, initialize/finalize process
group or spawn multiple processes.


Examples:

- Example to spawn `nprocs` processes that run `fn` with `args`: :meth:`~ignite.distributed.spawn`


.. currentmodule:: ignite.distributed

.. automodule:: ignite.distributed
:members:
:imported-members:

.. attribute:: has_xla_support

True if `torch_xla` package is found
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ of stable version as dependency):
engine
handlers
metrics
distributed
exceptions
utils

Expand Down
1 change: 1 addition & 0 deletions ignite/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ignite.contrib
import ignite.distributed
import ignite.engine
import ignite.exceptions
import ignite.handlers
Expand Down
106 changes: 62 additions & 44 deletions ignite/contrib/engines/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from functools import partial

import torch
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler

import ignite.distributed as idist
from ignite.contrib.handlers import (
LRScheduler,
MLflowLogger,
NeptuneLogger,
PolyaxonLogger,
Expand All @@ -19,7 +21,7 @@
)
from ignite.contrib.metrics import GpuInfo
from ignite.engine import Engine, Events
from ignite.handlers import EarlyStopping, ModelCheckpoint, TerminateOnNan
from ignite.handlers import Checkpoint, DiskSaver, EarlyStopping, TerminateOnNan
from ignite.metrics import RunningAverage


Expand All @@ -35,7 +37,9 @@ def setup_common_training_handlers(
with_pbars=True,
with_pbar_on_iters=True,
log_every_iters=100,
device="cuda",
device=None,
stop_on_nan=True,
clear_cuda_cache=True,
):
"""Helper method to setup trainer with common handlers (it also supports distributed configuration):
- :class:`~ignite.handlers.TerminateOnNan`
Expand All @@ -49,22 +53,31 @@ def setup_common_training_handlers(
or sequence or a single tensor.
train_sampler (torch.utils.data.DistributedSampler, optional): Optional distributed sampler used to call
`set_epoch` method on epoch started event.
to_save (dict, optional): dictionary with objects to save in the checkpoint. This is used with
:class:`~ignite.handlers.ModelCheckpoint`.
to_save (dict, optional): dictionary with objects to save in the checkpoint. This argument is passed to
:class:`~ignite.handlers.Checkpoint` instance.
save_every_iters (int, optional): saving interval. By default, `to_save` objects are stored
each 1000 iterations.
output_path (str, optional): output path to indicate where `to_save` objects are stored.
lr_scheduler (ParamScheduler or subclass of `torch.optim.lr_scheduler._LRScheduler`): learning rate scheduler
as native torch LRScheduler or ignite's parameter scheduler.
with_gpu_stats (bool, optional): if True, :class:`~ignite.contrib.metrics.handlers.GpuInfo` is attached to the
trainer. This requires `pynvml` package to be installed.
output_names (list/tuple): list of names associated with `update_function` output dictionary.
with_pbars (bool, optional): if True, two progress bars on epochs and optionally on iterations are attached
output_names (list/tuple, optional): list of names associated with `update_function` output dictionary.
with_pbars (bool, optional): if True, two progress bars on epochs and optionally on iterations are attached.
Default, True.
with_pbar_on_iters (bool, optional): if True, a progress bar on iterations is attached to the trainer.
Default, True.
log_every_iters (int, optional): logging interval for :class:`~ignite.contrib.metrics.handlers.GpuInfo` and for
epoch-wise progress bar.
device (str of torch.device, optional): Optional device specification in case of distributed computation usage.
epoch-wise progress bar. Default, 100.
stop_on_nan (bool, optional): if True, :class:`~ignite.handlers.TerminateOnNan` handler is added to the trainer.
Default, True.
clear_cuda_cache (bool, optional): if True, `torch.cuda.empty_cache()` is called every end of epoch.
Default, True.
device (str of torch.device, optional): deprecated argument, it will be removed in v0.5.0.
"""
if device is not None:
warnings.warn("Argument device is unused and deprecated. It will be removed in v0.5.0")

kwargs = dict(
to_save=to_save,
save_every_iters=save_every_iters,
Expand All @@ -75,16 +88,16 @@ def setup_common_training_handlers(
with_pbars=with_pbars,
with_pbar_on_iters=with_pbar_on_iters,
log_every_iters=log_every_iters,
device=device,
stop_on_nan=stop_on_nan,
clear_cuda_cache=clear_cuda_cache,
)
if dist.is_available() and dist.is_initialized():

if idist.get_world_size() > 1:
_setup_common_distrib_training_handlers(trainer, train_sampler=train_sampler, **kwargs)
else:
if train_sampler is not None:
if train_sampler is not None and isinstance(train_sampler, DistributedSampler):
warnings.warn(
"Argument train_sampler distributed sampler used to call `set_epoch` method on epoch "
"started event, but no distributed setting detected",
UserWarning,
"Argument train_sampler is a distributed sampler, but no distributed setting detected", UserWarning,
)
_setup_common_training_handlers(trainer, **kwargs)

Expand All @@ -98,28 +111,35 @@ def _setup_common_training_handlers(
save_every_iters=1000,
output_path=None,
lr_scheduler=None,
with_gpu_stats=True,
with_gpu_stats=False,
output_names=None,
with_pbars=True,
with_pbar_on_iters=True,
log_every_iters=100,
device="cuda",
stop_on_nan=True,
clear_cuda_cache=True,
):
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
if stop_on_nan:
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

if lr_scheduler is not None:
if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler):
trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step())
elif isinstance(lr_scheduler, LRScheduler):
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)
else:
trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)

trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache)
if torch.cuda.is_available() and clear_cuda_cache:
trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache)

if to_save is not None:
if output_path is None:
raise ValueError("If to_save argument is provided then output_path argument should be also defined")
checkpoint_handler = ModelCheckpoint(dirname=output_path, filename_prefix="training", require_empty=False)
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=save_every_iters), checkpoint_handler, to_save)
checkpoint_handler = Checkpoint(
to_save, DiskSaver(dirname=output_path, require_empty=False), filename_prefix="training",
)
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=save_every_iters), checkpoint_handler)

if with_gpu_stats:
GpuInfo().attach(trainer, name="gpu", event_name=Events.ITERATION_COMPLETED(every=log_every_iters))
Expand All @@ -140,9 +160,9 @@ def output_transform(x, index, name):
)

for i, n in enumerate(output_names):
RunningAverage(
output_transform=partial(output_transform, index=i, name=n), epoch_bound=False, device=device
).attach(trainer, n)
RunningAverage(output_transform=partial(output_transform, index=i, name=n), epoch_bound=False).attach(
trainer, n
)

if with_pbars:
if with_pbar_on_iters:
Expand All @@ -162,43 +182,38 @@ def _setup_common_distrib_training_handlers(
save_every_iters=1000,
output_path=None,
lr_scheduler=None,
with_gpu_stats=True,
with_gpu_stats=False,
output_names=None,
with_pbars=True,
with_pbar_on_iters=True,
log_every_iters=100,
device="cuda",
stop_on_nan=True,
clear_cuda_cache=True,
):
if not (dist.is_available() and dist.is_initialized()):
raise RuntimeError("Distributed setting is not initialized, please call `dist.init_process_group` before.")

_setup_common_training_handlers(
trainer,
to_save=None,
to_save=to_save,
output_path=output_path,
save_every_iters=save_every_iters,
lr_scheduler=lr_scheduler,
with_gpu_stats=with_gpu_stats,
output_names=output_names,
with_pbars=(dist.get_rank() == 0) and with_pbars,
with_pbars=(idist.get_rank() == 0) and with_pbars,
with_pbar_on_iters=with_pbar_on_iters,
log_every_iters=log_every_iters,
device=device,
stop_on_nan=stop_on_nan,
clear_cuda_cache=clear_cuda_cache,
)

if train_sampler is not None:
if not callable(getattr(train_sampler, "set_epoch", None)):
raise TypeError("Train sampler should have `set_epoch` method")
if not isinstance(train_sampler, DistributedSampler):
raise TypeError("Train sampler should be torch DistributedSampler and have `set_epoch` method")

@trainer.on(Events.EPOCH_STARTED)
def distrib_set_epoch(engine):
train_sampler.set_epoch(engine.state.epoch - 1)

if dist.get_rank() == 0:
if to_save is not None:
if output_path is None:
raise ValueError("If to_save argument is provided then output_path argument should be also defined")
checkpoint_handler = ModelCheckpoint(dirname=output_path, filename_prefix="training", require_empty=False)
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=save_every_iters), checkpoint_handler, to_save)


def empty_cuda_cache(_):
torch.cuda.empty_cache()
Expand Down Expand Up @@ -447,21 +462,24 @@ def save_best_model_by_val_score(output_path, evaluator, model, metric_name, n_s
tag (str, optional): score name prefix: `{tag}_{metric_name}`. By default, tag is "val".

Returns:
A :class:`~ignite.handlers.checkpoint.ModelCheckpoint` handler.
A :class:`~ignite.handlers.checkpoint.Checkpoint` handler.
"""
global_step_transform = None
if trainer is not None:
global_step_transform = global_step_from_engine(trainer)

best_model_handler = ModelCheckpoint(
dirname=output_path,
best_model_handler = Checkpoint(
{"model": model,},
DiskSaver(dirname=output_path, require_empty=False),
filename_prefix="best",
n_saved=n_saved,
global_step_transform=global_step_transform,
score_name="{}_{}".format(tag, metric_name.lower()),
score_function=get_default_score_fn(metric_name),
)
evaluator.add_event_handler(Events.COMPLETED, best_model_handler, {"model": model,})
evaluator.add_event_handler(
Events.COMPLETED, best_model_handler,
)

return best_model_handler

Expand Down
2 changes: 1 addition & 1 deletion ignite/contrib/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from ignite.contrib.handlers.base_logger import global_step_from_engine
from ignite.contrib.handlers.custom_events import CustomPeriodicEvent
from ignite.contrib.handlers.lr_finder import FastaiLRFinder
from ignite.contrib.handlers.mlflow_logger import MLflowLogger
Expand All @@ -19,3 +18,4 @@
from ignite.contrib.handlers.trains_logger import TrainsLogger
from ignite.contrib.handlers.visdom_logger import VisdomLogger
from ignite.contrib.handlers.wandb_logger import WandBLogger
from ignite.handlers import global_step_from_engine # ref
1 change: 0 additions & 1 deletion ignite/contrib/handlers/base_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch

from ignite.engine import Engine, State
from ignite.handlers import global_step_from_engine


class BaseHandler(metaclass=ABCMeta):
Expand Down
13 changes: 3 additions & 10 deletions ignite/contrib/handlers/mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,8 @@

import torch

from ignite.contrib.handlers.base_logger import (
BaseLogger,
BaseOptimizerParamsHandler,
BaseOutputHandler,
global_step_from_engine,
)
from ignite.contrib.handlers.base_logger import BaseLogger, BaseOptimizerParamsHandler, BaseOutputHandler
from ignite.handlers import global_step_from_engine

__all__ = ["MLflowLogger", "OutputHandler", "OptimizerParamsHandler", "global_step_from_engine"]

Expand Down Expand Up @@ -287,10 +283,7 @@ def __getattr__(self, attr):

import mlflow

def wrapper(*args, **kwargs):
return getattr(mlflow, attr)(*args, **kwargs)

return wrapper
return getattr(mlflow, attr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above - would be nice to replace this if we can.


def close(self):
import mlflow
Expand Down
12 changes: 7 additions & 5 deletions ignite/contrib/handlers/neptune_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
import torch

import ignite
import ignite.distributed as idist
from ignite.contrib.handlers.base_logger import (
BaseLogger,
BaseOptimizerParamsHandler,
BaseOutputHandler,
BaseWeightsScalarHandler,
global_step_from_engine,
)
from ignite.handlers import global_step_from_engine
from ignite.handlers.checkpoint import BaseSaveHandler

__all__ = [
Expand Down Expand Up @@ -478,10 +479,7 @@ def __getattr__(self, attr):

import neptune

def wrapper(*args, **kwargs):
return getattr(neptune, attr)(*args, **kwargs)

return wrapper
return getattr(neptune, attr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here.


def __init__(self, *args, **kwargs):
try:
Expand Down Expand Up @@ -571,14 +569,18 @@ def score_function(engine):

"""

@idist.one_rank_only()
def __init__(self, neptune_logger: NeptuneLogger):
self._logger = neptune_logger

@idist.one_rank_only()
def __call__(self, checkpoint: Mapping, filename: str, metadata: Optional[Mapping] = None) -> None:
# wont work on XLA

with tempfile.NamedTemporaryFile() as tmp:
torch.save(checkpoint, tmp.name)
self._logger.log_artifact(tmp.name, filename)

@idist.one_rank_only(with_barrier=True)
def remove(self, filename: str) -> None:
self._logger.delete_artifacts(filename)
13 changes: 3 additions & 10 deletions ignite/contrib/handlers/polyaxon_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,8 @@

import torch

from ignite.contrib.handlers.base_logger import (
BaseLogger,
BaseOptimizerParamsHandler,
BaseOutputHandler,
global_step_from_engine,
)
from ignite.contrib.handlers.base_logger import BaseLogger, BaseOptimizerParamsHandler, BaseOutputHandler
from ignite.handlers import global_step_from_engine

__all__ = ["PolyaxonLogger", "OutputHandler", "OptimizerParamsHandler", "global_step_from_engine"]

Expand Down Expand Up @@ -271,10 +267,7 @@ def __init__(self, *args, **kwargs):
self.experiment = Experiment(*args, **kwargs)

def __getattr__(self, attr):
def wrapper(*args, **kwargs):
return getattr(self.experiment, attr)(*args, **kwargs)

return wrapper
return getattr(self.experiment, attr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here.


def _create_output_handler(self, *args, **kwargs):
return OutputHandler(*args, **kwargs)
Expand Down
Loading