Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
✨ mwe working aggregation
Browse filesSigned-off-by: peter szemraj <[email protected]>
- aggregate.py +158 -67
- app.py +91 -10
aggregate.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
| 1 |
-
|
| 2 |
import logging
|
| 3 |
import time
|
| 4 |
|
| 5 |
import torch
|
| 6 |
from transformers import GenerationConfig, pipeline
|
| 7 |
|
|
|
|
|
|
|
| 8 |
# Setting up logging
|
| 9 |
logging.basicConfig(
|
| 10 |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
|
@@ -12,94 +14,182 @@ logging.basicConfig(
|
|
| 12 |
|
| 13 |
|
| 14 |
class BatchAggregator:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
def __init__(
|
| 16 |
self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1", **kwargs
|
| 17 |
):
|
|
|
|
|
|
|
| 18 |
self.logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
self.model_name = model_name
|
| 20 |
-
self.
|
| 21 |
-
self.
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
|
|
|
|
|
|
|
|
|
| 28 |
try:
|
| 29 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
except Exception as e:
|
| 31 |
-
self.logger.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
try:
|
| 33 |
-
self.aggregator.model
|
| 34 |
-
|
| 35 |
-
)
|
| 36 |
except Exception as e:
|
| 37 |
-
self.logger.warning(
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
self.
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
)
|
| 54 |
|
| 55 |
-
if "bart" in model_name.lower():
|
| 56 |
-
self.logger.info("Using BART model, updating generation config")
|
| 57 |
-
upd = {
|
| 58 |
-
"num_beams": 8,
|
| 59 |
-
"repetition_penalty": 1.3,
|
| 60 |
-
"length_penalty": 1.0,
|
| 61 |
-
"_from_model_config": False,
|
| 62 |
-
"max_new_tokens": 256,
|
| 63 |
-
"min_new_tokens": 32,
|
| 64 |
-
"no_repeat_ngram_size": 3,
|
| 65 |
-
"encoder_no_repeat_ngram_size": 6,
|
| 66 |
-
}
|
| 67 |
-
self.aggregator.model.generation_config.update(**upd)
|
| 68 |
-
if self.model_name != "pszemraj/bart-large-mnli-dolly_hhrlhf-v1":
|
| 69 |
-
self.logger.info("Updating generation config with defaults")
|
| 70 |
-
self.update_generation_config()
|
| 71 |
self.logger.info(self.aggregator.model.generation_config.to_json_string())
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
def update_generation_config(self, **kwargs):
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
|
|
|
|
|
|
| 91 |
def infer_aggregate(
|
| 92 |
self,
|
| 93 |
text_list: list,
|
| 94 |
-
instruction: str =
|
| 95 |
**kwargs,
|
| 96 |
-
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
joined_text = "\n".join(text_list)
|
| 98 |
prompt = f"{instruction}\n\n{joined_text}\n"
|
| 99 |
if kwargs:
|
| 100 |
self.update_generation_config(**kwargs)
|
| 101 |
st = time.perf_counter()
|
| 102 |
-
self.logger.info(f"
|
| 103 |
result = self.aggregator(
|
| 104 |
prompt,
|
| 105 |
generation_config=self.aggregator.model.generation_config,
|
|
@@ -110,7 +200,8 @@ class BatchAggregator:
|
|
| 110 |
)
|
| 111 |
return result
|
| 112 |
|
| 113 |
-
def count_tokens(self, text: str):
|
|
|
|
| 114 |
return (
|
| 115 |
len(self.aggregator.tokenizer.encode(text, truncation=False, padding=False))
|
| 116 |
if text
|
|
|
|
| 1 |
+
import pprint as pp
|
| 2 |
import logging
|
| 3 |
import time
|
| 4 |
|
| 5 |
import torch
|
| 6 |
from transformers import GenerationConfig, pipeline
|
| 7 |
|
| 8 |
+
from utils import compare_model_size
|
| 9 |
+
|
| 10 |
# Setting up logging
|
| 11 |
logging.basicConfig(
|
| 12 |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
class BatchAggregator:
|
| 17 |
+
CONFIGURED_MODELS = [
|
| 18 |
+
"pszemraj/bart-large-mnli-dolly_hhrlhf-v1"
|
| 19 |
+
] # TODO: Add models here
|
| 20 |
+
DEFAULT_INSTRUCTION = "Write a comprehensive yet concise summary that pulls together the main points of the following text:"
|
| 21 |
+
GENERIC_CONFIG = GenerationConfig(
|
| 22 |
+
num_beams=8,
|
| 23 |
+
early_stopping=True,
|
| 24 |
+
do_sample=False,
|
| 25 |
+
min_new_tokens=32,
|
| 26 |
+
max_new_tokens=256,
|
| 27 |
+
repetition_penalty=1.1,
|
| 28 |
+
length_penalty=1.4,
|
| 29 |
+
no_repeat_ngram_size=4,
|
| 30 |
+
encoder_no_repeat_ngram_size=5,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
def __init__(
|
| 34 |
self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1", **kwargs
|
| 35 |
):
|
| 36 |
+
self.device = None
|
| 37 |
+
self.is_compiled = False
|
| 38 |
self.logger = logging.getLogger(__name__)
|
| 39 |
+
self.init_model(model_name)
|
| 40 |
+
|
| 41 |
+
def init_model(self, model_name: str) -> None:
|
| 42 |
+
"""
|
| 43 |
+
Initialize the model.
|
| 44 |
+
|
| 45 |
+
:param model_name: The name of the model to use.
|
| 46 |
+
"""
|
| 47 |
+
# Free up memory
|
| 48 |
+
if torch.cuda.is_available():
|
| 49 |
+
torch.cuda.empty_cache()
|
| 50 |
+
|
| 51 |
+
self.logger.info(f"Setting model to {model_name}")
|
| 52 |
self.model_name = model_name
|
| 53 |
+
self.aggregator = self._create_pipeline(model_name)
|
| 54 |
+
self._configure_model()
|
| 55 |
+
# update the generation config with the specific tokenizer
|
| 56 |
+
tokenizer_params = {
|
| 57 |
+
"decoder_start_token_id": 0
|
| 58 |
+
if "t5" in model_name.lower()
|
| 59 |
+
else self.aggregator.tokenizer.eos_token_id,
|
| 60 |
+
"eos_token_id": 1
|
| 61 |
+
if "t5" in model_name.lower()
|
| 62 |
+
else self.aggregator.tokenizer.eos_token_id,
|
| 63 |
+
"pad_token_id": 0
|
| 64 |
+
if "t5" in model_name.lower()
|
| 65 |
+
else self.aggregator.tokenizer.pad_token_id,
|
| 66 |
+
}
|
| 67 |
+
self.update_generation_config(**tokenizer_params)
|
| 68 |
+
|
| 69 |
+
def _create_pipeline(
|
| 70 |
+
self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1"
|
| 71 |
+
) -> pipeline:
|
| 72 |
+
"""
|
| 73 |
+
_create_pipeline creates a pipeline for the model.
|
| 74 |
+
|
| 75 |
+
:param str model_name: model name to use, default: "pszemraj/bart-large-mnli-dolly_hhrlhf-v1"
|
| 76 |
+
:return pipeline: the pipeline for the model
|
| 77 |
|
| 78 |
+
:raises Exception: if the pipeline cannot be created
|
| 79 |
+
"""
|
| 80 |
+
self.device = 0 if torch.cuda.is_available() else -1
|
| 81 |
try:
|
| 82 |
+
self.logger.info(
|
| 83 |
+
f"Creating pipeline with model {model_name} on device {self.device}"
|
| 84 |
+
)
|
| 85 |
+
return pipeline(
|
| 86 |
+
"text2text-generation",
|
| 87 |
+
model_name,
|
| 88 |
+
device=self.device,
|
| 89 |
+
torch_dtype=torch.float32,
|
| 90 |
+
)
|
| 91 |
except Exception as e:
|
| 92 |
+
self.logger.error(f"Failed to create pipeline: {e}")
|
| 93 |
+
raise
|
| 94 |
+
|
| 95 |
+
def _configure_model(self):
|
| 96 |
+
"""
|
| 97 |
+
Configure the model for generation.
|
| 98 |
+
"""
|
| 99 |
try:
|
| 100 |
+
self.aggregator.model = torch.compile(self.aggregator.model)
|
| 101 |
+
self.is_compiled = True
|
|
|
|
| 102 |
except Exception as e:
|
| 103 |
+
self.logger.warning(f"Could not compile model with Torch 2.0: {e}")
|
| 104 |
+
|
| 105 |
+
if self.model_name not in self.CONFIGURED_MODELS:
|
| 106 |
+
self.logger.info("Setting generation config to general defaults")
|
| 107 |
+
self._set_default_generation_config()
|
| 108 |
+
else:
|
| 109 |
+
try:
|
| 110 |
+
self.logger.info("Loading generation config from hub")
|
| 111 |
+
self.aggregator.model.generation_config = (
|
| 112 |
+
GenerationConfig.from_pretrained(self.model_name)
|
| 113 |
+
)
|
| 114 |
+
except Exception as e:
|
| 115 |
+
self.logger.warning(
|
| 116 |
+
f"Could not load generation config, using defaults: {e}"
|
| 117 |
+
)
|
| 118 |
+
self._set_default_generation_config()
|
|
|
|
| 119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
self.logger.info(self.aggregator.model.generation_config.to_json_string())
|
| 121 |
|
| 122 |
+
def _set_default_generation_config(self):
|
| 123 |
+
"""
|
| 124 |
+
Set the default generation configuration for the model.
|
| 125 |
+
"""
|
| 126 |
+
self.aggregator.model.generation_config = self.GENERIC_CONFIG
|
| 127 |
+
|
| 128 |
+
if "bart" in self.model_name.lower():
|
| 129 |
+
self.logger.info("Using BART model, updating generation config")
|
| 130 |
+
upd = {
|
| 131 |
+
"num_beams": 8,
|
| 132 |
+
"repetition_penalty": 1.3,
|
| 133 |
+
"length_penalty": 1.0,
|
| 134 |
+
"_from_model_config": False,
|
| 135 |
+
"max_new_tokens": 256,
|
| 136 |
+
"min_new_tokens": 32,
|
| 137 |
+
"no_repeat_ngram_size": 3,
|
| 138 |
+
"encoder_no_repeat_ngram_size": 6,
|
| 139 |
+
} # TODO: clean up
|
| 140 |
+
self.aggregator.model.generation_config.update(**upd)
|
| 141 |
+
|
| 142 |
+
if (
|
| 143 |
+
"large"
|
| 144 |
+
or "xl" in self.model_name.lower()
|
| 145 |
+
or compare_model_size(self.model_name, 500)
|
| 146 |
+
):
|
| 147 |
+
upd = {"num_beams": 4}
|
| 148 |
+
self.update_generation_config(**upd)
|
| 149 |
+
|
| 150 |
def update_generation_config(self, **kwargs):
|
| 151 |
+
"""
|
| 152 |
+
Update the generation configuration with the specified parameters.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
**kwargs: The parameters to update in the generation configuration.
|
| 156 |
+
"""
|
| 157 |
+
self.logger.info(f"Updating generation config with {pp.pformat(kwargs)}")
|
| 158 |
+
|
| 159 |
+
self.aggregator.model.generation_config.update(**kwargs)
|
| 160 |
+
|
| 161 |
+
def update_loglevel(self, level: str = "INFO"):
|
| 162 |
+
"""
|
| 163 |
+
Update the log level.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
level (str): The log level to set. Defaults to "INFO".
|
| 167 |
+
"""
|
| 168 |
+
self.logger.setLevel(level)
|
| 169 |
+
|
| 170 |
def infer_aggregate(
|
| 171 |
self,
|
| 172 |
text_list: list,
|
| 173 |
+
instruction: str = DEFAULT_INSTRUCTION,
|
| 174 |
**kwargs,
|
| 175 |
+
) -> str:
|
| 176 |
+
f"""
|
| 177 |
+
Generate a summary of the specified texts.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
text_list (list): The texts to summarize.
|
| 181 |
+
instruction (str): The instruction for the summary. Defaults to {self.DEFAULT_INSTRUCTION}.
|
| 182 |
+
**kwargs: Additional parameters to update in the generation configuration.
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
The generated summary.
|
| 186 |
+
"""
|
| 187 |
joined_text = "\n".join(text_list)
|
| 188 |
prompt = f"{instruction}\n\n{joined_text}\n"
|
| 189 |
if kwargs:
|
| 190 |
self.update_generation_config(**kwargs)
|
| 191 |
st = time.perf_counter()
|
| 192 |
+
self.logger.info(f"inference on {len(text_list)} texts ...")
|
| 193 |
result = self.aggregator(
|
| 194 |
prompt,
|
| 195 |
generation_config=self.aggregator.model.generation_config,
|
|
|
|
| 200 |
)
|
| 201 |
return result
|
| 202 |
|
| 203 |
+
def count_tokens(self, text: str) -> int:
|
| 204 |
+
"""count the number of tokens in a text"""
|
| 205 |
return (
|
| 206 |
len(self.aggregator.tokenizer.encode(text, truncation=False, padding=False))
|
| 207 |
if text
|
app.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
"""
|
| 2 |
-
app.py - the main module for the gradio app
|
| 3 |
|
| 4 |
Usage:
|
| 5 |
python app.py
|
|
@@ -19,6 +19,7 @@ import random
|
|
| 19 |
import re
|
| 20 |
import time
|
| 21 |
from pathlib import Path
|
|
|
|
| 22 |
|
| 23 |
os.environ["USE_TORCH"] = "1"
|
| 24 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
@@ -31,16 +32,18 @@ logging.basicConfig(
|
|
| 31 |
import gradio as gr
|
| 32 |
import nltk
|
| 33 |
import torch
|
|
|
|
| 34 |
from cleantext import clean
|
| 35 |
from doctr.models import ocr_predictor
|
| 36 |
-
|
| 37 |
from pdf2text import convert_PDF_to_Text
|
| 38 |
from summarize import load_model_and_tokenizer, summarize_via_tokenbatches
|
| 39 |
from utils import (
|
|
|
|
| 40 |
load_example_filenames,
|
| 41 |
saves_summary,
|
| 42 |
textlist2html,
|
| 43 |
truncate_word_count,
|
|
|
|
| 44 |
)
|
| 45 |
|
| 46 |
_here = Path(__file__).parent
|
|
@@ -57,10 +60,76 @@ MODEL_OPTIONS = [
|
|
| 57 |
"pszemraj/pegasus-x-large-book-summary",
|
| 58 |
] # models users can choose from
|
| 59 |
|
|
|
|
|
|
|
| 60 |
# if duplicating space,, uncomment this line to adjust the max words
|
| 61 |
# os.environ["APP_MAX_WORDS"] = str(2048) # set the max words to 2048
|
| 62 |
# os.environ["APP_OCR_MAX_PAGES"] = str(40) # set the max pages to 40
|
| 63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
def predict(
|
| 66 |
input_text: str,
|
|
@@ -128,6 +197,7 @@ def proc_submission(
|
|
| 128 |
str in HTML format, string of the summary, str of score
|
| 129 |
"""
|
| 130 |
|
|
|
|
| 131 |
settings = {
|
| 132 |
"length_penalty": float(length_penalty),
|
| 133 |
"repetition_penalty": float(repetition_penalty),
|
|
@@ -208,7 +278,6 @@ def proc_submission(
|
|
| 208 |
# save to file
|
| 209 |
settings["model_name"] = model_name
|
| 210 |
saved_file = saves_summary(summarize_output=_summaries, outpath=None, **settings)
|
| 211 |
-
|
| 212 |
return html, full_summary, scores_out, saved_file
|
| 213 |
|
| 214 |
|
|
@@ -361,7 +430,7 @@ if __name__ == "__main__":
|
|
| 361 |
summarize_button = gr.Button(
|
| 362 |
"Summarize!",
|
| 363 |
variant="primary",
|
| 364 |
-
)
|
| 365 |
output_text = gr.HTML("<p><em>Output will appear below:</em></p>")
|
| 366 |
with gr.Column():
|
| 367 |
gr.Markdown("#### Results & Scores")
|
|
@@ -384,11 +453,19 @@ if __name__ == "__main__":
|
|
| 384 |
label="Summary Scores",
|
| 385 |
placeholder="Summary scores will appear here",
|
| 386 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
|
| 388 |
-
gr.Markdown("#### **Summary Output**")
|
| 389 |
-
summary_text = gr.HTML(
|
| 390 |
-
label="Summary", value="<i>Summary will appear here!</i>"
|
| 391 |
-
)
|
| 392 |
gr.Markdown("---")
|
| 393 |
with gr.Column():
|
| 394 |
gr.Markdown("### Advanced Settings")
|
|
@@ -456,5 +533,9 @@ if __name__ == "__main__":
|
|
| 456 |
],
|
| 457 |
outputs=[output_text, summary_text, summary_scores, text_file],
|
| 458 |
)
|
| 459 |
-
|
| 460 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
+
app.py - the main module for the gradio app for summarization
|
| 3 |
|
| 4 |
Usage:
|
| 5 |
python app.py
|
|
|
|
| 19 |
import re
|
| 20 |
import time
|
| 21 |
from pathlib import Path
|
| 22 |
+
import pprint as pp
|
| 23 |
|
| 24 |
os.environ["USE_TORCH"] = "1"
|
| 25 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
| 32 |
import gradio as gr
|
| 33 |
import nltk
|
| 34 |
import torch
|
| 35 |
+
from aggregate import BatchAggregator
|
| 36 |
from cleantext import clean
|
| 37 |
from doctr.models import ocr_predictor
|
|
|
|
| 38 |
from pdf2text import convert_PDF_to_Text
|
| 39 |
from summarize import load_model_and_tokenizer, summarize_via_tokenbatches
|
| 40 |
from utils import (
|
| 41 |
+
extract_batches,
|
| 42 |
load_example_filenames,
|
| 43 |
saves_summary,
|
| 44 |
textlist2html,
|
| 45 |
truncate_word_count,
|
| 46 |
+
remove_stagnant_files,
|
| 47 |
)
|
| 48 |
|
| 49 |
_here = Path(__file__).parent
|
|
|
|
| 60 |
"pszemraj/pegasus-x-large-book-summary",
|
| 61 |
] # models users can choose from
|
| 62 |
|
| 63 |
+
SUMMARY_PLACEHOLDER = "<p><em>Output will appear below:</em></p>"
|
| 64 |
+
|
| 65 |
# if duplicating space,, uncomment this line to adjust the max words
|
| 66 |
# os.environ["APP_MAX_WORDS"] = str(2048) # set the max words to 2048
|
| 67 |
# os.environ["APP_OCR_MAX_PAGES"] = str(40) # set the max pages to 40
|
| 68 |
|
| 69 |
+
aggregator = BatchAggregator("MBZUAI/LaMini-Flan-T5-783M")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def aggregate_text(
|
| 73 |
+
summary_text: str,
|
| 74 |
+
text_file: gr.inputs.File = None,
|
| 75 |
+
):
|
| 76 |
+
"""
|
| 77 |
+
Aggregate the text from the batches.
|
| 78 |
+
|
| 79 |
+
NOTE: you should probably include passing the BatchAggregator object as a parameter if using this code
|
| 80 |
+
outside of this file.
|
| 81 |
+
:param batches_html: The batches to aggregate, in html format
|
| 82 |
+
"""
|
| 83 |
+
if summary_text is None or summary_text == SUMMARY_PLACEHOLDER:
|
| 84 |
+
logging.error("No text provided. Make sure a summary has been generated first.")
|
| 85 |
+
return "Error: No text provided. Make sure a summary has been generated first."
|
| 86 |
+
|
| 87 |
+
try:
|
| 88 |
+
extracted_batches = extract_batches(summary_text)
|
| 89 |
+
except Exception as e:
|
| 90 |
+
logging.info(summary_text)
|
| 91 |
+
logging.info(f"the batches html is: {type(summary_text)}")
|
| 92 |
+
return f"Error: unable to extract batches - check input: {e}"
|
| 93 |
+
if not extracted_batches:
|
| 94 |
+
logging.error("unable to extract batches - check input")
|
| 95 |
+
return "Error: unable to extract batches - check input"
|
| 96 |
+
|
| 97 |
+
out_path = None
|
| 98 |
+
if text_file is not None:
|
| 99 |
+
out_path = text_file.name # assuming name attribute stores the file path
|
| 100 |
+
|
| 101 |
+
content_batches = [batch["content"] for batch in extracted_batches]
|
| 102 |
+
full_summary = aggregator.infer_aggregate(content_batches)
|
| 103 |
+
|
| 104 |
+
# if a path that exists is provided, save the summary with markdown formatting
|
| 105 |
+
if out_path:
|
| 106 |
+
out_path = Path(out_path)
|
| 107 |
+
|
| 108 |
+
try:
|
| 109 |
+
with open(out_path, "a", encoding="utf-8") as f:
|
| 110 |
+
f.write("\n\n### Aggregate Summary\n\n")
|
| 111 |
+
f.write(
|
| 112 |
+
"- This is an instruction-based LLM aggregation of the previous 'summary batches'.\n"
|
| 113 |
+
)
|
| 114 |
+
f.write(f"- Aggregation model: {aggregator.model_name}\n\n")
|
| 115 |
+
f.write(f"{full_summary}\n\n")
|
| 116 |
+
logging.info(f"Updated {out_path} with aggregate summary")
|
| 117 |
+
except Exception as e:
|
| 118 |
+
logging.error(f"unable to update {out_path} with aggregate summary: {e}")
|
| 119 |
+
|
| 120 |
+
full_summary_html = f"""
|
| 121 |
+
<div style="
|
| 122 |
+
margin-bottom: 20px;
|
| 123 |
+
font-size: 18px;
|
| 124 |
+
line-height: 1.5em;
|
| 125 |
+
color: #333;
|
| 126 |
+
">
|
| 127 |
+
<h2 style="font-size: 22px; color: #555;">Aggregate Summary:</h2>
|
| 128 |
+
<p style="white-space: pre-line;">{full_summary}</p>
|
| 129 |
+
</div>
|
| 130 |
+
"""
|
| 131 |
+
return full_summary_html
|
| 132 |
+
|
| 133 |
|
| 134 |
def predict(
|
| 135 |
input_text: str,
|
|
|
|
| 197 |
str in HTML format, string of the summary, str of score
|
| 198 |
"""
|
| 199 |
|
| 200 |
+
remove_stagnant_files() # clean up old files
|
| 201 |
settings = {
|
| 202 |
"length_penalty": float(length_penalty),
|
| 203 |
"repetition_penalty": float(repetition_penalty),
|
|
|
|
| 278 |
# save to file
|
| 279 |
settings["model_name"] = model_name
|
| 280 |
saved_file = saves_summary(summarize_output=_summaries, outpath=None, **settings)
|
|
|
|
| 281 |
return html, full_summary, scores_out, saved_file
|
| 282 |
|
| 283 |
|
|
|
|
| 430 |
summarize_button = gr.Button(
|
| 431 |
"Summarize!",
|
| 432 |
variant="primary",
|
| 433 |
+
) # TODO: collapse button to be on same line as something else
|
| 434 |
output_text = gr.HTML("<p><em>Output will appear below:</em></p>")
|
| 435 |
with gr.Column():
|
| 436 |
gr.Markdown("#### Results & Scores")
|
|
|
|
| 453 |
label="Summary Scores",
|
| 454 |
placeholder="Summary scores will appear here",
|
| 455 |
)
|
| 456 |
+
with gr.Column():
|
| 457 |
+
gr.Markdown("#### **Summary Output**")
|
| 458 |
+
summary_text = gr.HTML(
|
| 459 |
+
label="Summary", value="<i>Summary will appear here!</i>"
|
| 460 |
+
)
|
| 461 |
+
with gr.Column():
|
| 462 |
+
gr.Markdown("##### **Aggregate Summary Batches**")
|
| 463 |
+
aggregate_button = gr.Button(
|
| 464 |
+
"Aggregate!",
|
| 465 |
+
variant="primary",
|
| 466 |
+
) # TODO: collapse button to be on same line as something else
|
| 467 |
+
aggregated_summary = gr.HTML(label="Aggregate Summary", value="")
|
| 468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
gr.Markdown("---")
|
| 470 |
with gr.Column():
|
| 471 |
gr.Markdown("### Advanced Settings")
|
|
|
|
| 533 |
],
|
| 534 |
outputs=[output_text, summary_text, summary_scores, text_file],
|
| 535 |
)
|
| 536 |
+
aggregate_button.click(
|
| 537 |
+
fn=aggregate_text,
|
| 538 |
+
inputs=[summary_text, text_file],
|
| 539 |
+
outputs=[aggregated_summary],
|
| 540 |
+
)
|
| 541 |
+
demo.launch(enable_queue=True, share=True)
|