Skip to content

Commit

Permalink
Merge pull request jwyang#35 from cclauss/modernize-python2-code
Browse files Browse the repository at this point in the history
Modernize Python 2 code to get ready for Python 3
  • Loading branch information
jwyang committed Jan 20, 2018
2 parents df76857 + 35ad4ed commit 1cc6121
Show file tree
Hide file tree
Showing 22 changed files with 153 additions and 81 deletions.
20 changes: 13 additions & 7 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,20 @@
from model.faster_rcnn.resnet import resnet
import pdb

try:
xrange # Python 2
except NameError:
xrange = range # Python 3


def parse_args():
"""
Parse input arguments
"""
parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')
parser.add_argument('--dataset', dest='dataset',
help='training dataset',
default='pascal_voc', type=str)
default='pascal_voc', type=str)
parser.add_argument('--cfg', dest='cfg_file',
help='optional config file',
default='cfgs/vgg16.yml', type=str)
Expand All @@ -59,7 +65,7 @@ def parse_args():
nargs=argparse.REMAINDER)
parser.add_argument('--image_dir', dest='image_dir',
help='directory to load images for demo', default="images",
nargs=argparse.REMAINDER)
nargs=argparse.REMAINDER)
parser.add_argument('--cuda', dest='cuda',
help='whether use CUDA',
action='store_true')
Expand All @@ -68,7 +74,7 @@ def parse_args():
action='store_true')
parser.add_argument('--cag', dest='class_agnostic',
help='whether perform class_agnostic bbox regression',
action='store_true')
action='store_true')
parser.add_argument('--parallel_type', dest='parallel_type',
help='which part of model to parallel, 0: all, 1: model before roi pooling',
default=0, type=int)
Expand All @@ -86,7 +92,7 @@ def parse_args():
default=1, type=int)
parser.add_argument('--vis', dest='vis',
help='visualization mode',
action='store_true')
action='store_true')
args = parser.parse_args()
return args

Expand Down Expand Up @@ -174,7 +180,7 @@ def _get_image_blob(im):
pdb.set_trace()

fasterRCNN.create_architecture()

print("load checkpoint %s" % (load_name))
checkpoint = torch.load(load_name)
fasterRCNN.load_state_dict(checkpoint['model'])
Expand Down Expand Up @@ -259,7 +265,7 @@ def _get_image_blob(im):
rpn_loss_cls, rpn_loss_box, \
RCNN_loss_cls, RCNN_loss_bbox, \
rois_label = fasterRCNN(im_data, im_info, gt_boxes, num_boxes)

scores = cls_prob.data
boxes = rois.data[:, :, 1:5]

Expand All @@ -285,7 +291,7 @@ def _get_image_blob(im):

pred_boxes /= im_scales[0]

scores = scores.squeeze()
scores = scores.squeeze()
pred_boxes = pred_boxes.squeeze()
det_toc = time.time()
detect_time = det_toc - det_tic
Expand Down
11 changes: 9 additions & 2 deletions lib/datasets/imagenet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import print_function
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
Expand All @@ -17,6 +18,12 @@
import subprocess
import pdb

try:
xrange # Python 2
except NameError:
xrange = range # Python 3


class imagenet(imdb):
def __init__(self, image_set, devkit_path, data_path):
imdb.__init__(self, image_set)
Expand Down Expand Up @@ -151,14 +158,14 @@ def gt_roidb(self):
if os.path.exists(cache_file):
with open(cache_file, 'rb') as fid:
roidb = cPickle.load(fid)
print '{} gt roidb loaded from {}'.format(self.name, cache_file)
print('{} gt roidb loaded from {}'.format(self.name, cache_file))
return roidb

gt_roidb = [self._load_imagenet_annotation(index)
for index in self.image_index]
with open(cache_file, 'wb') as fid:
cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL)
print 'wrote gt roidb to {}'.format(cache_file)
print('wrote gt roidb to {}'.format(cache_file))

return gt_roidb

Expand Down
28 changes: 17 additions & 11 deletions lib/datasets/pascal_voc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import print_function
from __future__ import absolute_import
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
Expand All @@ -21,13 +23,17 @@

from .imdb import imdb
from .imdb import ROOT_DIR
import ds_utils
from . import ds_utils
from .voc_eval import voc_eval

# TODO: make fast_rcnn irrelevant
# >>>> obsolete, because it depends on sth outside of this project
from model.utils.config import cfg

try:
xrange # Python 2
except NameError:
xrange = range # Python 3

# <<<< obsolete

Expand Down Expand Up @@ -120,14 +126,14 @@ def gt_roidb(self):
if os.path.exists(cache_file):
with open(cache_file, 'rb') as fid:
roidb = cPickle.load(fid)
print '{} gt roidb loaded from {}'.format(self.name, cache_file)
print('{} gt roidb loaded from {}'.format(self.name, cache_file))
return roidb

gt_roidb = [self._load_pascal_annotation(index)
for index in self.image_index]
with open(cache_file, 'wb') as fid:
cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL)
print 'wrote gt roidb to {}'.format(cache_file)
print('wrote gt roidb to {}'.format(cache_file))

return gt_roidb

Expand All @@ -144,7 +150,7 @@ def selective_search_roidb(self):
if os.path.exists(cache_file):
with open(cache_file, 'rb') as fid:
roidb = cPickle.load(fid)
print '{} ss roidb loaded from {}'.format(self.name, cache_file)
print('{} ss roidb loaded from {}'.format(self.name, cache_file))
return roidb

if int(self._year) == 2007 or self._image_set != 'test':
Expand All @@ -155,7 +161,7 @@ def selective_search_roidb(self):
roidb = self._load_selective_search_roidb(None)
with open(cache_file, 'wb') as fid:
cPickle.dump(roidb, fid, cPickle.HIGHEST_PROTOCOL)
print 'wrote ss roidb to {}'.format(cache_file)
print('wrote ss roidb to {}'.format(cache_file))

return roidb

Expand All @@ -171,7 +177,7 @@ def rpn_roidb(self):

def _load_rpn_roidb(self, gt_roidb):
filename = self.config['rpn_file']
print 'loading {}'.format(filename)
print('loading {}'.format(filename))
assert os.path.exists(filename), \
'rpn data not found at: {}'.format(filename)
with open(filename, 'rb') as f:
Expand Down Expand Up @@ -268,7 +274,7 @@ def _write_voc_results_file(self, all_boxes):
for cls_ind, cls in enumerate(self.classes):
if cls == '__background__':
continue
print 'Writing {} VOC results file'.format(cls)
print('Writing {} VOC results file'.format(cls))
filename = self._get_voc_results_file_template().format(cls)
with open(filename, 'wt') as f:
for im_ind, index in enumerate(self.image_index):
Expand Down Expand Up @@ -298,7 +304,7 @@ def _do_python_eval(self, output_dir='output'):
aps = []
# The PASCAL VOC metric changed in 2010
use_07_metric = True if int(self._year) < 2010 else False
print 'VOC07 metric? ' + ('Yes' if use_07_metric else 'No')
print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No'))
if not os.path.isdir(output_dir):
os.mkdir(output_dir)
for i, cls in enumerate(self._classes):
Expand Down Expand Up @@ -328,9 +334,9 @@ def _do_python_eval(self, output_dir='output'):
print('--------------------------------------------------------------')

def _do_matlab_eval(self, output_dir='output'):
print '-----------------------------------------------------'
print 'Computing results with the official MATLAB eval code.'
print '-----------------------------------------------------'
print('-----------------------------------------------------')
print('Computing results with the official MATLAB eval code.')
print('-----------------------------------------------------')
path = os.path.join(cfg.ROOT_DIR, 'lib', 'datasets',
'VOCdevkit-matlab-wrapper')
cmd = 'cd {} && '.format(path)
Expand Down
3 changes: 2 additions & 1 deletion lib/datasets/tools/mcg_munge.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import print_function
import os
import sys

Expand Down Expand Up @@ -28,7 +29,7 @@ def munge(src_dir):
os.makedirs(dst_dir)
src = os.path.join(src_dir, fn)
dst = os.path.join(dst_dir, fn)
print 'MV: {} -> {}'.format(src, dst)
print('MV: {} -> {}'.format(src, dst))
os.rename(src, dst)

if __name__ == '__main__':
Expand Down
28 changes: 18 additions & 10 deletions lib/datasets/vg.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import print_function
from __future__ import absolute_import
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
Expand All @@ -15,11 +17,17 @@
import gzip
import PIL
import json
from vg_eval import vg_eval
from .vg_eval import vg_eval
from model.utils.config import cfg
import pickle
import pdb

try:
xrange # Python 2
except NameError:
xrange = range # Python 3


class vg(imdb):
def __init__(self, version, image_set, ):
imdb.__init__(self, 'vg_' + version + '_' + image_set)
Expand Down Expand Up @@ -121,11 +129,11 @@ def _image_split_path(self):
if self._image_set == "minitrain":
return os.path.join(self._data_path, 'train.txt')
if self._image_set == "smalltrain":
return os.path.join(self._data_path, 'train.txt')
return os.path.join(self._data_path, 'train.txt')
if self._image_set == "minival":
return os.path.join(self._data_path, 'val.txt')
if self._image_set == "smallval":
return os.path.join(self._data_path, 'val.txt')
return os.path.join(self._data_path, 'val.txt')
else:
return os.path.join(self._data_path, self._image_set+'.txt')

Expand All @@ -141,7 +149,7 @@ def _load_image_set_index(self):
if self._image_set == "minitrain":
metadata = metadata[:1000]
elif self._image_set == "smalltrain":
metadata = metadata[:20000]
metadata = metadata[:20000]
elif self._image_set == "minival":
metadata = metadata[:100]
elif self._image_set == "smallval":
Expand Down Expand Up @@ -172,21 +180,21 @@ def gt_roidb(self):
Return the database of ground-truth regions of interest.
This function loads/saves from/to a cache file to speed up future calls.
"""
"""
cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
if os.path.exists(cache_file):
fid = gzip.open(cache_file,'rb')
roidb = cPickle.load(fid)
fid.close()
print '{} gt roidb loaded from {}'.format(self.name, cache_file)
print('{} gt roidb loaded from {}'.format(self.name, cache_file))
return roidb

gt_roidb = [self._load_vg_annotation(index)
for index in self.image_index]
fid = gzip.open(cache_file,'wb')
cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL)
fid.close()
print 'wrote gt roidb to {}'.format(cache_file)
print('wrote gt roidb to {}'.format(cache_file))
return gt_roidb

def _get_size(self, index):
Expand Down Expand Up @@ -227,7 +235,7 @@ def _load_vg_annotation(self, index):
y2 = min(height-1,float(bbox.find('ymax').text))
# If bboxes are not positive, just give whole image coords (there are a few examples)
if x2 < x1 or y2 < y1:
print 'Failed bbox in %s, object %s' % (filename, obj_name)
print('Failed bbox in %s, object %s' % (filename, obj_name))
x1 = 0
y1 = 0
x2 = width-1
Expand Down Expand Up @@ -312,7 +320,7 @@ def _write_voc_results_file(self, classes, all_boxes, output_dir):
for cls_ind, cls in enumerate(classes):
if cls == '__background__':
continue
print 'Writing "{}" vg results file'.format(cls)
print('Writing "{}" vg results file'.format(cls))
filename = self._get_vg_results_file_template(output_dir).format(cls)
with open(filename, 'wt') as f:
for im_ind, index in enumerate(self.image_index):
Expand All @@ -334,7 +342,7 @@ def _do_python_eval(self, output_dir, pickle=True, eval_attributes = False):
thresh = []
# The PASCAL VOC metric changed in 2010
use_07_metric = False
print 'VOC07 metric? ' + ('Yes' if use_07_metric else 'No')
print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No'))
if not os.path.isdir(output_dir):
os.mkdir(output_dir)
# Load ground truth
Expand Down
3 changes: 2 additions & 1 deletion lib/datasets/vg_eval.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import absolute_import
# --------------------------------------------------------
# Fast/er R-CNN
# Licensed under The MIT License [see LICENSE for details]
Expand All @@ -8,7 +9,7 @@
import os
import cPickle
import numpy as np
from voc_eval import voc_ap
from .voc_eval import voc_ap



Expand Down
1 change: 1 addition & 0 deletions lib/model/nms/build.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import print_function
import os
import torch
from torch.utils.ffi import create_extension
Expand Down
3 changes: 2 additions & 1 deletion lib/model/nms/nms_gpu.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import absolute_import
import torch
import numpy as np
from _ext import nms
from ._ext import nms
import pdb

def nms_gpu(dets, thresh):
Expand Down
1 change: 1 addition & 0 deletions lib/model/roi_align/build.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import print_function
import os
import torch
from torch.utils.ffi import create_extension
Expand Down
1 change: 1 addition & 0 deletions lib/model/roi_crop/build.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import print_function
import os
import torch
from torch.utils.ffi import create_extension
Expand Down
1 change: 1 addition & 0 deletions lib/model/roi_pooling/build.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import print_function
import os
import torch
from torch.utils.ffi import create_extension
Expand Down
Loading

0 comments on commit 1cc6121

Please sign in to comment.