Skip to content

Commit

Permalink
Add None check for max_epochs (pytorch#1519)
Browse files Browse the repository at this point in the history
* Add None check for max_epochs

* Small CR fix

Co-authored-by: vfdev <vfdev.5@gmail.com>
  • Loading branch information
gruebel and vfdev-5 committed Dec 22, 2020
1 parent cfea1e4 commit e3ef192
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
3 changes: 2 additions & 1 deletion ignite/contrib/handlers/param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None
if isinstance(value, list):
if len(value) != len(self.optimizer_param_groups):
raise ValueError(
f"size of value is different than optimizer_param_groups {len(value)} != {len(self.optimizer_param_groups)}"
"size of value is different than optimizer_param_groups "
f"{len(value)} != {len(self.optimizer_param_groups)}"
)

for i, param_group in enumerate(self.optimizer_param_groups):
Expand Down
12 changes: 9 additions & 3 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,8 @@ def _is_done(state: State) -> bool:
is_done_iters = state.max_iters is not None and state.iteration >= state.max_iters
is_done_count = (
state.epoch_length is not None
and state.iteration >= state.epoch_length * state.max_epochs # type: ignore[operator]
and state.max_epochs is not None
and state.iteration >= state.epoch_length * state.max_epochs
)
is_done_epochs = state.max_epochs is not None and state.epoch >= state.max_epochs
return is_done_iters or is_done_count or is_done_epochs
Expand Down Expand Up @@ -833,12 +834,17 @@ def _run_once_on_dataset(self) -> float:
# Should exit while loop if we can not iterate
if should_exit:
if not self._is_done(self.state):
total_iters = (
self.state.epoch_length * self.state.max_epochs
if self.state.max_epochs is not None
else self.state.max_iters
)

warnings.warn(
"Data iterator can not provide data anymore but required total number of "
"iterations to run is not reached. "
"Current iteration: {} vs Total iterations to run : {}".format(
self.state.iteration,
self.state.epoch_length * self.state.max_epochs, # type: ignore[operator]
self.state.iteration, total_iters,
)
)
break
Expand Down

0 comments on commit e3ef192

Please sign in to comment.