diff --git a/configs/_base_/default_runtime.py b/configs/_base_/default_runtime.py index 2e743ec7..e46a8d2c 100644 --- a/configs/_base_/default_runtime.py +++ b/configs/_base_/default_runtime.py @@ -13,7 +13,7 @@ init_kwargs={ 'dir': "wandb_dir", 'entity': "video-da", - 'project': "multimodal", + 'project': "analysisPaper", 'resume': 'allow' }, interval=5, diff --git a/configs/mic/viperHR2csHR_mic_hrda.py b/configs/mic/viperHR2csHR_mic_hrda.py index 532b5ae3..79c5e9e7 100644 --- a/configs/mic/viperHR2csHR_mic_hrda.py +++ b/configs/mic/viperHR2csHR_mic_hrda.py @@ -122,9 +122,9 @@ n_gpus = None launcher = "slurm" #"slurm" gpu_model = 'A40' -runner = dict(type='IterBasedRunner', max_iters=15000) +runner = dict(type='IterBasedRunner', max_iters=40000) # Logging Configuration -checkpoint_config = dict(by_epoch=False, interval=3000, max_keep_ckpts=2) +checkpoint_config = dict(by_epoch=False, interval=8000, max_keep_ckpts=1) evaluation = dict(interval=3000, eval_settings={ "metrics": ["mIoU", "pred_pred", "gt_pred", "M5", "M5Fixed", "mIoU_gt_pred"], "sub_metrics": ["mask_count"], diff --git a/mmseg/models/uda/dacs.py b/mmseg/models/uda/dacs.py index 1ee32f57..fedb7f88 100644 --- a/mmseg/models/uda/dacs.py +++ b/mmseg/models/uda/dacs.py @@ -928,7 +928,7 @@ def forward_train(self, img, img_metas, img_extra, target_img, target_img_metas, subplotimg(axs[1][6], pseudo_label_warped[0], 'Original PL Warped', cmap="cityscapes") if self.consis_filter_rare_class: - pseudo_label_warped = rare_class_or_filter(pseudo_label, pseudo_label_warped) + pseudo_label_warped = rare_class_or_filter(pseudo_label, pseudo_label_warped, rare_common_compare=True) pseudo_weight_warped[pseudo_label_warped == 255] = 0 if self.oracle_mask: diff --git a/tools/aggregate_flows/flow/my_utils.py b/tools/aggregate_flows/flow/my_utils.py index c9b9b78d..80e6cbb6 100644 --- a/tools/aggregate_flows/flow/my_utils.py +++ b/tools/aggregate_flows/flow/my_utils.py @@ -489,10 +489,11 @@ def labelMapToIm(label, label_map): ([50,0,90], 31) ] -def rare_class_or_filter(pl1, pl2): +def rare_class_or_filter(pl1, pl2, rare_common_compare=False): """ pl1: (B, H, W) pl2: (B, H, W) + rare_common_compare: boolean that determines whether to do priority of rarity between classes, or just abs if it belongs to "rare" or "common" group returns a pseudolabel which keeps consistent pixels, and masks out inconsistent pixels except when the pixel is rare. In the case that both pl1 and pl2 are rare take the more rare pixel """ # most to least rare @@ -514,13 +515,21 @@ def rare_class_or_filter(pl1, pl2): # print("output1", output) output[consistent_pixels] = pl1[consistent_pixels] # print("output2", output) + + if rare_common_compare: + #only take prediction if one is rare and other is common + pl1_rarer_and_inconsistent = inconsis_pixels & (pl1_rarity < rarity_thresh) & (pl2_rarity >= rarity_thresh) + output[pl1_rarer_and_inconsistent] = pl1[pl1_rarer_and_inconsistent] + pl2_rarer_and_inconsistent = inconsis_pixels & (pl2_rarity < rarity_thresh) & (pl1_rarity >= rarity_thresh) + output[pl2_rarer_and_inconsistent] = pl2[pl2_rarer_and_inconsistent] - pl1_rarer_and_inconsistent = inconsis_pixels & (pl1_rarity < rarity_thresh) & (pl1_rarity < pl2_rarity) - output[pl1_rarer_and_inconsistent] = pl1[pl1_rarer_and_inconsistent] + else: + pl1_rarer_and_inconsistent = inconsis_pixels & (pl1_rarity < rarity_thresh) & (pl1_rarity < pl2_rarity) + output[pl1_rarer_and_inconsistent] = pl1[pl1_rarer_and_inconsistent] - pl2_rarer_and_inconsistent = inconsis_pixels & (pl2_rarity < rarity_thresh) & (pl2_rarity < pl1_rarity) - output[pl2_rarer_and_inconsistent] = pl2[pl2_rarer_and_inconsistent] + pl2_rarer_and_inconsistent = inconsis_pixels & (pl2_rarity < rarity_thresh) & (pl2_rarity < pl1_rarity) + output[pl2_rarer_and_inconsistent] = pl2[pl2_rarer_and_inconsistent] # breakpoint() return output