Aidan Phillips
commited on
Commit
·
dc76b04
1
Parent(s):
b837a10
sussy math works with default sentence
Browse files- categories/fluency.py +74 -50
- requirements.txt +2 -1
- scorer.ipynb +33 -20
categories/fluency.py
CHANGED
|
@@ -3,6 +3,7 @@ from transformers import AutoTokenizer, AutoModelForMaskedLM
|
|
| 3 |
import torch
|
| 4 |
import numpy as np
|
| 5 |
import spacy
|
|
|
|
| 6 |
|
| 7 |
tool = language_tool_python.LanguageTool('en-US')
|
| 8 |
model_name="distilbert-base-multilingual-cased"
|
|
@@ -12,7 +13,10 @@ model.eval()
|
|
| 12 |
|
| 13 |
nlp = spacy.load("en_core_web_sm")
|
| 14 |
|
| 15 |
-
def
|
|
|
|
|
|
|
|
|
|
| 16 |
"""
|
| 17 |
We want to return
|
| 18 |
{
|
|
@@ -26,67 +30,87 @@ def pseudo_perplexity(text, max_len=128):
|
|
| 26 |
]
|
| 27 |
}
|
| 28 |
"""
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
-
|
| 32 |
-
|
|
|
|
| 33 |
|
| 34 |
loss_values = []
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
|
| 40 |
with torch.no_grad():
|
| 41 |
-
outputs = model(
|
| 42 |
-
logits = outputs.logits[0
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
current_end = i + 1
|
| 62 |
-
curr_loss = loss
|
| 63 |
-
else:
|
| 64 |
-
if current_end - current_start > max_length:
|
| 65 |
-
longest_start, longest_end = current_start, current_end
|
| 66 |
-
max_length = current_end - current_start
|
| 67 |
-
current_start, current_end = 0, 0
|
| 68 |
-
|
| 69 |
-
if current_end - current_start > max_length: # Check the last sequence
|
| 70 |
-
longest_start, longest_end = current_start, current_end
|
| 71 |
-
|
| 72 |
-
longest_sequence = (longest_start, longest_end)
|
| 73 |
|
| 74 |
-
|
|
|
|
|
|
|
| 75 |
|
| 76 |
res = {
|
| 77 |
-
"score": __fluency_score_from_ppl(
|
| 78 |
-
"errors":
|
| 79 |
-
{
|
| 80 |
-
"start": longest_sequence[0],
|
| 81 |
-
"end": longest_sequence[1],
|
| 82 |
-
"message": f"Perplexity above threshold: {curr_loss}"
|
| 83 |
-
}
|
| 84 |
-
]
|
| 85 |
}
|
| 86 |
|
| 87 |
return res
|
| 88 |
|
| 89 |
-
def __fluency_score_from_ppl(ppl, midpoint=
|
| 90 |
"""
|
| 91 |
Use a logistic function to map perplexity to 0–100.
|
| 92 |
Midpoint is the PPL where score is 50.
|
|
@@ -135,12 +159,12 @@ def grammar_errors(text) -> tuple[int, list[str]]:
|
|
| 135 |
|
| 136 |
return res
|
| 137 |
|
| 138 |
-
def __grammar_score_from_prob(error_ratio
|
| 139 |
"""
|
| 140 |
Transform the number of errors divided by words into a score from 0 to 100.
|
| 141 |
Steepness controls how quickly the score drops as errors increase.
|
| 142 |
"""
|
| 143 |
-
score = 100
|
| 144 |
return round(score, 2)
|
| 145 |
|
| 146 |
|
|
|
|
| 3 |
import torch
|
| 4 |
import numpy as np
|
| 5 |
import spacy
|
| 6 |
+
import wordfreq
|
| 7 |
|
| 8 |
tool = language_tool_python.LanguageTool('en-US')
|
| 9 |
model_name="distilbert-base-multilingual-cased"
|
|
|
|
| 13 |
|
| 14 |
nlp = spacy.load("en_core_web_sm")
|
| 15 |
|
| 16 |
+
def __get_word_pr_score(word, lang="en") -> list[float]:
|
| 17 |
+
return -np.log(wordfreq.word_frequency(word, lang) + 1e-12)
|
| 18 |
+
|
| 19 |
+
def pseudo_perplexity(text, threshold=20, max_len=128):
|
| 20 |
"""
|
| 21 |
We want to return
|
| 22 |
{
|
|
|
|
| 30 |
]
|
| 31 |
}
|
| 32 |
"""
|
| 33 |
+
encoding = tokenizer(text, return_tensors="pt", return_offsets_mapping=True)
|
| 34 |
+
input_ids = encoding["input_ids"][0]
|
| 35 |
+
print(input_ids)
|
| 36 |
+
offset_mapping = encoding["offset_mapping"][0]
|
| 37 |
+
print(offset_mapping)
|
| 38 |
+
tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
| 39 |
+
|
| 40 |
+
# Group token indices by word based on offset mapping
|
| 41 |
+
word_groups = []
|
| 42 |
+
current_group = []
|
| 43 |
+
|
| 44 |
+
prev_end = None
|
| 45 |
+
|
| 46 |
+
for i, (start, end) in enumerate(offset_mapping):
|
| 47 |
+
if input_ids[i] in tokenizer.all_special_ids:
|
| 48 |
+
continue # skip special tokens like [CLS] and [SEP]
|
| 49 |
+
|
| 50 |
+
if prev_end is not None and start > prev_end:
|
| 51 |
+
# Word boundary detected → start new group
|
| 52 |
+
word_groups.append(current_group)
|
| 53 |
+
current_group = [i]
|
| 54 |
+
else:
|
| 55 |
+
current_group.append(i)
|
| 56 |
+
|
| 57 |
+
prev_end = end
|
| 58 |
|
| 59 |
+
# Append final group
|
| 60 |
+
if current_group:
|
| 61 |
+
word_groups.append(current_group)
|
| 62 |
|
| 63 |
loss_values = []
|
| 64 |
+
tok_loss = []
|
| 65 |
+
for group in word_groups:
|
| 66 |
+
if group[0] == 0 or group[-1] == len(input_ids) - 1:
|
| 67 |
+
continue # skip [CLS] and [SEP]
|
| 68 |
|
| 69 |
+
masked = input_ids.clone()
|
| 70 |
+
for i in group:
|
| 71 |
+
masked[i] = tokenizer.mask_token_id
|
| 72 |
|
| 73 |
with torch.no_grad():
|
| 74 |
+
outputs = model(masked.unsqueeze(0))
|
| 75 |
+
logits = outputs.logits[0]
|
| 76 |
+
|
| 77 |
+
log_probs = []
|
| 78 |
+
for i in group:
|
| 79 |
+
probs = torch.softmax(logits[i], dim=-1)
|
| 80 |
+
true_token_id = input_ids[i].item()
|
| 81 |
+
prob = probs[true_token_id].item()
|
| 82 |
+
log_probs.append(np.log(prob + 1e-12))
|
| 83 |
+
tok_loss.append(-np.log(prob + 1e-12))
|
| 84 |
+
|
| 85 |
+
word_loss = -np.sum(log_probs) / len(log_probs)
|
| 86 |
+
word = tokenizer.decode(input_ids[group[0]])
|
| 87 |
+
word_loss -= 0.6 * __get_word_pr_score(word)
|
| 88 |
+
loss_values.append(word_loss)
|
| 89 |
|
| 90 |
+
print(loss_values)
|
| 91 |
+
|
| 92 |
+
errors = []
|
| 93 |
+
for i, l in enumerate(loss_values):
|
| 94 |
+
if l < threshold:
|
| 95 |
+
continue
|
| 96 |
+
errors.append({
|
| 97 |
+
"start": i,
|
| 98 |
+
"end": i,
|
| 99 |
+
"message": f"Perplexity {l} over threshold {threshold}"
|
| 100 |
+
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
+
print(tok_loss)
|
| 103 |
+
s_ppl = np.mean(tok_loss)
|
| 104 |
+
print(s_ppl)
|
| 105 |
|
| 106 |
res = {
|
| 107 |
+
"score": __fluency_score_from_ppl(s_ppl),
|
| 108 |
+
"errors": errors
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
}
|
| 110 |
|
| 111 |
return res
|
| 112 |
|
| 113 |
+
def __fluency_score_from_ppl(ppl, midpoint=8, steepness=0.3):
|
| 114 |
"""
|
| 115 |
Use a logistic function to map perplexity to 0–100.
|
| 116 |
Midpoint is the PPL where score is 50.
|
|
|
|
| 159 |
|
| 160 |
return res
|
| 161 |
|
| 162 |
+
def __grammar_score_from_prob(error_ratio):
|
| 163 |
"""
|
| 164 |
Transform the number of errors divided by words into a score from 0 to 100.
|
| 165 |
Steepness controls how quickly the score drops as errors increase.
|
| 166 |
"""
|
| 167 |
+
score = 100*(1-error_ratio)
|
| 168 |
return round(score, 2)
|
| 169 |
|
| 170 |
|
requirements.txt
CHANGED
|
@@ -1,3 +1,4 @@
|
|
| 1 |
language_tool_python
|
| 2 |
transformers
|
| 3 |
-
torch
|
|
|
|
|
|
| 1 |
language_tool_python
|
| 2 |
transformers
|
| 3 |
+
torch
|
| 4 |
+
wordfreq
|
scorer.ipynb
CHANGED
|
@@ -4,16 +4,7 @@
|
|
| 4 |
"cell_type": "code",
|
| 5 |
"execution_count": 1,
|
| 6 |
"metadata": {},
|
| 7 |
-
"outputs": [
|
| 8 |
-
{
|
| 9 |
-
"name": "stderr",
|
| 10 |
-
"output_type": "stream",
|
| 11 |
-
"text": [
|
| 12 |
-
"/opt/anaconda3/envs/teach-bs/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
| 13 |
-
" from .autonotebook import tqdm as notebook_tqdm\n"
|
| 14 |
-
]
|
| 15 |
-
}
|
| 16 |
-
],
|
| 17 |
"source": [
|
| 18 |
"from categories.fluency import *"
|
| 19 |
]
|
|
@@ -27,7 +18,25 @@
|
|
| 27 |
"name": "stdout",
|
| 28 |
"output_type": "stream",
|
| 29 |
"text": [
|
| 30 |
-
"Sentence: The
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
]
|
| 32 |
}
|
| 33 |
],
|
|
@@ -40,7 +49,7 @@
|
|
| 40 |
"print(\"Sentence:\", s) # Print the input sentence\n",
|
| 41 |
"\n",
|
| 42 |
"err = grammar_errors(s) # Call the function to execute the grammar error checking\n",
|
| 43 |
-
"flu = pseudo_perplexity(s) # Call the function to execute the fluency checking"
|
| 44 |
]
|
| 45 |
},
|
| 46 |
{
|
|
@@ -52,8 +61,12 @@
|
|
| 52 |
"name": "stdout",
|
| 53 |
"output_type": "stream",
|
| 54 |
"text": [
|
| 55 |
-
"
|
| 56 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
]
|
| 58 |
}
|
| 59 |
],
|
|
@@ -62,26 +75,26 @@
|
|
| 62 |
"\n",
|
| 63 |
"for e in combined_err:\n",
|
| 64 |
" substr = \" \".join(s.split(\" \")[e[\"start\"]:e[\"end\"]+1])\n",
|
| 65 |
-
" print(f\"{e['message']}: {substr}\") # Print the error messages\n"
|
| 66 |
-
"\n",
|
| 67 |
-
"print(combined_err)\n"
|
| 68 |
]
|
| 69 |
},
|
| 70 |
{
|
| 71 |
"cell_type": "code",
|
| 72 |
-
"execution_count":
|
| 73 |
"metadata": {},
|
| 74 |
"outputs": [
|
| 75 |
{
|
| 76 |
"name": "stdout",
|
| 77 |
"output_type": "stream",
|
| 78 |
"text": [
|
| 79 |
-
"
|
|
|
|
| 80 |
]
|
| 81 |
}
|
| 82 |
],
|
| 83 |
"source": [
|
| 84 |
-
"fluency_score = 0.
|
|
|
|
| 85 |
"print(\"Fluency Score:\", fluency_score) # Print the fluency score"
|
| 86 |
]
|
| 87 |
}
|
|
|
|
| 4 |
"cell_type": "code",
|
| 5 |
"execution_count": 1,
|
| 6 |
"metadata": {},
|
| 7 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
"source": [
|
| 9 |
"from categories.fluency import *"
|
| 10 |
]
|
|
|
|
| 18 |
"name": "stdout",
|
| 19 |
"output_type": "stream",
|
| 20 |
"text": [
|
| 21 |
+
"Sentence: The cat sat the quickly up apples banana.\n",
|
| 22 |
+
"tensor([ 101, 10117, 41163, 20694, 10105, 23590, 10741, 72894, 11268, 99304,\n",
|
| 23 |
+
" 10219, 119, 102])\n",
|
| 24 |
+
"tensor([[ 0, 0],\n",
|
| 25 |
+
" [ 0, 3],\n",
|
| 26 |
+
" [ 4, 7],\n",
|
| 27 |
+
" [ 8, 11],\n",
|
| 28 |
+
" [12, 15],\n",
|
| 29 |
+
" [16, 23],\n",
|
| 30 |
+
" [24, 26],\n",
|
| 31 |
+
" [27, 30],\n",
|
| 32 |
+
" [30, 33],\n",
|
| 33 |
+
" [34, 38],\n",
|
| 34 |
+
" [38, 40],\n",
|
| 35 |
+
" [40, 41],\n",
|
| 36 |
+
" [ 0, 0]])\n",
|
| 37 |
+
"[np.float64(0.00905743383887514), np.float64(1.1257066968185931), np.float64(4.8056646935577145), np.float64(4.473408069089179), np.float64(4.732453441503642), np.float64(3.028744414819041), np.float64(5.1115574262487735), np.float64(-0.6523823890571343)]\n",
|
| 38 |
+
"[np.float64(1.7636628003080927), np.float64(6.955413759407024), np.float64(10.828562153345375), np.float64(6.228013435558396), np.float64(10.258657658689351), np.float64(6.635744767229443), np.float64(11.163667119285972), np.float64(10.499412826924114), np.float64(11.96113847381264), np.float64(10.010973250156082), np.float64(2.470404176100153)]\n",
|
| 39 |
+
"0.5208035409471965\n"
|
| 40 |
]
|
| 41 |
}
|
| 42 |
],
|
|
|
|
| 49 |
"print(\"Sentence:\", s) # Print the input sentence\n",
|
| 50 |
"\n",
|
| 51 |
"err = grammar_errors(s) # Call the function to execute the grammar error checking\n",
|
| 52 |
+
"flu = pseudo_perplexity(s, threshold=2.5) # Call the function to execute the fluency checking"
|
| 53 |
]
|
| 54 |
},
|
| 55 |
{
|
|
|
|
| 61 |
"name": "stdout",
|
| 62 |
"output_type": "stream",
|
| 63 |
"text": [
|
| 64 |
+
"An apostrophe may be missing.: apples banana.\n",
|
| 65 |
+
"Perplexity 4.8056646935577145 over threshold 2.5: sat\n",
|
| 66 |
+
"Perplexity 4.473408069089179 over threshold 2.5: the\n",
|
| 67 |
+
"Perplexity 4.732453441503642 over threshold 2.5: quickly\n",
|
| 68 |
+
"Perplexity 3.028744414819041 over threshold 2.5: up\n",
|
| 69 |
+
"Perplexity 5.1115574262487735 over threshold 2.5: apples\n"
|
| 70 |
]
|
| 71 |
}
|
| 72 |
],
|
|
|
|
| 75 |
"\n",
|
| 76 |
"for e in combined_err:\n",
|
| 77 |
" substr = \" \".join(s.split(\" \")[e[\"start\"]:e[\"end\"]+1])\n",
|
| 78 |
+
" print(f\"{e['message']}: {substr}\") # Print the error messages\n"
|
|
|
|
|
|
|
| 79 |
]
|
| 80 |
},
|
| 81 |
{
|
| 82 |
"cell_type": "code",
|
| 83 |
+
"execution_count": null,
|
| 84 |
"metadata": {},
|
| 85 |
"outputs": [
|
| 86 |
{
|
| 87 |
"name": "stdout",
|
| 88 |
"output_type": "stream",
|
| 89 |
"text": [
|
| 90 |
+
"87.5 99.71\n",
|
| 91 |
+
"Fluency Score: 92.384\n"
|
| 92 |
]
|
| 93 |
}
|
| 94 |
],
|
| 95 |
"source": [
|
| 96 |
+
"fluency_score = 0.7 * err[\"score\"] + 0.3 * flu[\"score\"] # Calculate the fluency score\n",
|
| 97 |
+
"print(err[\"score\"], flu[\"score\"]) # Print the individual scores\n",
|
| 98 |
"print(\"Fluency Score:\", fluency_score) # Print the fluency score"
|
| 99 |
]
|
| 100 |
}
|