Spaces:
Sleeping
Sleeping
| import re | |
| import math | |
| import random | |
| from collections import defaultdict | |
| def naive_parse(answer): | |
| out = [] | |
| start = False | |
| end = False | |
| for l in reversed(list(answer)): | |
| if l in '0123456789' and not end: | |
| start = True | |
| out.append(l) | |
| else: | |
| if start: | |
| end = True | |
| out = reversed(out) | |
| return ''.join(out) | |
| import re | |
| import sys | |
| import subprocess | |
| def return_last_print(output, n): | |
| lines = output.strip().split('\n') | |
| if lines: | |
| return lines[n] | |
| else: | |
| return "" | |
| def process_code(code, return_shell_output=False): | |
| def repl(match): | |
| if "real" not in match.group(): | |
| return "{}{}".format(match.group()[:-1], ', real=True)') | |
| else: | |
| return "{}{}".format(match.group()[:-1], ')') | |
| code = re.sub(r"symbols\([^)]+\)", repl, code) | |
| if return_shell_output: | |
| code = code.replace('\n', '\n ') | |
| # Add a try...except block | |
| code = "\ntry:\n from sympy import *\n{}\nexcept Exception as e:\n print(e)\n print('FAIL')\n".format(code) | |
| if not return_shell_output: | |
| print(code) | |
| with open('code.py', 'w') as fout: | |
| fout.write(code) | |
| batcmd = 'timeout 7 ' + sys.executable + ' code.py' | |
| try: | |
| shell_output = subprocess.check_output(batcmd, shell=True).decode('utf8') | |
| return_value = return_last_print(shell_output, -1) | |
| print(shell_output) | |
| if return_shell_output: | |
| if return_value=='FAIL': | |
| CODE_STATUS = False | |
| return_value = return_last_print(shell_output, -2) | |
| if "not defined" in return_value: | |
| return_value+='\nTry checking the formatting and imports' | |
| else: | |
| CODE_STATUS = True | |
| return return_value, CODE_STATUS | |
| code_output = round(float(eval(return_value))) % 1000 | |
| except Exception as e: | |
| print(e,'shell_output') | |
| code_output = -1 | |
| if return_shell_output: | |
| if code_output==-1: | |
| CODE_STATUS = False | |
| else: | |
| CODE_STATUS = True | |
| return code_output, CODE_STATUS | |
| return code_output | |
| def process_text_output(output): | |
| result = output | |
| try: | |
| result_output = re.findall(r'\\boxed\{(\d+)\}', result) | |
| print('BOXED', result_output) | |
| if not len(result_output): | |
| result_output = naive_parse(result) | |
| else: | |
| result_output = result_output[-1] | |
| print('BOXED FINAL', result_output) | |
| if not len(result_output): | |
| result_output = -1 | |
| else: | |
| result_output = round(float(eval(result_output))) % 1000 | |
| except Exception as e: | |
| print(e) | |
| print('ERROR PARSING TEXT') | |
| result_output = -1 | |
| return result_output | |
| from collections import defaultdict | |
| from collections import Counter | |
| def predict(problem): | |
| temperature = 0.9 | |
| top_p = 3.0 | |
| temperature_coding = 0.9 | |
| top_p_coding = 3.0 | |
| total_results = {} | |
| total_answers = {} | |
| best_stats = {} | |
| total_outputs = {} | |
| question_type_counts = {} | |
| starting_counts = (2,3) | |
| i = 0 | |
| global n_repetitions,TOTAL_TOKENS,model,tokenizer,USE_PAST_KEY,NOTEBOOK_START_TIME,promplt_options,code,cot | |
| for jj in tqdm(range(n_repetitions)): | |
| best, best_count = best_stats.get(i,(-1,-1)) | |
| if best_count>np.sqrt(jj): | |
| print("SKIPPING CAUSE ALREADY FOUND BEST") | |
| continue | |
| outputs = total_outputs.get(i,[]) | |
| text_answers, code_answers = question_type_counts.get(i,starting_counts) | |
| results = total_results.get(i,[]) | |
| answers = total_answers.get(i,[]) | |
| for _ in range(5): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| time.sleep(0.2) | |
| try: | |
| ALREADY_GEN = 0 | |
| code_error = None | |
| code_error_count = 0 | |
| code_output = -1 | |
| #initail_message = problem + tool_instruction | |
| counts = np.array([text_answers,code_answers]) | |
| draw = choice(promplt_options, 1, | |
| p=counts/counts.sum()) | |
| initail_message = draw[0].format(problem,"{}") | |
| prompt = f"User: {initail_message}" | |
| current_printed = len(prompt) | |
| print(f"{jj}_{prompt}\n") | |
| model_inputs = tokenizer(prompt, return_tensors='pt').to(model.device) | |
| input_len = len(model_inputs['input_ids'][0]) | |
| generation_output = model.generate(**model_inputs, | |
| max_new_tokens=TOTAL_TOKENS-ALREADY_GEN, | |
| return_dict_in_generate=USE_PAST_KEY, | |
| do_sample = True, | |
| temperature = temperature, | |
| top_p = top_p, | |
| num_return_sequences=1, stopping_criteria = stopping_criteria) | |
| if USE_PAST_KEY: | |
| output_ids = generation_output.sequences[0] | |
| else: | |
| output_ids = generation_output[0] | |
| decoded_output = tokenizer.decode(output_ids, skip_special_tokens=True) | |
| print(f"{decoded_output[current_printed:]}\n") | |
| current_printed += len(decoded_output[current_printed:]) | |
| cummulative_code = "" | |
| stop_word_cond = False | |
| for stop_word in stop_words: | |
| stop_word_cond = stop_word_cond or (decoded_output[-len(stop_word):]==stop_word) | |
| while (stop_word_cond) and (ALREADY_GEN<(TOTAL_TOKENS)): | |
| if (decoded_output[-len("```python"):]=="```python"): | |
| temperature_inner=temperature_coding | |
| top_p_inner = top_p_coding | |
| prompt = decoded_output | |
| else: | |
| temperature_inner=temperature | |
| top_p_inner = top_p | |
| try: | |
| if (decoded_output[-len("``````output"):]=="``````output"): | |
| code_text = decoded_output.split('```python')[-1].split("``````")[0] | |
| else: | |
| code_text = decoded_output.split('```python')[-1].split("```")[0] | |
| cummulative_code+=code_text | |
| code_output, CODE_STATUS = process_code(cummulative_code, return_shell_output=True) | |
| print('CODE RESULTS', code_output) | |
| if code_error==code_output: | |
| code_error_count+=1 | |
| else: | |
| code_error=code_output | |
| code_error_count = 0 | |
| if not CODE_STATUS: | |
| cummulative_code = cummulative_code[:-len(code_text)] | |
| if code_error_count>=1: | |
| print("REPEATED ERRORS") | |
| break | |
| except Exception as e: | |
| print(e) | |
| print('ERROR PARSING CODE') | |
| code_output = -1 | |
| if code_output!=-1: | |
| if (decoded_output[-len(")\n```"):]==")\n```"): | |
| prompt = decoded_output+'```output\n'+str(code_output)+'\n```\n' | |
| else: | |
| prompt = decoded_output+'\n'+str(code_output)+'\n```\n' | |
| else: | |
| prompt = decoded_output | |
| cummulative_code="" | |
| model_inputs = tokenizer(prompt, return_tensors='pt').to(model.device) | |
| ALREADY_GEN = len(model_inputs['input_ids'][0])-input_len | |
| if USE_PAST_KEY: | |
| old_values = generation_output.past_key_values | |
| else: | |
| old_values = None | |
| generation_output = model.generate(**model_inputs, | |
| max_new_tokens=TOTAL_TOKENS-ALREADY_GEN, | |
| return_dict_in_generate=USE_PAST_KEY, | |
| past_key_values=old_values, | |
| do_sample = True, | |
| temperature = temperature_inner, | |
| top_p = top_p_inner, | |
| num_return_sequences=1, stopping_criteria = stopping_criteria) | |
| if USE_PAST_KEY: | |
| output_ids = generation_output.sequences[0] | |
| else: | |
| output_ids = generation_output[0] | |
| decoded_output = tokenizer.decode(output_ids, skip_special_tokens=True) | |
| print(f"\nINTERMEDIATE OUT :\n{decoded_output[current_printed:]}\n") | |
| current_printed+=len(decoded_output[current_printed:]) | |
| stop_word_cond = False | |
| for stop_word in stop_words: | |
| stop_word_cond = stop_word_cond or (decoded_output[-len(stop_word):]==stop_word) | |
| if USE_PAST_KEY: | |
| output_ids = generation_output.sequences[0] | |
| else: | |
| output_ids = generation_output[0] | |
| raw_output = tokenizer.decode(output_ids[input_len:], skip_special_tokens=True) | |
| #print(f"\n\nOutput :\n{raw_output}\n") | |
| result_output = process_text_output(raw_output) | |
| try: | |
| code_output = round(float(eval(code_output))) % 1000 | |
| except Exception as e: | |
| print(e,'final_eval') | |
| code_output = -1 | |
| except Exception as e: | |
| print(e,"5") | |
| result_output, code_output = -1, -1 | |
| if code_output!=-1: | |
| outputs.append(code_output) | |
| code_answers+=1 | |
| if result_output!=-1: | |
| outputs.append(result_output) | |
| text_answers+=1 | |
| if len(outputs) > 0: | |
| occurances = Counter(outputs).most_common() | |
| print(occurances) | |
| if occurances[0][1] > best_count: | |
| print("GOOD ANSWER UPDATED!") | |
| best = occurances[0][0] | |
| best_count = occurances[0][1] | |
| if occurances[0][1] > 5: | |
| print("ANSWER FOUND!") | |
| break | |
| results.append(result_output) | |
| answers.append(code_output) | |
| best_stats[i] = (best, best_count) | |
| question_type_counts[i] = (text_answers, code_answers) | |
| total_outputs[i] = outputs | |
| total_results[i] = results | |
| total_answers[i] = answers | |
| print("code_answers",code_answers-starting_counts[1],"text_answers",text_answers-starting_counts[0]) | |
| return best_stats[0][0] | |