add nllb translate codes
Browse files
app.py
CHANGED
|
@@ -24,7 +24,7 @@ from gradio.themes.utils import colors, fonts, sizes
|
|
| 24 |
import argparse
|
| 25 |
|
| 26 |
import langid
|
| 27 |
-
from transformers import pipeline
|
| 28 |
|
| 29 |
|
| 30 |
class myTheme(Base):
|
|
@@ -112,8 +112,6 @@ def opus_trans(article, target_language):
|
|
| 112 |
target_lang = "en"
|
| 113 |
elif target_language == "Chinese":
|
| 114 |
target_lang = "zh"
|
| 115 |
-
elif target_language == "Spanish":
|
| 116 |
-
target_lang = "es"
|
| 117 |
|
| 118 |
if result_lang != target_lang:
|
| 119 |
task_name = f"translation_{result_lang}_to_{target_lang}"
|
|
@@ -129,15 +127,31 @@ def opus_trans(article, target_language):
|
|
| 129 |
|
| 130 |
|
| 131 |
def nllb_trans(article, target_language):
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
|
| 135 |
def translate(article, toolkit, target_language):
|
| 136 |
if toolkit == "OPUS":
|
| 137 |
translated = opus_trans(article, target_language)
|
| 138 |
-
return translated
|
| 139 |
elif toolkit == "NLLB":
|
| 140 |
-
|
|
|
|
|
|
|
| 141 |
|
| 142 |
|
| 143 |
myTheme = myTheme()
|
|
|
|
| 24 |
import argparse
|
| 25 |
|
| 26 |
import langid
|
| 27 |
+
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
|
| 28 |
|
| 29 |
|
| 30 |
class myTheme(Base):
|
|
|
|
| 112 |
target_lang = "en"
|
| 113 |
elif target_language == "Chinese":
|
| 114 |
target_lang = "zh"
|
|
|
|
|
|
|
| 115 |
|
| 116 |
if result_lang != target_lang:
|
| 117 |
task_name = f"translation_{result_lang}_to_{target_lang}"
|
|
|
|
| 127 |
|
| 128 |
|
| 129 |
def nllb_trans(article, target_language):
|
| 130 |
+
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
|
| 131 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
|
| 132 |
+
inputs = tokenizer(article, return_tensors="pt")
|
| 133 |
+
|
| 134 |
+
if target_language == "English":
|
| 135 |
+
target_lang = "Eng_Latn"
|
| 136 |
+
elif target_language == "Chinese":
|
| 137 |
+
target_lang = "zho_Hans"
|
| 138 |
+
|
| 139 |
+
translated_tokens = model.generate(
|
| 140 |
+
**inputs,
|
| 141 |
+
forced_bos_token_id=tokenizer.lang_code_to_id[target_lang],
|
| 142 |
+
max_length=30,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
return tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
|
| 146 |
|
| 147 |
|
| 148 |
def translate(article, toolkit, target_language):
|
| 149 |
if toolkit == "OPUS":
|
| 150 |
translated = opus_trans(article, target_language)
|
|
|
|
| 151 |
elif toolkit == "NLLB":
|
| 152 |
+
translated = nllb_trans(article, target_language)
|
| 153 |
+
|
| 154 |
+
return translated
|
| 155 |
|
| 156 |
|
| 157 |
myTheme = myTheme()
|