Skip to content

Commit

Permalink
添加RFL CNT分支infer支持
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiminzhang0830 committed Oct 8, 2022
1 parent 3f8602c commit c459b72
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 11 deletions.
5 changes: 2 additions & 3 deletions ppocr/modeling/heads/rec_rfl_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def forward(self, x, targets=None):
else:
seq_outputs = self.seq_head(seq_inputs, None,
self.batch_max_legnth)
return cnt_outputs, seq_outputs
else:
seq_outputs = None

return cnt_outputs, seq_outputs
return cnt_outputs
10 changes: 6 additions & 4 deletions ppocr/postprocess/rec_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,8 @@ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
return result_list

def __call__(self, preds, label=None, *args, **kwargs):
cnt_pred, preds = preds
if preds is not None:

if len(preds) == 2:
cnt_pred, preds = preds
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
Expand All @@ -302,9 +301,12 @@ def __call__(self, preds, label=None, *args, **kwargs):
return text, label

else:
cnt_pred = preds
if isinstance(cnt_pred, paddle.Tensor):
cnt_pred = cnt_pred.numpy()
cnt_length = []
for lens in cnt_pred:
length = round(paddle.sum(lens).item())
length = round(np.sum(lens))
cnt_length.append(length)
if label is None:
return cnt_length
Expand Down
14 changes: 10 additions & 4 deletions tools/infer_rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def main():
elif config['Architecture']['algorithm'] == "SAR":
op[op_name]['keep_keys'] = ['image', 'valid_ratio']
elif config['Architecture']['algorithm'] == "RobustScanner":
op[op_name]['keep_keys'] = ['image', 'valid_ratio', 'word_positons']
op[op_name][
'keep_keys'] = ['image', 'valid_ratio', 'word_positons']
else:
op[op_name]['keep_keys'] = ['image']
transforms.append(op)
Expand Down Expand Up @@ -136,9 +137,10 @@ def main():
if config['Architecture']['algorithm'] == "RobustScanner":
valid_ratio = np.expand_dims(batch[1], axis=0)
word_positons = np.expand_dims(batch[2], axis=0)
img_metas = [paddle.to_tensor(valid_ratio),
paddle.to_tensor(word_positons),
]
img_metas = [
paddle.to_tensor(valid_ratio),
paddle.to_tensor(word_positons),
]
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
if config['Architecture']['algorithm'] == "SRN":
Expand All @@ -160,6 +162,10 @@ def main():
"score": float(post_result[key][0][1]),
}
info = json.dumps(rec_info, ensure_ascii=False)
elif isinstance(post_result, list) and isinstance(post_result[0],
int):
# for RFLearning CNT branch
info = str(post_result[0])
else:
if len(post_result[0]) >= 2:
info = post_result[0][0] + "\t" + str(post_result[0][1])
Expand Down

0 comments on commit c459b72

Please sign in to comment.