Spaces:
Build error
Build error
Joe Davison
commited on
Commit
·
85dd546
1
Parent(s):
039194f
fix model caching error
Browse files
app.py
CHANGED
|
@@ -13,6 +13,9 @@ import psutil
|
|
| 13 |
with open("hit_log.txt", mode='a') as file:
|
| 14 |
file.write(str(datetime.datetime.now()) + '\n')
|
| 15 |
|
|
|
|
|
|
|
|
|
|
| 16 |
MODEL_DESC = {
|
| 17 |
'Bart MNLI': """Bart with a classification head trained on MNLI.\n\nSequences are posed as NLI premises and topic labels are turned into premises, i.e. `business` -> `This text is about business.`""",
|
| 18 |
'Bart MNLI + Yahoo Answers': """Bart with a classification head trained on MNLI and then further fine-tuned on Yahoo Answers topic classification.\n\nSequences are posed as NLI premises and topic labels are turned into premises, i.e. `business` -> `This text is about business.`""",
|
|
@@ -58,10 +61,22 @@ models = load_models()
|
|
| 58 |
def load_tokenizer(tok_id):
|
| 59 |
return AutoTokenizer.from_pretrained(tok_id)
|
| 60 |
|
| 61 |
-
@st.cache(allow_output_mutation=True, show_spinner=False
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
return outputs['labels'], outputs['scores']
|
| 66 |
|
| 67 |
def load_examples(model_id):
|
|
@@ -88,7 +103,6 @@ def plot_result(top_topics, scores):
|
|
| 88 |
fig.update_traces(texttemplate='%{text:0.1f}%', textposition='outside')
|
| 89 |
st.plotly_chart(fig)
|
| 90 |
|
| 91 |
-
|
| 92 |
|
| 93 |
def main():
|
| 94 |
with open("style.css") as f:
|
|
@@ -124,18 +138,11 @@ def main():
|
|
| 124 |
st.markdown(CODE_DESC.format(model_id))
|
| 125 |
|
| 126 |
with st.spinner('Classifying...'):
|
| 127 |
-
top_topics, scores = get_most_likely(model_id, sequence, labels, hypothesis_template, multi_class
|
| 128 |
-
|
| 129 |
-
plot_result(top_topics[::-1][-10:], scores[::-1][-10:])
|
| 130 |
-
|
| 131 |
-
if "socat" not in [p.name() for p in psutil.process_iter()]:
|
| 132 |
-
os.system('socat tcp-listen:8000,reuseaddr,fork tcp:localhost:8001 &')
|
| 133 |
-
|
| 134 |
-
|
| 135 |
|
|
|
|
| 136 |
|
| 137 |
|
| 138 |
|
| 139 |
if __name__ == '__main__':
|
| 140 |
main()
|
| 141 |
-
|
|
|
|
| 13 |
with open("hit_log.txt", mode='a') as file:
|
| 14 |
file.write(str(datetime.datetime.now()) + '\n')
|
| 15 |
|
| 16 |
+
|
| 17 |
+
MAX_GRAPH_ROWS = 10
|
| 18 |
+
|
| 19 |
MODEL_DESC = {
|
| 20 |
'Bart MNLI': """Bart with a classification head trained on MNLI.\n\nSequences are posed as NLI premises and topic labels are turned into premises, i.e. `business` -> `This text is about business.`""",
|
| 21 |
'Bart MNLI + Yahoo Answers': """Bart with a classification head trained on MNLI and then further fine-tuned on Yahoo Answers topic classification.\n\nSequences are posed as NLI premises and topic labels are turned into premises, i.e. `business` -> `This text is about business.`""",
|
|
|
|
| 61 |
def load_tokenizer(tok_id):
|
| 62 |
return AutoTokenizer.from_pretrained(tok_id)
|
| 63 |
|
| 64 |
+
@st.cache(allow_output_mutation=True, show_spinner=False, hash_funcs={
|
| 65 |
+
torch.nn.Parameter: lambda _: None
|
| 66 |
+
})
|
| 67 |
+
def get_most_likely(nli_model_id, sequence, labels, hypothesis_template, multi_class):
|
| 68 |
+
classifier = pipeline(
|
| 69 |
+
'zero-shot-classification',
|
| 70 |
+
model=models[nli_model_id],
|
| 71 |
+
tokenizer=load_tokenizer(nli_model_id),
|
| 72 |
+
device=device
|
| 73 |
+
)
|
| 74 |
+
outputs = classifier(
|
| 75 |
+
sequence,
|
| 76 |
+
candidate_labels=labels,
|
| 77 |
+
hypothesis_template=hypothesis_template,
|
| 78 |
+
multi_label=multi_class
|
| 79 |
+
)
|
| 80 |
return outputs['labels'], outputs['scores']
|
| 81 |
|
| 82 |
def load_examples(model_id):
|
|
|
|
| 103 |
fig.update_traces(texttemplate='%{text:0.1f}%', textposition='outside')
|
| 104 |
st.plotly_chart(fig)
|
| 105 |
|
|
|
|
| 106 |
|
| 107 |
def main():
|
| 108 |
with open("style.css") as f:
|
|
|
|
| 138 |
st.markdown(CODE_DESC.format(model_id))
|
| 139 |
|
| 140 |
with st.spinner('Classifying...'):
|
| 141 |
+
top_topics, scores = get_most_likely(model_id, sequence, labels, hypothesis_template, multi_class)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
+
plot_result(top_topics[::-1][-MAX_GRAPH_ROWS:], scores[::-1][-MAX_GRAPH_ROWS:])
|
| 144 |
|
| 145 |
|
| 146 |
|
| 147 |
if __name__ == '__main__':
|
| 148 |
main()
|
|
|