Skip to content

Commit

Permalink
added focal loss
Browse files Browse the repository at this point in the history
  • Loading branch information
jwyang committed Sep 21, 2017
1 parent c726bc8 commit c69573f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
2 changes: 2 additions & 0 deletions cfgs/res101.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ TRAIN:
TEST:
HAS_RPN: True
POOLING_SIZE: 7
POOLING_MODE: align
CROP_RESIZE_WITH_MAX_POOL: False
7 changes: 6 additions & 1 deletion lib/model/faster_rcnn/faster_rcnn_cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,12 @@ def forward(self, im_data, im_info, gt_boxes, num_boxes):
self.fg_cnt = torch.sum(label.data.ne(0))
self.bg_cnt = label.data.numel() - self.fg_cnt

self.RCNN_loss_cls = F.cross_entropy(cls_score, label)
# focal loss
weights = (1 - cls_prob) ** 2
focal_prob = torch.log(cls_prob) * weights.detach()
self.RCNN_loss_cls = F.nll_loss(focal_prob, label)

# self.RCNN_loss_cls = F.cross_entropy(cls_score, label, weights.detach())

# bounding box regression L1 loss
self.RCNN_loss_bbox = _smooth_l1_loss(bbox_pred, rois_target, rois_inside_ws, rois_outside_ws)
Expand Down

0 comments on commit c69573f

Please sign in to comment.