Skip to content

Commit

Permalink
Merge pull request #52 from mhd-medfa/main
Browse files Browse the repository at this point in the history
feat: Introduce Co-DINO with SwinL Backbone (LVIS) Inference Config Files
  • Loading branch information
TempleX98 committed Sep 10, 2023
2 parents a43a29c + 4a7da20 commit 8d693d5
Show file tree
Hide file tree
Showing 8 changed files with 776 additions and 30 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ In this paper, we present a novel collaborative hybrid assignments training sche
| Co-DINO | Swin-L | 12 | LSJ | COCO | 59.3 | [config](https://github.com/Sense-X/Co-DETR/blob/main/projects/configs/co_dino/co_dino_5scale_lsj_swin_large_1x_coco.py) | [model](https://drive.google.com/drive/folders/1nAXOkzqrEgz-YnXxIEs4d5j9li_kmrnv?usp=sharing) |
| Co-DINO | Swin-L | 24 | LSJ | COCO | 60.4 | [config](https://github.com/Sense-X/Co-DETR/blob/main/projects/configs/co_dino/co_dino_5scale_lsj_swin_large_2x_coco.py) | [model](https://drive.google.com/drive/folders/1nAXOkzqrEgz-YnXxIEs4d5j9li_kmrnv?usp=sharing) |
| Co-DINO | Swin-L | 36 | LSJ | COCO | 60.7 | [config](https://github.com/Sense-X/Co-DETR/blob/main/projects/configs/co_dino/co_dino_5scale_lsj_swin_large_3x_coco.py) | [model](https://drive.google.com/drive/folders/1nAXOkzqrEgz-YnXxIEs4d5j9li_kmrnv?usp=sharing) |
| Co-DINO | Swin-L | 36 | LSJ | LVIS | 56.9 | config | [model](https://drive.google.com/drive/folders/1nAXOkzqrEgz-YnXxIEs4d5j9li_kmrnv?usp=sharing) |
| Co-DINO | Swin-L | 36 | LSJ | LVIS | 56.9 | [config](https://github.com/Sense-X/Co-DETR/blob/main/projects/configs/co_dino/co_dino_5scale_swin_large_3x_lvis.py) | [model](https://drive.google.com/drive/folders/1nAXOkzqrEgz-YnXxIEs4d5j9li_kmrnv?usp=sharing) |

### Co-Deformable-DETR

Expand Down
7 changes: 4 additions & 3 deletions demo/inference_demo.ipynb

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions mmdet/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint

from mmdet.core import get_classes
from mmdet.core import get_classes, DatasetEnum
from mmdet.datasets import replace_ImageToTensor
from mmdet.datasets.pipelines import Compose
from mmdet.models import build_detector
from projects import *


def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None):
def init_detector(config, checkpoint=None, dataset=DatasetEnum.COCO, device='cuda:0', cfg_options=None):
"""Initialize a detector from config file.
Args:
Expand Down Expand Up @@ -49,9 +49,9 @@ def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None):
model.CLASSES = checkpoint['meta']['CLASSES']
else:
warnings.simplefilter('once')
warnings.warn('Class names are not saved in the checkpoint\'s '
'meta data, use COCO classes by default.')
model.CLASSES = get_classes('coco')
warnings.warn(f'Class names are not saved in the checkpoint\'s '
f'meta data, use {dataset.value} classes.')
model.CLASSES = get_classes(dataset)
model.cfg = config # save the config in the model for convenience
model.to(device)
model.eval()
Expand Down
4 changes: 2 additions & 2 deletions mmdet/core/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .class_names import (cityscapes_classes, coco_classes, dataset_aliases,
get_classes, imagenet_det_classes,
imagenet_vid_classes, oid_challenge_classes,
oid_v6_classes, voc_classes)
oid_v6_classes, voc_classes, DatasetEnum)
from .eval_hooks import DistEvalHook, EvalHook
from .mean_ap import average_precision, eval_map, print_map_summary
from .panoptic_utils import INSTANCE_OFFSET
Expand All @@ -15,5 +15,5 @@
'DistEvalHook', 'EvalHook', 'average_precision', 'eval_map',
'print_map_summary', 'eval_recalls', 'print_recall_summary',
'plot_num_recall', 'plot_iou_recall', 'oid_v6_classes',
'oid_challenge_classes', 'INSTANCE_OFFSET'
'oid_challenge_classes', 'INSTANCE_OFFSET', 'DatasetEnum',
]
300 changes: 281 additions & 19 deletions mmdet/core/evaluation/class_names.py

Large diffs are not rendered by default.

67 changes: 67 additions & 0 deletions projects/configs/co_dino/co_dino_5scale_lsj_r50_1x_lvis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
_base_ = [
'co_dino_5scale_r50_1x_lvis.py'
]

model = dict(with_attn_mask=False)

img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

image_size = (1536, 1536)
load_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(
type='Resize',
img_scale=image_size,
ratio_range=(0.1, 2.0),
multiscale_mode='range',
keep_ratio=True),
dict(
type='RandomCrop',
crop_type='absolute_range',
crop_size=image_size,
recompute_bbox=True,
allow_negative_crop=True),
dict(type='FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Pad', size=image_size, pad_val=dict(img=(114, 114, 114))),
]
train_pipeline = [
dict(type='CopyPaste', max_num_pasted=100),
dict(type='Normalize', **img_norm_cfg),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=image_size,
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Pad', size=image_size, pad_val=dict(img=(114, 114, 114))),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
])
]
dataset_type = 'LVISV1Dataset'
data_root = 'data/lvis_v1/'
img_data_root = 'data/coco/'
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type='MultiImageMixDataset',
dataset=dict(
type=dataset_type,
ann_file=data_root + 'lvis_v1_train.json',
img_prefix=img_data_root,
filter_empty_gt=False,
pipeline=load_pipeline),
pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
Loading

0 comments on commit 8d693d5

Please sign in to comment.