Update Summarizer/Extractive.py
Browse files- Summarizer/Extractive.py +26 -0
Summarizer/Extractive.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import nltk
|
|
|
|
| 2 |
from summarizer import Summarizer
|
| 3 |
from sumy.nlp.tokenizers import Tokenizer
|
| 4 |
from sumy.summarizers.lsa import LsaSummarizer
|
|
@@ -37,6 +38,31 @@ def summarize(file, model):
|
|
| 37 |
skip_special_tokens=True,
|
| 38 |
clean_up_tokenization_spaces=False)
|
| 39 |
summary = summary[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
elif model == "TextRank":
|
| 42 |
summary = extractive(LexRankSummarizer(), doc)
|
|
|
|
| 1 |
import nltk
|
| 2 |
+
import torch
|
| 3 |
from summarizer import Summarizer
|
| 4 |
from sumy.nlp.tokenizers import Tokenizer
|
| 5 |
from sumy.summarizers.lsa import LsaSummarizer
|
|
|
|
| 38 |
skip_special_tokens=True,
|
| 39 |
clean_up_tokenization_spaces=False)
|
| 40 |
summary = summary[0]
|
| 41 |
+
elif model == "LEDBill":
|
| 42 |
+
tokenizer = AutoTokenizer.from_pretrained("d0r1h/LEDBill")
|
| 43 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("d0r1h/LEDBill", return_dict_in_generate=True)
|
| 44 |
+
|
| 45 |
+
input_ids = tokenizer(doc, return_tensors="pt").input_ids
|
| 46 |
+
global_attention_mask = torch.zeros_like(input_ids)
|
| 47 |
+
global_attention_mask[:, 0] = 1
|
| 48 |
+
|
| 49 |
+
sequences = model.generate(input_ids, global_attention_mask=global_attention_mask).sequences
|
| 50 |
+
summary = tokenizer.batch_decode(sequences, skip_special_tokens=True)
|
| 51 |
+
|
| 52 |
+
summary = summary[0]
|
| 53 |
+
|
| 54 |
+
elif model == "ILC":
|
| 55 |
+
tokenizer = AutoTokenizer.from_pretrained("d0r1h/led-base-ilc")
|
| 56 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("d0r1h/led-base-ilc", return_dict_in_generate=True)
|
| 57 |
+
|
| 58 |
+
input_ids = tokenizer(doc, return_tensors="pt").input_ids
|
| 59 |
+
global_attention_mask = torch.zeros_like(input_ids)
|
| 60 |
+
global_attention_mask[:, 0] = 1
|
| 61 |
+
|
| 62 |
+
sequences = model.generate(input_ids, global_attention_mask=global_attention_mask).sequences
|
| 63 |
+
summary = tokenizer.batch_decode(sequences, skip_special_tokens=True)
|
| 64 |
+
|
| 65 |
+
summary = summary[0]
|
| 66 |
|
| 67 |
elif model == "TextRank":
|
| 68 |
summary = extractive(LexRankSummarizer(), doc)
|