Spaces:
Running
Running
| import os | |
| import argparse | |
| from lats import run_lats | |
| def get_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--run_name", type=str, help="The name of the run") | |
| parser.add_argument("--root_dir", type=str, | |
| help="The root logging directory", default="root") | |
| parser.add_argument("--dataset_path", type=str, | |
| help="The path to the benchmark dataset", default="root") | |
| parser.add_argument("--strategy", type=str, | |
| help="Strategy: `simple`, `reflexion`") | |
| parser.add_argument("--language", type=str, help="Strategy: `py` or `rs`") | |
| parser.add_argument( | |
| "--model", type=str, help="OpenAI models only for now. For best results, use GPT-4") | |
| parser.add_argument("--pass_at_k", type=int, | |
| help="Pass@k metric", default=1) | |
| parser.add_argument("--max_iters", type=int, | |
| help="The maximum number of self-improvement iterations", default=10) | |
| parser.add_argument("--expansion_factor", type=int, | |
| help="The expansion factor for the reflexion UCS and A* strategy", default=3) | |
| parser.add_argument("--verbose", action='store_true', | |
| help="To print live logs") | |
| parser.add_argument("--instruction", type=str, | |
| help="text string", default="") | |
| parser.add_argument("--n_samples", type=int, | |
| help="The number of nodes added during expansion", default=3) | |
| parser.add_argument("--depth", type=int, | |
| help="Tree depth", default=5) | |
| # TODO: implement this | |
| # parser.add_argument("--is_resume", action='store_true', help="To resume run") | |
| # parser.add_argument("--resume_dir", type=str, help="If resume, the logging directory", default="") | |
| args = parser.parse_args() | |
| return args | |
| def strategy_factory(strategy: str): | |
| def kwargs_wrapper_gen(func, delete_keys=[]): | |
| def kwargs_wrapper(**kwargs): | |
| for key in delete_keys: | |
| del kwargs[key] | |
| return func(**kwargs) | |
| return kwargs_wrapper | |
| return kwargs_wrapper_gen(run_lats, delete_keys=[]) | |
| def lats_main(args): | |
| # check if the strategy is valid | |
| run_strategy = strategy_factory(args.strategy) | |
| # start the run | |
| # evaluate with pass@k | |
| x = run_strategy( | |
| model_name=args.model, | |
| language=args.language, | |
| max_iters=args.max_iters, | |
| verbose=args.verbose, | |
| instruction=args.instruction, | |
| n_samples=args.n_samples, | |
| depth=args.depth | |
| ) | |
| return x | |
| def main(args): | |
| lats_main(args) | |
| if __name__ == "__main__": | |
| args = get_args() | |
| main(args) | |