Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
* Fixes pytorch#1162
- relaxed check of optimizer type

* Updated docs
  • Loading branch information
vfdev-5 committed Jun 26, 2020
1 parent 543ae1e commit 9a7be7b
Show file tree
Hide file tree
Showing 13 changed files with 122 additions and 21 deletions.
7 changes: 7 additions & 0 deletions docs/source/contrib/handlers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ mlflow_logger
tqdm_logger
-----------

See `tqdm mnist example <https://github.com/pytorch/ignite/blob/master/examples/contrib/mnist/mnist_with_tqdm_logger.py>`_
for detailed usage.


.. automodule:: ignite.contrib.handlers.tqdm_logger
:members:

Expand All @@ -88,6 +92,9 @@ polyaxon_logger
wandb_logger
---------------

See `wandb mnist example <https://github.com/pytorch/ignite/blob/master/examples/contrib/mnist/mnist_with_wandb_logger.py>`_
for detailed usage.

.. automodule:: ignite.contrib.handlers.wandb_logger
:members:
:inherited-members:
Expand Down
10 changes: 8 additions & 2 deletions ignite/contrib/handlers/base_logger.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import numbers
import warnings
from abc import ABCMeta, abstractmethod
from collections.abc import Sequence
from typing import Any, Mapping

import torch
from torch.optim import Optimizer

from ignite.engine import Engine, State

Expand All @@ -20,9 +22,13 @@ class BaseOptimizerParamsHandler(BaseHandler):
"""

def __init__(self, optimizer, param_name="lr", tag=None):
if not isinstance(optimizer, torch.optim.Optimizer):
if not (
isinstance(optimizer, Optimizer)
or (hasattr(optimizer, "param_groups") and isinstance(optimizer.param_groups, Sequence))
):
raise TypeError(
"Argument optimizer should be of type torch.optim.Optimizer, " "but given {}".format(type(optimizer))
"Argument optimizer should be torch.optim.Optimizer or has attribute 'param_groups' as list/tuple, "
"but given {}".format(type(optimizer))
)

self.optimizer = optimizer
Expand Down
3 changes: 2 additions & 1 deletion ignite/contrib/handlers/mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler):
)
Args:
optimizer (torch.optim.Optimizer): torch optimizer which parameters to log
optimizer (torch.optim.Optimizer or object): torch optimizer or any object with attribute ``param_groups``
as a sequence.
param_name (str): parameter name
tag (str, optional): common title for all produced plots. For example, 'generator'
"""
Expand Down
3 changes: 2 additions & 1 deletion ignite/contrib/handlers/neptune_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler):
)
Args:
optimizer (torch.optim.Optimizer): torch optimizer which parameters to log
optimizer (torch.optim.Optimizer or object): torch optimizer or any object with attribute ``param_groups``
as a sequence.
param_name (str): parameter name
tag (str, optional): common title for all produced plots. For example, "generator"
"""
Expand Down
25 changes: 18 additions & 7 deletions ignite/contrib/handlers/param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ class ParamScheduler(metaclass=ABCMeta):
training.
Args:
optimizer (`torch.optim.Optimizer`): optimizer
optimizer (torch.optim.Optimizer or object): torch optimizer or any object with attribute ``param_groups``
as a sequence.
param_name (str): name of optimizer's parameter to update.
save_history (bool, optional): whether to log the parameter values to
`engine.state.param_history`, (default=False).
Expand All @@ -33,8 +34,14 @@ class ParamScheduler(metaclass=ABCMeta):

def __init__(self, optimizer, param_name, save_history=False, param_group_index=None):

if not isinstance(optimizer, Optimizer):
raise TypeError("Argument optimizer should be torch.optim.Optimizer")
if not (
isinstance(optimizer, Optimizer)
or (hasattr(optimizer, "param_groups") and isinstance(optimizer.param_groups, Sequence))
):
raise TypeError(
"Argument optimizer should be torch.optim.Optimizer or has attribute 'param_groups' as list/tuple, "
"but given {}".format(type(optimizer))
)

self.optimizer = optimizer
self.param_group_index = param_group_index
Expand Down Expand Up @@ -212,7 +219,8 @@ class CyclicalScheduler(ParamScheduler):
cycle of some size.
Args:
optimizer (`torch.optim.Optimizer`): optimizer
optimizer (torch.optim.Optimizer or object): torch optimizer or any object with attribute ``param_groups``
as a sequence.
param_name (str): name of optimizer's parameter to update.
start_value (float): value at start of cycle.
end_value (float): value at the middle of the cycle.
Expand Down Expand Up @@ -287,7 +295,8 @@ class LinearCyclicalScheduler(CyclicalScheduler):
adjusts it back to 'start_value' for a half-cycle.
Args:
optimizer (`torch.optim.Optimizer`): optimizer
optimizer (torch.optim.Optimizer or object): torch optimizer or any object with attribute ``param_groups``
as a sequence.
param_name (str): name of optimizer's parameter to update.
start_value (float): value at start of cycle.
end_value (float): value at the middle of the cycle.
Expand Down Expand Up @@ -332,7 +341,8 @@ class CosineAnnealingScheduler(CyclicalScheduler):
wave (as suggested in [Smith17]_).
Args:
optimizer (`torch.optim.Optimizer`): optimizer
optimizer (torch.optim.Optimizer or object): torch optimizer or any object with attribute ``param_groups``
as a sequence.
param_name (str): name of optimizer's parameter to update.
start_value (float): value at start of cycle.
end_value (float): value at the end of the cycle.
Expand Down Expand Up @@ -818,7 +828,8 @@ class PiecewiseLinear(ParamScheduler):
Piecewise linear parameter scheduler
Args:
optimizer (`torch.optim.Optimizer`): optimizer.
optimizer (torch.optim.Optimizer or object): torch optimizer or any object with attribute ``param_groups``
as a sequence.
param_name (str): name of optimizer's parameter to update.
milestones_values (list of tuples (int, float)): list of tuples (event index, parameter value)
represents milestones and parameter. Milestones should be increasing integers.
Expand Down
3 changes: 2 additions & 1 deletion ignite/contrib/handlers/polyaxon_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler):
)
Args:
optimizer (torch.optim.Optimizer): torch optimizer which parameters to log
optimizer (torch.optim.Optimizer or object): torch optimizer or any object with attribute ``param_groups``
as a sequence.
param_name (str): parameter name
tag (str, optional): common title for all produced plots. For example, "generator"
"""
Expand Down
3 changes: 2 additions & 1 deletion ignite/contrib/handlers/tensorboard_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler):
)
Args:
optimizer (torch.optim.Optimizer): torch optimizer which parameters to log
optimizer (torch.optim.Optimizer or object): torch optimizer or any object with attribute ``param_groups``
as a sequence.
param_name (str): parameter name
tag (str, optional): common title for all produced plots. For example, "generator"
"""
Expand Down
3 changes: 2 additions & 1 deletion ignite/contrib/handlers/trains_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler):
)
Args:
optimizer (torch.optim.Optimizer): torch optimizer which parameters to log
optimizer (torch.optim.Optimizer or object): torch optimizer or any object with attribute ``param_groups``
as a sequence.
param_name (str): parameter name
tag (str, optional): common title for all produced plots. For example, "generator"
"""
Expand Down
3 changes: 2 additions & 1 deletion ignite/contrib/handlers/visdom_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler, _BaseVisDrawer):
)
Args:
optimizer (torch.optim.Optimizer): torch optimizer which parameters to log
optimizer (torch.optim.Optimizer or object): torch optimizer or any object with attribute ``param_groups``
as a sequence.
param_name (str): parameter name
tag (str, optional): common title for all produced plots. For example, "generator"
show_legend (bool, optional): flag to show legend in the window
Expand Down
3 changes: 2 additions & 1 deletion ignite/contrib/handlers/wandb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler):
)
Args:
optimizer (torch.optim.Optimizer): torch optimizer which parameters to log
optimizer (torch.optim.Optimizer or object): torch optimizer or any object with attribute ``param_groups``
as a sequence.
param_name (str): parameter name
tag (str, optional): common title for all produced plots. For example, "generator"
sync (bool, optional): If set to False, process calls to log in a seperate thread. Default (None) uses whatever
Expand Down
14 changes: 14 additions & 0 deletions tests/ignite/contrib/handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
class MockFP16DeepSpeedZeroOptimizer:
def __init__(self, optimizer):
self.optimizer = optimizer

def step(self, closure=None):
self.optimizer.step()

def _get_param_groups(self):
return self.optimizer.param_groups

def _set_param_groups(self, value):
self.optimizer.param_groups = value

param_groups = property(_get_param_groups, _set_param_groups)
23 changes: 21 additions & 2 deletions tests/ignite/contrib/handlers/test_base_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ignite.contrib.handlers import CustomPeriodicEvent
from ignite.contrib.handlers.base_logger import BaseLogger, BaseOptimizerParamsHandler, BaseOutputHandler
from ignite.engine import Engine, Events, State
from tests.ignite.contrib.handlers import MockFP16DeepSpeedZeroOptimizer


class DummyOutputHandler(BaseOutputHandler):
Expand All @@ -15,8 +16,13 @@ def __call__(self, *args, **kwargs):


class DummyOptParamsHandler(BaseOptimizerParamsHandler):
def __call__(self, *args, **kwargs):
pass
def __call__(self, engine, logger, event_name, **kwargs):
tag_prefix = "{}/".format(self.tag) if self.tag else ""
params = {
"{}{}/group_{}".format(tag_prefix, self.param_name, i): float(param_group[self.param_name])
for i, param_group in enumerate(self.optimizer.param_groups)
}
return params


class DummyLogger(BaseLogger):
Expand All @@ -41,6 +47,9 @@ def test_base_output_handler_wrong_setup():
with pytest.raises(TypeError, match="global_step_transform should be a function"):
DummyOutputHandler("tag", metric_names=["loss"], global_step_transform="abc")

with pytest.raises(TypeError, match=r"Argument optimizer should be torch.optim.Optimizer"):
DummyOptParamsHandler({}, "lr")


def test_base_output_handler_setup_output_metrics():

Expand Down Expand Up @@ -81,6 +90,16 @@ def test_base_output_handler_setup_output_metrics():
assert metrics == true_metrics


def test_opt_params_handler_on_non_torch_optimizers():
tensor = torch.zeros([1], requires_grad=True)
base_optimizer = torch.optim.SGD([tensor], lr=0.1234)
optimizer = MockFP16DeepSpeedZeroOptimizer(base_optimizer)
handler = DummyOptParamsHandler(optimizer=optimizer, param_name="lr")
res = handler(engine=None, logger=None, event_name=None)
assert isinstance(res, dict)
assert "lr/group_0" in res and res["lr/group_0"] == 0.1234


def test_attach():

n_epochs = 5
Expand Down
43 changes: 40 additions & 3 deletions tests/ignite/contrib/handlers/test_param_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest.mock import MagicMock

import numpy as np
import pytest
import torch
Expand All @@ -14,6 +16,7 @@
create_lr_scheduler_with_warmup,
)
from ignite.engine import Engine, Events
from tests.ignite.contrib.handlers import MockFP16DeepSpeedZeroOptimizer

try:
from torch.optim.lr_scheduler import MultiplicativeLR
Expand All @@ -26,10 +29,12 @@
has_multiplicative_lr = LooseVersion(torch.__version__) >= LooseVersion("1.5.0")


class FakeParamScheduler(ParamScheduler):
def get_param(self):
return [0]


def test_param_scheduler_asserts():
class FakeParamScheduler(ParamScheduler):
def get_param(self):
return [0]

t1 = torch.zeros([1], requires_grad=True)
t2 = torch.zeros([1], requires_grad=True)
Expand All @@ -46,6 +51,9 @@ def get_param(self):
with pytest.raises(ValueError, match=r"Required state attribute 'event_index' is absent in provided state_dict"):
lr_scheduler.load_state_dict({})

with pytest.raises(TypeError, match=r"Argument optimizer should be torch.optim.Optimizer"):
FakeParamScheduler({}, "lr")


def test_linear_scheduler():

Expand Down Expand Up @@ -1220,3 +1228,32 @@ def save_lr():

torch_lr_scheduler = StepLR(optimizer, step_size=50, gamma=0.5)
_test(LRScheduler(torch_lr_scheduler), optimizer)


def test_lr_scheduling_on_non_torch_optimizers():
# tests https://github.com/pytorch/ignite/issues/1162
optimizer = MagicMock()
optimizer.param_groups = [{"params": 0}]
FakeParamScheduler(optimizer, "lr")

tensor = torch.zeros([1], requires_grad=True)
base_optimizer = torch.optim.SGD([tensor], lr=0)
optimizer = MockFP16DeepSpeedZeroOptimizer(base_optimizer)

milestones_values = [(5, 0.5), (15, 1.0)]

scheduler = PiecewiseLinear(optimizer, "lr", milestones_values=milestones_values)

def save_lr(engine):
lrs.append(optimizer.param_groups[0]["lr"])

trainer = Engine(lambda engine, batch: None)
trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)
trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)

lrs = []
trainer.run([0] * 15, max_epochs=1)

assert lrs == list(
map(pytest.approx, [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95,],)
)

0 comments on commit 9a7be7b

Please sign in to comment.