Skip to content

Commit

Permalink
[create_supervised_trainer] add automatic mixed precision (pytorch#1589)
Browse files Browse the repository at this point in the history
* amp init

* docs complete

* add tests

* unscale_ + clip_grad_norm_, move checks to private func, more edge case checks

* scaler must be provided by user and its optional

* full docstring, on_cuda_amp test

* extract into 4 functions for normal, amp, apex and tpu training

* explicit training step, independent mode

* mypy fix

* fix(tests): pytest.raises checks with match, skipif < 1.6.0

* fix(tests): align tests name, coverage append in tpu ci

* fix: remove ununsed amp import

* fix: docstring with default values, more tests, code review suggestions

* fix(docs): update function names

* fix: docstring from code review

* fix: engine state only has attribute scaler if scaler is only True

* fix: address code review

* fix: create scaler or None in _check_arg

* fix: no return for scaler in supervised_training_step_amp

* fix: gpu tests for apex

* fix: gpu tests for apex and amp

* chore: add more tests for coverage

* fix: state only has scaler attribute if True

* fix: use prefix for scaler

* Apply suggestions from code review

Co-authored-by: vfdev <vfdev.5@gmail.com>

* fix: skip apex tests if apex is not installed

* fix: skip apex error test if apex

* fix: ImportError instead of ModuleNotFoundError

* fix(docs): no device tpu in gpu functions and vice versa

* fix: raise an error instead of warn for invalid scaler and amp_mode

Co-authored-by: Sylvain Desroziers <sylvain.desroziers@gmail.com>
Co-authored-by: vfdev <vfdev.5@gmail.com>
  • Loading branch information
3 people committed Feb 21, 2021
1 parent 51970e3 commit f379b18
Show file tree
Hide file tree
Showing 5 changed files with 402 additions and 36 deletions.
7 changes: 7 additions & 0 deletions docs/source/engine.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ More details about those structures can be found in :doc:`concepts`.

.. autofunction:: ignite.engine.create_supervised_evaluator

.. autofunction:: ignite.engine.supervised_training_step

.. autofunction:: ignite.engine.supervised_training_step_amp

.. autofunction:: ignite.engine.supervised_training_step_apex

.. autofunction:: ignite.engine.supervised_training_step_tpu

Resuming the training
---------------------
Expand Down
304 changes: 278 additions & 26 deletions ignite/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@
from ignite.metrics import Metric
from ignite.utils import convert_tensor

if idist.has_xla_support:
import torch_xla.core.xla_model as xm


__all__ = [
"State",
"create_supervised_trainer",
Expand All @@ -25,6 +21,10 @@
"EventEnum",
"CallableEventWithFilter",
"RemovableEventHandle",
"supervised_training_step",
"supervised_training_step_amp",
"supervised_training_step_apex",
"supervised_training_step_tpu",
]


Expand All @@ -41,6 +41,233 @@ def _prepare_batch(
)


def supervised_training_step(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
loss_fn: Union[Callable, torch.nn.Module],
device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False,
prepare_batch: Callable = _prepare_batch,
output_transform: Callable = lambda x, y, y_pred, loss: loss.item(),
) -> Callable:
"""Factory function for supervised training.
Args:
model (torch.nn.Module): the model to train.
optimizer (torch.optim.Optimizer): the optimizer to use.
loss_fn (torch.nn loss function): the loss function to use.
device (str, optional): device type specification (default: None).
Applies to batches after starting the engine. Model *will not* be moved.
Device can be CPU, GPU.
non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
tuple of tensors `(batch_x, batch_y)`.
output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
Returns:
Callable: update function.
.. versionadded:: 0.5.0
"""

def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
model.train()
optimizer.zero_grad()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
return output_transform(x, y, y_pred, loss)

return update


def supervised_training_step_amp(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
loss_fn: Union[Callable, torch.nn.Module],
device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False,
prepare_batch: Callable = _prepare_batch,
output_transform: Callable = lambda x, y, y_pred, loss: loss.item(),
scaler: Optional["torch.cuda.amp.GradScaler"] = None,
) -> Callable:
"""Factory function for supervised training using ``torch.cuda.amp``.
Args:
model (torch.nn.Module): the model to train.
optimizer (torch.optim.Optimizer): the optimizer to use.
loss_fn (torch.nn loss function): the loss function to use.
device (str, optional): device type specification (default: None).
Applies to batches after starting the engine. Model *will not* be moved.
Device can be CPU, GPU.
non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
tuple of tensors `(batch_x, batch_y)`.
output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
scaler (torch.cuda.amp.GradScaler, optional): GradScaler instance for gradient scaling. (default: None)
Returns:
Callable: update function
.. versionadded:: 0.5.0
"""

try:
from torch.cuda.amp import autocast
except ImportError:
raise ImportError("Please install torch>=1.6.0 to use amp_mode='amp'.")

def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
model.train()
optimizer.zero_grad()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
with autocast(enabled=True):
y_pred = model(x)
loss = loss_fn(y_pred, y)
if scaler:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
return output_transform(x, y, y_pred, loss)

return update


def supervised_training_step_apex(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
loss_fn: Union[Callable, torch.nn.Module],
device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False,
prepare_batch: Callable = _prepare_batch,
output_transform: Callable = lambda x, y, y_pred, loss: loss.item(),
) -> Callable:
"""Factory function for supervised training using apex.
Args:
model (torch.nn.Module): the model to train.
optimizer (torch.optim.Optimizer): the optimizer to use.
loss_fn (torch.nn loss function): the loss function to use.
device (str, optional): device type specification (default: None).
Applies to batches after starting the engine. Model *will not* be moved.
Device can be CPU, GPU.
non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
tuple of tensors `(batch_x, batch_y)`.
output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
Returns:
Callable: update function.
.. versionadded:: 0.5.0
"""

try:
from apex import amp as apex_amp
except ModuleNotFoundError:
raise ModuleNotFoundError("Please install apex from https://github.com/nvidia/apex to use amp_mode='apex'.")

def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
model.train()
optimizer.zero_grad()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
loss = loss_fn(y_pred, y)
with apex_amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
return output_transform(x, y, y_pred, loss)

return update


def supervised_training_step_tpu(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
loss_fn: Union[Callable, torch.nn.Module],
device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False,
prepare_batch: Callable = _prepare_batch,
output_transform: Callable = lambda x, y, y_pred, loss: loss.item(),
) -> Callable:
"""Factory function for supervised training using ``torch_xla``.
Args:
model (torch.nn.Module): the model to train.
optimizer (torch.optim.Optimizer): the optimizer to use.
loss_fn (torch.nn loss function): the loss function to use.
device (str, optional): device type specification (default: None).
Applies to batches after starting the engine. Model *will not* be moved.
Device can be CPU, TPU.
non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch (callable, optional): function that receives `batch`, `device`, `non_blocking` and outputs
tuple of tensors `(batch_x, batch_y)`.
output_transform (callable, optional): function that receives 'x', 'y', 'y_pred', 'loss' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
Returns:
Callable: update function.
.. versionadded:: 0.5.0
"""
try:
import torch_xla.core.xla_model as xm
except ModuleNotFoundError:
raise ModuleNotFoundError("torch_xla cannot be imported, please install PyTorch XLA.")

def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
model.train()
optimizer.zero_grad()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
xm.optimizer_step(optimizer, barrier=True)
return output_transform(x, y, y_pred, loss)

return update


def _check_arg(
on_tpu: bool, amp_mode: Optional[str], scaler: Optional[Union[bool, "torch.cuda.amp.GradScaler"]]
) -> Tuple[Optional[str], Optional["torch.cuda.amp.GradScaler"]]:
"""Checking tpu, amp and GradScaler instance combinations."""
if on_tpu and not idist.has_xla_support:
raise RuntimeError("In order to run on TPU, please install PyTorch XLA")

if amp_mode and on_tpu:
raise ValueError("amp_mode cannot be used with xla device. Consider using amp_mode=None or device='cuda'.")

if scaler:
if amp_mode != "amp":
raise ValueError(f"scaler argument is {scaler}, but amp_mode is {amp_mode}. Consider using amp_mode='amp'.")
elif amp_mode == "amp" and isinstance(scaler, bool):
try:
from torch.cuda.amp import GradScaler
except ImportError:
raise ImportError("Please install torch>=1.6.0 to use scaler argument.")
scaler = GradScaler(enabled=True)

if on_tpu:
return "tpu", None
elif scaler and amp_mode == "amp":
return amp_mode, scaler # type: ignore[return-value]
else:
return amp_mode, None


def create_supervised_trainer(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
Expand All @@ -50,12 +277,14 @@ def create_supervised_trainer(
prepare_batch: Callable = _prepare_batch,
output_transform: Callable = lambda x, y, y_pred, loss: loss.item(),
deterministic: bool = False,
amp_mode: Optional[str] = None,
scaler: Union[bool, "torch.cuda.amp.GradScaler"] = False,
) -> Engine:
"""Factory function for creating a trainer for supervised models.
Args:
model (`torch.nn.Module`): the model to train.
optimizer (`torch.optim.Optimizer`): the optimizer to use.
model (torch.nn.Module): the model to train.
optimizer (torch.optim.Optimizer): the optimizer to use.
loss_fn (torch.nn loss function): the loss function to use.
device (str, optional): device type specification (default: None).
Applies to batches after starting the engine. Model *will not* be moved.
Expand All @@ -69,48 +298,71 @@ def create_supervised_trainer(
deterministic (bool, optional): if True, returns deterministic engine of type
:class:`~ignite.engine.deterministic.DeterministicEngine`, otherwise :class:`~ignite.engine.engine.Engine`
(default: False).
amp_mode (str, optional): can be ``amp`` or ``apex``, model and optimizer will be casted to float16 using
`torch.cuda.amp <https://pytorch.org/docs/stable/amp.html>`_ for ``amp`` and
using `apex <https://nvidia.github.io/apex>`_ for ``apex``. (default: None)
scaler (torch.cuda.amp.GradScaler, bool, optional): GradScaler instance for gradient scaling if `torch>=1.6.0`
and ``amp_mode`` is ``amp``. If ``amp_mode`` is ``apex``, this argument will be ignored.
If True, will create default GradScaler. If GradScaler instance is passed, it will be used instead.
(default: False)
Note:
If ``scaler`` is True, GradScaler instance will be created internally and trainer state has attribute named
``scaler`` for that instance and can be used for saving and loading.
Note:
`engine.state.output` for this engine is defined by `output_transform` parameter and is the loss
of the processed batch by default.
.. warning::
The internal use of `device` has changed.
`device` will now *only* be used to move the input data to the correct device.
The `model` should be moved by the user before creating an optimizer.
For more information see:
- `PyTorch Documentation <https://pytorch.org/docs/stable/optim.html#constructing-it>`_
- `PyTorch's Explanation <https://github.com/pytorch/pytorch/issues/7844#issuecomment-503713840>`_
.. warning::
If ``amp_mode='apex'`` , the model(s) and optimizer(s) must be initialized beforehand
since ``amp.initialize`` should be called after you have finished constructing your model(s)
and optimizer(s), but before you send your model through any DistributedDataParallel wrapper.
See more: https://nvidia.github.io/apex/amp.html#module-apex.amp
Returns:
Engine: a trainer engine with supervised update function.
.. versionchanged:: 0.5.0
- Added ``amp_mode`` argument for automatic mixed precision.
- Added ``scaler`` argument for gradient scaling.
"""

device_type = device.type if isinstance(device, torch.device) else device
on_tpu = "xla" in device_type if device_type is not None else False

if on_tpu and not idist.has_xla_support:
raise RuntimeError("In order to run on TPU, please install PyTorch XLA")

def _update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
model.train()
optimizer.zero_grad()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()

if on_tpu:
xm.optimizer_step(optimizer, barrier=True)
else:
optimizer.step()

return output_transform(x, y, y_pred, loss)
mode, _scaler = _check_arg(on_tpu, amp_mode, scaler)

if mode == "amp":
_update = supervised_training_step_amp(
model, optimizer, loss_fn, device, non_blocking, prepare_batch, output_transform, _scaler
)
elif mode == "apex":
_update = supervised_training_step_apex(
model, optimizer, loss_fn, device, non_blocking, prepare_batch, output_transform
)
elif mode == "tpu":
_update = supervised_training_step_tpu(
model, optimizer, loss_fn, device, non_blocking, prepare_batch, output_transform
)
else:
_update = supervised_training_step(
model, optimizer, loss_fn, device, non_blocking, prepare_batch, output_transform
)

trainer = Engine(_update) if not deterministic else DeterministicEngine(_update)
if _scaler and scaler and isinstance(scaler, bool):
trainer.state.scaler = _scaler # type: ignore[attr-defined]

return trainer

Expand Down
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ warn_unreachable = False
warn_unused_configs = True
warn_unused_ignores = True

[mypy-apex.*]
ignore_missing_imports = True

[mypy-clearml.*]
ignore_missing_imports = True

Expand Down
Loading

0 comments on commit f379b18

Please sign in to comment.