|
|
|
|
|
|
|
|
import os |
|
|
import argparse |
|
|
from datasets import Dataset |
|
|
from transformers import ( |
|
|
AutoModelForSeq2SeqLM, |
|
|
AutoTokenizer, |
|
|
DataCollatorForSeq2Seq, |
|
|
Seq2SeqTrainingArguments, |
|
|
Seq2SeqTrainer, |
|
|
) |
|
|
|
|
|
def train_model(): |
|
|
""" |
|
|
Fine-tunes a pre-trained NLLB model on a parallel dataset. |
|
|
""" |
|
|
parser = argparse.ArgumentParser(description="Fine-tune a translation model.") |
|
|
parser.add_argument("--model_checkpoint", type=str, default="facebook/nllb-200-distilled-600M") |
|
|
parser.add_argument("--source_lang", type=str, required=True, help="Source language code (e.g., 'ne')") |
|
|
parser.add_argument("--target_lang", type=str, default="en") |
|
|
parser.add_argument("--source_lang_tokenizer", type=str, required=True, help="Source language code for tokenizer (e.g., 'nep_Npan')") |
|
|
parser.add_argument("--train_file_source", type=str, required=True, help="Path to the source language training file") |
|
|
parser.add_argument("--train_file_target", type=str, required=True, help="Path to the target language training file") |
|
|
parser.add_argument("--output_dir", type=str, required=True, help="Directory to save the fine-tuned model") |
|
|
parser.add_argument("--epochs", type=int, default=3) |
|
|
parser.add_argument("--batch_size", type=int, default=8) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
MODEL_CHECKPOINT = args.model_checkpoint |
|
|
SOURCE_LANG = args.source_lang |
|
|
TARGET_LANG = args.target_lang |
|
|
MODEL_OUTPUT_DIR = args.output_dir |
|
|
|
|
|
|
|
|
print("Loading tokenizer and model...") |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
MODEL_CHECKPOINT, src_lang=args.source_lang_tokenizer, tgt_lang="eng_Latn" |
|
|
) |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_CHECKPOINT) |
|
|
|
|
|
|
|
|
print("Loading and preprocessing data...") |
|
|
|
|
|
def generate_examples(): |
|
|
with open(args.train_file_source, "r", encoding="utf-8") as f_src, \ |
|
|
open(args.train_file_target, "r", encoding="utf-8") as f_tgt: |
|
|
for src_line, tgt_line in zip(f_src, f_tgt): |
|
|
yield {"translation": {SOURCE_LANG: src_line.strip(), TARGET_LANG: tgt_line.strip()}} |
|
|
|
|
|
dataset = Dataset.from_generator(generate_examples) |
|
|
|
|
|
split_datasets = dataset.train_test_split(train_size=0.95, seed=42) |
|
|
split_datasets["validation"] = split_datasets.pop("test") |
|
|
|
|
|
def preprocess_function(examples): |
|
|
inputs = [ex[SOURCE_LANG] for ex in examples["translation"]] |
|
|
targets = [ex[TARGET_LANG] for ex in examples["translation"]] |
|
|
|
|
|
model_inputs = tokenizer(inputs, text_target=targets, max_length=128, truncation=True) |
|
|
return model_inputs |
|
|
|
|
|
tokenized_datasets = split_datasets.map( |
|
|
preprocess_function, |
|
|
batched=True, |
|
|
remove_columns=split_datasets["train"].column_names, |
|
|
) |
|
|
|
|
|
|
|
|
print("Setting up training arguments...") |
|
|
training_args = Seq2SeqTrainingArguments( |
|
|
output_dir=MODEL_OUTPUT_DIR, |
|
|
eval_strategy="epoch", |
|
|
learning_rate=2e-5, |
|
|
per_device_train_batch_size=args.batch_size, |
|
|
per_device_eval_batch_size=args.batch_size, |
|
|
weight_decay=0.01, |
|
|
save_total_limit=3, |
|
|
num_train_epochs=args.epochs, |
|
|
predict_with_generate=True, |
|
|
fp16=False, |
|
|
) |
|
|
|
|
|
|
|
|
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model) |
|
|
|
|
|
trainer = Seq2SeqTrainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=tokenized_datasets["train"], |
|
|
eval_dataset=tokenized_datasets["validation"], |
|
|
tokenizer=tokenizer, |
|
|
data_collator=data_collator, |
|
|
) |
|
|
|
|
|
|
|
|
print("\n--- Starting model fine-tuning ---") |
|
|
trainer.train() |
|
|
print("--- Training complete ---") |
|
|
|
|
|
|
|
|
print(f"Saving final model to {MODEL_OUTPUT_DIR}") |
|
|
trainer.save_model() |
|
|
print("Model saved successfully!") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
train_model() |
|
|
|