From 46f3f490fb977c3f2d09460fe099b786c95ea80a Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 14 Oct 2020 09:59:52 +0200 Subject: [PATCH] Removed state.restart method --- ignite/engine/engine.py | 7 +++---- ignite/engine/events.py | 3 --- tests/ignite/engine/test_engine_state_dict.py | 4 ++-- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index c50bcfe3f3d..c9583c56688 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -639,8 +639,7 @@ def run( def switch_batch(engine): engine.state.batch = preprocess_batch(engine.state.batch) - Restart the training from the beginning. User can reset `max_epochs = None` or either call - `trainer.state.restart()`: + Restart the training from the beginning. User can reset `max_epochs = None`: .. code-block:: python @@ -648,7 +647,7 @@ def switch_batch(engine): trainer.run(train_loader, max_epochs=5) # Reset model weights etc. and restart the training - trainer.state.restart() # equivalent to trainer.state.max_epochs = None + trainer.state.max_epochs = None trainer.run(train_loader, max_epochs=2) """ @@ -667,7 +666,7 @@ def switch_batch(engine): if max_epochs < self.state.epoch: raise ValueError( "Argument max_epochs should be larger than the start epoch " - "defined in the state: {} vs {}. Please, call state.restart() " + "defined in the state: {} vs {}. Please, set engine.state.max_epochs = None " "before calling engine.run() in order to restart the training from the beginning.".format( max_epochs, self.state.epoch ) diff --git a/ignite/engine/events.py b/ignite/engine/events.py index 7cf2b4627a5..818c491770e 100644 --- a/ignite/engine/events.py +++ b/ignite/engine/events.py @@ -396,9 +396,6 @@ def get_event_attrib_value(self, event_name: Union[CallableEventWithFilter, Enum raise RuntimeError("Unknown event name '{}'".format(event_name)) return getattr(self, State.event_to_attr[event_name]) - def restart(self) -> None: - self.max_epochs = None - def __repr__(self) -> str: s = "State:\n" for attr, value in self.__dict__.items(): diff --git a/tests/ignite/engine/test_engine_state_dict.py b/tests/ignite/engine/test_engine_state_dict.py index 7bc55220fbb..f65cee08d52 100644 --- a/tests/ignite/engine/test_engine_state_dict.py +++ b/tests/ignite/engine/test_engine_state_dict.py @@ -276,9 +276,9 @@ def test_restart_training(): with pytest.raises( ValueError, match=r"Argument max_epochs should be larger than the start epoch defined in the state: 2 vs 5. " - r"Please, call state.restart\(\) " + r"Please, .+ " r"before calling engine.run\(\) in order to restart the training from the beginning.", ): state = engine.run(data, max_epochs=2) - state.restart() + state.max_epochs = None engine.run(data, max_epochs=2)