Skip to content

Commit

Permalink
Activate mypy in ignite.engine (pytorch#1379)
Browse files Browse the repository at this point in the history
* Activate mypy in ignite.engine

* Fix missing import

* Fix typing issues with nighty build

* Fix PR findings

Co-authored-by: Sylvain Desroziers <sylvain.desroziers@gmail.com>
Co-authored-by: vfdev <vfdev.5@gmail.com>
  • Loading branch information
3 people committed Oct 14, 2020
1 parent 260017d commit a635897
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 94 deletions.
3 changes: 2 additions & 1 deletion ignite/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union

import torch
Expand Down Expand Up @@ -27,7 +28,7 @@

def _prepare_batch(
batch: Sequence[torch.Tensor], device: Optional[Union[str, torch.device]] = None, non_blocking: bool = False
):
) -> Tuple[Union[torch.Tensor, collections.Sequence, collections.Mapping, str, bytes], ...]:
"""Prepare batch for training: pass to a device with options.
"""
Expand Down
54 changes: 32 additions & 22 deletions ignite/engine/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
from collections import OrderedDict
from functools import wraps
from typing import Callable, Generator, Iterator, Optional
from typing import Any, Callable, Generator, Iterator, List, Optional, cast

import torch
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(self, batch_sampler: BatchSampler, start_iteration: Optional[int] =
if not isinstance(batch_sampler, BatchSampler):
raise TypeError("Argument batch_sampler should be torch.utils.data.sampler.BatchSampler")

self.batch_indices = None
self.batch_indices = [] # type: List
self.batch_sampler = batch_sampler
self.start_iteration = start_iteration
self.sampler = self.batch_sampler.sampler
Expand All @@ -84,7 +84,7 @@ def __len__(self) -> int:
return len(self.batch_sampler)


def _get_rng_states():
def _get_rng_states() -> List[Any]:
output = [random.getstate(), torch.get_rng_state()]
try:
import numpy as np
Expand All @@ -96,7 +96,7 @@ def _get_rng_states():
return output


def _set_rng_states(rng_states):
def _set_rng_states(rng_states: List[Any]) -> None:
random.setstate(rng_states[0])
torch.set_rng_state(rng_states[1])
try:
Expand All @@ -107,14 +107,14 @@ def _set_rng_states(rng_states):
pass


def _repr_rng_state(rng_states):
def _repr_rng_state(rng_states: List[Any]) -> str:
from hashlib import md5

out = " ".join([md5(str(list(s)).encode("utf-8")).hexdigest() for s in rng_states])
return out


def keep_random_state(func: Callable):
def keep_random_state(func: Callable) -> Callable:
"""Helper decorator to keep random state of torch, numpy and random intact
while executing a function. For more details on usage, please see :ref:`Dataflow synchronization`.
Expand All @@ -123,7 +123,7 @@ def keep_random_state(func: Callable):
"""

@wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args: Any, **kwargs: Any) -> None:
rng_states = _get_rng_states()
func(*args, **kwargs)
_set_rng_states(rng_states)
Expand Down Expand Up @@ -181,16 +181,20 @@ def state_dict(self) -> OrderedDict:
return state_dict

def _init_run(self) -> None:
seed = torch.randint(0, int(1e9), (1,)).item()
self.state.seed = seed
self.state.seed = int(torch.randint(0, int(1e9), (1,)).item())
if not hasattr(self.state, "rng_states"):
self.state.rng_states = None
self.state.rng_states = None # type: ignore[attr-defined]

if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

def _setup_engine(self) -> None:
if self.state.dataloader is None:
raise RuntimeError(
"Internal error, self.state.dataloader is None. Please, file an issue if you encounter this error."
)

self._dataloader_len = self._get_data_length(self.state.dataloader)

# if input data is torch dataloader we replace batch sampler by a batch sampler
Expand All @@ -199,22 +203,24 @@ def _setup_engine(self) -> None:
# attribute _dataset_kind is introduced since 1.3.0 => before 1.3.0 all datasets are map-like
can_patch_dataloader = True
if hasattr(self.state.dataloader, "_dataset_kind"):
from torch.utils.data.dataloader import _DatasetKind
from torch.utils.data.dataloader import _DatasetKind # type: ignore[attr-defined]

_dataloader_kind = self.state.dataloader._dataset_kind
_dataloader_kind = self.state.dataloader._dataset_kind # type: ignore[attr-defined]
can_patch_dataloader = _dataloader_kind == _DatasetKind.Map
if can_patch_dataloader:
if (self._dataloader_len is not None) and hasattr(self.state.dataloader.sampler, "epoch"):
if self._dataloader_len is not None and hasattr(
self.state.dataloader.sampler, "epoch" # type: ignore[attr-defined]
):
if self._dataloader_len != self.state.epoch_length:
warnings.warn(
"When defined engine's epoch length is different of input dataloader length, "
"distributed sampler indices can not be setup in a reproducible manner"
)

batch_sampler = self.state.dataloader.batch_sampler
batch_sampler = self.state.dataloader.batch_sampler # type: ignore[attr-defined]
if not (batch_sampler is None or isinstance(batch_sampler, ReproducibleBatchSampler)):
self.state.dataloader = update_dataloader(
self.state.dataloader, ReproducibleBatchSampler(batch_sampler)
self.state.dataloader, ReproducibleBatchSampler(batch_sampler) # type: ignore[arg-type]
)

iteration = self.state.iteration
Expand All @@ -228,28 +234,32 @@ def _setup_engine(self) -> None:
# restore rng state if in the middle
in_the_middle = self.state.iteration % self._dataloader_len > 0 if self._dataloader_len is not None else False
if (getattr(self.state, "rng_states", None) is not None) and in_the_middle:
_set_rng_states(self.state.rng_states)
self.state.rng_states = None
_set_rng_states(self.state.rng_states) # type: ignore[attr-defined]
self.state.rng_states = None # type: ignore[attr-defined]

def _from_iteration(self, iteration: int) -> Iterator:
if self.state.dataloader is None:
raise RuntimeError(
"Internal error, self.state.dataloader is None. Please, file an issue if you encounter this error."
)
data = self.state.dataloader
if isinstance(data, DataLoader):
try:
# following is unsafe for IterableDatasets
iteration %= len(data.batch_sampler)
iteration %= len(data.batch_sampler) # type: ignore[attr-defined, arg-type]
# Synchronize dataflow according to state.iteration
self._setup_seed()
if iteration > 0:
# batch sampler is ReproducibleBatchSampler
data.batch_sampler.start_iteration = iteration
data.batch_sampler.start_iteration = iteration # type: ignore[attr-defined, union-attr]
return iter(data)
except TypeError as e:
# Probably we can do nothing with DataLoader built upon IterableDatasets
pass

self.logger.info("Resuming from iteration for provided data will fetch data until required iteration ...")
if hasattr(data, "__len__"):
iteration %= len(data)
iteration %= len(data) # type: ignore[arg-type]
# Synchronize dataflow from the begining
self._setup_seed(iteration=0)
data_iter = iter(data)
Expand All @@ -263,11 +273,11 @@ def _from_iteration(self, iteration: int) -> Iterator:

return data_iter

def _setup_seed(self, _=None, iter_counter=None, iteration=None):
def _setup_seed(self, _: Any = None, iter_counter: Optional[int] = None, iteration: Optional[int] = None) -> None:
if iter_counter is None:
le = self._dataloader_len if self._dataloader_len is not None else 1
else:
le = iter_counter
if iteration is None:
iteration = self.state.iteration
manual_seed(self.state.seed + iteration // le)
manual_seed(self.state.seed + iteration // le) # type: ignore[operator]
72 changes: 44 additions & 28 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import weakref
from collections import OrderedDict, defaultdict
from collections.abc import Mapping
from typing import Any, Callable, Iterable, List, Optional, Union
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union

from torch.utils.data import DataLoader

from ignite._utils import _to_hours_mins_secs
from ignite.base import Serializable
Expand Down Expand Up @@ -120,18 +122,18 @@ def compute_mean_std(engine, batch):
_state_dict_one_of_opt_keys = ("iteration", "epoch")

def __init__(self, process_function: Callable):
self._event_handlers = defaultdict(list)
self._event_handlers = defaultdict(list) # type: Dict[Any, List]
self.logger = logging.getLogger(__name__ + "." + self.__class__.__name__)
self._process_function = process_function
self.last_event_name = None
self.last_event_name = None # type: Optional[Events]
self.should_terminate = False
self.should_terminate_single_epoch = False
self.state = State()
self._state_dict_user_keys = []
self._allowed_events = []
self._state_dict_user_keys = [] # type: List[str]
self._allowed_events = [] # type: List[EventEnum]

self._dataloader_iter = None
self._init_iter = []
self._dataloader_iter = None # type: Optional[Iterator[Any]]
self._init_iter = [] # type: List[int]

self.register_events(*Events)

Expand Down Expand Up @@ -232,16 +234,16 @@ def _handler_wrapper(self, handler: Callable, event_name: Any, event_filter: Cal
# signature of the following wrapper will be inspected during registering to check if engine is necessary
# we have to build a wrapper with relevant signature : solution is functools.wraps
@functools.wraps(handler)
def wrapper(*args, **kwargs) -> Any:
def wrapper(*args: Any, **kwargs: Any) -> Any:
event = self.state.get_event_attrib_value(event_name)
if event_filter(self, event):
return handler(*args, **kwargs)

# setup input handler as parent to make has_event_handler work
wrapper._parent = weakref.ref(handler)
wrapper._parent = weakref.ref(handler) # type: ignore[attr-defined]
return wrapper

def add_event_handler(self, event_name: Any, handler: Callable, *args, **kwargs):
def add_event_handler(self, event_name: Any, handler: Callable, *args: Any, **kwargs: Any) -> RemovableEventHandle:
"""Add an event handler to be executed when the specified event is fired.
Args:
Expand Down Expand Up @@ -312,7 +314,7 @@ def execute_something():
return RemovableEventHandle(event_name, handler, self)

@staticmethod
def _assert_non_filtered_event(event_name: Any):
def _assert_non_filtered_event(event_name: Any) -> None:
if (
isinstance(event_name, CallableEventWithFilter)
and event_name.filter != CallableEventWithFilter.default_event_filter
Expand All @@ -321,7 +323,7 @@ def _assert_non_filtered_event(event_name: Any):
"Argument event_name should not be a filtered event, " "please use event without any event filtering"
)

def has_event_handler(self, handler: Callable, event_name: Optional[Any] = None):
def has_event_handler(self, handler: Callable, event_name: Optional[Any] = None) -> bool:
"""Check if the specified event has the specified handler.
Args:
Expand All @@ -332,7 +334,7 @@ def has_event_handler(self, handler: Callable, event_name: Optional[Any] = None)
if event_name is not None:
if event_name not in self._event_handlers:
return False
events = [event_name]
events = [event_name] # type: Union[List[Any], Dict[Any, List]]
else:
events = self._event_handlers
for e in events:
Expand All @@ -344,10 +346,10 @@ def has_event_handler(self, handler: Callable, event_name: Optional[Any] = None)
@staticmethod
def _compare_handlers(user_handler: Callable, registered_handler: Callable) -> bool:
if hasattr(registered_handler, "_parent"):
registered_handler = registered_handler._parent()
registered_handler = registered_handler._parent() # type: ignore[attr-defined]
return registered_handler == user_handler

def remove_event_handler(self, handler: Callable, event_name: Any):
def remove_event_handler(self, handler: Callable, event_name: Any) -> None:
"""Remove event handler `handler` from registered handlers of the engine
Args:
Expand All @@ -367,7 +369,7 @@ def remove_event_handler(self, handler: Callable, event_name: Any):
raise ValueError("Input handler '{}' is not found among registered event handlers".format(handler))
self._event_handlers[event_name] = new_event_handlers

def on(self, event_name, *args, **kwargs):
def on(self, event_name: Any, *args: Any, **kwargs: Any) -> Callable:
"""Decorator shortcut for add_event_handler.
Args:
Expand Down Expand Up @@ -398,7 +400,7 @@ def decorator(f: Callable) -> Callable:

return decorator

def _fire_event(self, event_name: Any, *event_args, **event_kwargs) -> None:
def _fire_event(self, event_name: Any, *event_args: Any, **event_kwargs: Any) -> None:
"""Execute all the handlers associated with given event.
This method executes all handlers associated with the event
Expand Down Expand Up @@ -460,7 +462,7 @@ def terminate_epoch(self) -> None:
)
self.should_terminate_single_epoch = True

def _handle_exception(self, e: Exception) -> None:
def _handle_exception(self, e: BaseException) -> None:
if Events.EXCEPTION_RAISED in self._event_handlers:
self._fire_event(Events.EXCEPTION_RAISED, e)
else:
Expand Down Expand Up @@ -497,7 +499,7 @@ def save_engine(_):
a dictionary containing engine's state
"""
keys = self._state_dict_all_req_keys + (self._state_dict_one_of_opt_keys[0],)
keys = self._state_dict_all_req_keys + (self._state_dict_one_of_opt_keys[0],) # type: Tuple[str, ...]
keys += tuple(self._state_dict_user_keys)
return OrderedDict([(k, getattr(self.state, k)) for k in keys])

Expand Down Expand Up @@ -555,9 +557,9 @@ def load_state_dict(self, state_dict: Mapping) -> None:

@staticmethod
def _is_done(state: State) -> bool:
return state.iteration == state.epoch_length * state.max_epochs
return state.iteration == state.epoch_length * state.max_epochs # type: ignore[operator]

def set_data(self, data):
def set_data(self, data: Union[Iterable, DataLoader]) -> None:
"""Method to set data. After calling the method the next batch passed to `processing_function` is
from newly provided data. Please, note that epoch length is not modified.
Expand Down Expand Up @@ -705,21 +707,25 @@ def switch_batch(engine):
return self._internal_run()

@staticmethod
def _init_timers(state: State):
def _init_timers(state: State) -> None:
state.times[Events.EPOCH_COMPLETED.name] = 0.0
state.times[Events.COMPLETED.name] = 0.0

def _get_data_length(self, data):
data_length = None
def _get_data_length(self, data: Iterable) -> Optional[int]:
try:
if hasattr(data, "__len__"):
data_length = len(data)
return len(data) # type: ignore[arg-type]
except TypeError:
# _InfiniteConstantSampler can raise a TypeError on DataLoader length of a IterableDataset
pass
return data_length
return None

def _setup_engine(self) -> None:
if self.state.dataloader is None:
raise RuntimeError(
"Internal error, self.state.dataloader is None. Please, file an issue if you encounter this error."
)

iteration = self.state.iteration
self._dataloader_iter = iter(self.state.dataloader)

Expand All @@ -734,7 +740,7 @@ def _internal_run(self) -> State:
try:
start_time = time.time()
self._fire_event(Events.STARTED)
while self.state.epoch < self.state.max_epochs and not self.should_terminate:
while self.state.epoch < self.state.max_epochs and not self.should_terminate: # type: ignore[operator]
self.state.epoch += 1
self._fire_event(Events.EPOCH_STARTED)

Expand Down Expand Up @@ -785,6 +791,15 @@ def _run_once_on_dataset(self) -> float:
iter_counter = self._init_iter.pop() if len(self._init_iter) > 0 else 0
should_exit = False
try:
if self._dataloader_iter is None:
raise RuntimeError(
"Internal error, self._dataloader_iter is None. Please, file an issue if you encounter this error."
)
if self.state.dataloader is None:
raise RuntimeError(
"Internal error, self.state.dataloader is None. Please, file an issue if you encounter this error."
)

while True:
try:
# Avoid Events.GET_BATCH_STARTED triggered twice when data iter is restarted
Expand All @@ -808,7 +823,8 @@ def _run_once_on_dataset(self) -> float:
"Data iterator can not provide data anymore but required total number of "
"iterations to run is not reached. "
"Current iteration: {} vs Total iterations to run : {}".format(
self.state.iteration, self.state.epoch_length * self.state.max_epochs
self.state.iteration,
self.state.epoch_length * self.state.max_epochs, # type: ignore[operator]
)
)
break
Expand Down
Loading

0 comments on commit a635897

Please sign in to comment.