Spaces:
Runtime error
Runtime error
| from datasets import load_dataset | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from transformers import TrainingArguments, Trainer | |
| import os | |
| import torch | |
| # Load dataset | |
| ds = load_dataset("knkarthick/dialogsum") | |
| # Load tokenizer and model | |
| tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn") | |
| model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn") | |
| # Preprocessing function | |
| def preprocess_function(batch): | |
| source = batch['dialogue'] | |
| target = batch['summary'] | |
| source_enc = tokenizer(source, padding='max_length', truncation=True, max_length=128) | |
| target_enc = tokenizer(target, padding='max_length', truncation=True, max_length=128) | |
| labels = target_enc['input_ids'] | |
| labels = [[(token if token != tokenizer.pad_token_id else -100) for token in label] for label in labels] | |
| return { | |
| 'input_ids': source_enc['input_ids'], | |
| 'attention_mask': source_enc['attention_mask'], | |
| 'labels': labels | |
| } | |
| # Apply preprocessing | |
| df_source = ds.map(preprocess_function, batched=True) | |
| # Training arguments | |
| training_args = TrainingArguments( | |
| output_dir='/content/TextSummarizer_output', | |
| per_device_train_batch_size=8, | |
| num_train_epochs=2, | |
| save_total_limit=1, | |
| save_strategy="epoch", | |
| remove_unused_columns=True, | |
| logging_dir='/content/logs', | |
| logging_steps=50, | |
| ) | |
| # Trainer | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=df_source['train'], | |
| eval_dataset=df_source['test'], | |
| ) | |
| # Train | |
| trainer.train() | |
| # Evaluate | |
| eval_results = trainer.evaluate() | |
| print("Evaluation Results:", eval_results) | |
| # ===> Save to Google Drive path | |
| save_path = "/content/drive/MyDrive/TextSummarizer2/model_directory" | |
| os.makedirs(save_path, exist_ok=True) | |
| # Save model and tokenizer (use safe_serialization for large model.safetensors) | |
| model.save_pretrained(save_path, safe_serialization=True) | |
| tokenizer.save_pretrained(save_path) | |
| print(f"β Model and tokenizer saved to: {save_path}") | |
| print("π¦ Files saved:", os.listdir(save_path)) | |