Skip to content

Commit

Permalink
Improve typing for ignite.handlers module (1343) (pytorch#1349)
Browse files Browse the repository at this point in the history
* Improve typing for ignite.handlers module (1343)

* autopep8 fix

* Fix typing for py35, remove handlers block from mypy.ini

* Add exception to ModelCheckpoint when saving last checkpoint

* Add test for ModelCheckpoint with redefined save_handler case

* autopep8 fix

Co-authored-by: AutoPEP8 <>
Co-authored-by: Sylvain Desroziers <sylvain.desroziers@gmail.com>
Co-authored-by: vfdev <vfdev.5@gmail.com>
Co-authored-by: trsvchn <trsvchn@users.noreply.github.com>
  • Loading branch information
4 people committed Oct 6, 2020
1 parent bf2faa4 commit f9e236e
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 14 deletions.
28 changes: 18 additions & 10 deletions ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import warnings
from abc import ABCMeta, abstractmethod
from collections import OrderedDict, namedtuple
from tempfile import _TemporaryFileWrapper # type: ignore
from typing import Callable, Mapping, Optional, Union

import torch
Expand Down Expand Up @@ -235,7 +236,7 @@ def score_function(engine):

def __init__(
self,
to_save: Mapping,
to_save: Optional[Mapping],
save_handler: Union[Callable, BaseSaveHandler],
filename_prefix: str = "",
score_function: Optional[Callable] = None,
Expand Down Expand Up @@ -287,7 +288,7 @@ def __init__(
self.ext = "pt"
self.global_step_transform = global_step_transform
self.filename_pattern = filename_pattern
self._saved = []
self._saved = [] # type: list
self.include_self = include_self

@property
Expand Down Expand Up @@ -378,10 +379,11 @@ def __call__(self, engine: Engine) -> None:

def _setup_checkpoint(self) -> dict:
checkpoint = {}
for k, obj in self.to_save.items():
if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
obj = obj.module
checkpoint[k] = obj.state_dict()
if self.to_save is not None:
for k, obj in self.to_save.items():
if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
obj = obj.module
checkpoint[k] = obj.state_dict()
return checkpoint

@staticmethod
Expand Down Expand Up @@ -572,7 +574,7 @@ def _save_native(self, checkpoint: Mapping, path: str):
self._save_func(checkpoint, path, torch.save)

def _save_xla(self, checkpoint: Mapping, path: str):
import torch_xla.core.xla_model as xm
import torch_xla.core.xla_model as xm # type: ignore

# all tpu procs should enter here as internally performs sync across device
self._save_func(checkpoint, path, xm.save, rank=idist.get_rank())
Expand All @@ -582,8 +584,8 @@ def _save_func(self, checkpoint: Mapping, path: str, func: Callable, rank: int =
func(checkpoint, path, **self.kwargs)
else:
tmp_file = None
tmp_name = None
tmp = None
tmp_name = ""
tmp = None # type: _TemporaryFileWrapper
if rank == 0:
tmp = tempfile.NamedTemporaryFile(delete=False, dir=self.dirname)
tmp_file = tmp.file
Expand Down Expand Up @@ -728,9 +730,15 @@ def __init__(
def last_checkpoint(self) -> Union[str, None]:
if len(self._saved) < 1:
return None

if not isinstance(self.save_handler, DiskSaver):
raise RuntimeError(
"Unable to save checkpoint, save_handler should be DiskSaver, got {}.".format(type(self.save_handler))
)

return os.path.join(self.save_handler.dirname, self._saved[-1].filename)

def __call__(self, engine: Engine, to_save: Mapping) -> None:
def __call__(self, engine: Engine, to_save: Mapping) -> None: # type: ignore

if len(to_save) == 0:
raise RuntimeError("No objects to checkpoint found.")
Expand Down
4 changes: 0 additions & 4 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@ show_error_codes = True

ignore_errors = True

[mypy-ignite.handlers.*]

ignore_errors = True

[mypy-ignite.engine.*]

ignore_errors = True
Expand Down
14 changes: 14 additions & 0 deletions tests/ignite/handlers/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,20 @@ def _test(ext, require_empty):
_test(".pt", require_empty=False)


def test_model_checkpoint_invalid_save_handler(dirname):
h = ModelCheckpoint(dirname, _PREFIX)
to_save = {"model": DummyModel()}
# Redefine save_handler
h.save_handler = lambda x, y: None
h(Engine(lambda x, y: None), to_save)

with pytest.raises(
RuntimeError,
match=r"Unable to save checkpoint, save_handler should be DiskSaver, got {}.".format(type(h.save_handler)),
):
h.last_checkpoint


def test_disk_saver_atomic(dirname):

model = DummyModel()
Expand Down

0 comments on commit f9e236e

Please sign in to comment.