Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -12,6 +12,13 @@ model = BertForMaskedLM.from_pretrained(pretrained_model_name_or_path)
|
|
| 12 |
vocab = tokenizer.vocab
|
| 13 |
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
def func_macro_correct(text):
|
| 16 |
with torch.no_grad():
|
| 17 |
outputs = model(**tokenizer([text], padding=True, return_tensors='pt'))
|
|
@@ -29,7 +36,53 @@ def func_macro_correct(text):
|
|
| 29 |
return False
|
| 30 |
return True
|
| 31 |
|
| 32 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
"""Get new corrected text and errors between corrected text and origin text
|
| 34 |
code from: https://github.com/shibing624/pycorrector
|
| 35 |
"""
|
|
@@ -57,7 +110,14 @@ def func_macro_correct(text):
|
|
| 57 |
|
| 58 |
_text = tokenizer.decode(torch.argmax(outputs.logits[0], dim=-1), skip_special_tokens=True).replace(' ', '')
|
| 59 |
corrected_text = _text[:len(text)]
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
print(text, ' => ', corrected_text, details)
|
| 62 |
return corrected_text + ' ' + str(details)
|
| 63 |
|
|
@@ -88,6 +148,6 @@ if __name__ == '__main__':
|
|
| 88 |
description="Copy or input error Chinese text. Submit and the machine will correct text.",
|
| 89 |
article="Link to <a href='https://github.com/yongzhuo/macro-correct' style='color:blue;' target='_blank\'>Github REPO: macro-correct</a>",
|
| 90 |
examples=examples
|
| 91 |
-
).launch()
|
| 92 |
|
| 93 |
|
|
|
|
| 12 |
vocab = tokenizer.vocab
|
| 13 |
|
| 14 |
|
| 15 |
+
# from modelscope import AutoTokenizer, AutoModelForMaskedLM
|
| 16 |
+
# pretrained_model_name_or_path = "Macadam/macbert4mdcspell_v2"
|
| 17 |
+
# tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
|
| 18 |
+
# model = AutoModelForMaskedLM.from_pretrained(pretrained_model_name_or_path)
|
| 19 |
+
# vocab = tokenizer.vocab
|
| 20 |
+
|
| 21 |
+
|
| 22 |
def func_macro_correct(text):
|
| 23 |
with torch.no_grad():
|
| 24 |
outputs = model(**tokenizer([text], padding=True, return_tensors='pt'))
|
|
|
|
| 36 |
return False
|
| 37 |
return True
|
| 38 |
|
| 39 |
+
def get_errors_from_diff_length(corrected_text, origin_text, unk_tokens=[], know_tokens=[]):
|
| 40 |
+
"""Get errors between corrected text and origin text
|
| 41 |
+
code from: https://github.com/shibing624/pycorrector
|
| 42 |
+
"""
|
| 43 |
+
new_corrected_text = ""
|
| 44 |
+
errors = []
|
| 45 |
+
i, j = 0, 0
|
| 46 |
+
unk_tokens = unk_tokens or [' ', 'β', 'β', 'β', 'β', 'η', '\n', 'β¦', 'ζ€', '\t', 'η', 'ο']
|
| 47 |
+
while i < len(origin_text) and j < len(corrected_text):
|
| 48 |
+
if origin_text[i] in unk_tokens or origin_text[i] not in know_tokens:
|
| 49 |
+
new_corrected_text += origin_text[i]
|
| 50 |
+
i += 1
|
| 51 |
+
elif corrected_text[j] in unk_tokens:
|
| 52 |
+
new_corrected_text += corrected_text[j]
|
| 53 |
+
j += 1
|
| 54 |
+
# Deal with Chinese characters
|
| 55 |
+
elif flag_total_chinese(origin_text[i]) and flag_total_chinese(corrected_text[j]):
|
| 56 |
+
# If the two characters are the same, then the two pointers move forward together
|
| 57 |
+
if origin_text[i] == corrected_text[j]:
|
| 58 |
+
new_corrected_text += corrected_text[j]
|
| 59 |
+
i += 1
|
| 60 |
+
j += 1
|
| 61 |
+
else:
|
| 62 |
+
# Check for insertion errors
|
| 63 |
+
if j + 1 < len(corrected_text) and origin_text[i] == corrected_text[j + 1]:
|
| 64 |
+
errors.append(('', corrected_text[j], j))
|
| 65 |
+
new_corrected_text += corrected_text[j]
|
| 66 |
+
j += 1
|
| 67 |
+
# Check for deletion errors
|
| 68 |
+
elif i + 1 < len(origin_text) and origin_text[i + 1] == corrected_text[j]:
|
| 69 |
+
errors.append((origin_text[i], '', i))
|
| 70 |
+
i += 1
|
| 71 |
+
# Check for replacement errors
|
| 72 |
+
else:
|
| 73 |
+
errors.append((origin_text[i], corrected_text[j], i))
|
| 74 |
+
new_corrected_text += corrected_text[j]
|
| 75 |
+
i += 1
|
| 76 |
+
j += 1
|
| 77 |
+
else:
|
| 78 |
+
new_corrected_text += origin_text[i]
|
| 79 |
+
if origin_text[i] == corrected_text[j]:
|
| 80 |
+
j += 1
|
| 81 |
+
i += 1
|
| 82 |
+
errors = sorted(errors, key=operator.itemgetter(2))
|
| 83 |
+
return new_corrected_text, errors
|
| 84 |
+
|
| 85 |
+
def get_errors_from_same_length(corrected_text, origin_text, unk_tokens=[], know_tokens=[]):
|
| 86 |
"""Get new corrected text and errors between corrected text and origin text
|
| 87 |
code from: https://github.com/shibing624/pycorrector
|
| 88 |
"""
|
|
|
|
| 110 |
|
| 111 |
_text = tokenizer.decode(torch.argmax(outputs.logits[0], dim=-1), skip_special_tokens=True).replace(' ', '')
|
| 112 |
corrected_text = _text[:len(text)]
|
| 113 |
+
print("#" * 128)
|
| 114 |
+
print(text)
|
| 115 |
+
print(corrected_text)
|
| 116 |
+
print(len(text), len(corrected_text))
|
| 117 |
+
if len(corrected_text) == len(text):
|
| 118 |
+
corrected_text, details = get_errors_from_same_length(corrected_text, text, know_tokens=vocab)
|
| 119 |
+
else:
|
| 120 |
+
corrected_text, details = get_errors_from_diff_length(corrected_text, text, know_tokens=vocab)
|
| 121 |
print(text, ' => ', corrected_text, details)
|
| 122 |
return corrected_text + ' ' + str(details)
|
| 123 |
|
|
|
|
| 148 |
description="Copy or input error Chinese text. Submit and the machine will correct text.",
|
| 149 |
article="Link to <a href='https://github.com/yongzhuo/macro-correct' style='color:blue;' target='_blank\'>Github REPO: macro-correct</a>",
|
| 150 |
examples=examples
|
| 151 |
+
).launch() # .launch(server_name="0.0.0.0", server_port=8036, share=False, debug=True)
|
| 152 |
|
| 153 |
|