Skip to content

Commit

Permalink
modified pr
Browse files Browse the repository at this point in the history
  • Loading branch information
smilelite committed Jul 11, 2022
1 parent 4a3b874 commit cb37041
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 68 deletions.
2 changes: 1 addition & 1 deletion ppocr/data/imaug/label_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,7 +1217,7 @@ def add_special_char(self, dict_character):
dict_character = ['</s>'] + dict_character
return dict_character

class SPINAttnLabelEncode(BaseRecLabelEncode):
class SPINAttnLabelEncode(AttnLabelEncode):
""" Convert between text-label and text-index """

def __init__(self,
Expand Down
5 changes: 5 additions & 0 deletions ppocr/modeling/heads/rec_spin_att_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This code is refer from:
https://github.com/hikopensource/DAVAR-Lab-OCR/davarocr/davar_rcg/models/sequence_heads/att_head.py
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
Expand Down
147 changes: 81 additions & 66 deletions ppocr/postprocess/rec_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,86 @@ def add_special_char(self, dict_character):
return dict_character


class SPINAttnLabelDecode(BaseRecLabelDecode):
# class SPINAttnLabelDecode(BaseRecLabelDecode):
# """ Convert between text-label and text-index """

# def __init__(self, character_dict_path=None, use_space_char=False,
# **kwargs):
# super(SPINAttnLabelDecode, self).__init__(character_dict_path,
# use_space_char)

# def add_special_char(self, dict_character):
# self.beg_str = "sos"
# self.end_str = "eos"
# dict_character = dict_character
# dict_character = [self.beg_str] + [self.end_str] + dict_character
# return dict_character

# def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
# """ convert text-index into text-label. """
# result_list = []
# ignored_tokens = self.get_ignored_tokens()
# [beg_idx, end_idx] = self.get_ignored_tokens()
# batch_size = len(text_index)
# for batch_idx in range(batch_size):
# char_list = []
# conf_list = []
# for idx in range(len(text_index[batch_idx])):
# if text_index[batch_idx][idx] == int(beg_idx):
# continue
# if int(text_index[batch_idx][idx]) == int(end_idx):
# break
# if is_remove_duplicate:
# # only for predict
# if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
# batch_idx][idx]:
# continue
# char_list.append(self.character[int(text_index[batch_idx][
# idx])])
# if text_prob is not None:
# conf_list.append(text_prob[batch_idx][idx])
# else:
# conf_list.append(1)
# text = ''.join(char_list)
# result_list.append((text.lower(), np.mean(conf_list).tolist()))
# return result_list

# def __call__(self, preds, label=None, *args, **kwargs):
# """
# text = self.decode(text)
# if label is None:
# return text
# else:
# label = self.decode(label, is_remove_duplicate=False)
# return text, label
# """
# if isinstance(preds, paddle.Tensor):
# preds = preds.numpy()

# preds_idx = preds.argmax(axis=2)
# preds_prob = preds.max(axis=2)
# text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
# if label is None:
# return text
# label = self.decode(label, is_remove_duplicate=False)
# return text, label

# def get_ignored_tokens(self):
# beg_idx = self.get_beg_end_flag_idx("beg")
# end_idx = self.get_beg_end_flag_idx("end")
# return [beg_idx, end_idx]

# def get_beg_end_flag_idx(self, beg_or_end):
# if beg_or_end == "beg":
# idx = np.array(self.dict[self.beg_str])
# elif beg_or_end == "end":
# idx = np.array(self.dict[self.end_str])
# else:
# assert False, "unsupport type %s in get_beg_end_flag_idx" \
# % beg_or_end
# return idx

class SPINAttnLabelDecode(AttnLabelDecode):
""" Convert between text-label and text-index """

def __init__(self, character_dict_path=None, use_space_char=False,
Expand All @@ -682,68 +761,4 @@ def add_special_char(self, dict_character):
self.end_str = "eos"
dict_character = dict_character
dict_character = [self.beg_str] + [self.end_str] + dict_character
return dict_character

def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
""" convert text-index into text-label. """
result_list = []
ignored_tokens = self.get_ignored_tokens()
[beg_idx, end_idx] = self.get_ignored_tokens()
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
conf_list = []
for idx in range(len(text_index[batch_idx])):
if text_index[batch_idx][idx] == int(beg_idx):
continue
if int(text_index[batch_idx][idx]) == int(end_idx):
break
if is_remove_duplicate:
# only for predict
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
batch_idx][idx]:
continue
char_list.append(self.character[int(text_index[batch_idx][
idx])])
if text_prob is not None:
conf_list.append(text_prob[batch_idx][idx])
else:
conf_list.append(1)
text = ''.join(char_list)
result_list.append((text.lower(), np.mean(conf_list).tolist()))
return result_list

def __call__(self, preds, label=None, *args, **kwargs):
"""
text = self.decode(text)
if label is None:
return text
else:
label = self.decode(label, is_remove_duplicate=False)
return text, label
"""
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()

preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
if label is None:
return text
label = self.decode(label, is_remove_duplicate=False)
return text, label

def get_ignored_tokens(self):
beg_idx = self.get_beg_end_flag_idx("beg")
end_idx = self.get_beg_end_flag_idx("end")
return [beg_idx, end_idx]

def get_beg_end_flag_idx(self, beg_or_end):
if beg_or_end == "beg":
idx = np.array(self.dict[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict[self.end_str])
else:
assert False, "unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
return dict_character
2 changes: 1 addition & 1 deletion tools/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def export_single_model(model,
]
# print([None, 3, 32, 128])
model = to_static(model, input_spec=other_shape)
elif arch_config["algorithm"] == "NRTR" or arch_config["algorithm"] == "SPIN":
elif arch_config["algorithm"] in ["NRTR", "SPIN"]:
other_shape = [
paddle.static.InputSpec(
shape=[None, 1, 32, 100], dtype="float32"),
Expand Down

0 comments on commit cb37041

Please sign in to comment.