LLM4Binary commited on
Commit
4c55900
·
verified ·
1 Parent(s): eee0ccf
Files changed (4) hide show
  1. llm_server.py +96 -0
  2. normalize_pseudo.py +212 -0
  3. reverse_sample.json +16 -0
  4. sk2decompile_inf.py +179 -0
llm_server.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vllm import LLM, SamplingParams
2
+ from argparse import ArgumentParser
3
+ import os
4
+ import json
5
+ from transformers import AutoTokenizer
6
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
7
+
8
+ inputs = []
9
+ def parse_args() -> ArgumentParser:
10
+ parser = ArgumentParser()
11
+ parser.add_argument("--model_path", type=str)
12
+ parser.add_argument("--gpus", type=int, default=1)
13
+ parser.add_argument("--max_num_seqs", type=int, default=1)
14
+ parser.add_argument("--gpu_memory_utilization", type=float, default=0.95)
15
+ parser.add_argument("--temperature", type=float, default=0)
16
+ parser.add_argument("--max_total_tokens", type=int, default=8192)
17
+ parser.add_argument("--max_new_tokens", type=int, default=512)
18
+ parser.add_argument("--stop_sequences", type=str, default=None)
19
+ parser.add_argument("--testset_path", type=str)
20
+ parser.add_argument("--output_path", type=str, default=None)
21
+ return parser.parse_args()
22
+
23
+ # def llm_inference(inputs, args):
24
+ # llm = LLM(
25
+ # model=args.model_path,
26
+ # tensor_parallel_size=args.gpus,
27
+ # max_model_len=args.max_total_tokens,
28
+ # gpu_memory_utilization=args.gpu_memory_utilization,
29
+ # )
30
+
31
+ # sampling_params = SamplingParams(
32
+ # temperature=args.temperature,
33
+ # max_tokens=args.max_new_tokens,
34
+ # stop=args.stop_sequences,
35
+ # )
36
+
37
+ # gen_results = llm.generate(inputs, sampling_params)
38
+ # gen_results = [[output.outputs[0].text] for output in gen_results]
39
+
40
+ # return gen_results
41
+
42
+
43
+ def llm_inference(inputs,
44
+ model_path,
45
+ gpus=1,
46
+ max_total_tokens=8192,
47
+ gpu_memory_utilization=0.95,
48
+ temperature=0,
49
+ max_new_tokens=512,
50
+ stop_sequences=None):
51
+ llm = LLM(
52
+ model=model_path,
53
+ tensor_parallel_size=gpus,
54
+ max_model_len=max_total_tokens,
55
+ gpu_memory_utilization=gpu_memory_utilization,
56
+ )
57
+
58
+ sampling_params = SamplingParams(
59
+ temperature=temperature,
60
+ max_tokens=max_new_tokens,
61
+ stop=stop_sequences,
62
+ )
63
+
64
+ gen_results = llm.generate(inputs, sampling_params)
65
+ gen_results = [[output.outputs[0].text] for output in gen_results]
66
+
67
+ return gen_results
68
+
69
+ if __name__ == "__main__":
70
+ args = parse_args()
71
+ with open(args.testset_path, "r") as f:
72
+ samples = json.load(f)
73
+ before = "# This is the assembly code:\n"
74
+ after = "\n# What is the source code?\n"
75
+ for sample in samples:
76
+ prompt = before + sample["input_asm_prompt"].strip() + after
77
+ inputs.append(prompt)
78
+
79
+ tokenizer = AutoTokenizer.from_pretrained(args.model_path)
80
+ if args.stop_sequences is None:
81
+ args.stop_sequences = [tokenizer.eos_token]
82
+ gen_results = llm_inference(inputs, args.model_path,
83
+ args.gpus,
84
+ args.max_total_tokens,
85
+ args.gpu_memory_utilization,
86
+ args.temperature,
87
+ args.max_new_tokens,
88
+ args.stop_sequences)
89
+
90
+ if not os.path.exists(args.output_path):
91
+ os.mkdir(args.output_path)
92
+ idx = 0
93
+ for gen_result in gen_results:
94
+ with open(args.output_path + '/' + str(idx) + '.c', 'w') as f:
95
+ f.write(gen_result[0])
96
+ idx += 1
normalize_pseudo.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import argparse
4
+ from multiprocessing import Pool, cpu_count
5
+ from tqdm import tqdm # ✅ 添加进度条模块
6
+ import random
7
+
8
+ import subprocess
9
+
10
+ def good_func(func):
11
+ func = '{'.join(func.split('{')[1:])
12
+ func_sp = func.split('\n')
13
+ total = 0
14
+ for line in func_sp:
15
+ if len(line.strip())>=3:
16
+ total+=1
17
+ if total>3 and total<300:
18
+ return True
19
+ return False
20
+ def strip_empty(code):
21
+ return "\n".join(line for line in code.splitlines() if line.strip())
22
+ def format_with_clang(func: str, style: str = "Google") -> str:
23
+ # Build the command
24
+ if not func:
25
+ return None
26
+ cmd = ["clang-format", f"--style={style}"]
27
+ try:
28
+ proc = subprocess.run(
29
+ cmd,
30
+ input=func,
31
+ text=True,
32
+ capture_output=True,
33
+ check=True,
34
+ timeout=0.5
35
+ )
36
+ return proc.stdout
37
+ except Exception as e:
38
+ # print(f"clang-format failed:{e}")
39
+ # print(func)
40
+ # print('-------------------------')
41
+ return None
42
+
43
+
44
+ # ----------------------------
45
+ # 1. 十六进制转十进制
46
+ # ----------------------------
47
+ def hex_to_dec(text):
48
+ pattern = re.compile(r'\b(0x[0-9a-fA-F]+)([uUlL]{1,3})?\b')
49
+ def convert(match):
50
+ hex_part = match.group(1)
51
+ suffix = match.group(2) or ""
52
+ dec_value = str(int(hex_part, 16))
53
+ return dec_value + suffix
54
+ return pattern.sub(convert, text)
55
+
56
+
57
+ # ----------------------------
58
+ # 2. 删除特定关键字
59
+ # ----------------------------
60
+ def remove_keywords(text):
61
+ patterns = [
62
+ r'\b__fastcall\b',
63
+ r'\b__cdecl\b',
64
+ r'\b__ptr32\b',
65
+ r'\b__noreturn\s+noreturn\b'
66
+ ]
67
+ combined_pattern = re.compile('|'.join(patterns))
68
+ return combined_pattern.sub('', text)
69
+
70
+
71
+ # ----------------------------
72
+ # 3. 替换 typedef 类型为原始类型
73
+ # ----------------------------
74
+ typedef_map = {
75
+ "cpu_set_t": "int",
76
+ "nl_item": "int",
77
+ "__time_t": "int",
78
+ "__mode_t": "unsigned short",
79
+ "__off64_t": "long long",
80
+ "__blksize_t": "long",
81
+ "__ino_t": "unsigned long",
82
+ "__blkcnt_t": "unsigned long long",
83
+ "__syscall_slong_t": "long",
84
+ "__ssize_t": "long int",
85
+ "wchar_t": "unsigned short int",
86
+ "wctype_t": "unsigned short int",
87
+ "__int64": "long long",
88
+ "__int32": "int",
89
+ "__int16": "short",
90
+ "__int8": "char",
91
+ "_QWORD": "uint64_t",
92
+ "_OWORD": "long double",
93
+ "_DWORD": "uint32_t",
94
+ "size_t": "unsigned int",
95
+ "_BYTE": "uint8_t",
96
+ "_TBYTE": "uint16_t",
97
+ "_BOOL8": "uint8_t",
98
+ "gcc_va_list": "va_list",
99
+ "_WORD": "unsigned short",
100
+ "_BOOL4": "int",
101
+ "__va_list_tag": "va_list",
102
+ "_IO_FILE": "FILE",
103
+ "DIR": "int",
104
+ "__fsword_t": "long",
105
+ "__kernel_ulong_t": "int",
106
+ "cc_t": "int",
107
+ "speed_t": "int",
108
+ "fd_set": "int",
109
+ "__suseconds_t": "int",
110
+ "_UNKNOWN": "void",
111
+ "__sighandler_t": "void (*)(int)",
112
+ "__compar_fn_t": "int (*)(const void *, const void *)",
113
+ }
114
+
115
+ def replace_typedefs(text):
116
+ for alias, original in typedef_map.items():
117
+ pattern = re.compile(rf'\b{re.escape(alias)}\b')
118
+ text = pattern.sub(original, text)
119
+ return text
120
+
121
+
122
+ # ----------------------------
123
+ # 4. 删除注释
124
+ # ----------------------------
125
+ def remove_comments(text):
126
+ text = re.sub(r'/\*.*?\*/', '', text, flags=re.DOTALL)
127
+ text = re.sub(r'//.*?$', '', text, flags=re.MULTILINE)
128
+ return text
129
+
130
+
131
+ # ----------------------------
132
+ # 5. 单条伪代码处理
133
+ # ----------------------------
134
+ def process_code(code_str):
135
+ code_str = remove_comments(code_str)
136
+ code_str = hex_to_dec(code_str)
137
+ code_str = remove_keywords(code_str)
138
+ code_str = replace_typedefs(code_str)
139
+ return code_str
140
+
141
+
142
+ # 包装 process_code,使其接受一个 dict 并处理字段
143
+ def process_entry(entry, key_name='pseudo'):
144
+ # result = {}
145
+
146
+ # # 原始字段保留
147
+ # result['ida_pseudo'] = entry.get('ida_pseudo', '')
148
+ # result['ida_strip_pseudo'] = entry.get('ida_strip_pseudo', '')
149
+
150
+ # # 分别处理两个字段
151
+ # result['ida_pseudo_result'] = process_code(result['ida_pseudo'])
152
+ # result['ida_strip_pseudo_result'] = process_code(result['ida_strip_pseudo'])
153
+
154
+ result = process_code(entry.get(key_name, ''))
155
+ if not result.strip():
156
+ return ''
157
+ formatted = format_with_clang(result)
158
+ if formatted is None:
159
+ return None
160
+ cleaned = strip_empty(formatted)
161
+
162
+ return cleaned
163
+
164
+ # 主函数
165
+ def normalize_code_list_parallel(input_json, output_json, key_name='pseudo', num_workers=None, remove=1):
166
+ with open(input_json, 'r', encoding='utf-8') as f:
167
+ data = json.load(f)
168
+
169
+ if not isinstance(data, list):
170
+ raise ValueError("输入 JSON 应为对象数组")
171
+
172
+ num_workers = num_workers or cpu_count()
173
+ print(f"[+] 开始处理 {len(data)} 条记录,使用 {num_workers} 个进程")
174
+
175
+ from functools import partial
176
+ process_entry_key = partial(process_entry, key_name=key_name)
177
+
178
+ with Pool(processes=num_workers) as pool:
179
+ result = list(tqdm(pool.imap(process_entry_key, data), total=len(data), desc="Processing"))
180
+
181
+ data_good = []
182
+ for record, norm in zip(data, result):
183
+ if norm:
184
+ if not good_func(norm):
185
+ continue
186
+ record[f"{key_name}_norm"] = norm
187
+ data_good.append(record)
188
+ elif norm is None:
189
+ if not remove:
190
+ record[f"{key_name}_norm"] = record[f"{key_name}"]
191
+ data_good.append(record)
192
+
193
+ with open(output_json, 'w', encoding='utf-8') as f:
194
+ json.dump(data_good, f, indent=2, ensure_ascii=False)
195
+
196
+ print(f"[✓] 完成处理:{input_json}:{len(data)} → {output_json}:{len(data_good)}")
197
+
198
+
199
+
200
+ # ----------------------------
201
+ # 7. 命令行入口
202
+ # ----------------------------
203
+ if __name__ == '__main__':
204
+ parser = argparse.ArgumentParser(description="并行处理 IDA 伪代码字符串列表")
205
+ parser.add_argument('--input_json', default="exebench_format_top1p.json", help='输入 JSON 文件路径(每项为字符串)')
206
+ parser.add_argument('--output_json', default="exebench_format_pseudo_top1p.json", help='输出 JSON 文件路径')
207
+ parser.add_argument('--key_name', default="pseudo", help='输出 JSON 文件路径')
208
+ parser.add_argument('--workers', type=int, default=32, help='进程数默认使用8核心')
209
+ parser.add_argument('--remove', type=int, default=1, help='remove fail cases')
210
+ args = parser.parse_args()
211
+
212
+ normalize_code_list_parallel(args.input_json, args.output_json, args.key_name, args.workers, args.remove)
reverse_sample.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "index": "0",
4
+ "func_name": "func0",
5
+ "opt": "O3",
6
+ "language": "c",
7
+ "ida_pseudo": "_BYTE *__fastcall sub_1370(const char *a1)\n{\n int v2; // eax\n int v3; // ebx\n int v4; // r13d\n int v5; // r15d\n _BYTE *v6; // rdi\n const char *v7; // rsi\n int v8; // ebp\n const char *v9; // rdx\n const char *v10; // rax\n const char *v11; // rax\n _BYTE *v12; // rdx\n char v13; // cl\n __int64 v15; // rax\n char *v16; // r14\n const char *v17; // rax\n const char *v18; // rdx\n size_t v19; // [rsp+8h] [rbp-40h]\n\n v2 = strlen(a1);\n v3 = 2 * v2;\n v4 = v2;\n v5 = v2;\n v19 = 2 * v2 + 1;\n v6 = malloc(v19);\n if ( v6 )\n {\n if ( v4 <= 0 )\n {\n v6 = (_BYTE *)__strncpy_chk(v6, a1, v4, v19);\nLABEL_11:\n v6[v3] = 0;\n }\n else\n {\n v7 = a1;\n v8 = 0;\n while ( (v4 - v8) >> 1 )\n {\n v9 = &a1[v4 - 1];\n v10 = v7;\n while ( *v10 == *v9 )\n {\n ++v10;\n --v9;\n if ( v10 == &v7[(v4 - v8) >> 1] )\n goto LABEL_13;\n }\n ++v8;\n ++v7;\n if ( v5 == v8 )\n {\n v6 = (_BYTE *)__strncpy_chk(v6, a1, v4, v19);\n v11 = &a1[v4 - 1];\n v12 = &v6[v4];\n do\n {\n v13 = *v11--;\n *v12++ = v13;\n }\n while ( &a1[v4 - 2 - (v4 - 1)] != v11 );\n goto LABEL_11;\n }\n }\nLABEL_13:\n v15 = __strncpy_chk(v6, a1, v4, v19);\n v6 = (_BYTE *)v15;\n if ( v8 )\n {\n v16 = (char *)(v15 + v4);\n v17 = &a1[v8 - 1];\n do\n {\n *v16++ = *v17;\n v18 = v17--;\n }\n while ( v18 != a1 );\n }\n v6[v8 + v4] = 0;\n }\n }\n return v6;\n}"
8
+ },
9
+ {
10
+ "index": "1",
11
+ "func_name": "func0",
12
+ "opt": "O2",
13
+ "language": "c",
14
+ "ida_pseudo": "__int64 __fastcall sub_1169(float *a1, int a2)\n{\n __int64 result; // rax\n float v3; // [rsp+Ch] [rbp-10h]\n float v4; // [rsp+10h] [rbp-Ch]\n int i; // [rsp+14h] [rbp-8h]\n int j; // [rsp+18h] [rbp-4h]\n\n v3 = *a1;\n v4 = *a1;\n for ( i = 1; i < a2; ++i )\n {\n if ( v3 > a1[i] )\n v3 = a1[i];\n if ( a1[i] > v4 )\n v4 = a1[i];\n }\n for ( j = 0; ; ++j )\n {\n result = (unsigned int)j;\n if ( j >= a2 )\n break;\n a1[j] = (float)(a1[j] - v3) / (float)(v4 - v3);\n }\n return result;\n}"
15
+ }
16
+ ]
sk2decompile_inf.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llm_server import llm_inference
2
+ from transformers import AutoTokenizer
3
+ import json
4
+ import argparse
5
+ import shutil
6
+ import os
7
+ from tqdm import tqdm
8
+
9
+ opts = ["O0", "O1", "O2", "O3"]
10
+ current_dir = os.path.dirname(os.path.abspath(__file__))
11
+
12
+ if __name__ == "__main__":
13
+ arg_parser = argparse.ArgumentParser()
14
+ arg_parser.add_argument("--model_path",type=str,default="LLM4Binary/sk2decompile-struct-6.7b")
15
+ arg_parser.add_argument("--dataset_path",type=str,default='reverse_sample.json')
16
+ arg_parser.add_argument("--decompiler",type=str,default='ida_pseudo_norm')
17
+ arg_parser.add_argument("--gpus", type=int, default=1)
18
+ arg_parser.add_argument("--max_num_seqs", type=int, default=1)
19
+ arg_parser.add_argument("--gpu_memory_utilization", type=float, default=0.8)
20
+ arg_parser.add_argument("--temperature", type=float, default=0)
21
+ arg_parser.add_argument("--max_total_tokens", type=int, default=32768)
22
+ arg_parser.add_argument("--max_new_tokens", type=int, default=4096)
23
+ arg_parser.add_argument("--stop_sequences", type=str, default=None)
24
+ arg_parser.add_argument("--recover_model_path", type=str, default='LLM4Binary/sk2decompile-ident-6.7', help="Path to the model to recover from, if any.")
25
+ arg_parser.add_argument("--output_path", type=str, default='./result/sk2decompile')
26
+ arg_parser.add_argument("--only_save", type=int, default=0)
27
+ arg_parser.add_argument("--strip", type=int, default=1)
28
+ arg_parser.add_argument("--language", type=str, default='c')
29
+ args = arg_parser.parse_args()
30
+
31
+ before = "# This is the assembly code:\n"
32
+ after = "\n# What is the source code?\n"
33
+
34
+ if args.dataset_path.endswith('.json'):
35
+ with open(args.dataset_path, "r") as f:
36
+ print("===========")
37
+ print(f"Loading dataset from {args.dataset_path}")
38
+ print("===========")
39
+ samples = json.load(f)
40
+ elif args.dataset_path.endswith('.jsonl'):
41
+ samples = []
42
+ with open(args.dataset_path, "r") as f:
43
+ for line in f:
44
+ line = line.strip()
45
+ if line:
46
+ samples.append(json.loads(line))
47
+
48
+
49
+ tokenizer = AutoTokenizer.from_pretrained(args.model_path)
50
+ if args.stop_sequences is None:
51
+ args.stop_sequences = [tokenizer.eos_token]
52
+
53
+ inputs = []
54
+ infos = []
55
+ for sample in samples:
56
+ prompt = before + sample[args.decompiler].strip() + after
57
+ sample['prompt_model1'] = prompt
58
+ inputs.append(prompt)
59
+ infos.append({
60
+ "opt": sample["opt"],
61
+ "language": sample["language"],
62
+ "index": sample["index"],
63
+ "func_name": sample["func_name"]
64
+ })
65
+
66
+
67
+ print("Starting first model inference...")
68
+ gen_results = llm_inference(inputs, args.model_path,
69
+ args.gpus,
70
+ args.max_total_tokens,
71
+ args.gpu_memory_utilization,
72
+ args.temperature,
73
+ args.max_new_tokens,
74
+ args.stop_sequences)
75
+ gen_results = [gen_result[0] for gen_result in gen_results]
76
+
77
+ for idx in range(len(gen_results)):
78
+ samples[idx]['gen_result_model1'] = gen_results[idx]
79
+
80
+ inputs_recovery = []
81
+ before_recovery = "# This is the normalized code:\n"
82
+ after_recovery = "\n# What is the source code?\n"
83
+
84
+ for idx, sample in enumerate(gen_results):
85
+ prompt_recovery = before_recovery + sample.strip() + after_recovery
86
+ samples[idx]['prompt_model2'] = prompt_recovery
87
+ inputs_recovery.append(prompt_recovery)
88
+
89
+ print("Starting recovery model inference...")
90
+ gen_results_recovery = llm_inference(inputs_recovery, args.recover_model_path,
91
+ args.gpus,
92
+ args.max_total_tokens,
93
+ args.gpu_memory_utilization,
94
+ args.temperature,
95
+ args.max_new_tokens,
96
+ args.stop_sequences)
97
+ gen_results_recovery = [gen_result[0] for gen_result in gen_results_recovery]
98
+
99
+
100
+ for idx in range(len(gen_results_recovery)):
101
+ samples[idx]['gen_result_model2'] = gen_results_recovery[idx]
102
+
103
+ if args.output_path:
104
+ if os.path.exists(args.output_path):
105
+ shutil.rmtree(args.output_path)
106
+ for opt in opts:
107
+ os.makedirs(os.path.join(args.output_path, opt))
108
+
109
+ if args.strip:
110
+ print("Processing function name stripping...")
111
+ for idx in range(len(gen_results_recovery)):
112
+ one = gen_results_recovery[idx]
113
+ func_name_in_gen = one.split('(')[0].split(' ')[-1].strip()
114
+ if func_name_in_gen.strip() and func_name_in_gen[0:2] == '**':
115
+ func_name_in_gen = func_name_in_gen[2:]
116
+ elif func_name_in_gen.strip() and func_name_in_gen[0] == '*':
117
+ func_name_in_gen = func_name_in_gen[1:]
118
+
119
+ original_func_name = samples[idx]["func_name"]
120
+ gen_results_recovery[idx] = one.replace(func_name_in_gen, original_func_name)
121
+ samples[idx]["gen_result_model2_stripped"] = gen_results_recovery[idx]
122
+
123
+ print("Saving inference results and logs...")
124
+ for idx_sample, final_result in enumerate(gen_results_recovery):
125
+ opt = infos[idx_sample]['opt']
126
+ language = infos[idx_sample]['language']
127
+ original_index = samples[idx_sample]['index']
128
+
129
+ save_path = os.path.join(args.output_path, opt, f"{original_index}_{opt}.{language}")
130
+ with open(save_path, "w") as f:
131
+ f.write(final_result)
132
+
133
+ log_path = save_path + ".log"
134
+ log_data = {
135
+ "index": original_index,
136
+ "opt": opt,
137
+ "language": language,
138
+ "func_name": samples[idx_sample]["func_name"],
139
+ "decompiler": args.decompiler,
140
+ "input_asm": samples[idx_sample][args.decompiler].strip(),
141
+ "prompt_model1": samples[idx_sample]['prompt_model1'],
142
+ "gen_result_model1": samples[idx_sample]['gen_result_model1'],
143
+ "prompt_model2": samples[idx_sample]['prompt_model2'],
144
+ "gen_result_model2": samples[idx_sample]['gen_result_model2'],
145
+ "final_result": final_result,
146
+ "stripped": args.strip
147
+ }
148
+
149
+ if args.strip and "gen_result_model2_stripped" in samples[idx_sample]:
150
+ log_data["gen_result_model2_stripped"] = samples[idx_sample]["gen_result_model2_stripped"]
151
+
152
+ with open(log_path, "w") as f:
153
+ json.dump(log_data, f, indent=2, ensure_ascii=False)
154
+
155
+ json_path = os.path.join(args.output_path, 'inference_results.jsonl')
156
+ with open(json_path, 'w') as f:
157
+ for sample in samples:
158
+ f.write(json.dumps(sample) + '\n')
159
+
160
+ stats_path = os.path.join(args.output_path, 'inference_stats.txt')
161
+ with open(stats_path, 'w') as f:
162
+ f.write(f"Total samples processed: {len(samples)}\n")
163
+ f.write(f"Model path: {args.model_path}\n")
164
+ f.write(f"Recovery model path: {args.recover_model_path}\n")
165
+ f.write(f"Dataset path: {args.dataset_path}\n")
166
+ f.write(f"Language: {args.language}\n")
167
+ f.write(f"Decompiler: {args.decompiler}\n")
168
+ f.write(f"Strip function names: {bool(args.strip)}\n")
169
+
170
+ opt_counts = {"O0": 0, "O1": 0, "O2": 0, "O3": 0}
171
+ for sample in samples:
172
+ opt_counts[sample['opt']] += 1
173
+
174
+ f.write("\nSamples per optimization level:\n")
175
+ for opt, count in opt_counts.items():
176
+ f.write(f" {opt}: {count}\n")
177
+
178
+ print(f"Inference completed! Results saved to {args.output_path}")
179
+ print(f"Total {len(samples)} samples processed.")