Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[1] [contrib/engines] setup typing in contrib part of the library #1351

Merged
merged 5 commits into from
Oct 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 128 additions & 64 deletions ignite/contrib/engines/common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import numbers
import warnings
from collections.abc import Mapping, Sequence
from functools import partial
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Union

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer
from torch.utils.data.distributed import DistributedSampler

import ignite.distributed as idist
Expand All @@ -20,29 +22,32 @@
WandBLogger,
global_step_from_engine,
)
from ignite.contrib.handlers.base_logger import BaseLogger
from ignite.contrib.handlers.param_scheduler import ParamScheduler
from ignite.contrib.metrics import GpuInfo
from ignite.engine import Engine, Events
from ignite.handlers import Checkpoint, DiskSaver, EarlyStopping, TerminateOnNan
from ignite.handlers.checkpoint import BaseSaveHandler
from ignite.metrics import RunningAverage


def setup_common_training_handlers(
trainer,
train_sampler=None,
to_save=None,
save_every_iters=1000,
output_path=None,
lr_scheduler=None,
with_gpu_stats=False,
output_names=None,
with_pbars=True,
with_pbar_on_iters=True,
log_every_iters=100,
device=None,
stop_on_nan=True,
clear_cuda_cache=True,
save_handler=None,
**kwargs
trainer: Engine,
train_sampler: Optional[DistributedSampler] = None,
to_save: Optional[Mapping] = None,
save_every_iters: int = 1000,
output_path: Optional[str] = None,
lr_scheduler: Optional[Union[ParamScheduler, _LRScheduler]] = None,
with_gpu_stats: bool = False,
output_names: Optional[Iterable[str]] = None,
with_pbars: bool = True,
with_pbar_on_iters: bool = True,
log_every_iters: int = 100,
device: Optional[Union[str, torch.device]] = None,
stop_on_nan: bool = True,
clear_cuda_cache: bool = True,
save_handler: Optional[Union[Callable, BaseSaveHandler]] = None,
**kwargs: Any
):
"""Helper method to setup trainer with common handlers (it also supports distributed configuration):
- :class:`~ignite.handlers.TerminateOnNan`
Expand Down Expand Up @@ -119,20 +124,20 @@ class to use to store ``to_save``. See :class:`~ignite.handlers.checkpoint.Check


def _setup_common_training_handlers(
trainer,
to_save=None,
save_every_iters=1000,
output_path=None,
lr_scheduler=None,
with_gpu_stats=False,
output_names=None,
with_pbars=True,
with_pbar_on_iters=True,
log_every_iters=100,
stop_on_nan=True,
clear_cuda_cache=True,
save_handler=None,
**kwargs
trainer: Engine,
to_save: Optional[Mapping] = None,
save_every_iters: int = 1000,
output_path: Optional[str] = None,
lr_scheduler: Optional[Union[ParamScheduler, _LRScheduler]] = None,
with_gpu_stats: bool = False,
output_names: Optional[Iterable[str]] = None,
with_pbars: bool = True,
with_pbar_on_iters: bool = True,
log_every_iters: int = 100,
stop_on_nan: bool = True,
clear_cuda_cache: bool = True,
save_handler: Optional[Union[Callable, BaseSaveHandler]] = None,
**kwargs: Any
):
if output_path is not None and save_handler is not None:
raise ValueError(
Expand Down Expand Up @@ -200,21 +205,21 @@ def output_transform(x, index, name):


def _setup_common_distrib_training_handlers(
trainer,
train_sampler=None,
to_save=None,
save_every_iters=1000,
output_path=None,
lr_scheduler=None,
with_gpu_stats=False,
output_names=None,
with_pbars=True,
with_pbar_on_iters=True,
log_every_iters=100,
stop_on_nan=True,
clear_cuda_cache=True,
save_handler=None,
**kwargs
trainer: Engine,
train_sampler: Optional[DistributedSampler] = None,
to_save: Optional[Mapping] = None,
save_every_iters: int = 1000,
output_path: Optional[str] = None,
lr_scheduler: Optional[Union[ParamScheduler, _LRScheduler]] = None,
with_gpu_stats: bool = False,
output_names: Optional[Iterable[str]] = None,
with_pbars: bool = True,
with_pbar_on_iters: bool = True,
log_every_iters: int = 100,
stop_on_nan: bool = True,
clear_cuda_cache: bool = True,
save_handler: Optional[Union[Callable, BaseSaveHandler]] = None,
**kwargs: Any
):

_setup_common_training_handlers(
Expand Down Expand Up @@ -257,10 +262,14 @@ def setup_any_logging(logger, logger_module, trainer, optimizers, evaluators, lo
)


def _setup_logging(logger, trainer, optimizers, evaluators, log_every_iters):
def _setup_logging(
logger: BaseLogger,
trainer: Engine,
optimizers: Union[Optimizer, Dict[str, Optimizer]],
evaluators: Union[Engine, Dict[str, Engine]],
log_every_iters: int,
):
if optimizers is not None:
from torch.optim.optimizer import Optimizer

if not isinstance(optimizers, (Optimizer, Mapping)):
raise TypeError("Argument optimizers should be either a single optimizer or a dictionary or optimizers")

Expand Down Expand Up @@ -298,7 +307,14 @@ def _setup_logging(logger, trainer, optimizers, evaluators, log_every_iters):
)


def setup_tb_logging(output_path, trainer, optimizers=None, evaluators=None, log_every_iters=100, **kwargs):
def setup_tb_logging(
output_path: str,
trainer: Engine,
optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
log_every_iters: int = 100,
**kwargs: Any
):
"""Method to setup TensorBoard logging on trainer and a list of evaluators. Logged metrics are:
- Training metrics, e.g. running average loss values
- Learning rate(s)
Expand All @@ -323,7 +339,13 @@ def setup_tb_logging(output_path, trainer, optimizers=None, evaluators=None, log
return logger


def setup_visdom_logging(trainer, optimizers=None, evaluators=None, log_every_iters=100, **kwargs):
def setup_visdom_logging(
trainer: Engine,
optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
log_every_iters: int = 100,
**kwargs: Any
):
"""Method to setup Visdom logging on trainer and a list of evaluators. Logged metrics are:
- Training metrics, e.g. running average loss values
- Learning rate(s)
Expand All @@ -347,7 +369,13 @@ def setup_visdom_logging(trainer, optimizers=None, evaluators=None, log_every_it
return logger


def setup_mlflow_logging(trainer, optimizers=None, evaluators=None, log_every_iters=100, **kwargs):
def setup_mlflow_logging(
trainer: Engine,
optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
log_every_iters: int = 100,
**kwargs: Any
):
"""Method to setup MLflow logging on trainer and a list of evaluators. Logged metrics are:
- Training metrics, e.g. running average loss values
- Learning rate(s)
Expand All @@ -371,7 +399,13 @@ def setup_mlflow_logging(trainer, optimizers=None, evaluators=None, log_every_it
return logger


def setup_neptune_logging(trainer, optimizers=None, evaluators=None, log_every_iters=100, **kwargs):
def setup_neptune_logging(
trainer: Engine,
optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
log_every_iters: int = 100,
**kwargs: Any
):
"""Method to setup Neptune logging on trainer and a list of evaluators. Logged metrics are:
- Training metrics, e.g. running average loss values
- Learning rate(s)
Expand All @@ -395,7 +429,13 @@ def setup_neptune_logging(trainer, optimizers=None, evaluators=None, log_every_i
return logger


def setup_wandb_logging(trainer, optimizers=None, evaluators=None, log_every_iters=100, **kwargs):
def setup_wandb_logging(
trainer: Engine,
optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
log_every_iters: int = 100,
**kwargs: Any
):
"""Method to setup WandB logging on trainer and a list of evaluators. Logged metrics are:
- Training metrics, e.g. running average loss values
- Learning rate(s)
Expand All @@ -419,7 +459,13 @@ def setup_wandb_logging(trainer, optimizers=None, evaluators=None, log_every_ite
return logger


def setup_plx_logging(trainer, optimizers=None, evaluators=None, log_every_iters=100, **kwargs):
def setup_plx_logging(
trainer: Engine,
optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
log_every_iters: int = 100,
**kwargs: Any
):
"""Method to setup Polyaxon logging on trainer and a list of evaluators. Logged metrics are:
- Training metrics, e.g. running average loss values
- Learning rate(s)
Expand All @@ -443,7 +489,13 @@ def setup_plx_logging(trainer, optimizers=None, evaluators=None, log_every_iters
return logger


def setup_trains_logging(trainer, optimizers=None, evaluators=None, log_every_iters=100, **kwargs):
def setup_trains_logging(
trainer: Engine,
optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
log_every_iters: int = 100,
**kwargs: Any
):
"""Method to setup Trains logging on trainer and a list of evaluators. Logged metrics are:
- Training metrics, e.g. running average loss values
- Learning rate(s)
Expand All @@ -467,16 +519,23 @@ def setup_trains_logging(trainer, optimizers=None, evaluators=None, log_every_it
return logger


def get_default_score_fn(metric_name):
def wrapper(engine):
def get_default_score_fn(metric_name: str):
def wrapper(engine: Engine):
score = engine.state.metrics[metric_name]
return score

return wrapper


def gen_save_best_models_by_val_score(
save_handler, evaluator, models, metric_name, n_saved=3, trainer=None, tag="val", **kwargs
save_handler: Union[Callable, BaseSaveHandler],
evaluator: Engine,
models: Union[torch.nn.Module, Dict[str, torch.nn.Module]],
metric_name: str,
n_saved: int = 3,
trainer: Optional[Engine] = None,
tag: str = "val",
**kwargs: Any
):
"""Method adds a handler to ``evaluator`` to save ``n_saved`` of best models based on the metric
(named by ``metric_name``) provided by ``evaluator`` (i.e. ``evaluator.state.metrics[metric_name]``).
Expand Down Expand Up @@ -521,15 +580,20 @@ def gen_save_best_models_by_val_score(
score_function=get_default_score_fn(metric_name),
**kwargs,
)
evaluator.add_event_handler(
Events.COMPLETED, best_model_handler,
)
evaluator.add_event_handler(Events.COMPLETED, best_model_handler)

return best_model_handler


def save_best_model_by_val_score(
output_path, evaluator, model, metric_name, n_saved=3, trainer=None, tag="val", **kwargs
output_path: str,
evaluator: Engine,
model: torch.nn.Module,
metric_name: str,
n_saved: int = 3,
trainer: Optional[Engine] = None,
tag: str = "val",
**kwargs: Any
):
"""Method adds a handler to ``evaluator`` to save on a disk ``n_saved`` of best models based on the metric
(named by ``metric_name``) provided by ``evaluator`` (i.e. ``evaluator.state.metrics[metric_name]``).
Expand Down Expand Up @@ -561,7 +625,7 @@ def save_best_model_by_val_score(
)


def add_early_stopping_by_val_score(patience, evaluator, trainer, metric_name):
def add_early_stopping_by_val_score(patience: int, evaluator: Engine, trainer: Engine, metric_name: str):
"""Method setups early stopping handler based on the score (named by `metric_name`) provided by `evaluator`.
Metric value should increase in order to keep training and not early stop.

Expand Down
16 changes: 13 additions & 3 deletions ignite/contrib/engines/tbptt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# coding: utf-8
from typing import Callable, Mapping, Optional, Sequence, Union

import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer

from ignite.engine import Engine, EventEnum, _prepare_batch
from ignite.utils import apply_to_tensor
Expand All @@ -17,7 +20,7 @@ class Tbptt_Events(EventEnum):
TIME_ITERATION_COMPLETED = "time_iteration_completed"


def _detach_hidden(hidden):
def _detach_hidden(hidden: Union[torch.Tensor, Sequence, Mapping, str, bytes]):
"""Cut backpropagation graph.

Auxillary function to cut the backpropagation graph by detaching the hidden
Expand All @@ -27,7 +30,14 @@ def _detach_hidden(hidden):


def create_supervised_tbptt_trainer(
model, optimizer, loss_fn, tbtt_step, dim=0, device=None, non_blocking=False, prepare_batch=_prepare_batch
model: nn.Module,
optimizer: Optimizer,
loss_fn: nn.Module,
tbtt_step: int,
dim: int = 0,
device: Optional[str] = None,
non_blocking: bool = False,
prepare_batch: Callable = _prepare_batch,
):
"""Create a trainer for truncated backprop through time supervised models.

Expand Down Expand Up @@ -73,7 +83,7 @@ def create_supervised_tbptt_trainer(

"""

def _update(engine, batch):
def _update(engine: Engine, batch: Sequence[torch.Tensor]):
loss_list = []
hidden = None

Expand Down