Skip to content

Commit

Permalink
Revert "Remove all deprecated args, kwargs for v0.5.0 (#1396) (#1397)"
Browse files Browse the repository at this point in the history
This reverts commit a85da62.
  • Loading branch information
vfdev-5 committed May 31, 2021
1 parent a65155b commit 1c07eaf
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 20 deletions.
4 changes: 4 additions & 0 deletions ignite/contrib/engines/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def setup_common_training_handlers(
with_pbars: bool = True,
with_pbar_on_iters: bool = True,
log_every_iters: int = 100,
device: Optional[Union[str, torch.device]] = None,
stop_on_nan: bool = True,
clear_cuda_cache: bool = True,
save_handler: Optional[Union[Callable, BaseSaveHandler]] = None,
Expand Down Expand Up @@ -87,7 +88,10 @@ def setup_common_training_handlers(
class to use to store ``to_save``. See :class:`~ignite.handlers.checkpoint.Checkpoint` for more details.
Argument is mutually exclusive with ``output_path``.
kwargs: optional keyword args to be passed to construct :class:`~ignite.handlers.checkpoint.Checkpoint`.
device: deprecated argument, it will be removed in v0.5.0.
"""
if device is not None:
warnings.warn("Argument device is unused and deprecated. It will be removed in v0.5.0")

if idist.get_world_size() > 1:
_setup_common_distrib_training_handlers(
Expand Down
16 changes: 12 additions & 4 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def state_dict_user_keys(self) -> List:
return self._state_dict_user_keys

def state_dict(self) -> OrderedDict:
"""Returns a dictionary containing engine's state: "epoch_length", "max_epochs" and "iteration" and
"""Returns a dictionary containing engine's state: "seed", "epoch_length", "max_epochs" and "iteration" and
other state values defined by `engine.state_dict_user_keys`
.. code-block:: python
Expand Down Expand Up @@ -507,8 +507,8 @@ def save_engine(_):
def load_state_dict(self, state_dict: Mapping) -> None:
"""Setups engine from `state_dict`.
State dictionary should contain keys: `iteration` or `epoch`, `max_epochs` and `epoch_length`.
If `engine.state_dict_user_keys` contains keys, they should be also present in the state dictionary.
State dictionary should contain keys: `iteration` or `epoch` and `max_epochs`, `epoch_length` and
`seed`. If `engine.state_dict_user_keys` contains keys, they should be also present in the state dictionary.
Iteration and epoch values are 0-based: the first iteration or epoch is zero.
This method does not remove any custom attributes added by user.
Expand Down Expand Up @@ -607,12 +607,13 @@ def run(
max_epochs: Optional[int] = None,
max_iters: Optional[int] = None,
epoch_length: Optional[int] = None,
seed: Optional[int] = None,
) -> State:
"""Runs the `process_function` over the passed data.
Engine has a state and the following logic is applied in this function:
- At the first call, new state is defined by `max_epochs`, `max_iters`, `epoch_length`, if provided.
- At the first call, new state is defined by `max_epochs`, `max_iters`, `epoch_length`, `seed`, if provided.
A timer for total and per-epoch time is initialized when Events.STARTED is handled.
- If state is already defined such that there are iterations to run until `max_epochs` and no input arguments
provided, state is kept and used in the function.
Expand All @@ -631,6 +632,7 @@ def run(
This argument should not change if run is resuming from a state.
max_iters: Number of iterations to run for.
`max_iters` and `max_epochs` are mutually exclusive; only one of the two arguments should be provided.
seed: Deprecated argument. Please, use `torch.manual_seed` or :meth:`~ignite.utils.manual_seed`.
Returns:
State: output state.
Expand Down Expand Up @@ -659,6 +661,12 @@ def switch_batch(engine):
trainer.run(train_loader, max_epochs=2)
"""
if seed is not None:
warnings.warn(
"Argument seed is deprecated. It will be removed in 0.5.0. "
"Please, use torch.manual_seed or ignite.utils.manual_seed"
)

if not isinstance(data, Iterable):
raise TypeError("Argument data should be iterable")

Expand Down
12 changes: 12 additions & 0 deletions ignite/engine/events.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numbers
import warnings
import weakref
from enum import Enum
from types import DynamicClassAttribute
Expand Down Expand Up @@ -138,6 +139,17 @@ def __or__(self, other: Any) -> "EventsList":
return EventsList() | self | other


class CallableEvents(CallableEventWithFilter):
# For backward compatibility
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(CallableEvents, self).__init__(*args, **kwargs)
warnings.warn(
"Class ignite.engine.events.CallableEvents is deprecated. It will be removed in 0.5.0. "
"Please, use ignite.engine.EventEnum instead",
DeprecationWarning,
)


class EventEnum(CallableEventWithFilter, Enum): # type: ignore[misc]
"""Base class for all :class:`~ignite.engine.events.Events`. User defined custom events should also inherit
this class. For example, Custom events based on the loss calculation and backward pass can be created as follows:
Expand Down
31 changes: 31 additions & 0 deletions ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class Checkpoint(Serializable):
Input of the function is ``(engine, event_name)``. Output of function should be an integer.
Default is None, global_step based on attached engine. If provided, uses function output as global_step.
To setup global step from another engine, please use :meth:`~ignite.handlers.global_step_from_engine`.
archived: Deprecated argument as models saved by ``torch.save`` are already compressed.
filename_pattern: If ``filename_pattern`` is provided, this pattern will be used to render
checkpoint filenames. If the pattern is not defined, the default pattern would be used. See Note for
details.
Expand Down Expand Up @@ -263,6 +264,7 @@ def __init__(
score_name: Optional[str] = None,
n_saved: Optional[int] = 1,
global_step_transform: Optional[Callable] = None,
archived: bool = False,
filename_pattern: Optional[str] = None,
include_self: bool = False,
greater_or_equal: bool = False,
Expand Down Expand Up @@ -290,6 +292,8 @@ def __init__(

if global_step_transform is not None and not callable(global_step_transform):
raise TypeError(f"global_step_transform should be a function, got {type(global_step_transform)} instead.")
if archived:
warnings.warn("Argument archived is deprecated and will be removed in 0.5.0")

self.to_save = to_save
self.filename_prefix = filename_prefix
Expand Down Expand Up @@ -732,6 +736,11 @@ class ModelCheckpoint(Checkpoint):
Behaviour of this class has been changed since v0.3.0.
Argument ``save_as_state_dict`` is deprecated and should not be used. It is considered as True.
Argument ``save_interval`` is deprecated and should not be used. Please, use events filtering instead, e.g.
:attr:`~ignite.engine.events.Events.ITERATION_STARTED(every=1000)`
There is no more internal counter that has been used to indicate the number of save actions. User could
see its value `step_number` in the filename, e.g. `{filename_prefix}_{name}_{step_number}.pt`. Actually,
`step_number` is replaced by current engine's epoch if `score_function` is specified and current iteration
Expand Down Expand Up @@ -762,6 +771,7 @@ class ModelCheckpoint(Checkpoint):
To setup global step from another engine, please use :meth:`~ignite.handlers.global_step_from_engine`.
include_self: Whether to include the `state_dict` of this object in the checkpoint. If `True`, then
there must not be another object in ``to_save`` with key ``checkpointer``.
archived: Deprecated argument as models saved by `torch.save` are already compressed.
kwargs: Accepted keyword arguments for `torch.save` or `xm.save` in `DiskSaver`.
.. versionchanged:: 0.4.2
Expand All @@ -787,17 +797,37 @@ def __init__(
self,
dirname: str,
filename_prefix: str,
save_interval: Optional[Callable] = None,
score_function: Optional[Callable] = None,
score_name: Optional[str] = None,
n_saved: Union[int, None] = 1,
atomic: bool = True,
require_empty: bool = True,
create_dir: bool = True,
save_as_state_dict: bool = True,
global_step_transform: Optional[Callable] = None,
archived: bool = False,
include_self: bool = False,
**kwargs: Any,
):

if not save_as_state_dict:
raise ValueError(
"Argument save_as_state_dict is deprecated and should be True."
"This argument will be removed in 0.5.0."
)
if save_interval is not None:
msg = (
"Argument save_interval is deprecated and should be None. This argument will be removed in 0.5.0."
"Please, use events filtering instead, e.g. Events.ITERATION_STARTED(every=1000)"
)
if save_interval == 1:
# Do not break for old version who used `save_interval=1`
warnings.warn(msg)
else:
# No choice
raise ValueError(msg)

disk_saver = DiskSaver(dirname, atomic=atomic, create_dir=create_dir, require_empty=require_empty, **kwargs)

super(ModelCheckpoint, self).__init__(
Expand All @@ -808,6 +838,7 @@ def __init__(
score_name=score_name,
n_saved=n_saved,
global_step_transform=global_step_transform,
archived=archived,
include_self=include_self,
)

Expand Down
17 changes: 2 additions & 15 deletions tests/ignite/contrib/engines/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,21 +148,8 @@ def test_asserts_setup_common_training_handlers():
train_sampler = MagicMock(spec=DistributedSampler)
setup_common_training_handlers(trainer, train_sampler=train_sampler)

with pytest.raises(RuntimeError, match=r"This contrib module requires available GPU"):
setup_common_training_handlers(trainer, with_gpu_stats=True)

with pytest.raises(TypeError, match=r"Unhandled type of update_function's output."):
trainer = Engine(lambda e, b: None)
setup_common_training_handlers(
trainer,
output_names=["loss"],
with_pbar_on_iters=False,
with_pbars=False,
with_gpu_stats=False,
stop_on_nan=False,
clear_cuda_cache=False,
)
trainer.run([1])
with pytest.warns(UserWarning, match=r"Argument device is unused and deprecated"):
setup_common_training_handlers(trainer, device="cpu")


def test_no_warning_with_train_sampler(recwarn):
Expand Down
14 changes: 13 additions & 1 deletion tests/ignite/engine/test_custom_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,19 @@

import ignite.distributed as idist
from ignite.engine import Engine, Events
from ignite.engine.events import CallableEventWithFilter, EventEnum, EventsList
from ignite.engine.events import CallableEvents, CallableEventWithFilter, EventEnum, EventsList


def test_deprecated_callable_events_class():
engine = Engine(lambda engine, batch: 0)

with pytest.warns(DeprecationWarning, match=r"Class ignite\.engine\.events\.CallableEvents is deprecated"):

class CustomEvents(CallableEvents, Enum):
TEST_EVENT = "test_event"

with pytest.raises(TypeError, match=r"Value at \d of event_names should be a str or EventEnum"):
engine.register_events(*CustomEvents)


def test_custom_events():
Expand Down
3 changes: 3 additions & 0 deletions tests/ignite/engine/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,9 @@ def test_run_asserts():
with pytest.raises(ValueError, match=r"Input data has zero size. Please provide non-empty data"):
engine.run([])

with pytest.warns(UserWarning, match="Argument seed is deprecated"):
engine.run([0, 1, 2, 3, 4], seed=1234)


def test_state_get_event_attrib_value():
state = State()
Expand Down
12 changes: 12 additions & 0 deletions tests/ignite/handlers/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def test_checkpoint_wrong_input():
with pytest.raises(TypeError, match=r"global_step_transform should be a function."):
Checkpoint(to_save, lambda x: x, score_function=lambda e: 123, score_name="acc", global_step_transform=123)

with pytest.warns(UserWarning, match=r"Argument archived is deprecated"):
Checkpoint(to_save, lambda x: x, score_function=lambda e: 123, score_name="acc", archived=True)

with pytest.raises(ValueError, match=r"Cannot have key 'checkpointer' if `include_self` is True"):
Checkpoint({"checkpointer": model}, lambda x: x, include_self=True)

Expand Down Expand Up @@ -502,15 +505,24 @@ def test_model_checkpoint_args_validation(dirname):
with pytest.raises(ValueError, match=r"with extension '.pt' are already present "):
ModelCheckpoint(nonempty, _PREFIX)

with pytest.raises(ValueError, match=r"Argument save_interval is deprecated and should be None"):
ModelCheckpoint(existing, _PREFIX, save_interval=42)

with pytest.raises(ValueError, match=r"Directory path '\S+' is not found"):
ModelCheckpoint(os.path.join(dirname, "non_existing_dir"), _PREFIX, create_dir=False)

with pytest.raises(ValueError, match=r"Argument save_as_state_dict is deprecated and should be True"):
ModelCheckpoint(existing, _PREFIX, create_dir=False, save_as_state_dict=False)

with pytest.raises(ValueError, match=r"If `score_name` is provided, then `score_function` "):
ModelCheckpoint(existing, _PREFIX, create_dir=False, score_name="test")

with pytest.raises(TypeError, match=r"global_step_transform should be a function"):
ModelCheckpoint(existing, _PREFIX, create_dir=False, global_step_transform=1234)

with pytest.warns(UserWarning, match=r"Argument archived is deprecated"):
ModelCheckpoint(existing, _PREFIX, create_dir=False, archived=True)

h = ModelCheckpoint(dirname, _PREFIX, create_dir=False)
assert h.last_checkpoint is None
with pytest.raises(RuntimeError, match=r"No objects to checkpoint found."):
Expand Down

0 comments on commit 1c07eaf

Please sign in to comment.