|
|
--- |
|
|
license: mit |
|
|
language: |
|
|
- vi |
|
|
pipeline_tag: translation |
|
|
--- |
|
|
## Dataset |
|
|
The model is trained on a high-quality dataset for English-Vietnamese translation: [](https://github.com/stefan-it/nmt-en-vi) |
|
|
## Usage |
|
|
```python |
|
|
import tensorflow as tf |
|
|
from translator import Translator |
|
|
from utils import tokenizer_utils |
|
|
from utils.preprocessing import input_processing, output_processing |
|
|
from models.transformer import Transformer |
|
|
from models.encoder import Encoder |
|
|
from models.decoder import Decoder |
|
|
from models.layers import EncoderLayer, DecoderLayer, MultiHeadAttention, point_wise_feed_forward_network |
|
|
from models.utils import masked_loss, masked_accuracy |
|
|
|
|
|
def main(sentence, model): |
|
|
|
|
|
# Load tokenizers |
|
|
en_tokenizer, vi_tokenizer = tokenizer_utils.load_tokenizers() # Update to include tokenizers.tokenizer_utils |
|
|
|
|
|
# Create translator |
|
|
translator = Translator(en_tokenizer, vi_tokenizer, loaded_model) |
|
|
|
|
|
# Process and translate the input sentence |
|
|
processed_sentence = input_processing(sentence) |
|
|
translated_text = translator(processed_sentence) |
|
|
translated_text = output_processing(translated_text) |
|
|
|
|
|
print("Input:", processed_sentence) |
|
|
print("Translated:", translated_text) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
# Example sentence |
|
|
sentence = """ |
|
|
For at least six centuries, residents along a lake in the mountains of central Japan |
|
|
have marked the depth of winter by celebrating the return of a natural phenomenon |
|
|
once revered as the trail of a wandering god. |
|
|
""" |
|
|
|
|
|
# Define custom objects for model loading |
|
|
custom_objects = { |
|
|
'Transformer': Transformer, |
|
|
'Encoder': Encoder, |
|
|
'Decoder': Decoder, |
|
|
'EncoderLayer': EncoderLayer, |
|
|
'DecoderLayer': DecoderLayer, |
|
|
'MultiHeadAttention': MultiHeadAttention, |
|
|
'point_wise_feed_forward_network': point_wise_feed_forward_network, |
|
|
'masked_loss': masked_loss, |
|
|
'masked_accuracy': masked_accuracy |
|
|
} |
|
|
|
|
|
# Load the model |
|
|
loaded_model = tf.keras.models.load_model('ckpts/en_vi_translation.keras', |
|
|
custom_objects=custom_objects) |
|
|
|
|
|
main(sentence=sentence, model=loaded_model) |
|
|
``` |