diff --git a/checklist/editor.py b/checklist/editor.py index 51dde25..ac4214c 100644 --- a/checklist/editor.py +++ b/checklist/editor.py @@ -478,9 +478,9 @@ def visual_suggest(self, templates, **kwargs): tagged_keys = find_all_keys(templates) template_strs = get_all_strings_ordered(templates) items = self._get_fillin_items(tagged_keys, max_count=5, **kwargs) + kwargs["verbose"] = False mask_suggests = self.suggest(templates, **kwargs) - if not mask_suggests: raise Exception('No valid suggestions for the given template!') self.selected_suggestions = [] @@ -490,7 +490,8 @@ def visual_suggest(self, templates, **kwargs): tag_dict=items, mask_suggests=mask_suggests[:50], format_fn=recursive_format, - select_suggests_fn=self._set_selected_suggestions + select_suggests_fn=self._set_selected_suggestions, + tokenizer=self.tg.tokenizer ) def add_lexicon(self, name, values, overwrite=False): diff --git a/checklist/viewer/template_editor.py b/checklist/viewer/template_editor.py index 97decd1..0d0cff6 100644 --- a/checklist/viewer/template_editor.py +++ b/checklist/viewer/template_editor.py @@ -3,8 +3,6 @@ import os import typing import itertools -from spacy.attrs import LEMMA, ORTH, NORM -from spacy.lang.en import English try: from IPython.core.display import display, Javascript @@ -37,51 +35,54 @@ def __init__(self, \ mask_suggests: typing.List[typing.Union[str, tuple]], \ format_fn: typing.Callable, \ select_suggests_fn: typing.Callable, \ + tokenizer, \ **kwargs): widgets.DOMWidget.__init__(self, **kwargs) self.format_fn = format_fn self.select_suggests_fn = select_suggests_fn - - nlp = English() # ONLY do tokenization here - self.tokenizer = nlp.Defaults.create_tokenizer(nlp) + self.tokenizer = tokenizer self.bert_suggests = mask_suggests self.templates = [ self.tokenize_template_str(s, tagged_keys, tag_dict) for \ s in template_strs] self.on_msg(self.handle_events) + def tokenize_template_str(self, template_str, tagged_keys, tag_dict, max_count=5): tagged_keys = list(tagged_keys) trans_keys = ["{" + key + "}" for key in tagged_keys] - #keys = list(fillins.keys()) + [bert_key] - for idx, key in enumerate(tagged_keys): - case = [{LEMMA: key.split(":")[-1], NORM: key, ORTH: trans_keys[idx] }] - self.tokenizer.add_special_case(trans_keys[idx], case) - tokens = self.tokenizer(template_str) - template_tokens = [] item_keys = [x[0] for x in tag_dict.items()] item_vals = [[x[1][:max_count]] if type(x[1]) not in [list, tuple] else x[1][:max_count] for x in tag_dict.items()] local_items = [] + for idx, key in enumerate(tagged_keys): + self.tokenizer.add_tokens(trans_keys[idx]) for item_val in itertools.product(*item_vals): if len(item_val) != len(set([str(x) for x in item_val])): continue local_item = {item_keys[i]: item_val[i] for i, _ in enumerate(item_val)} local_items.append(local_item) - - for t in tokens: - if t.norm_ in tagged_keys: - tag = t.norm_ + + def _tokenize(text): + tokens = [self.tokenizer.decode(x) for x in self.tokenizer.encode(text, add_special_tokens=False)] + return [t for t in tokens if t] + def get_meta(text): + if text in trans_keys: + idx = trans_keys.index(text) + norm = tagged_keys[idx] + lemma = norm.split(":")[-1] + normalized_key = lemma.split('[')[0].split('.')[0] texts = list() for local_item in local_items: try: - text = self.format_fn(["{" + t.lemma_ +"}"], local_item)[0] - texts.append(text) + texts.append(self.format_fn(["{" + lemma +"}"], local_item)[0]) except: pass - template_tokens.append((texts, t.norm_, t.lemma_.split('[')[0].split('.')[0])) + return (texts, norm, normalized_key) else: - template_tokens.append(t.text) + return text + + template_tokens = [get_meta(t) for t in _tokenize(template_str)] return template_tokens def handle_events(self, _, content, buffers):