Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-535] Fix bugs in LR Schedulers and add warmup #11234

Merged
merged 3 commits into from
Aug 26, 2018

Conversation

rahul003
Copy link
Member

@rahul003 rahul003 commented Jun 11, 2018

Description

  • Adds warmup to all LR schedulers. Of two modes, linear increase and constant warmup
  • Also fixes inconsistencies/bugs where base_lr is not taken by MultiFactorScheduler and FactorScheduler.
  • Added tests for LR schedulers

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

@rahul003 rahul003 requested a review from szha as a code owner June 11, 2018 23:59
@rahul003 rahul003 changed the title [MXNET-535] Add Warmup Learning Rate Scheduler and fix inconsistencies in LR Schedulers [MXNET-535] Add Warmup Learning Rate Scheduler and fix bugs in LR Schedulers Jun 18, 2018
@@ -153,18 +153,57 @@ class PolyScheduler(LRScheduler):

"""

def __init__(self, max_update, base_lr=0.01, pwr=2):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't remove base_lr, it will break API.
Pass it to super init instead

if warmup_steps <= 0:
raise ValueError("Warmup steps has to be positive")
self.warmup_steps = warmup_steps
self.lrs_updates = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the point of this cache? Looks like it will always miss

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would have for each batch, number of calls to call equal to the number of learnable parameter arrays.

self.lrs_updates[num_update] = self.lr_begin + increase
else:
if isinstance(self.scheduler, PolyScheduler):
self.lrs_updates[num_update] = self.scheduler(num_update - self.warmup_steps)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why special case for PolyScheduler?
Is num_update - self.warmup_steps standard? Does Tf or Pytorch do it this way?
Why not num_update directly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PolyScheduler, and CosineScheduler(not implemented here) reduce lr from a "starting lr" to an "ending lr" smoothly, for example from 0.1 to 0.

With warmup, we first increase the lr from a small value (e.g. 0) to the starting lr, then apply the main scheduler. Assuming we have warmup for the first 5 epochs, and the total training epochs is 90, then the effective number of epochs for the poly scheduler is 85.

As for piecewise-constant/factor scheduler, it decays the lr only at certain points, on which the warmup stage has no effect.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah for the above reason, but I'm updating the code to remove this special case, and pass a wamup_steps param to such schedulers so we can handle it cleanly

@rahul003 rahul003 changed the title [MXNET-535] Add Warmup Learning Rate Scheduler and fix bugs in LR Schedulers [MXNET-535] [WIP] Add Warmup Learning Rate Scheduler and fix bugs in LR Schedulers Jun 21, 2018
@rahul003
Copy link
Member Author

Improved how warmup is handled, please review

@rahul003 rahul003 changed the title [MXNET-535] [WIP] Add Warmup Learning Rate Scheduler and fix bugs in LR Schedulers [MXNET-535] Add Warmup Learning Rate Scheduler and fix bugs in LR Schedulers Jul 5, 2018
@rahul003
Copy link
Member Author

rahul003 commented Jul 5, 2018

@piiswrong Could you please review? This has a fix for an important bug where the MultiFactorScheduler didn't take a base_lr previously. This meant that the example/image_classification/ scripts didn't use the given LR correctly. It would drop from x to 0.001 and 0.0001 regardless of the LR given.

@rahul003 rahul003 changed the title [MXNET-535] Add Warmup Learning Rate Scheduler and fix bugs in LR Schedulers [MXNET-535] Add Warmup to learning rate schedulers and fix bugs in LR Schedulers Jul 6, 2018
@rahul003
Copy link
Member Author

@piiswrong @szha @eric-haibin-lin Please review

@rahul003 rahul003 changed the title [MXNET-535] Add Warmup to learning rate schedulers and fix bugs in LR Schedulers [MXNET-535] Fix bugs in LR Schedulers and add warmup Jul 24, 2018
@rahul003
Copy link
Member Author

@szha @eric-haibin-lin please review

@@ -29,8 +30,31 @@ class LRScheduler(object):
base_lr : float, optional
The initial learning rate.
"""
def __init__(self, base_lr=0.01):
def __init__(self, base_lr=0.01, warmup_steps=0, warmup_begin_lr=0, warmup_mode='linear'):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind adding documentation for warmup_begin_lr?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was some for the inherited classes, but not for this base abstract class. Anyway, now added for all. Please check.

"""

def __init__(self, max_update, base_lr=0.01, pwr=2):
super(PolyScheduler, self).__init__(base_lr)
def __init__(self, max_update, base_lr=0.01, final_lr=0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did you remove pwr? This is API breakage

Copy link
Member Author

@rahul003 rahul003 Aug 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've not removed it. Git is getting confused :/ It thinks I've changed PolyScheduler to CosineScheduler when in fact I've modified PolyScheduler and added a new CosineScheduler.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refer #11234 (comment)

@rahul003
Copy link
Member Author

rahul003 commented Aug 17, 2018

Hopefully this will give committers confidence to merge

Interfaces

CosineScheduler

def __init__(self, max_update, base_lr=0.01, final_lr=0,
                 warmup_steps=0, warmup_begin_lr=0, warmup_mode='linear'):

PolyScheduler

This PR

def __init__(self, max_update, base_lr=0.01, pwr=2, final_lr=0,
                 warmup_steps=0, warmup_begin_lr=0, warmup_mode='linear'):

Earlier

def __init__(self, max_update, base_lr=0.01, pwr=2):

MultiFactorScheduler

This PR

def __init__(self, step, factor=1, base_lr=0.01, warmup_steps=0, warmup_begin_lr=0,
                 warmup_mode='linear'):

Earlier

def __init__(self, step, factor=1):

FactorScheduler

This PR

def __init__(self, step, factor=1, stop_factor_lr=1e-8, base_lr=0.01,
                 warmup_steps=0, warmup_begin_lr=0, warmup_mode='linear'):

Earlier

def __init__(self, step, factor=1, stop_factor_lr=1e-8):

Scheduler

This PR

def __init__(self, base_lr=0.01,
                 warmup_steps=0, warmup_begin_lr=0, warmup_mode='linear'):

Earlier

def __init__(self, base_lr=0.01):

Plots of LR decay from unit tests


@eric-haibin-lin eric-haibin-lin merged commit 48d2155 into apache:master Aug 26, 2018
anirudh2290 pushed a commit to anirudh2290/mxnet that referenced this pull request Sep 19, 2018
* Add warmup and fix inconsistencies with learning rate schedulers

* add comments

* remove assert
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants