-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BYOL Single GPU implementation #1
base: master
Are you sure you want to change the base?
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
) Summary: Pull Request resolved: facebookresearch#343 Some basic changes to make this script work within FBinfra. 1. Register Manifold in PathManager. 1. In order to do #1, create fb/extra_scripts/convert_sharded_checkpoint.y and add necessary dependencies in TARGETS 1. Replace some torch.loads using PathManager. Reviewed By: prigoyal Differential Revision: D29166520 fbshipit-source-id: a61b4eb80d74526b0a7e2d38f973eb688b311a94
As per the commit fc1217d addressed the following:
|
def __init__(self, base_momentum: float, shuffle_batch: bool = True): | ||
super().__init__() | ||
self.base_momentum = base_momentum | ||
self.is_distributed = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: Do we need this?
|
||
class BYOLHook(ClassyHook): | ||
""" | ||
TODO: Update description |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update description to BYOL.
|
||
@staticmethod | ||
def cosine_decay(training_iter, max_iters, initial_value): | ||
# TODO: Why do we need this min statement? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment this method.
Put types in the method.
return initial_value * cosine_decay_value | ||
|
||
@staticmethod | ||
def target_ema(training_iter, base_ema, max_iters): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's comment every method in the byol_hooks.py and byol_losses.py and make sure they all have type hints.
|
||
def _build_byol_target_network(self, task: tasks.ClassyTask) -> None: | ||
""" | ||
Create the model replica called the target. This will slowly track |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Improve comment. Something like: "Target network is exponential moving average of online network, ... "
@torch.no_grad() | ||
def on_forward(self, task: tasks.ClassyTask) -> None: | ||
""" | ||
- Update the target model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update comment for BYOL (this was copy/pasted from moco).
@register_loss("byol_loss") | ||
class BYOLLoss(ClassyLoss): | ||
""" | ||
TODO: change description |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change loss description.
and https://github.com/facebookresearch/moco for a reference implementation, reused here | ||
|
||
Config params: | ||
embedding_dim (int): head output output dimension |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need/use these vars?
"_BYOLLossConfig", ["embedding_dim", "momentum"] | ||
) | ||
|
||
def regression_loss(x, y): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Type-hints + comments for all these functions.
@classmethod | ||
def from_config(cls, config: BYOLLossConfig): | ||
""" | ||
Instantiates BYOLLoss from configuration. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put in the config options in the docstring here.
|
||
def forward(self, online_network_prediction: torch.Tensor, *args, **kwargs) -> torch.Tensor: | ||
""" | ||
Given the encoder queries, the key and the queue of the previous queries, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment I think is copy/pasted.
self.is_distributed = False | ||
|
||
self.momentum = None | ||
self.inv_momentum = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove this.
|
||
self.momentum = None | ||
self.inv_momentum = None | ||
self.total_iters = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rename this max_iters.
Implementation of BYOL: https://arxiv.org/abs/2006.07733 on Single GPU
Issue #190