Spaces:
Runtime error
Runtime error
| """ | |
| app.py - the main file for the app. This creates the flask app and handles the routes. | |
| """ | |
| import argparse | |
| import logging | |
| import os | |
| import sys | |
| import time | |
| import warnings | |
| from os.path import dirname | |
| from pathlib import Path | |
| import gradio as gr | |
| import nltk | |
| import torch | |
| from cleantext import clean | |
| from gradio.inputs import Slider, Textbox | |
| from transformers import pipeline | |
| from converse import discussion | |
| from grammar_improve import ( | |
| build_symspell_obj, | |
| detect_propers, | |
| fix_punct_spacing, | |
| load_ns_checker, | |
| neuspell_correct, | |
| remove_repeated_words, | |
| remove_trailing_punctuation, | |
| symspeller, | |
| synthesize_grammar, | |
| ) | |
| from utils import corr | |
| nltk.download("stopwords") # download stopwords | |
| sys.path.append(dirname(dirname(os.path.abspath(__file__)))) | |
| warnings.filterwarnings(action="ignore", message=".*gradient_checkpointing*") | |
| import transformers | |
| transformers.logging.set_verbosity_error() | |
| logging.basicConfig() | |
| cwd = Path.cwd() | |
| my_cwd = str(cwd.resolve()) # string so it can be passed to os.path() objects | |
| def chat(prompt_message, temperature=0.7, top_p=0.95, top_k=50): | |
| """ | |
| chat - helper function that makes the whole gradio thing work. | |
| Args: | |
| trivia_query (str): the question to ask the bot | |
| Returns: | |
| [str]: the bot's response | |
| """ | |
| history = [] | |
| response = ask_gpt( | |
| message=prompt_message, | |
| chat_pipe=my_chatbot, | |
| top_p=top_p, | |
| top_k=top_k, | |
| temperature=temperature, | |
| ) | |
| history = [prompt_message, response] | |
| html = "" | |
| for item in history: | |
| html += f"<b>{item}</b> <br><br>" | |
| html += "" | |
| return html | |
| def ask_gpt( | |
| message: str, | |
| chat_pipe, | |
| speaker="person alpha", | |
| responder="person beta", | |
| max_len=96, | |
| top_p=0.95, | |
| top_k=25, | |
| temperature=0.6, | |
| ): | |
| """ | |
| ask_gpt - a function that takes in a prompt and generates a response using the pipeline. This interacts the discussion function. | |
| Parameters: | |
| message (str): the question to ask the bot | |
| chat_pipe (str): the chat_pipe to use for the bot (default: "pszemraj/Ballpark-Trivia-XL") | |
| speaker (str): the name of the speaker (default: "person alpha") | |
| responder (str): the name of the responder (default: "person beta") | |
| max_len (int): the maximum length of the response (default: 128) | |
| top_p (float): the top probability threshold (default: 0.95) | |
| top_k (int): the top k threshold (default: 50) | |
| temperature (float): the temperature of the response (default: 0.7) | |
| """ | |
| st = time.perf_counter() | |
| prompt = clean(message) # clean user input | |
| prompt = prompt.strip() # get rid of any extra whitespace | |
| in_len = len(prompt) | |
| if in_len > 512: | |
| prompt = prompt[-512:] # truncate to 512 chars | |
| print(f"Truncated prompt to last 512 chars: started with {in_len} chars") | |
| max_len = min(max_len, 512) | |
| resp = discussion( | |
| prompt_text=prompt, | |
| pipeline=chat_pipe, | |
| speaker=speaker, | |
| responder=responder, | |
| top_p=top_p, | |
| top_k=top_k, | |
| temperature=temperature, | |
| max_length=max_len, | |
| ) | |
| gpt_et = time.perf_counter() | |
| gpt_rt = round(gpt_et - st, 2) | |
| rawtxt = resp["out_text"] | |
| # check for proper nouns | |
| if basic_sc and not detect_propers(rawtxt): | |
| cln_resp = symspeller(rawtxt, sym_checker=schnellspell) | |
| elif not detect_propers(rawtxt): | |
| cln_resp = synthesize_grammar(corrector=grammarbot, message=cln_resp) | |
| else: | |
| # no correction needed | |
| cln_resp = rawtxt.strip() | |
| bot_resp_a = corr(remove_repeated_words(cln_resp)) | |
| bot_resp = fix_punct_spacing(bot_resp_a) | |
| print(f"the prompt was:\n\t{message}\nand the response was:\n\t{bot_resp}\n") | |
| corr_rt = round(time.perf_counter() - gpt_et, 4) | |
| print( | |
| f"took {gpt_rt + corr_rt} sec to respond, {gpt_rt} for GPT, {corr_rt} for correction\n" | |
| ) | |
| return remove_trailing_punctuation(bot_resp) | |
| def get_parser(): | |
| """ | |
| get_parser - a helper function for the argparse module | |
| """ | |
| parser = argparse.ArgumentParser( | |
| description="submit a question, GPT model responds" | |
| ) | |
| parser.add_argument( | |
| "-m", | |
| "--model", | |
| required=False, | |
| type=str, | |
| default="ethzanalytics/ai-msgbot-gpt2-XL", # default model | |
| help="the model to use for the chatbot on https://huggingface.co/models OR a path to a local model", | |
| ) | |
| parser.add_argument( | |
| "--gram-model", | |
| required=False, | |
| type=str, | |
| default="pszemraj/t5-v1_1-base-ft-jflAUG", | |
| help="text2text generation model ID from huggingface for the model to correct grammar", | |
| ) | |
| parser.add_argument( | |
| "--basic-sc", | |
| required=False, | |
| default=False, # TODO: change this back to False once Neuspell issues are resolved. | |
| action="store_true", | |
| help="turn on symspell (baseline) correction instead of the more advanced neural net models", | |
| ) | |
| parser.add_argument( | |
| "--verbose", | |
| action="store_true", | |
| default=False, | |
| help="turn on verbose logging", | |
| ) | |
| return parser | |
| if __name__ == "__main__": | |
| args = get_parser().parse_args() | |
| default_model = str(args.model) | |
| model_loc = Path(default_model) # if the model is a path, use it | |
| basic_sc = args.basic_sc # whether to use the baseline spellchecker | |
| gram_model = str(args.gram_model) | |
| device = 0 if torch.cuda.is_available() else -1 | |
| print(f"CUDA avail is {torch.cuda.is_available()}") | |
| my_chatbot = ( | |
| pipeline("text-generation", model=model_loc.resolve(), device=device) | |
| if model_loc.exists() and model_loc.is_dir() | |
| else pipeline("text-generation", model=default_model, device=device) | |
| ) # if the model is a name, use it. stays on CPU if no GPU available | |
| print(f"using model {my_chatbot.model}") | |
| if basic_sc: | |
| print("Using the baseline spellchecker") | |
| schnellspell = build_symspell_obj() | |
| else: | |
| print("using neural spell checker") | |
| grammarbot = pipeline("'text2text-generation", gram_model, device=device) | |
| print(f"using model stored here: \n {model_loc} \n") | |
| iface = gr.Interface( | |
| chat, | |
| inputs=[ | |
| Textbox( | |
| default="Why is everyone here eating chocolate cake?", | |
| label="prompt_message", | |
| placeholder="Enter a question", | |
| lines=2, | |
| ), | |
| Slider( | |
| minimum=0.0, maximum=1.0, step=0.01, default=0.6, label="temperature" | |
| ), | |
| Slider(minimum=0.0, maximum=1.0, step=0.01, default=0.95, label="top_p"), | |
| Slider(minimum=0, maximum=100, step=5, default=20, label="top_k"), | |
| ], | |
| outputs="html", | |
| examples_per_page=8, | |
| examples=[ | |
| ["Point Break or Bad Boys II?", 0.75, 0.95, 50], | |
| ["So... you're saying this wasn't an accident?", 0.6, 0.95, 50], | |
| ["Hi, my name is Reginald", 0.6, 0.95, 100], | |
| ["Happy birthday!", 0.9, 0.95, 50], | |
| ["I have a question, can you help me?", 0.6, 0.95, 50], | |
| ["Do you know a joke?", 0.8, 0.85, 50], | |
| ["Will you marry me?", 0.9, 0.95, 100], | |
| ["Are you single?", 0.6, 0.95, 100], | |
| ["Do you like people?", 0.7, 0.95, 25], | |
| ["You never took a short cut before?", 0.7, 0.95, 100], | |
| ], | |
| title=f"GPT Chatbot Demo: {default_model} Model", | |
| description=f"A Demo of a Chatbot trained for conversation with humans. Size XL= 1.5B parameters.\n\n" | |
| "**Important Notes & About:**\n\n" | |
| "You can find a link to the model card **[here](https://huggingface.co/ethzanalytics/ai-msgbot-gpt2-XL-dialogue)**\n\n" | |
| "1. responses can take up to 60 seconds to respond sometimes, patience is a virtue.\n" | |
| "2. the model was trained on several different datasets. fact-check responses instead of regarding as a true statement.\n" | |
| "3. Try adjusting the **[generation parameters](https://huggingface.co/blog/how-to-generate)** to get a better understanding of how they work!\n", | |
| css=""" | |
| .chatbox {display:flex;flex-direction:row} | |
| .user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%} | |
| .user_msg {background-color:cornflowerblue;color:white;align-self:start} | |
| .resp_msg {background-color:lightgray;align-self:self-end} | |
| """, | |
| allow_screenshot=True, | |
| allow_flagging="never", | |
| theme="dark", | |
| ) | |
| # launch the gradio interface and start the server | |
| iface.launch( | |
| # prevent_thread_lock=True, | |
| enable_queue=True, # also allows for dealing with multiple users simultaneously (per newer gradio version) | |
| ) | |