Spaces:
Runtime error
Runtime error
| """ | |
| Script for decoding summarization models available through Huggingface Transformers. | |
| Usage with Huggingface Datasets: | |
| python generation.py --model <model name> --data_path <path to data in jsonl format> | |
| Usage with custom datasets in JSONL format: | |
| python generation.py --model <model name> --dataset <dataset name> --split <data split> | |
| """ | |
| #!/usr/bin/env python | |
| # coding: utf-8 | |
| import argparse | |
| import json | |
| import os | |
| import torch | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| from datasets import load_dataset | |
| from tqdm import tqdm | |
| BATCH_SIZE = 8 | |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| BART_CNNDM_CHECKPOINT = 'facebook/bart-large-cnn' | |
| BART_XSUM_CHECKPOINT = 'facebook/bart-large-xsum' | |
| PEGASUS_CNNDM_CHECKPOINT = 'google/pegasus-cnn_dailymail' | |
| PEGASUS_XSUM_CHECKPOINT = 'google/pegasus-xsum' | |
| PEGASUS_NEWSROOM_CHECKPOINT = 'google/pegasus-newsroom' | |
| PEGASUS_MULTINEWS_CHECKPOINT = 'google/pegasus-multi_news' | |
| MODEL_CHECKPOINTS = { | |
| 'bart-xsum': BART_XSUM_CHECKPOINT, | |
| 'bart-cnndm': BART_CNNDM_CHECKPOINT, | |
| 'pegasus-xsum': PEGASUS_XSUM_CHECKPOINT, | |
| 'pegasus-cnndm': PEGASUS_CNNDM_CHECKPOINT, | |
| 'pegasus-newsroom': PEGASUS_NEWSROOM_CHECKPOINT, | |
| 'pegasus-multinews': PEGASUS_MULTINEWS_CHECKPOINT | |
| } | |
| class JSONDataset(torch.utils.data.Dataset): | |
| def __init__(self, data_path): | |
| super(JSONDataset, self).__init__() | |
| with open(data_path) as fd: | |
| self.data = [json.loads(line) for line in fd] | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| return self.data[idx] | |
| def preprocess_data(raw_data, dataset): | |
| """ | |
| Unify format of Huggingface Datastes | |
| :param raw_data: loaded data | |
| :param dataset: name of dataset | |
| """ | |
| if dataset == 'xsum': | |
| raw_data['article'] = raw_data['document'] | |
| raw_data['target'] = raw_data['summary'] | |
| del raw_data['document'] | |
| del raw_data['summary'] | |
| elif dataset == 'cnndm': | |
| raw_data['target'] = raw_data['highlights'] | |
| del raw_data['highlights'] | |
| elif dataset == 'gigaword': | |
| raw_data['article'] = raw_data['document'] | |
| raw_data['target'] = raw_data['summary'] | |
| del raw_data['document'] | |
| del raw_data['summary'] | |
| return raw_data | |
| def postprocess_data(raw_data, decoded): | |
| """ | |
| Remove generation artifacts and postprocess outputs | |
| :param raw_data: loaded data | |
| :param decoded: model outputs | |
| """ | |
| raw_data['target'] = [x.replace('\n', ' ') for x in raw_data['target']] | |
| raw_data['decoded'] = [x.replace('<n>', ' ') for x in decoded] | |
| return [dict(zip(raw_data, t)) for t in zip(*raw_data.values())] | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description='Process some integers.') | |
| parser.add_argument('--model', type=str, required=True, choices=['bart-xsum', 'bart-cnndm', 'pegasus-xsum', 'pegasus-cnndm', 'pegasus-newsroom', 'pegasus-multinews']) | |
| parser.add_argument('--data_path', type=str) | |
| parser.add_argument('--dataset', type=str, choices=['xsum', 'cnndm', 'gigaword']) | |
| parser.add_argument('--split', type=str, choices=['train', 'validation', 'test']) | |
| args = parser.parse_args() | |
| if args.dataset and not args.split: | |
| raise RuntimeError('If `dataset` flag is specified `split` must also be provided.') | |
| if args.data_path: | |
| args.dataset = os.path.splitext(os.path.basename(args.data_path))[0] | |
| args.split = 'user' | |
| # Load models & data | |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_CHECKPOINTS[args.model]).to(DEVICE) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINTS[args.model]) | |
| if not args.data_path: | |
| if args.dataset == 'cnndm': | |
| dataset = load_dataset('cnn_dailymail', '3.0.0', split=args.split) | |
| elif args.dataset =='xsum': | |
| dataset = load_dataset('xsum', split=args.split) | |
| elif args.dataset =='gigaword': | |
| dataset = load_dataset('gigaword', split=args.split) | |
| else: | |
| dataset = JSONDataset(args.data_path) | |
| dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE) | |
| # Run validation | |
| filename = '%s.%s.%s.results' % (args.model.replace("/", "-"), args.dataset, args.split) | |
| fd_out = open(filename, 'w') | |
| results = [] | |
| model.eval() | |
| with torch.no_grad(): | |
| for raw_data in tqdm(dataloader): | |
| raw_data = preprocess_data(raw_data, args.dataset) | |
| batch = tokenizer(raw_data["article"], return_tensors="pt", truncation=True, padding="longest").to(DEVICE) | |
| summaries = model.generate(input_ids=batch.input_ids, attention_mask=batch.attention_mask) | |
| decoded = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
| result = postprocess_data(raw_data, decoded) | |
| results.extend(result) | |
| for example in result: | |
| fd_out.write(json.dumps(example) + '\n') | |