File size: 4,184 Bytes
b653f91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# src/train.py

import os
import argparse
from datasets import Dataset
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)

def train_model():
    """
    Fine-tunes a pre-trained NLLB model on a parallel dataset.
    """
    parser = argparse.ArgumentParser(description="Fine-tune a translation model.")
    parser.add_argument("--model_checkpoint", type=str, default="facebook/nllb-200-distilled-600M")
    parser.add_argument("--source_lang", type=str, required=True, help="Source language code (e.g., 'ne')")
    parser.add_argument("--target_lang", type=str, default="en")
    parser.add_argument("--source_lang_tokenizer", type=str, required=True, help="Source language code for tokenizer (e.g., 'nep_Npan')")
    parser.add_argument("--train_file_source", type=str, required=True, help="Path to the source language training file")
    parser.add_argument("--train_file_target", type=str, required=True, help="Path to the target language training file")
    parser.add_argument("--output_dir", type=str, required=True, help="Directory to save the fine-tuned model")
    parser.add_argument("--epochs", type=int, default=3)
    parser.add_argument("--batch_size", type=int, default=8)

    args = parser.parse_args()

    # --- 1. Configuration ---
    MODEL_CHECKPOINT = args.model_checkpoint
    SOURCE_LANG = args.source_lang
    TARGET_LANG = args.target_lang
    MODEL_OUTPUT_DIR = args.output_dir

    # --- 2. Load Tokenizer and Model ---
    print("Loading tokenizer and model...")
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_CHECKPOINT, src_lang=args.source_lang_tokenizer, tgt_lang="eng_Latn"
    )
    model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_CHECKPOINT)

    # --- 3. Load and Preprocess Data (Memory-Efficiently) ---
    print("Loading and preprocessing data...")
    
    def generate_examples():
        with open(args.train_file_source, "r", encoding="utf-8") as f_src, \
             open(args.train_file_target, "r", encoding="utf-8") as f_tgt:
            for src_line, tgt_line in zip(f_src, f_tgt):
                yield {"translation": {SOURCE_LANG: src_line.strip(), TARGET_LANG: tgt_line.strip()}}

    dataset = Dataset.from_generator(generate_examples)
    
    split_datasets = dataset.train_test_split(train_size=0.95, seed=42)
    split_datasets["validation"] = split_datasets.pop("test")

    def preprocess_function(examples):
        inputs = [ex[SOURCE_LANG] for ex in examples["translation"]]
        targets = [ex[TARGET_LANG] for ex in examples["translation"]]
        
        model_inputs = tokenizer(inputs, text_target=targets, max_length=128, truncation=True)
        return model_inputs

    tokenized_datasets = split_datasets.map(
        preprocess_function,
        batched=True,
        remove_columns=split_datasets["train"].column_names,
    )

    # --- 4. Set Up Training Arguments ---
    print("Setting up training arguments...")
    training_args = Seq2SeqTrainingArguments(
        output_dir=MODEL_OUTPUT_DIR,
        eval_strategy="epoch",
        learning_rate=2e-5,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        weight_decay=0.01,
        save_total_limit=3,
        num_train_epochs=args.epochs,
        predict_with_generate=True,
        fp16=False, # Set to True if you have a compatible GPU
    )

    # --- 5. Create the Trainer ---
    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
    
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["validation"],
        tokenizer=tokenizer,
        data_collator=data_collator,
    )

    # --- 6. Start Training ---
    print("\n--- Starting model fine-tuning ---")
    trainer.train()
    print("--- Training complete ---")

    # --- 7. Save the Final Model ---
    print(f"Saving final model to {MODEL_OUTPUT_DIR}")
    trainer.save_model()
    print("Model saved successfully!")

if __name__ == "__main__":
    train_model()