Skip to content

Commit

Permalink
analysis paper settings
Browse files Browse the repository at this point in the history
  • Loading branch information
SimarKareer committed Jul 1, 2023
1 parent 391bd61 commit 4bffbab
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 9 deletions.
2 changes: 1 addition & 1 deletion configs/_base_/default_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
init_kwargs={
'dir': "wandb_dir",
'entity': "video-da",
'project': "multimodal",
'project': "analysisPaper",
'resume': 'allow'
},
interval=5,
Expand Down
4 changes: 2 additions & 2 deletions configs/mic/viperHR2csHR_mic_hrda.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@
n_gpus = None
launcher = "slurm" #"slurm"
gpu_model = 'A40'
runner = dict(type='IterBasedRunner', max_iters=15000)
runner = dict(type='IterBasedRunner', max_iters=40000)
# Logging Configuration
checkpoint_config = dict(by_epoch=False, interval=3000, max_keep_ckpts=2)
checkpoint_config = dict(by_epoch=False, interval=8000, max_keep_ckpts=1)
evaluation = dict(interval=3000, eval_settings={
"metrics": ["mIoU", "pred_pred", "gt_pred", "M5", "M5Fixed", "mIoU_gt_pred"],
"sub_metrics": ["mask_count"],
Expand Down
2 changes: 1 addition & 1 deletion mmseg/models/uda/dacs.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,7 @@ def forward_train(self, img, img_metas, img_extra, target_img, target_img_metas,
subplotimg(axs[1][6], pseudo_label_warped[0], 'Original PL Warped', cmap="cityscapes")

if self.consis_filter_rare_class:
pseudo_label_warped = rare_class_or_filter(pseudo_label, pseudo_label_warped)
pseudo_label_warped = rare_class_or_filter(pseudo_label, pseudo_label_warped, rare_common_compare=True)
pseudo_weight_warped[pseudo_label_warped == 255] = 0

if self.oracle_mask:
Expand Down
19 changes: 14 additions & 5 deletions tools/aggregate_flows/flow/my_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,10 +489,11 @@ def labelMapToIm(label, label_map):
([50,0,90], 31)
]

def rare_class_or_filter(pl1, pl2):
def rare_class_or_filter(pl1, pl2, rare_common_compare=False):
"""
pl1: (B, H, W)
pl2: (B, H, W)
rare_common_compare: boolean that determines whether to do priority of rarity between classes, or just abs if it belongs to "rare" or "common" group
returns a pseudolabel which keeps consistent pixels, and masks out inconsistent pixels except when the pixel is rare. In the case that both pl1 and pl2 are rare take the more rare pixel
"""
# most to least rare
Expand All @@ -514,13 +515,21 @@ def rare_class_or_filter(pl1, pl2):
# print("output1", output)
output[consistent_pixels] = pl1[consistent_pixels]
# print("output2", output)

if rare_common_compare:
#only take prediction if one is rare and other is common
pl1_rarer_and_inconsistent = inconsis_pixels & (pl1_rarity < rarity_thresh) & (pl2_rarity >= rarity_thresh)
output[pl1_rarer_and_inconsistent] = pl1[pl1_rarer_and_inconsistent]

pl2_rarer_and_inconsistent = inconsis_pixels & (pl2_rarity < rarity_thresh) & (pl1_rarity >= rarity_thresh)
output[pl2_rarer_and_inconsistent] = pl2[pl2_rarer_and_inconsistent]

pl1_rarer_and_inconsistent = inconsis_pixels & (pl1_rarity < rarity_thresh) & (pl1_rarity < pl2_rarity)
output[pl1_rarer_and_inconsistent] = pl1[pl1_rarer_and_inconsistent]
else:
pl1_rarer_and_inconsistent = inconsis_pixels & (pl1_rarity < rarity_thresh) & (pl1_rarity < pl2_rarity)
output[pl1_rarer_and_inconsistent] = pl1[pl1_rarer_and_inconsistent]

pl2_rarer_and_inconsistent = inconsis_pixels & (pl2_rarity < rarity_thresh) & (pl2_rarity < pl1_rarity)
output[pl2_rarer_and_inconsistent] = pl2[pl2_rarer_and_inconsistent]
pl2_rarer_and_inconsistent = inconsis_pixels & (pl2_rarity < rarity_thresh) & (pl2_rarity < pl1_rarity)
output[pl2_rarer_and_inconsistent] = pl2[pl2_rarer_and_inconsistent]
# breakpoint()

return output
Expand Down

0 comments on commit 4bffbab

Please sign in to comment.