Skip to content

Commit

Permalink
feat: introduce DatasetEnum for Dataset Names
Browse files Browse the repository at this point in the history
  • Loading branch information
mhd-medfa committed Sep 7, 2023
1 parent 6e72cbb commit f1549b5
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 21 deletions.
4 changes: 2 additions & 2 deletions mmdet/core/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .class_names import (cityscapes_classes, coco_classes, dataset_aliases,
get_classes, imagenet_det_classes,
imagenet_vid_classes, oid_challenge_classes,
oid_v6_classes, voc_classes)
oid_v6_classes, voc_classes, DatasetEnum)
from .eval_hooks import DistEvalHook, EvalHook
from .mean_ap import average_precision, eval_map, print_map_summary
from .panoptic_utils import INSTANCE_OFFSET
Expand All @@ -15,5 +15,5 @@
'DistEvalHook', 'EvalHook', 'average_precision', 'eval_map',
'print_map_summary', 'eval_recalls', 'print_recall_summary',
'plot_num_recall', 'plot_iou_recall', 'oid_v6_classes',
'oid_challenge_classes', 'INSTANCE_OFFSET'
'oid_challenge_classes', 'INSTANCE_OFFSET', 'DatasetEnum',
]
59 changes: 40 additions & 19 deletions mmdet/core/evaluation/class_names.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
from enum import Enum

class DatasetEnum(Enum):
VOC = 'voc'
IMAGENET_DET = 'imagenet_det'
IMAGENET_VID = 'imagenet_vid'
COCO = 'coco'
LVIS = 'lvis'
WIDER_FACE = 'wider_face'
CITYSCAPES = 'cityscapes'
OID_CHALLENGE = 'oid_challenge'
OID_V6 = 'oid_v6'

def wider_face_classes():
return ['face']
Expand Down Expand Up @@ -544,30 +555,40 @@ def oid_v6_classes():


dataset_aliases = {
'voc': ['voc', 'pascal_voc', 'voc07', 'voc12'],
'imagenet_det': ['det', 'imagenet_det', 'ilsvrc_det'],
'imagenet_vid': ['vid', 'imagenet_vid', 'ilsvrc_vid'],
'coco': ['coco', 'mscoco', 'ms_coco'],
'lvis': ['lvis'],
'wider_face': ['WIDERFaceDataset', 'wider_face', 'WIDERFace'],
'cityscapes': ['cityscapes'],
'oid_challenge': ['oid_challenge', 'openimages_challenge'],
'oid_v6': ['oid_v6', 'openimages_v6']
DatasetEnum.VOC: ['voc', 'pascal_voc', 'voc07', 'voc12'],
DatasetEnum.IMAGENET_DET: ['det', 'imagenet_det', 'ilsvrc_det'],
DatasetEnum.IMAGENET_VID: ['vid', 'imagenet_vid', 'ilsvrc_vid'],
DatasetEnum.COCO: ['coco', 'mscoco', 'ms_coco'],
DatasetEnum.LVIS: ['lvis'],
DatasetEnum.WIDER_FACE: ['WIDERFaceDataset', 'wider_face', 'WIDERFace'],
DatasetEnum.CITYSCAPES: ['cityscapes'],
DatasetEnum.OID_CHALLENGE: ['oid_challenge', 'openimages_challenge'],
DatasetEnum.OID_V6: ['oid_v6', 'openimages_v6']
}

# def get_classes(dataset):
# """Get class names of a dataset."""
# alias2name = {}
# for name, aliases in dataset_aliases.items():
# for alias in aliases:
# alias2name[alias] = name

# if mmcv.is_str(dataset):
# if dataset in alias2name:
# labels = eval(alias2name[dataset] + '_classes()')
# else:
# raise ValueError(f'Unrecognized dataset: {dataset}')
# else:
# raise TypeError(f'dataset must a str, but got {type(dataset)}')
# return labels

def get_classes(dataset):
"""Get class names of a dataset."""
alias2name = {}
for name, aliases in dataset_aliases.items():
for alias in aliases:
alias2name[alias] = name

if mmcv.is_str(dataset):
if dataset in alias2name:
labels = eval(alias2name[dataset] + '_classes()')
if isinstance(dataset, DatasetEnum):
if dataset in dataset_aliases:
labels = eval(dataset_aliases[dataset][0] + '_classes()')
else:
raise ValueError(f'Unrecognized dataset: {dataset}')
else:
raise TypeError(f'dataset must a str, but got {type(dataset)}')
return labels
raise TypeError(f'dataset must be a DatasetEnum member, but got {type(dataset)}')
return labels

0 comments on commit f1549b5

Please sign in to comment.