diff --git a/mmdet/core/evaluation/__init__.py b/mmdet/core/evaluation/__init__.py index 67e7c55..ef7c0ca 100644 --- a/mmdet/core/evaluation/__init__.py +++ b/mmdet/core/evaluation/__init__.py @@ -2,7 +2,7 @@ from .class_names import (cityscapes_classes, coco_classes, dataset_aliases, get_classes, imagenet_det_classes, imagenet_vid_classes, oid_challenge_classes, - oid_v6_classes, voc_classes) + oid_v6_classes, voc_classes, DatasetEnum) from .eval_hooks import DistEvalHook, EvalHook from .mean_ap import average_precision, eval_map, print_map_summary from .panoptic_utils import INSTANCE_OFFSET @@ -15,5 +15,5 @@ 'DistEvalHook', 'EvalHook', 'average_precision', 'eval_map', 'print_map_summary', 'eval_recalls', 'print_recall_summary', 'plot_num_recall', 'plot_iou_recall', 'oid_v6_classes', - 'oid_challenge_classes', 'INSTANCE_OFFSET' + 'oid_challenge_classes', 'INSTANCE_OFFSET', 'DatasetEnum', ] diff --git a/mmdet/core/evaluation/class_names.py b/mmdet/core/evaluation/class_names.py index 13531e2..f10b95f 100644 --- a/mmdet/core/evaluation/class_names.py +++ b/mmdet/core/evaluation/class_names.py @@ -1,6 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. import mmcv +from enum import Enum +class DatasetEnum(Enum): + VOC = 'voc' + IMAGENET_DET = 'imagenet_det' + IMAGENET_VID = 'imagenet_vid' + COCO = 'coco' + LVIS = 'lvis' + WIDER_FACE = 'wider_face' + CITYSCAPES = 'cityscapes' + OID_CHALLENGE = 'oid_challenge' + OID_V6 = 'oid_v6' def wider_face_classes(): return ['face'] @@ -544,30 +555,40 @@ def oid_v6_classes(): dataset_aliases = { - 'voc': ['voc', 'pascal_voc', 'voc07', 'voc12'], - 'imagenet_det': ['det', 'imagenet_det', 'ilsvrc_det'], - 'imagenet_vid': ['vid', 'imagenet_vid', 'ilsvrc_vid'], - 'coco': ['coco', 'mscoco', 'ms_coco'], - 'lvis': ['lvis'], - 'wider_face': ['WIDERFaceDataset', 'wider_face', 'WIDERFace'], - 'cityscapes': ['cityscapes'], - 'oid_challenge': ['oid_challenge', 'openimages_challenge'], - 'oid_v6': ['oid_v6', 'openimages_v6'] + DatasetEnum.VOC: ['voc', 'pascal_voc', 'voc07', 'voc12'], + DatasetEnum.IMAGENET_DET: ['det', 'imagenet_det', 'ilsvrc_det'], + DatasetEnum.IMAGENET_VID: ['vid', 'imagenet_vid', 'ilsvrc_vid'], + DatasetEnum.COCO: ['coco', 'mscoco', 'ms_coco'], + DatasetEnum.LVIS: ['lvis'], + DatasetEnum.WIDER_FACE: ['WIDERFaceDataset', 'wider_face', 'WIDERFace'], + DatasetEnum.CITYSCAPES: ['cityscapes'], + DatasetEnum.OID_CHALLENGE: ['oid_challenge', 'openimages_challenge'], + DatasetEnum.OID_V6: ['oid_v6', 'openimages_v6'] } +# def get_classes(dataset): +# """Get class names of a dataset.""" +# alias2name = {} +# for name, aliases in dataset_aliases.items(): +# for alias in aliases: +# alias2name[alias] = name + +# if mmcv.is_str(dataset): +# if dataset in alias2name: +# labels = eval(alias2name[dataset] + '_classes()') +# else: +# raise ValueError(f'Unrecognized dataset: {dataset}') +# else: +# raise TypeError(f'dataset must a str, but got {type(dataset)}') +# return labels def get_classes(dataset): """Get class names of a dataset.""" - alias2name = {} - for name, aliases in dataset_aliases.items(): - for alias in aliases: - alias2name[alias] = name - - if mmcv.is_str(dataset): - if dataset in alias2name: - labels = eval(alias2name[dataset] + '_classes()') + if isinstance(dataset, DatasetEnum): + if dataset in dataset_aliases: + labels = eval(dataset_aliases[dataset][0] + '_classes()') else: raise ValueError(f'Unrecognized dataset: {dataset}') else: - raise TypeError(f'dataset must a str, but got {type(dataset)}') - return labels + raise TypeError(f'dataset must be a DatasetEnum member, but got {type(dataset)}') + return labels \ No newline at end of file