Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import random | |
| import time, os | |
| import copy | |
| import re | |
| import torch | |
| from rich.console import Console | |
| from rich.table import Table | |
| from datetime import datetime | |
| from threading import Thread | |
| from typing import Optional | |
| from transformers import TextIteratorStreamer | |
| from utils.special_tok_llama2 import ( | |
| B_CODE, | |
| E_CODE, | |
| B_RESULT, | |
| E_RESULT, | |
| B_INST, | |
| E_INST, | |
| B_SYS, | |
| E_SYS, | |
| DEFAULT_PAD_TOKEN, | |
| DEFAULT_BOS_TOKEN, | |
| DEFAULT_EOS_TOKEN, | |
| DEFAULT_UNK_TOKEN, | |
| IGNORE_INDEX, | |
| ) | |
| from finetuning.conversation_template import ( | |
| json_to_code_result_tok_temp, | |
| msg_to_code_result_tok_temp, | |
| ) | |
| import warnings | |
| warnings.filterwarnings("ignore", category=UserWarning, module="transformers") | |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" | |
| from code_interpreter.LlamaCodeInterpreter import LlamaCodeInterpreter | |
| class StreamingLlamaCodeInterpreter(LlamaCodeInterpreter): | |
| streamer: Optional[TextIteratorStreamer] = None | |
| # overwirte generate function | |
| def generate( | |
| self, | |
| prompt: str = "[INST]\n###User : hi\n###Assistant :", | |
| max_new_tokens=512, | |
| do_sample: bool = True, | |
| use_cache: bool = True, | |
| top_p: float = 0.95, | |
| temperature: float = 0.1, | |
| top_k: int = 50, | |
| repetition_penalty: float = 1.0, | |
| ) -> str: | |
| # Get the model and tokenizer, and tokenize the user text. | |
| self.streamer = TextIteratorStreamer( | |
| self.tokenizer, skip_prompt=True, Timeout=5 | |
| ) | |
| input_prompt = copy.deepcopy(prompt) | |
| inputs = self.tokenizer([prompt], return_tensors="pt") | |
| input_tokens_shape = inputs["input_ids"].shape[-1] | |
| eos_token_id = self.tokenizer.convert_tokens_to_ids(DEFAULT_EOS_TOKEN) | |
| e_code_token_id = self.tokenizer.convert_tokens_to_ids(E_CODE) | |
| kwargs = dict( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=do_sample, | |
| top_p=top_p, | |
| temperature=temperature, | |
| use_cache=use_cache, | |
| top_k=top_k, | |
| repetition_penalty=repetition_penalty, | |
| eos_token_id=[ | |
| eos_token_id, | |
| e_code_token_id, | |
| ], # Stop generation at either EOS or E_CODE token | |
| streamer=self.streamer, | |
| ) | |
| thread = Thread(target=self.model.generate, kwargs=kwargs) | |
| thread.start() | |
| return "" | |
| def change_markdown_image(text: str): | |
| modified_text = re.sub(r"!\[(.*?)\]\(\'(.*?)\'\)", r"", text) | |
| return modified_text | |
| def gradio_launch(model_path: str, load_in_4bit: bool = True, MAX_TRY: int = 5): | |
| with gr.Blocks(theme=gr.themes.Monochrome()) as demo: | |
| chatbot = gr.Chatbot(height=820, avatar_images="./assets/logo2.png") | |
| msg = gr.Textbox() | |
| clear = gr.Button("Clear") | |
| interpreter = StreamingLlamaCodeInterpreter( | |
| model_path=model_path, load_in_4bit=load_in_4bit | |
| ) | |
| def bot(history): | |
| user_message = history[-1][0] | |
| interpreter.dialog.append({"role": "user", "content": user_message}) | |
| print(f"###User : [bold]{user_message}[bold]") | |
| # print(f"###Assistant : ") | |
| # setup | |
| HAS_CODE = False # For now | |
| INST_END_TOK_FLAG = False | |
| full_generated_text = "" | |
| prompt = interpreter.dialog_to_prompt(dialog=interpreter.dialog) | |
| start_prompt = copy.deepcopy(prompt) | |
| prompt = f"{prompt} {E_INST}" | |
| _ = interpreter.generate(prompt) | |
| history[-1][1] = "" | |
| generated_text = "" | |
| for character in interpreter.streamer: | |
| history[-1][1] += character | |
| generated_text += character | |
| yield history | |
| full_generated_text += generated_text | |
| HAS_CODE, generated_code_block = interpreter.extract_code_blocks( | |
| generated_text | |
| ) | |
| attempt = 1 | |
| while HAS_CODE: | |
| if attempt > MAX_TRY: | |
| break | |
| # if no code then doesn't have to execute it | |
| # refine code block for history | |
| history[-1][1] = ( | |
| history[-1][1] | |
| .replace(f"{B_CODE}", "\n```python\n") | |
| .replace(f"{E_CODE}", "\n```\n") | |
| ) | |
| history[-1][1] = change_markdown_image(history[-1][1]) | |
| yield history | |
| # replace unknown thing to none '' | |
| generated_code_block = generated_code_block.replace( | |
| "<unk>_", "" | |
| ).replace("<unk>", "") | |
| ( | |
| code_block_output, | |
| error_flag, | |
| ) = interpreter.execute_code_and_return_output( | |
| f"{generated_code_block}" | |
| ) | |
| code_block_output = interpreter.clean_code_output(code_block_output) | |
| generated_text = ( | |
| f"{generated_text}\n{B_RESULT}\n{code_block_output}\n{E_RESULT}\n" | |
| ) | |
| full_generated_text += ( | |
| f"\n{B_RESULT}\n{code_block_output}\n{E_RESULT}\n" | |
| ) | |
| # append code output | |
| history[-1][1] += f"\n```RESULT\n{code_block_output}\n```\n" | |
| history[-1][1] = change_markdown_image(history[-1][1]) | |
| yield history | |
| prompt = f"{prompt} {generated_text}" | |
| _ = interpreter.generate(prompt) | |
| for character in interpreter.streamer: | |
| history[-1][1] += character | |
| generated_text += character | |
| history[-1][1] = change_markdown_image(history[-1][1]) | |
| yield history | |
| HAS_CODE, generated_code_block = interpreter.extract_code_blocks( | |
| generated_text | |
| ) | |
| if generated_text.endswith("</s>"): | |
| break | |
| attempt += 1 | |
| interpreter.dialog.append( | |
| { | |
| "role": "assistant", | |
| "content": generated_text.replace("<unk>_", "") | |
| .replace("<unk>", "") | |
| .replace("</s>", ""), | |
| } | |
| ) | |
| print("----------\n" * 2) | |
| print(interpreter.dialog) | |
| print("----------\n" * 2) | |
| return history[-1][1] | |
| def user(user_message, history): | |
| return "", history + [[user_message, None]] | |
| msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then( | |
| bot, chatbot, chatbot | |
| ) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| demo.queue() | |
| demo.launch() | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Process path for LLAMA2_FINETUNEED.") | |
| parser.add_argument( | |
| "--path", | |
| type=str, | |
| required=True, | |
| help="Path to the finetuned LLAMA2 model.", | |
| default="./output/llama-2-7b-codellama-ci", | |
| ) | |
| args = parser.parse_args() | |
| gradio_launch(model_path=args.path, load_in_4bit=True) | |