Spaces:
Runtime error
Runtime error
| import sys | |
| import os | |
| prj_root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| sys.path.append(prj_root_path) | |
| from code_interpreter.JuypyterClient import JupyterNotebook | |
| from code_interpreter.BaseCodeInterpreter import BaseCodeInterpreter | |
| from utils.const import * | |
| from typing import List, Literal, Optional, Tuple, TypedDict, Dict | |
| from colorama import init, Fore, Style | |
| import copy | |
| import re | |
| import torch | |
| import transformers | |
| from transformers import LlamaForCausalLM, LlamaTokenizer | |
| from peft import PeftModel | |
| sys.path.append(os.path.dirname(__file__)) | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
| from finetuning.conversation_template import msg_to_code_result_tok_temp | |
| from utils.special_tok_llama2 import ( | |
| B_CODE, | |
| E_CODE, | |
| B_RESULT, | |
| E_RESULT, | |
| B_INST, | |
| E_INST, | |
| B_SYS, | |
| E_SYS, | |
| DEFAULT_PAD_TOKEN, | |
| DEFAULT_BOS_TOKEN, | |
| DEFAULT_EOS_TOKEN, | |
| DEFAULT_UNK_TOKEN, | |
| IGNORE_INDEX, | |
| ) | |
| import warnings | |
| warnings.filterwarnings("ignore", category=UserWarning, module="transformers") | |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" | |
| class LlamaCodeInterpreter(BaseCodeInterpreter): | |
| def __init__( | |
| self, | |
| model_path: str, | |
| load_in_8bit: bool = False, | |
| load_in_4bit: bool = False, | |
| peft_model: Optional[str] = None, | |
| ): | |
| # build tokenizer | |
| self.tokenizer = LlamaTokenizer.from_pretrained( | |
| model_path, | |
| padding_side="right", | |
| use_fast=False, | |
| ) | |
| # Handle special tokens | |
| special_tokens_dict = dict() | |
| if self.tokenizer.pad_token is None: | |
| special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN # 32000 | |
| if self.tokenizer.eos_token is None: | |
| special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN # 2 | |
| if self.tokenizer.bos_token is None: | |
| special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN # 1 | |
| if self.tokenizer.unk_token is None: | |
| special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN | |
| self.tokenizer.add_special_tokens(special_tokens_dict) | |
| self.tokenizer.add_tokens( | |
| [B_CODE, E_CODE, B_RESULT, E_RESULT, B_INST, E_INST, B_SYS, E_SYS], | |
| special_tokens=True, | |
| ) | |
| self.model = LlamaForCausalLM.from_pretrained( | |
| model_path, | |
| device_map="auto", | |
| load_in_4bit=load_in_4bit, | |
| load_in_8bit=load_in_8bit, | |
| torch_dtype=torch.float16, | |
| ) | |
| self.model.resize_token_embeddings(len(self.tokenizer)) | |
| if peft_model is not None: | |
| peft_model = PeftModel.from_pretrained(self.model, peft_model) | |
| self.model = self.model.eval() | |
| self.dialog = [ | |
| { | |
| "role": "system", | |
| "content": CODE_INTERPRETER_SYSTEM_PROMPT + "\nUse code to answer", | |
| }, | |
| # {"role": "user", "content": "How can I use BeautifulSoup to scrape a website and extract all the URLs on a page?"}, | |
| # {"role": "assistant", "content": "I think I need to use beatifulsoup to find current korean president,"} | |
| ] | |
| self.nb = JupyterNotebook() | |
| self.MAX_CODE_OUTPUT_LENGTH = 3000 | |
| out = self.nb.add_and_run(TOOLS_CODE) # tool import | |
| print(out) | |
| def dialog_to_prompt(self, dialog: List[Dict]) -> str: | |
| full_str = msg_to_code_result_tok_temp(dialog) | |
| return full_str | |
| def generate( | |
| self, | |
| prompt: str = "[INST]\n###User : hi\n###Assistant :", | |
| max_new_tokens=512, | |
| do_sample: bool = True, | |
| use_cache: bool = True, | |
| top_p: float = 0.95, | |
| temperature: float = 0.1, | |
| top_k: int = 50, | |
| repetition_penalty: float = 1.0, | |
| ) -> str: | |
| # Get the model and tokenizer, and tokenize the user text. | |
| input_prompt = copy.deepcopy(prompt) | |
| inputs = self.tokenizer([prompt], return_tensors="pt") | |
| input_tokens_shape = inputs["input_ids"].shape[-1] | |
| eos_token_id = self.tokenizer.convert_tokens_to_ids(DEFAULT_EOS_TOKEN) | |
| e_code_token_id = self.tokenizer.convert_tokens_to_ids(E_CODE) | |
| output = self.model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=do_sample, | |
| top_p=top_p, | |
| temperature=temperature, | |
| use_cache=use_cache, | |
| top_k=top_k, | |
| repetition_penalty=repetition_penalty, | |
| eos_token_id=[ | |
| eos_token_id, | |
| e_code_token_id, | |
| ], # Stop generation at either EOS or E_CODE token | |
| )[0] | |
| generated_tokens = output[input_tokens_shape:] | |
| generated_text = self.tokenizer.decode(generated_tokens) | |
| return generated_text | |
| def extract_code_blocks(self, prompt: str) -> Tuple[bool, str]: | |
| pattern = re.escape(B_CODE) + r"(.*?)" + re.escape(E_CODE) | |
| matches = re.findall(pattern, prompt, re.DOTALL) | |
| if matches: | |
| # Return the last matched code block | |
| return True, matches[-1].strip() | |
| else: | |
| return False, "" | |
| def clean_code_output(self, output: str) -> str: | |
| if self.MAX_CODE_OUTPUT_LENGTH < len(output): | |
| return ( | |
| output[: self.MAX_CODE_OUTPUT_LENGTH // 5] | |
| + "...(skip)..." | |
| + output[-self.MAX_CODE_OUTPUT_LENGTH // 5 :] | |
| ) | |
| return output | |
| def chat(self, user_message: str, VERBOSE: bool = False, MAX_TRY=5): | |
| self.dialog.append({"role": "user", "content": user_message}) | |
| if VERBOSE: | |
| print( | |
| "###User : " + Fore.BLUE + Style.BRIGHT + user_message + Style.RESET_ALL | |
| ) | |
| print("\n###Assistant : ") | |
| # setup | |
| HAS_CODE = False # For now | |
| INST_END_TOK_FLAG = False | |
| full_generated_text = "" | |
| prompt = self.dialog_to_prompt(dialog=self.dialog) | |
| start_prompt = copy.deepcopy(prompt) | |
| prompt = f"{prompt} {E_INST}" | |
| generated_text = self.generate(prompt) | |
| full_generated_text += generated_text | |
| HAS_CODE, generated_code_block = self.extract_code_blocks(generated_text) | |
| attempt = 1 | |
| while HAS_CODE: | |
| if attempt > MAX_TRY: | |
| break | |
| # if no code then doesn't have to execute it | |
| # replace unknown thing to none | |
| generated_code_block = generated_code_block.replace("<unk>_", "").replace( | |
| "<unk>", "" | |
| ) | |
| code_block_output, error_flag = self.execute_code_and_return_output( | |
| f"{generated_code_block}" | |
| ) | |
| code_block_output = self.clean_code_output(code_block_output) | |
| generated_text = ( | |
| f"{generated_text}\n{B_RESULT}\n{code_block_output}\n{E_RESULT}\n" | |
| ) | |
| full_generated_text += f"\n{B_RESULT}\n{code_block_output}\n{E_RESULT}\n" | |
| first_code_block_pos = ( | |
| generated_text.find(generated_code_block) | |
| if generated_code_block | |
| else -1 | |
| ) | |
| text_before_first_code_block = ( | |
| generated_text | |
| if first_code_block_pos == -1 | |
| else generated_text[:first_code_block_pos] | |
| ) | |
| if VERBOSE: | |
| print(Fore.GREEN + text_before_first_code_block + Style.RESET_ALL) | |
| print(Fore.GREEN + generated_code_block + Style.RESET_ALL) | |
| print( | |
| Fore.YELLOW | |
| + f"\n{B_RESULT}\n{code_block_output}\n{E_RESULT}\n" | |
| + Style.RESET_ALL | |
| ) | |
| # prompt = f"{prompt} {E_INST}{generated_text}" | |
| prompt = f"{prompt}{generated_text}" | |
| generated_text = self.generate(prompt) | |
| HAS_CODE, generated_code_block = self.extract_code_blocks(generated_text) | |
| full_generated_text += generated_text | |
| attempt += 1 | |
| if VERBOSE: | |
| print(Fore.GREEN + generated_text + Style.RESET_ALL) | |
| self.dialog.append( | |
| { | |
| "role": "assistant", | |
| "content": full_generated_text.replace("<unk>_", "") | |
| .replace("<unk>", "") | |
| .replace("</s>", ""), | |
| } | |
| ) | |
| return self.dialog[-1] | |
| if __name__ == "__main__": | |
| import random | |
| LLAMA2_MODEL_PATH = "./ckpt/llama-2-13b-chat" | |
| LLAMA2_MODEL_PATH = "meta-llama/Llama-2-70b-chat-hf" | |
| LLAMA2_FINETUNEED_PATH = "./output/llama-2-7b-chat-ci" | |
| interpreter = LlamaCodeInterpreter( | |
| model_path=LLAMA2_FINETUNEED_PATH, load_in_4bit=True | |
| ) | |
| output = interpreter.chat( | |
| user_message=random.choice( | |
| [ | |
| # "In a circle with center \( O \), \( AB \) is a chord such that the midpoint of \( AB \) is \( M \). A tangent at \( A \) intersects the extended segment \( OB \) at \( P \). If \( AM = 12 \) cm and \( MB = 12 \) cm, find the length of \( AP \)." | |
| # "A triangle \( ABC \) is inscribed in a circle (circumscribed). The sides \( AB \), \( BC \), and \( AC \) are tangent to the circle at points \( P \), \( Q \), and \( R \) respectively. If \( AP = 10 \) cm, \( BQ = 15 \) cm, and \( CR = 20 \) cm, find the radius of the circle.", | |
| # "Given an integer array nums, return the total number of contiguous subarrays that have a sum equal to 0.", | |
| "what is second largest city in japan?", | |
| # "Can you show me 120days chart of tesla from today to before 120?" | |
| ] | |
| ), | |
| VERBOSE=True, | |
| ) | |
| while True: | |
| input_char = input("Press 'q' to quit the dialog: ") | |
| if input_char.lower() == "q": | |
| break | |
| else: | |
| output = interpreter.chat(user_message=input_char, VERBOSE=True) | |