Skip to content

Commit

Permalink
Activate MyPy in ignite.contrib.engines (pytorch#1416)
Browse files Browse the repository at this point in the history
* Activate mypy in ignite.contrib.engines

* Fix review comments

* fix extra event too

* Update to fix strict errors
  • Loading branch information
rzats committed Nov 11, 2020
1 parent ea086e1 commit 96689c5
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 55 deletions.
118 changes: 71 additions & 47 deletions ignite/contrib/engines/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numbers
import warnings
from functools import partial
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Union
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union, cast

import torch
import torch.nn as nn
Expand Down Expand Up @@ -47,7 +47,7 @@ def setup_common_training_handlers(
clear_cuda_cache: bool = True,
save_handler: Optional[Union[Callable, BaseSaveHandler]] = None,
**kwargs: Any
):
) -> None:
"""Helper method to setup trainer with common handlers (it also supports distributed configuration):
- :class:`~ignite.handlers.TerminateOnNan`
Expand Down Expand Up @@ -88,24 +88,24 @@ class to use to store ``to_save``. See :class:`~ignite.handlers.checkpoint.Check
**kwargs: optional keyword args to be passed to construct :class:`~ignite.handlers.checkpoint.Checkpoint`.
"""

_kwargs = dict(
to_save=to_save,
save_every_iters=save_every_iters,
output_path=output_path,
lr_scheduler=lr_scheduler,
with_gpu_stats=with_gpu_stats,
output_names=output_names,
with_pbars=with_pbars,
with_pbar_on_iters=with_pbar_on_iters,
log_every_iters=log_every_iters,
stop_on_nan=stop_on_nan,
clear_cuda_cache=clear_cuda_cache,
save_handler=save_handler,
)
_kwargs.update(kwargs)

if idist.get_world_size() > 1:
_setup_common_distrib_training_handlers(trainer, train_sampler=train_sampler, **_kwargs)
_setup_common_distrib_training_handlers(
trainer,
train_sampler=train_sampler,
to_save=to_save,
save_every_iters=save_every_iters,
output_path=output_path,
lr_scheduler=lr_scheduler,
with_gpu_stats=with_gpu_stats,
output_names=output_names,
with_pbars=with_pbars,
with_pbar_on_iters=with_pbar_on_iters,
log_every_iters=log_every_iters,
stop_on_nan=stop_on_nan,
clear_cuda_cache=clear_cuda_cache,
save_handler=save_handler,
**kwargs,
)
else:
if train_sampler is not None and isinstance(train_sampler, DistributedSampler):
warnings.warn(
Expand All @@ -114,7 +114,22 @@ class to use to store ``to_save``. See :class:`~ignite.handlers.checkpoint.Check
"Train sampler argument will be ignored",
UserWarning,
)
_setup_common_training_handlers(trainer, **_kwargs)
_setup_common_training_handlers(
trainer,
to_save=to_save,
save_every_iters=save_every_iters,
output_path=output_path,
lr_scheduler=lr_scheduler,
with_gpu_stats=with_gpu_stats,
output_names=output_names,
with_pbars=with_pbars,
with_pbar_on_iters=with_pbar_on_iters,
log_every_iters=log_every_iters,
stop_on_nan=stop_on_nan,
clear_cuda_cache=clear_cuda_cache,
save_handler=save_handler,
**kwargs,
)


setup_common_distrib_training_handlers = setup_common_training_handlers
Expand All @@ -135,7 +150,7 @@ def _setup_common_training_handlers(
clear_cuda_cache: bool = True,
save_handler: Optional[Union[Callable, BaseSaveHandler]] = None,
**kwargs: Any
):
) -> None:
if output_path is not None and save_handler is not None:
raise ValueError(
"Arguments output_path and save_handler are mutually exclusive. Please, define only one of them"
Expand All @@ -146,7 +161,9 @@ def _setup_common_training_handlers(

if lr_scheduler is not None:
if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler):
trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step())
trainer.add_event_handler(
Events.ITERATION_COMPLETED, lambda engine: cast(_LRScheduler, lr_scheduler).step()
)
elif isinstance(lr_scheduler, LRScheduler):
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)
else:
Expand All @@ -164,15 +181,19 @@ def _setup_common_training_handlers(
if output_path is not None:
save_handler = DiskSaver(dirname=output_path, require_empty=False)

checkpoint_handler = Checkpoint(to_save, save_handler, filename_prefix="training", **kwargs)
checkpoint_handler = Checkpoint(
to_save, cast(Union[Callable, BaseSaveHandler], save_handler), filename_prefix="training", **kwargs
)
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=save_every_iters), checkpoint_handler)

if with_gpu_stats:
GpuInfo().attach(trainer, name="gpu", event_name=Events.ITERATION_COMPLETED(every=log_every_iters))
GpuInfo().attach(
trainer, name="gpu", event_name=Events.ITERATION_COMPLETED(every=log_every_iters) # type: ignore[arg-type]
)

if output_names is not None:

def output_transform(x, index, name):
def output_transform(x: Any, index: int, name: str) -> Any:
if isinstance(x, Mapping):
return x[name]
elif isinstance(x, Sequence):
Expand Down Expand Up @@ -217,7 +238,7 @@ def _setup_common_distrib_training_handlers(
clear_cuda_cache: bool = True,
save_handler: Optional[Union[Callable, BaseSaveHandler]] = None,
**kwargs: Any
):
) -> None:

_setup_common_training_handlers(
trainer,
Expand All @@ -241,18 +262,18 @@ def _setup_common_distrib_training_handlers(
raise TypeError("Train sampler should be torch DistributedSampler and have `set_epoch` method")

@trainer.on(Events.EPOCH_STARTED)
def distrib_set_epoch(engine):
train_sampler.set_epoch(engine.state.epoch - 1)
def distrib_set_epoch(engine: Engine) -> None:
cast(DistributedSampler, train_sampler).set_epoch(engine.state.epoch - 1)


def empty_cuda_cache(_):
def empty_cuda_cache(_: Engine) -> None:
torch.cuda.empty_cache()
import gc

gc.collect()


def setup_any_logging(logger, logger_module, trainer, optimizers, evaluators, log_every_iters):
def setup_any_logging(logger, logger_module, trainer, optimizers, evaluators, log_every_iters) -> None: # type: ignore
raise DeprecationWarning(
"ignite.contrib.engines.common.setup_any_logging is deprecated since 0.4.0. and will be remove in 0.6.0. "
"Please use instead: setup_tb_logging, setup_visdom_logging or setup_mlflow_logging etc."
Expand All @@ -262,10 +283,10 @@ def setup_any_logging(logger, logger_module, trainer, optimizers, evaluators, lo
def _setup_logging(
logger: BaseLogger,
trainer: Engine,
optimizers: Union[Optimizer, Dict[str, Optimizer]],
evaluators: Union[Engine, Dict[str, Engine]],
optimizers: Optional[Union[Optimizer, Dict[str, Optimizer], Dict[None, Optimizer]]],
evaluators: Optional[Union[Engine, Dict[str, Engine]]],
log_every_iters: int,
):
) -> None:
if optimizers is not None:
if not isinstance(optimizers, (Optimizer, Mapping)):
raise TypeError("Argument optimizers should be either a single optimizer or a dictionary or optimizers")
Expand Down Expand Up @@ -311,7 +332,7 @@ def setup_tb_logging(
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
log_every_iters: int = 100,
**kwargs: Any
):
) -> TensorboardLogger:
"""Method to setup TensorBoard logging on trainer and a list of evaluators. Logged metrics are:
- Training metrics, e.g. running average loss values
Expand Down Expand Up @@ -343,7 +364,7 @@ def setup_visdom_logging(
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
log_every_iters: int = 100,
**kwargs: Any
):
) -> VisdomLogger:
"""Method to setup Visdom logging on trainer and a list of evaluators. Logged metrics are:
- Training metrics, e.g. running average loss values
Expand Down Expand Up @@ -374,7 +395,7 @@ def setup_mlflow_logging(
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
log_every_iters: int = 100,
**kwargs: Any
):
) -> MLflowLogger:
"""Method to setup MLflow logging on trainer and a list of evaluators. Logged metrics are:
- Training metrics, e.g. running average loss values
Expand Down Expand Up @@ -405,7 +426,7 @@ def setup_neptune_logging(
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
log_every_iters: int = 100,
**kwargs: Any
):
) -> NeptuneLogger:
"""Method to setup Neptune logging on trainer and a list of evaluators. Logged metrics are:
- Training metrics, e.g. running average loss values
Expand Down Expand Up @@ -436,7 +457,7 @@ def setup_wandb_logging(
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
log_every_iters: int = 100,
**kwargs: Any
):
) -> WandBLogger:
"""Method to setup WandB logging on trainer and a list of evaluators. Logged metrics are:
- Training metrics, e.g. running average loss values
Expand Down Expand Up @@ -467,7 +488,7 @@ def setup_plx_logging(
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
log_every_iters: int = 100,
**kwargs: Any
):
) -> PolyaxonLogger:
"""Method to setup Polyaxon logging on trainer and a list of evaluators. Logged metrics are:
- Training metrics, e.g. running average loss values
Expand Down Expand Up @@ -498,7 +519,7 @@ def setup_trains_logging(
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
log_every_iters: int = 100,
**kwargs: Any
):
) -> TrainsLogger:
"""Method to setup Trains logging on trainer and a list of evaluators. Logged metrics are:
- Training metrics, e.g. running average loss values
Expand All @@ -523,8 +544,8 @@ def setup_trains_logging(
return logger


def get_default_score_fn(metric_name: str):
def wrapper(engine: Engine):
def get_default_score_fn(metric_name: str) -> Any:
def wrapper(engine: Engine) -> Any:
score = engine.state.metrics[metric_name]
return score

Expand All @@ -540,7 +561,7 @@ def gen_save_best_models_by_val_score(
trainer: Optional[Engine] = None,
tag: str = "val",
**kwargs: Any
):
) -> Checkpoint:
"""Method adds a handler to ``evaluator`` to save ``n_saved`` of best models based on the metric
(named by ``metric_name``) provided by ``evaluator`` (i.e. ``evaluator.state.metrics[metric_name]``).
Models with highest metric value will be retained. The logic of how to store objects is delegated to
Expand Down Expand Up @@ -570,9 +591,10 @@ def gen_save_best_models_by_val_score(
if trainer is not None:
global_step_transform = global_step_from_engine(trainer)

to_save = models
if isinstance(models, nn.Module):
to_save = {"model": models}
to_save = {"model": models} # type: Dict[str, nn.Module]
else:
to_save = models

best_model_handler = Checkpoint(
to_save,
Expand All @@ -598,7 +620,7 @@ def save_best_model_by_val_score(
trainer: Optional[Engine] = None,
tag: str = "val",
**kwargs: Any
):
) -> Checkpoint:
"""Method adds a handler to ``evaluator`` to save on a disk ``n_saved`` of best models based on the metric
(named by ``metric_name``) provided by ``evaluator`` (i.e. ``evaluator.state.metrics[metric_name]``).
Models with highest metric value will be retained.
Expand Down Expand Up @@ -629,7 +651,9 @@ def save_best_model_by_val_score(
)


def add_early_stopping_by_val_score(patience: int, evaluator: Engine, trainer: Engine, metric_name: str):
def add_early_stopping_by_val_score(
patience: int, evaluator: Engine, trainer: Engine, metric_name: str
) -> EarlyStopping:
"""Method setups early stopping handler based on the score (named by `metric_name`) provided by `evaluator`.
Metric value should increase in order to keep training and not early stop.
Expand Down
9 changes: 6 additions & 3 deletions ignite/contrib/engines/tbptt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# coding: utf-8
import collections.abc as collections
from typing import Callable, Mapping, Optional, Sequence, Union

import torch
Expand All @@ -20,7 +21,9 @@ class Tbptt_Events(EventEnum):
TIME_ITERATION_COMPLETED = "time_iteration_completed"


def _detach_hidden(hidden: Union[torch.Tensor, Sequence, Mapping, str, bytes]):
def _detach_hidden(
hidden: Union[torch.Tensor, Sequence, Mapping, str, bytes]
) -> Union[torch.Tensor, collections.Sequence, collections.Mapping, str, bytes]:
"""Cut backpropagation graph.
Auxillary function to cut the backpropagation graph by detaching the hidden
Expand All @@ -38,7 +41,7 @@ def create_supervised_tbptt_trainer(
device: Optional[str] = None,
non_blocking: bool = False,
prepare_batch: Callable = _prepare_batch,
):
) -> Engine:
"""Create a trainer for truncated backprop through time supervised models.
Training recurrent model on long sequences is computationally intensive as
Expand Down Expand Up @@ -83,7 +86,7 @@ def create_supervised_tbptt_trainer(
"""

def _update(engine: Engine, batch: Sequence[torch.Tensor]):
def _update(engine: Engine, batch: Sequence[torch.Tensor]) -> float:
loss_list = []
hidden = None

Expand Down
2 changes: 1 addition & 1 deletion ignite/contrib/handlers/tqdm_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def attach(
engine: Engine,
metric_names: Optional[str] = None,
output_transform: Optional[Callable] = None,
event_name: Events = Events.ITERATION_COMPLETED,
event_name: Union[CallableEventWithFilter, Events] = Events.ITERATION_COMPLETED,
closing_event_name: Events = Events.EPOCH_COMPLETED,
):
"""
Expand Down
4 changes: 0 additions & 4 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,6 @@ warn_unused_ignores = True

ignore_errors = True

[mypy-ignite.contrib.engines.*]

ignore_errors = True

[mypy-horovod.*]
ignore_missing_imports = True

Expand Down

0 comments on commit 96689c5

Please sign in to comment.