add mbart for translation
Browse files
app.py
CHANGED
|
@@ -25,10 +25,16 @@ import argparse
|
|
| 25 |
|
| 26 |
import langid
|
| 27 |
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
|
|
|
|
| 28 |
|
|
|
|
| 29 |
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
|
| 30 |
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
class myTheme(Base):
|
| 34 |
def __init__(
|
|
@@ -155,11 +161,23 @@ def nllb_trans(article, target_language):
|
|
| 155 |
return translated
|
| 156 |
|
| 157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
def translate(article, toolkit, target_language):
|
| 159 |
if toolkit == "OPUS":
|
| 160 |
translated = opus_trans(article, target_language)
|
| 161 |
elif toolkit == "NLLB":
|
| 162 |
translated = nllb_trans(article, target_language)
|
|
|
|
|
|
|
| 163 |
|
| 164 |
return translated
|
| 165 |
|
|
@@ -169,7 +187,7 @@ myTheme = myTheme()
|
|
| 169 |
with gr.Blocks(theme=myTheme) as demo:
|
| 170 |
article = gr.Textbox(label="Article")
|
| 171 |
toolkit_select = gr.Radio(
|
| 172 |
-
["OPUS", "NLLB"], label="Select Translation Model", value="OPUS"
|
| 173 |
)
|
| 174 |
lang_select = gr.Radio(["English", "Chinese"], label="Select Desired Language")
|
| 175 |
result = gr.Textbox(label="Translated Result")
|
|
|
|
| 25 |
|
| 26 |
import langid
|
| 27 |
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
|
| 28 |
+
from easynmt import EasyNMT
|
| 29 |
|
| 30 |
+
# Initialize nllb-200 models
|
| 31 |
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
|
| 32 |
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
|
| 33 |
|
| 34 |
+
# Initialize mbart50 models
|
| 35 |
+
mbart_m2en_model = EasyNMT("mbart50_m2en")
|
| 36 |
+
mbart_en2m_model = EasyNMT("mbart50_en2m")
|
| 37 |
+
|
| 38 |
|
| 39 |
class myTheme(Base):
|
| 40 |
def __init__(
|
|
|
|
| 161 |
return translated
|
| 162 |
|
| 163 |
|
| 164 |
+
def mbart_trans(article, target_language):
|
| 165 |
+
result_lang = detect_lang(article)
|
| 166 |
+
|
| 167 |
+
if result_lang != target_language:
|
| 168 |
+
if target_language == "English":
|
| 169 |
+
return mbart_m2en_model.translate(article)
|
| 170 |
+
else:
|
| 171 |
+
return mbart_en2m_model.translate(article, target_lang="zh")
|
| 172 |
+
|
| 173 |
+
|
| 174 |
def translate(article, toolkit, target_language):
|
| 175 |
if toolkit == "OPUS":
|
| 176 |
translated = opus_trans(article, target_language)
|
| 177 |
elif toolkit == "NLLB":
|
| 178 |
translated = nllb_trans(article, target_language)
|
| 179 |
+
elif toolkit == "MBART":
|
| 180 |
+
translated = mbart_trans(article, target_language)
|
| 181 |
|
| 182 |
return translated
|
| 183 |
|
|
|
|
| 187 |
with gr.Blocks(theme=myTheme) as demo:
|
| 188 |
article = gr.Textbox(label="Article")
|
| 189 |
toolkit_select = gr.Radio(
|
| 190 |
+
["OPUS", "NLLB", "MBART"], label="Select Translation Model", value="OPUS"
|
| 191 |
)
|
| 192 |
lang_select = gr.Radio(["English", "Chinese"], label="Select Desired Language")
|
| 193 |
result = gr.Textbox(label="Translated Result")
|