inference
Browse files- llm_server.py +96 -0
- normalize_pseudo.py +212 -0
- reverse_sample.json +16 -0
- sk2decompile_inf.py +179 -0
llm_server.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from vllm import LLM, SamplingParams
|
| 2 |
+
from argparse import ArgumentParser
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
from transformers import AutoTokenizer
|
| 6 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
| 7 |
+
|
| 8 |
+
inputs = []
|
| 9 |
+
def parse_args() -> ArgumentParser:
|
| 10 |
+
parser = ArgumentParser()
|
| 11 |
+
parser.add_argument("--model_path", type=str)
|
| 12 |
+
parser.add_argument("--gpus", type=int, default=1)
|
| 13 |
+
parser.add_argument("--max_num_seqs", type=int, default=1)
|
| 14 |
+
parser.add_argument("--gpu_memory_utilization", type=float, default=0.95)
|
| 15 |
+
parser.add_argument("--temperature", type=float, default=0)
|
| 16 |
+
parser.add_argument("--max_total_tokens", type=int, default=8192)
|
| 17 |
+
parser.add_argument("--max_new_tokens", type=int, default=512)
|
| 18 |
+
parser.add_argument("--stop_sequences", type=str, default=None)
|
| 19 |
+
parser.add_argument("--testset_path", type=str)
|
| 20 |
+
parser.add_argument("--output_path", type=str, default=None)
|
| 21 |
+
return parser.parse_args()
|
| 22 |
+
|
| 23 |
+
# def llm_inference(inputs, args):
|
| 24 |
+
# llm = LLM(
|
| 25 |
+
# model=args.model_path,
|
| 26 |
+
# tensor_parallel_size=args.gpus,
|
| 27 |
+
# max_model_len=args.max_total_tokens,
|
| 28 |
+
# gpu_memory_utilization=args.gpu_memory_utilization,
|
| 29 |
+
# )
|
| 30 |
+
|
| 31 |
+
# sampling_params = SamplingParams(
|
| 32 |
+
# temperature=args.temperature,
|
| 33 |
+
# max_tokens=args.max_new_tokens,
|
| 34 |
+
# stop=args.stop_sequences,
|
| 35 |
+
# )
|
| 36 |
+
|
| 37 |
+
# gen_results = llm.generate(inputs, sampling_params)
|
| 38 |
+
# gen_results = [[output.outputs[0].text] for output in gen_results]
|
| 39 |
+
|
| 40 |
+
# return gen_results
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def llm_inference(inputs,
|
| 44 |
+
model_path,
|
| 45 |
+
gpus=1,
|
| 46 |
+
max_total_tokens=8192,
|
| 47 |
+
gpu_memory_utilization=0.95,
|
| 48 |
+
temperature=0,
|
| 49 |
+
max_new_tokens=512,
|
| 50 |
+
stop_sequences=None):
|
| 51 |
+
llm = LLM(
|
| 52 |
+
model=model_path,
|
| 53 |
+
tensor_parallel_size=gpus,
|
| 54 |
+
max_model_len=max_total_tokens,
|
| 55 |
+
gpu_memory_utilization=gpu_memory_utilization,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
sampling_params = SamplingParams(
|
| 59 |
+
temperature=temperature,
|
| 60 |
+
max_tokens=max_new_tokens,
|
| 61 |
+
stop=stop_sequences,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
gen_results = llm.generate(inputs, sampling_params)
|
| 65 |
+
gen_results = [[output.outputs[0].text] for output in gen_results]
|
| 66 |
+
|
| 67 |
+
return gen_results
|
| 68 |
+
|
| 69 |
+
if __name__ == "__main__":
|
| 70 |
+
args = parse_args()
|
| 71 |
+
with open(args.testset_path, "r") as f:
|
| 72 |
+
samples = json.load(f)
|
| 73 |
+
before = "# This is the assembly code:\n"
|
| 74 |
+
after = "\n# What is the source code?\n"
|
| 75 |
+
for sample in samples:
|
| 76 |
+
prompt = before + sample["input_asm_prompt"].strip() + after
|
| 77 |
+
inputs.append(prompt)
|
| 78 |
+
|
| 79 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
|
| 80 |
+
if args.stop_sequences is None:
|
| 81 |
+
args.stop_sequences = [tokenizer.eos_token]
|
| 82 |
+
gen_results = llm_inference(inputs, args.model_path,
|
| 83 |
+
args.gpus,
|
| 84 |
+
args.max_total_tokens,
|
| 85 |
+
args.gpu_memory_utilization,
|
| 86 |
+
args.temperature,
|
| 87 |
+
args.max_new_tokens,
|
| 88 |
+
args.stop_sequences)
|
| 89 |
+
|
| 90 |
+
if not os.path.exists(args.output_path):
|
| 91 |
+
os.mkdir(args.output_path)
|
| 92 |
+
idx = 0
|
| 93 |
+
for gen_result in gen_results:
|
| 94 |
+
with open(args.output_path + '/' + str(idx) + '.c', 'w') as f:
|
| 95 |
+
f.write(gen_result[0])
|
| 96 |
+
idx += 1
|
normalize_pseudo.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import json
|
| 3 |
+
import argparse
|
| 4 |
+
from multiprocessing import Pool, cpu_count
|
| 5 |
+
from tqdm import tqdm # ✅ 添加进度条模块
|
| 6 |
+
import random
|
| 7 |
+
|
| 8 |
+
import subprocess
|
| 9 |
+
|
| 10 |
+
def good_func(func):
|
| 11 |
+
func = '{'.join(func.split('{')[1:])
|
| 12 |
+
func_sp = func.split('\n')
|
| 13 |
+
total = 0
|
| 14 |
+
for line in func_sp:
|
| 15 |
+
if len(line.strip())>=3:
|
| 16 |
+
total+=1
|
| 17 |
+
if total>3 and total<300:
|
| 18 |
+
return True
|
| 19 |
+
return False
|
| 20 |
+
def strip_empty(code):
|
| 21 |
+
return "\n".join(line for line in code.splitlines() if line.strip())
|
| 22 |
+
def format_with_clang(func: str, style: str = "Google") -> str:
|
| 23 |
+
# Build the command
|
| 24 |
+
if not func:
|
| 25 |
+
return None
|
| 26 |
+
cmd = ["clang-format", f"--style={style}"]
|
| 27 |
+
try:
|
| 28 |
+
proc = subprocess.run(
|
| 29 |
+
cmd,
|
| 30 |
+
input=func,
|
| 31 |
+
text=True,
|
| 32 |
+
capture_output=True,
|
| 33 |
+
check=True,
|
| 34 |
+
timeout=0.5
|
| 35 |
+
)
|
| 36 |
+
return proc.stdout
|
| 37 |
+
except Exception as e:
|
| 38 |
+
# print(f"clang-format failed:{e}")
|
| 39 |
+
# print(func)
|
| 40 |
+
# print('-------------------------')
|
| 41 |
+
return None
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ----------------------------
|
| 45 |
+
# 1. 十六进制转十进制
|
| 46 |
+
# ----------------------------
|
| 47 |
+
def hex_to_dec(text):
|
| 48 |
+
pattern = re.compile(r'\b(0x[0-9a-fA-F]+)([uUlL]{1,3})?\b')
|
| 49 |
+
def convert(match):
|
| 50 |
+
hex_part = match.group(1)
|
| 51 |
+
suffix = match.group(2) or ""
|
| 52 |
+
dec_value = str(int(hex_part, 16))
|
| 53 |
+
return dec_value + suffix
|
| 54 |
+
return pattern.sub(convert, text)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# ----------------------------
|
| 58 |
+
# 2. 删除特定关键字
|
| 59 |
+
# ----------------------------
|
| 60 |
+
def remove_keywords(text):
|
| 61 |
+
patterns = [
|
| 62 |
+
r'\b__fastcall\b',
|
| 63 |
+
r'\b__cdecl\b',
|
| 64 |
+
r'\b__ptr32\b',
|
| 65 |
+
r'\b__noreturn\s+noreturn\b'
|
| 66 |
+
]
|
| 67 |
+
combined_pattern = re.compile('|'.join(patterns))
|
| 68 |
+
return combined_pattern.sub('', text)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# ----------------------------
|
| 72 |
+
# 3. 替换 typedef 类型为原始类型
|
| 73 |
+
# ----------------------------
|
| 74 |
+
typedef_map = {
|
| 75 |
+
"cpu_set_t": "int",
|
| 76 |
+
"nl_item": "int",
|
| 77 |
+
"__time_t": "int",
|
| 78 |
+
"__mode_t": "unsigned short",
|
| 79 |
+
"__off64_t": "long long",
|
| 80 |
+
"__blksize_t": "long",
|
| 81 |
+
"__ino_t": "unsigned long",
|
| 82 |
+
"__blkcnt_t": "unsigned long long",
|
| 83 |
+
"__syscall_slong_t": "long",
|
| 84 |
+
"__ssize_t": "long int",
|
| 85 |
+
"wchar_t": "unsigned short int",
|
| 86 |
+
"wctype_t": "unsigned short int",
|
| 87 |
+
"__int64": "long long",
|
| 88 |
+
"__int32": "int",
|
| 89 |
+
"__int16": "short",
|
| 90 |
+
"__int8": "char",
|
| 91 |
+
"_QWORD": "uint64_t",
|
| 92 |
+
"_OWORD": "long double",
|
| 93 |
+
"_DWORD": "uint32_t",
|
| 94 |
+
"size_t": "unsigned int",
|
| 95 |
+
"_BYTE": "uint8_t",
|
| 96 |
+
"_TBYTE": "uint16_t",
|
| 97 |
+
"_BOOL8": "uint8_t",
|
| 98 |
+
"gcc_va_list": "va_list",
|
| 99 |
+
"_WORD": "unsigned short",
|
| 100 |
+
"_BOOL4": "int",
|
| 101 |
+
"__va_list_tag": "va_list",
|
| 102 |
+
"_IO_FILE": "FILE",
|
| 103 |
+
"DIR": "int",
|
| 104 |
+
"__fsword_t": "long",
|
| 105 |
+
"__kernel_ulong_t": "int",
|
| 106 |
+
"cc_t": "int",
|
| 107 |
+
"speed_t": "int",
|
| 108 |
+
"fd_set": "int",
|
| 109 |
+
"__suseconds_t": "int",
|
| 110 |
+
"_UNKNOWN": "void",
|
| 111 |
+
"__sighandler_t": "void (*)(int)",
|
| 112 |
+
"__compar_fn_t": "int (*)(const void *, const void *)",
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
def replace_typedefs(text):
|
| 116 |
+
for alias, original in typedef_map.items():
|
| 117 |
+
pattern = re.compile(rf'\b{re.escape(alias)}\b')
|
| 118 |
+
text = pattern.sub(original, text)
|
| 119 |
+
return text
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
# ----------------------------
|
| 123 |
+
# 4. 删除注释
|
| 124 |
+
# ----------------------------
|
| 125 |
+
def remove_comments(text):
|
| 126 |
+
text = re.sub(r'/\*.*?\*/', '', text, flags=re.DOTALL)
|
| 127 |
+
text = re.sub(r'//.*?$', '', text, flags=re.MULTILINE)
|
| 128 |
+
return text
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# ----------------------------
|
| 132 |
+
# 5. 单条伪代码处理
|
| 133 |
+
# ----------------------------
|
| 134 |
+
def process_code(code_str):
|
| 135 |
+
code_str = remove_comments(code_str)
|
| 136 |
+
code_str = hex_to_dec(code_str)
|
| 137 |
+
code_str = remove_keywords(code_str)
|
| 138 |
+
code_str = replace_typedefs(code_str)
|
| 139 |
+
return code_str
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# 包装 process_code,使其接受一个 dict 并处理字段
|
| 143 |
+
def process_entry(entry, key_name='pseudo'):
|
| 144 |
+
# result = {}
|
| 145 |
+
|
| 146 |
+
# # 原始字段保留
|
| 147 |
+
# result['ida_pseudo'] = entry.get('ida_pseudo', '')
|
| 148 |
+
# result['ida_strip_pseudo'] = entry.get('ida_strip_pseudo', '')
|
| 149 |
+
|
| 150 |
+
# # 分别处理两个字段
|
| 151 |
+
# result['ida_pseudo_result'] = process_code(result['ida_pseudo'])
|
| 152 |
+
# result['ida_strip_pseudo_result'] = process_code(result['ida_strip_pseudo'])
|
| 153 |
+
|
| 154 |
+
result = process_code(entry.get(key_name, ''))
|
| 155 |
+
if not result.strip():
|
| 156 |
+
return ''
|
| 157 |
+
formatted = format_with_clang(result)
|
| 158 |
+
if formatted is None:
|
| 159 |
+
return None
|
| 160 |
+
cleaned = strip_empty(formatted)
|
| 161 |
+
|
| 162 |
+
return cleaned
|
| 163 |
+
|
| 164 |
+
# 主函数
|
| 165 |
+
def normalize_code_list_parallel(input_json, output_json, key_name='pseudo', num_workers=None, remove=1):
|
| 166 |
+
with open(input_json, 'r', encoding='utf-8') as f:
|
| 167 |
+
data = json.load(f)
|
| 168 |
+
|
| 169 |
+
if not isinstance(data, list):
|
| 170 |
+
raise ValueError("输入 JSON 应为对象数组")
|
| 171 |
+
|
| 172 |
+
num_workers = num_workers or cpu_count()
|
| 173 |
+
print(f"[+] 开始处理 {len(data)} 条记录,使用 {num_workers} 个进程")
|
| 174 |
+
|
| 175 |
+
from functools import partial
|
| 176 |
+
process_entry_key = partial(process_entry, key_name=key_name)
|
| 177 |
+
|
| 178 |
+
with Pool(processes=num_workers) as pool:
|
| 179 |
+
result = list(tqdm(pool.imap(process_entry_key, data), total=len(data), desc="Processing"))
|
| 180 |
+
|
| 181 |
+
data_good = []
|
| 182 |
+
for record, norm in zip(data, result):
|
| 183 |
+
if norm:
|
| 184 |
+
if not good_func(norm):
|
| 185 |
+
continue
|
| 186 |
+
record[f"{key_name}_norm"] = norm
|
| 187 |
+
data_good.append(record)
|
| 188 |
+
elif norm is None:
|
| 189 |
+
if not remove:
|
| 190 |
+
record[f"{key_name}_norm"] = record[f"{key_name}"]
|
| 191 |
+
data_good.append(record)
|
| 192 |
+
|
| 193 |
+
with open(output_json, 'w', encoding='utf-8') as f:
|
| 194 |
+
json.dump(data_good, f, indent=2, ensure_ascii=False)
|
| 195 |
+
|
| 196 |
+
print(f"[✓] 完成处理:{input_json}:{len(data)} → {output_json}:{len(data_good)}")
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
# ----------------------------
|
| 201 |
+
# 7. 命令行入口
|
| 202 |
+
# ----------------------------
|
| 203 |
+
if __name__ == '__main__':
|
| 204 |
+
parser = argparse.ArgumentParser(description="并行处理 IDA 伪代码字符串列表")
|
| 205 |
+
parser.add_argument('--input_json', default="exebench_format_top1p.json", help='输入 JSON 文件路径(每项为字符串)')
|
| 206 |
+
parser.add_argument('--output_json', default="exebench_format_pseudo_top1p.json", help='输出 JSON 文件路径')
|
| 207 |
+
parser.add_argument('--key_name', default="pseudo", help='输出 JSON 文件路径')
|
| 208 |
+
parser.add_argument('--workers', type=int, default=32, help='进程数默认使用8核心')
|
| 209 |
+
parser.add_argument('--remove', type=int, default=1, help='remove fail cases')
|
| 210 |
+
args = parser.parse_args()
|
| 211 |
+
|
| 212 |
+
normalize_code_list_parallel(args.input_json, args.output_json, args.key_name, args.workers, args.remove)
|
reverse_sample.json
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"index": "0",
|
| 4 |
+
"func_name": "func0",
|
| 5 |
+
"opt": "O3",
|
| 6 |
+
"language": "c",
|
| 7 |
+
"ida_pseudo": "_BYTE *__fastcall sub_1370(const char *a1)\n{\n int v2; // eax\n int v3; // ebx\n int v4; // r13d\n int v5; // r15d\n _BYTE *v6; // rdi\n const char *v7; // rsi\n int v8; // ebp\n const char *v9; // rdx\n const char *v10; // rax\n const char *v11; // rax\n _BYTE *v12; // rdx\n char v13; // cl\n __int64 v15; // rax\n char *v16; // r14\n const char *v17; // rax\n const char *v18; // rdx\n size_t v19; // [rsp+8h] [rbp-40h]\n\n v2 = strlen(a1);\n v3 = 2 * v2;\n v4 = v2;\n v5 = v2;\n v19 = 2 * v2 + 1;\n v6 = malloc(v19);\n if ( v6 )\n {\n if ( v4 <= 0 )\n {\n v6 = (_BYTE *)__strncpy_chk(v6, a1, v4, v19);\nLABEL_11:\n v6[v3] = 0;\n }\n else\n {\n v7 = a1;\n v8 = 0;\n while ( (v4 - v8) >> 1 )\n {\n v9 = &a1[v4 - 1];\n v10 = v7;\n while ( *v10 == *v9 )\n {\n ++v10;\n --v9;\n if ( v10 == &v7[(v4 - v8) >> 1] )\n goto LABEL_13;\n }\n ++v8;\n ++v7;\n if ( v5 == v8 )\n {\n v6 = (_BYTE *)__strncpy_chk(v6, a1, v4, v19);\n v11 = &a1[v4 - 1];\n v12 = &v6[v4];\n do\n {\n v13 = *v11--;\n *v12++ = v13;\n }\n while ( &a1[v4 - 2 - (v4 - 1)] != v11 );\n goto LABEL_11;\n }\n }\nLABEL_13:\n v15 = __strncpy_chk(v6, a1, v4, v19);\n v6 = (_BYTE *)v15;\n if ( v8 )\n {\n v16 = (char *)(v15 + v4);\n v17 = &a1[v8 - 1];\n do\n {\n *v16++ = *v17;\n v18 = v17--;\n }\n while ( v18 != a1 );\n }\n v6[v8 + v4] = 0;\n }\n }\n return v6;\n}"
|
| 8 |
+
},
|
| 9 |
+
{
|
| 10 |
+
"index": "1",
|
| 11 |
+
"func_name": "func0",
|
| 12 |
+
"opt": "O2",
|
| 13 |
+
"language": "c",
|
| 14 |
+
"ida_pseudo": "__int64 __fastcall sub_1169(float *a1, int a2)\n{\n __int64 result; // rax\n float v3; // [rsp+Ch] [rbp-10h]\n float v4; // [rsp+10h] [rbp-Ch]\n int i; // [rsp+14h] [rbp-8h]\n int j; // [rsp+18h] [rbp-4h]\n\n v3 = *a1;\n v4 = *a1;\n for ( i = 1; i < a2; ++i )\n {\n if ( v3 > a1[i] )\n v3 = a1[i];\n if ( a1[i] > v4 )\n v4 = a1[i];\n }\n for ( j = 0; ; ++j )\n {\n result = (unsigned int)j;\n if ( j >= a2 )\n break;\n a1[j] = (float)(a1[j] - v3) / (float)(v4 - v3);\n }\n return result;\n}"
|
| 15 |
+
}
|
| 16 |
+
]
|
sk2decompile_inf.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from llm_server import llm_inference
|
| 2 |
+
from transformers import AutoTokenizer
|
| 3 |
+
import json
|
| 4 |
+
import argparse
|
| 5 |
+
import shutil
|
| 6 |
+
import os
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
opts = ["O0", "O1", "O2", "O3"]
|
| 10 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 11 |
+
|
| 12 |
+
if __name__ == "__main__":
|
| 13 |
+
arg_parser = argparse.ArgumentParser()
|
| 14 |
+
arg_parser.add_argument("--model_path",type=str,default="LLM4Binary/sk2decompile-struct-6.7b")
|
| 15 |
+
arg_parser.add_argument("--dataset_path",type=str,default='reverse_sample.json')
|
| 16 |
+
arg_parser.add_argument("--decompiler",type=str,default='ida_pseudo_norm')
|
| 17 |
+
arg_parser.add_argument("--gpus", type=int, default=1)
|
| 18 |
+
arg_parser.add_argument("--max_num_seqs", type=int, default=1)
|
| 19 |
+
arg_parser.add_argument("--gpu_memory_utilization", type=float, default=0.8)
|
| 20 |
+
arg_parser.add_argument("--temperature", type=float, default=0)
|
| 21 |
+
arg_parser.add_argument("--max_total_tokens", type=int, default=32768)
|
| 22 |
+
arg_parser.add_argument("--max_new_tokens", type=int, default=4096)
|
| 23 |
+
arg_parser.add_argument("--stop_sequences", type=str, default=None)
|
| 24 |
+
arg_parser.add_argument("--recover_model_path", type=str, default='LLM4Binary/sk2decompile-ident-6.7', help="Path to the model to recover from, if any.")
|
| 25 |
+
arg_parser.add_argument("--output_path", type=str, default='./result/sk2decompile')
|
| 26 |
+
arg_parser.add_argument("--only_save", type=int, default=0)
|
| 27 |
+
arg_parser.add_argument("--strip", type=int, default=1)
|
| 28 |
+
arg_parser.add_argument("--language", type=str, default='c')
|
| 29 |
+
args = arg_parser.parse_args()
|
| 30 |
+
|
| 31 |
+
before = "# This is the assembly code:\n"
|
| 32 |
+
after = "\n# What is the source code?\n"
|
| 33 |
+
|
| 34 |
+
if args.dataset_path.endswith('.json'):
|
| 35 |
+
with open(args.dataset_path, "r") as f:
|
| 36 |
+
print("===========")
|
| 37 |
+
print(f"Loading dataset from {args.dataset_path}")
|
| 38 |
+
print("===========")
|
| 39 |
+
samples = json.load(f)
|
| 40 |
+
elif args.dataset_path.endswith('.jsonl'):
|
| 41 |
+
samples = []
|
| 42 |
+
with open(args.dataset_path, "r") as f:
|
| 43 |
+
for line in f:
|
| 44 |
+
line = line.strip()
|
| 45 |
+
if line:
|
| 46 |
+
samples.append(json.loads(line))
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
|
| 50 |
+
if args.stop_sequences is None:
|
| 51 |
+
args.stop_sequences = [tokenizer.eos_token]
|
| 52 |
+
|
| 53 |
+
inputs = []
|
| 54 |
+
infos = []
|
| 55 |
+
for sample in samples:
|
| 56 |
+
prompt = before + sample[args.decompiler].strip() + after
|
| 57 |
+
sample['prompt_model1'] = prompt
|
| 58 |
+
inputs.append(prompt)
|
| 59 |
+
infos.append({
|
| 60 |
+
"opt": sample["opt"],
|
| 61 |
+
"language": sample["language"],
|
| 62 |
+
"index": sample["index"],
|
| 63 |
+
"func_name": sample["func_name"]
|
| 64 |
+
})
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
print("Starting first model inference...")
|
| 68 |
+
gen_results = llm_inference(inputs, args.model_path,
|
| 69 |
+
args.gpus,
|
| 70 |
+
args.max_total_tokens,
|
| 71 |
+
args.gpu_memory_utilization,
|
| 72 |
+
args.temperature,
|
| 73 |
+
args.max_new_tokens,
|
| 74 |
+
args.stop_sequences)
|
| 75 |
+
gen_results = [gen_result[0] for gen_result in gen_results]
|
| 76 |
+
|
| 77 |
+
for idx in range(len(gen_results)):
|
| 78 |
+
samples[idx]['gen_result_model1'] = gen_results[idx]
|
| 79 |
+
|
| 80 |
+
inputs_recovery = []
|
| 81 |
+
before_recovery = "# This is the normalized code:\n"
|
| 82 |
+
after_recovery = "\n# What is the source code?\n"
|
| 83 |
+
|
| 84 |
+
for idx, sample in enumerate(gen_results):
|
| 85 |
+
prompt_recovery = before_recovery + sample.strip() + after_recovery
|
| 86 |
+
samples[idx]['prompt_model2'] = prompt_recovery
|
| 87 |
+
inputs_recovery.append(prompt_recovery)
|
| 88 |
+
|
| 89 |
+
print("Starting recovery model inference...")
|
| 90 |
+
gen_results_recovery = llm_inference(inputs_recovery, args.recover_model_path,
|
| 91 |
+
args.gpus,
|
| 92 |
+
args.max_total_tokens,
|
| 93 |
+
args.gpu_memory_utilization,
|
| 94 |
+
args.temperature,
|
| 95 |
+
args.max_new_tokens,
|
| 96 |
+
args.stop_sequences)
|
| 97 |
+
gen_results_recovery = [gen_result[0] for gen_result in gen_results_recovery]
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
for idx in range(len(gen_results_recovery)):
|
| 101 |
+
samples[idx]['gen_result_model2'] = gen_results_recovery[idx]
|
| 102 |
+
|
| 103 |
+
if args.output_path:
|
| 104 |
+
if os.path.exists(args.output_path):
|
| 105 |
+
shutil.rmtree(args.output_path)
|
| 106 |
+
for opt in opts:
|
| 107 |
+
os.makedirs(os.path.join(args.output_path, opt))
|
| 108 |
+
|
| 109 |
+
if args.strip:
|
| 110 |
+
print("Processing function name stripping...")
|
| 111 |
+
for idx in range(len(gen_results_recovery)):
|
| 112 |
+
one = gen_results_recovery[idx]
|
| 113 |
+
func_name_in_gen = one.split('(')[0].split(' ')[-1].strip()
|
| 114 |
+
if func_name_in_gen.strip() and func_name_in_gen[0:2] == '**':
|
| 115 |
+
func_name_in_gen = func_name_in_gen[2:]
|
| 116 |
+
elif func_name_in_gen.strip() and func_name_in_gen[0] == '*':
|
| 117 |
+
func_name_in_gen = func_name_in_gen[1:]
|
| 118 |
+
|
| 119 |
+
original_func_name = samples[idx]["func_name"]
|
| 120 |
+
gen_results_recovery[idx] = one.replace(func_name_in_gen, original_func_name)
|
| 121 |
+
samples[idx]["gen_result_model2_stripped"] = gen_results_recovery[idx]
|
| 122 |
+
|
| 123 |
+
print("Saving inference results and logs...")
|
| 124 |
+
for idx_sample, final_result in enumerate(gen_results_recovery):
|
| 125 |
+
opt = infos[idx_sample]['opt']
|
| 126 |
+
language = infos[idx_sample]['language']
|
| 127 |
+
original_index = samples[idx_sample]['index']
|
| 128 |
+
|
| 129 |
+
save_path = os.path.join(args.output_path, opt, f"{original_index}_{opt}.{language}")
|
| 130 |
+
with open(save_path, "w") as f:
|
| 131 |
+
f.write(final_result)
|
| 132 |
+
|
| 133 |
+
log_path = save_path + ".log"
|
| 134 |
+
log_data = {
|
| 135 |
+
"index": original_index,
|
| 136 |
+
"opt": opt,
|
| 137 |
+
"language": language,
|
| 138 |
+
"func_name": samples[idx_sample]["func_name"],
|
| 139 |
+
"decompiler": args.decompiler,
|
| 140 |
+
"input_asm": samples[idx_sample][args.decompiler].strip(),
|
| 141 |
+
"prompt_model1": samples[idx_sample]['prompt_model1'],
|
| 142 |
+
"gen_result_model1": samples[idx_sample]['gen_result_model1'],
|
| 143 |
+
"prompt_model2": samples[idx_sample]['prompt_model2'],
|
| 144 |
+
"gen_result_model2": samples[idx_sample]['gen_result_model2'],
|
| 145 |
+
"final_result": final_result,
|
| 146 |
+
"stripped": args.strip
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
if args.strip and "gen_result_model2_stripped" in samples[idx_sample]:
|
| 150 |
+
log_data["gen_result_model2_stripped"] = samples[idx_sample]["gen_result_model2_stripped"]
|
| 151 |
+
|
| 152 |
+
with open(log_path, "w") as f:
|
| 153 |
+
json.dump(log_data, f, indent=2, ensure_ascii=False)
|
| 154 |
+
|
| 155 |
+
json_path = os.path.join(args.output_path, 'inference_results.jsonl')
|
| 156 |
+
with open(json_path, 'w') as f:
|
| 157 |
+
for sample in samples:
|
| 158 |
+
f.write(json.dumps(sample) + '\n')
|
| 159 |
+
|
| 160 |
+
stats_path = os.path.join(args.output_path, 'inference_stats.txt')
|
| 161 |
+
with open(stats_path, 'w') as f:
|
| 162 |
+
f.write(f"Total samples processed: {len(samples)}\n")
|
| 163 |
+
f.write(f"Model path: {args.model_path}\n")
|
| 164 |
+
f.write(f"Recovery model path: {args.recover_model_path}\n")
|
| 165 |
+
f.write(f"Dataset path: {args.dataset_path}\n")
|
| 166 |
+
f.write(f"Language: {args.language}\n")
|
| 167 |
+
f.write(f"Decompiler: {args.decompiler}\n")
|
| 168 |
+
f.write(f"Strip function names: {bool(args.strip)}\n")
|
| 169 |
+
|
| 170 |
+
opt_counts = {"O0": 0, "O1": 0, "O2": 0, "O3": 0}
|
| 171 |
+
for sample in samples:
|
| 172 |
+
opt_counts[sample['opt']] += 1
|
| 173 |
+
|
| 174 |
+
f.write("\nSamples per optimization level:\n")
|
| 175 |
+
for opt, count in opt_counts.items():
|
| 176 |
+
f.write(f" {opt}: {count}\n")
|
| 177 |
+
|
| 178 |
+
print(f"Inference completed! Results saved to {args.output_path}")
|
| 179 |
+
print(f"Total {len(samples)} samples processed.")
|