This directory contains the segmentation code of MasKD (Masked Knowledge Distillation with Receptive Tokens).
Put the Cityscapes dataset into ./data/cityscapes
folder.
Download the required checkpoints into ./ckpts
folder.
Backbones pretrained on ImageNet:
Teacher backbones:
Student models are trained on 8 * NVIDIA Tesla V100 GPUs.
*: The backbone parameters are random initialized.
Role | Network | Method | val mIoU | test mIoU | train script | log | ckpt |
---|---|---|---|---|---|---|---|
Teacher | DeepLabV3-ResNet101 | - | 78.07 | 77.46 | sh | - | Google Drive |
Student | DeepLabV3-ResNet18 | MasKD | 77.00 | 75.59 | sh | ||
Student | DeepLabV3-ResNet18* | MasKD | 73.95 | 73.74 | sh | ||
Student | DeepLabV3-MBV2 | MasKD | 75.26 | 74.23 | sh | ||
Student | PSPNet-ResNet18 | MasKD | 75.34 | 74.61 | sh |
python -m torch.distributed.launch --nproc_per_node=8 eval.py \
--model deeplabv3 \
--backbone resnet101 \
--data [your dataset path]/cityscapes/ \
--save-dir [your directory path to store log files] \
--gpu-id 0,1,2,3,4,5,6,7 \
--pretrained [your checkpoint path]/deeplabv3_resnet101_citys_best_model.pth
You can use test_deeplabv3_mbv2.sh, test_deeplabv3_res18.sh, and test_pspnet_res18.sh to test the student models, or use the script manually as follows:
python -m torch.distributed.launch --nproc_per_node=4 test.py \
--model deeplabv3 \
--backbone resnet101 \
--data [your dataset path]/cityscapes/ \
--save-dir [your directory path to store resulting images] \
--gpu-id 0,1,2,3 \
--save-pred \
--pretrained [your checkpoint path]/deeplabv3_resnet101_citys_best_model.pth
You can submit the resulting images to the Cityscapes test server.
Our pretrained mask module on deeplabv3-r101 is in [link] or work_dirs/dv3-r101/deeplabv3_resnet101_citys_best_model.pth
. [log]
You can train your own mask module with the following script:
sh train_scripts/train_mask_module/deeplabv3_res101.sh
Here is an example code to visualize the learned masks:
sh train_scripts/train_mask_module/vis_deeplabv3_res101.sh
The code is mostly based on the code in CIRKD.