Skip to content

PyTorch 1.11 reimplementation of multi task gradient adaptation ideas: Gradient Surgery (PCGrad) and Gradient Vaccine

License

Notifications You must be signed in to change notification settings

anzeyimana/Pytorch-PCGrad-GradVac-AMP-GradAccum

 
 

Repository files navigation

Pytorch-PCGrad-GradVac-AMP-GradAccum

PyTorch 1.11 reimplementation of multi task gradient adaptation ideas from papers:

Supports:

  • Automatic mixed precision (AMP)
  • Gradient Accumulation (with CPU offload support)

Adaptation from the following repositories:

Setup

Install the required packages via:

pip install -r requirements.txt

Usage

import torch
from pcgrad_amp import PCGradAMP
from gradvac_amp import GradVacAMP

DEVICE = ...
ACCUM_STEPS = ...
NUM_EPOCHS = ...
LR = ...
BATCH_SIZE = ...

model = ...
train_loader = ...

optimizer = torch.optim.Adam(model.parameters(), lr=LR)
# lr_scheduler = ...
scaler = torch.cuda.amp.GradScaler()
num_tasks = 2

# For Gradient Surgery/PCGrad
#grad_optimizer = PCGradAMP(num_tasks, optimizer, scaler=scaler, reduction='sum', cpu_offload= False)

# For Gradient Vaccine
grad_optimizer = GradVacAMP(num_tasks, optimizer, DEVICE, scaler = scaler, beta = 1e-2, reduction='sum', cpu_offload = False)

total_steps = 0
for ep in range(NUM_EPOCHS):
    for mini_batch in train_loader:
        losses = [...]
        grad_optimizer.backward(losses) # Compute & Accumulate gradients
        total_steps += 1
        if (total_steps % ACCUM_STEPS) == 0:
            # lr_scheduler.step()
            grad_optimizer.step() # Parameter update step

Training

  • Multi-MNIST

Please run the training script via the following command. Part of implementation is leveraged from https://github.com/intel-isl/MultiObjectiveOptimization

python main_multi_mnist_amp.py

Obtained results with default settings

Method left-digit right-digit
Jointly Training 89.88 87.51
Gradient Surgery (PCGrad) 90.92 88.13
Gradient Vaccine 91.07 88.79

The training loss curves indicate GradVac is slightly better than PCGrad.

Reference

Please cite as:

@article{yu2020gradient,
  title={Gradient surgery for multi-task learning},
  author={Yu, Tianhe and Kumar, Saurabh and Gupta, Abhishek and Levine, Sergey and Hausman, Karol and Finn, Chelsea},
  journal={Advances in Neural Information Processing Systems},
  volume={33},
  pages={5824--5836},
  year={2020}
}

@article{wang2020gradient,
  title={Gradient vaccine: Investigating and improving multi-task optimization in massively multilingual models},
  author={Wang, Zirui and Tsvetkov, Yulia and Firat, Orhan and Cao, Yuan},
  journal={arXiv preprint arXiv:2010.05874},
  year={2020}
}

@misc{Pytorch-PCGrad-GradVac-AMP-GradAccum,
  author = {Antoine Nzeyimana},
  title = {Pytorch-PCGrad-GradVac-AMP-GradAccum/Antoine Nzeyimana},
  url = {https://github.com/anzeyimana/Pytorch-PCGrad-GradVac-AMP-GradAccum},
  year = {2022}
}

About

PyTorch 1.11 reimplementation of multi task gradient adaptation ideas: Gradient Surgery (PCGrad) and Gradient Vaccine

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%