Spaces:
Runtime error
Runtime error
| import random | |
| import os | |
| import gradio as gr | |
| import torch | |
| from transformers import pipeline, set_seed | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import logging | |
| logger = logging.getLogger() | |
| logger.addHandler(logging.StreamHandler()) | |
| HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN", None) | |
| DEVICE = os.environ.get("DEVICE", "cpu") # cuda:0 | |
| if DEVICE != "cpu" and not torch.cuda.is_available(): | |
| DEVICE = "cpu" | |
| logger.info(f"DEVICE {DEVICE}") | |
| DTYPE = torch.float32 if DEVICE == "cpu" else torch.float16 | |
| MODEL_NAME = os.environ.get("MODEL_NAME", "bertin-project/bertin-gpt-j-6B") | |
| MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024)) | |
| HEADER_INFO = """ | |
| # BERTIN GPT-J-6B | |
| Spanish BERTIN GPT-J-6B Model. | |
| """.strip() | |
| LOGO = "https://huggingface.co/bertin-project/bertin-roberta-base-spanish/resolve/main/images/bertin.png" | |
| HEADER = f""" | |
| <link href="https://fonts.googleapis.com/css2?family=Roboto:wght@300&display=swap%22%20rel=%22stylesheet%22" rel="stylesheet"> | |
| <style> | |
| .ltr, | |
| textarea {{ | |
| font-family: Roboto !important; | |
| text-align: left; | |
| direction: ltr !important; | |
| }} | |
| .ltr-box {{ | |
| border-bottom: 1px solid #ddd; | |
| padding-bottom: 20px; | |
| }} | |
| .rtl {{ | |
| text-align: left; | |
| direction: ltr !important; | |
| }} | |
| span.result-text {{ | |
| padding: 3px 3px; | |
| line-height: 32px; | |
| }} | |
| span.generated-text {{ | |
| background-color: rgb(118 200 147 / 13%); | |
| }} | |
| </style> | |
| <div align=center> | |
| <img src="{LOGO}" width=150/> | |
| # BERTIN GPT-J-6B | |
| BERTIN proporciona una serie de modelos de lenguaje en Español entrenados en abierto. | |
| Este modelo ha sido entrenado con [Mesh Transformer JAX](https://github.com/kingoflolz/mesh-transformer-jax) en TPUs proporcionadas por Google a través del programa Tensor Research Cloud, a partir del modelo [GPT-J de EleutherAI](https://huggingface.co/EleutherAI/gpt-j-6B) con el corpus [mC4-es-sampled (gaussian)](https://huggingface.co/datasets/bertin-project/mc4-es-sampled). Esta demo funciona sobre una GPU proporcionada por HuggingFace. | |
| </div> | |
| """ | |
| FOOTER = """ | |
| Para más información, visite el [repositorio del modelo](https://huggingface.co/bertin-project/bertin-gpt-j-6B). | |
| """.strip() | |
| class Normalizer: | |
| def remove_repetitions(self, text): | |
| """Remove repetitions""" | |
| first_ocurrences = [] | |
| for sentence in text.split("."): | |
| if sentence not in first_ocurrences: | |
| first_ocurrences.append(sentence) | |
| return '.'.join(first_ocurrences) | |
| def trim_last_sentence(self, text): | |
| """Trim last sentence if incomplete""" | |
| return text[:text.rfind(".") + 1] | |
| def clean_txt(self, text): | |
| return self.trim_last_sentence(self.remove_repetitions(text)) | |
| class TextGeneration: | |
| def __init__(self): | |
| self.tokenizer = None | |
| self.generator = None | |
| self.task = "text-generation" | |
| self.model_name_or_path = MODEL_NAME | |
| set_seed(42) | |
| def load(self): | |
| logger.info("Loading model...") | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| self.model_name_or_path, use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None, | |
| ) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_name_or_path, use_auth_token=HF_AUTH_TOKEN if HF_AUTH_TOKEN else None, | |
| pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.eos_token_id, | |
| torch_dtype=DTYPE, low_cpu_mem_usage=False if DEVICE == "cpu" else True | |
| ).to(device=DEVICE, non_blocking=False) | |
| _ = self.model.eval() | |
| device_number = -1 if DEVICE == "cpu" else int(DEVICE.split(":")[-1]) | |
| self.generator = pipeline(self.task, model=self.model, tokenizer=self.tokenizer, device=device_number) | |
| logger.info("Loading model done.") | |
| # with torch.no_grad(): | |
| # tokens = tokenizer.encode(prompt, return_tensors='pt').to(device=device, non_blocking=True) | |
| # gen_tokens = self.model.generate(tokens, do_sample=True, temperature=0.8, max_length=128) | |
| # generated = tokenizer.batch_decode(gen_tokens)[0] | |
| # return generated | |
| def generate(self, text, generation_kwargs): | |
| max_length = len(self.tokenizer(text)["input_ids"]) + generation_kwargs["max_length"] | |
| generation_kwargs["max_length"] = min(max_length, self.model.config.n_positions) | |
| # generation_kwargs["num_return_sequences"] = 1 | |
| # generation_kwargs["return_full_text"] = False | |
| generated_text = None | |
| if text: | |
| for _ in range(10): | |
| generated_text = self.generator( | |
| text, | |
| **generation_kwargs, | |
| )[0]["generated_text"] | |
| if generation_kwargs["do_clean"]: | |
| generated_text = cleaner.clean_txt(generated_text) | |
| if generated_text.strip().startswith(text): | |
| generated_text = generated_text.replace(text, "", 1).strip() | |
| if generated_text: | |
| return ( | |
| text + " " + generated_text, | |
| [(text, None), (generated_text, "BERTIN")] | |
| ) | |
| if not generated_text: | |
| return ( | |
| "", | |
| [("Tras 10 intentos BERTIN no generó nada. Pruebe cambiando las opciones", "ERROR")] | |
| ) | |
| # return (text + " " + generated_text, | |
| # f'<p class="ltr ltr-box">' | |
| # f'<span class="result-text">{text} <span>' | |
| # f'<span class="result-text generated-text">{generated_text}</span>' | |
| # f'</p>' | |
| # ) | |
| #@st.cache(hash_funcs={torch.nn.parameter.Parameter: lambda _: None}) | |
| #@st.cache(allow_output_mutation=True) | |
| #@st.cache(allow_output_mutation=True, hash_funcs={TextGeneration: lambda _: None}) | |
| def load_text_generator(): | |
| text_generator = TextGeneration() | |
| text_generator.load() | |
| return text_generator | |
| cleaner = Normalizer() | |
| generator = load_text_generator() | |
| def complete_with_gpt(text, max_length, top_k, top_p, temperature, do_sample, do_clean): | |
| generation_kwargs = { | |
| "max_length": max_length, | |
| "top_k": top_k, | |
| "top_p": top_p, | |
| "temperature": temperature, | |
| "do_sample": do_sample, | |
| "do_clean": do_clean, | |
| } | |
| return generator.generate(text, generation_kwargs) | |
| with gr.Blocks() as demo: | |
| gr.Markdown(HEADER) | |
| with gr.Row(): | |
| with gr.Group(): | |
| with gr.Box(): | |
| gr.Markdown("Opciones") | |
| max_length = gr.Slider( | |
| label='Longitud máxima', | |
| # help="Número máximo (aproximado) de palabras a generar.", | |
| minimum=1, | |
| maximum=MAX_LENGTH, | |
| value=50, | |
| step=1 | |
| ) | |
| top_k = gr.Slider( | |
| label='Top-k', | |
| # help="Número de palabras con alta probabilidad a mantener para el filtrado `top-k`", | |
| minimum=40, | |
| maximum=80, | |
| value=50, | |
| step=1 | |
| ) | |
| top_p = gr.Slider( | |
| label='Top-p', | |
| # help="Solo las palabras más probables con probabilidades que sumen `top_p` o más se mantienen para la generación.", | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.01 | |
| ) | |
| temperature = gr.Slider( | |
| label='Temperatura', | |
| # help="Valor utilizado para modular las probabilidades de las siguientes palabras generadas.", | |
| minimum=0.1, | |
| maximum=10.0, | |
| value=0.8, | |
| step=0.05 | |
| ) | |
| do_sample = gr.Checkbox( | |
| label='¿Muestrear?', | |
| value = True, | |
| # options=(True, False), | |
| # help="Si no se muestrea se usará una decodificación voraz (_greedy_).", | |
| ) | |
| do_clean = gr.Checkbox( | |
| label='¿Limpiar texto?', | |
| value = True, | |
| # options=(True, False), | |
| # help="Si eliminar o no las palabras repetidas y recortar las últimas frases sin terminar.", | |
| ) | |
| with gr.Column(): | |
| textbox = gr.Textbox(label="Texto",placeholder="Escriba algo y pulse 'Generar'...", lines=8) | |
| hidden = gr.Textbox(visible=False, show_label=False) | |
| with gr.Box(): | |
| # output = gr.Markdown() | |
| output = gr.HighlightedText(label="Resultado", combine_adjacent=True, color_map={"BERTIN": "green", "ERROR": "red"}) | |
| with gr.Row(): | |
| btn = gr.Button("Generar") | |
| btn.click(complete_with_gpt, inputs=[textbox, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[hidden, output]) | |
| edit_btn = gr.Button("Editar", variant="secondary") | |
| edit_btn.click(lambda x: (x, "", []), inputs=[hidden], outputs=[textbox, hidden, output]) | |
| clean_btn = gr.Button("Limpiar", variant="secondary") | |
| clean_btn.click(lambda: ("", "", []), inputs=[], outputs=[textbox, hidden, output]) | |
| gr.Markdown(FOOTER) | |
| demo.launch() | |
| # gr.Interface(complete_with_gpt, inputs=[textbox, max_length, top_k, top_p, temperature, do_sample, do_clean], outputs=[hidden, output]).launch() | |