| import nltk | |
| import torch | |
| from summarizer import Summarizer | |
| from sumy.nlp.tokenizers import Tokenizer | |
| from sumy.summarizers.lsa import LsaSummarizer | |
| from sumy.parsers.plaintext import PlaintextParser | |
| from sumy.summarizers.lex_rank import LexRankSummarizer | |
| from sumy.summarizers.sum_basic import SumBasicSummarizer | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| nltk.download('punkt') | |
| def extractive(method, file): | |
| sumarizer = method | |
| sentences_ = [] | |
| doc_ = PlaintextParser(file, Tokenizer("en")).document | |
| for sentence in sumarizer(doc_, 5): | |
| sentences_.append(str(sentence)) | |
| summm_ = " ".join(sentences_) | |
| return summm_ | |
| def summarize(file, model): | |
| with open(file.name) as f: | |
| doc = f.read() | |
| if model == "Pegasus": | |
| checkpoint = "google/pegasus-billsum" | |
| tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint) | |
| inputs = tokenizer(doc, | |
| max_length=1024, | |
| truncation=True, | |
| return_tensors="pt") | |
| summary_ids = model.generate(inputs["input_ids"]) | |
| summary = tokenizer.batch_decode(summary_ids, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False) | |
| summary = summary[0] | |
| elif model == "LEDBill": | |
| tokenizer = AutoTokenizer.from_pretrained("d0r1h/LEDBill") | |
| model = AutoModelForSeq2SeqLM.from_pretrained("d0r1h/LEDBill", return_dict_in_generate=True) | |
| input_ids = tokenizer(doc, return_tensors="pt").input_ids | |
| global_attention_mask = torch.zeros_like(input_ids) | |
| global_attention_mask[:, 0] = 1 | |
| sequences = model.generate(input_ids, global_attention_mask=global_attention_mask).sequences | |
| summary = tokenizer.batch_decode(sequences, skip_special_tokens=True) | |
| summary = summary[0] | |
| elif model == "ILC": | |
| tokenizer = AutoTokenizer.from_pretrained("d0r1h/led-base-ilc") | |
| model = AutoModelForSeq2SeqLM.from_pretrained("d0r1h/led-base-ilc", return_dict_in_generate=True) | |
| input_ids = tokenizer(doc, return_tensors="pt").input_ids | |
| global_attention_mask = torch.zeros_like(input_ids) | |
| global_attention_mask[:, 0] = 1 | |
| sequences = model.generate(input_ids, global_attention_mask=global_attention_mask).sequences | |
| summary = tokenizer.batch_decode(sequences, skip_special_tokens=True) | |
| summary = summary[0] | |
| elif model == "Distill": | |
| checkpoint = "sshleifer/distill-pegasus-cnn-16-4" | |
| tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint) | |
| inputs = tokenizer(doc, | |
| max_length=1024, | |
| truncation=True, | |
| return_tensors="pt") | |
| summary_ids = model.generate(inputs["input_ids"]) | |
| summary = tokenizer.batch_decode(summary_ids, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False) | |
| summary = summary[0] | |
| elif model == "TextRank": | |
| summary = extractive(LexRankSummarizer(), doc) | |
| elif model == "SumBasic": | |
| summary = extractive(SumBasicSummarizer(), doc) | |
| elif model == "Lsa": | |
| summary = extractive(LsaSummarizer(), doc) | |
| elif model == "BERT": | |
| modelbert = Summarizer('distilbert-base-uncased', hidden=[-1,-2], hidden_concat=True) | |
| result = modelbert(doc) | |
| summary = ''.join(result) | |
| return summary |