Spaces:
Runtime error
Runtime error
| import pdb | |
| import torch | |
| import numpy as np | |
| from utils import bin_util | |
| def decode_trunck(trunck, model, device): | |
| with torch.no_grad(): | |
| signal = torch.FloatTensor(trunck).to(device).unsqueeze(0) | |
| message = (model.decode(signal) >= 0.5).int() | |
| message = message.detach().cpu().numpy().squeeze() | |
| return message | |
| def is_start_bit_match(start_bit, decoded_start_bit, start_bit_ber_threshold): | |
| assert decoded_start_bit.shape == start_bit.shape | |
| ber = 1 - np.mean(start_bit == decoded_start_bit) | |
| return ber < start_bit_ber_threshold | |
| def extract_watermark(data, start_bit, shift_range, num_point, start_bit_ber_threshold, model, device, | |
| verbose=False): | |
| # pdb.set_trace() | |
| shift_range_points = int(shift_range * num_point) | |
| i = 0 # 当前的指针位置 | |
| results = [] | |
| while True: | |
| start = i | |
| end = start + num_point | |
| trunck = data[start:end] | |
| if len(trunck) < num_point: | |
| break | |
| bit_array = decode_trunck(trunck, model, device) | |
| decoded_start_bit = bit_array[0:len(start_bit)] | |
| if not is_start_bit_match(start_bit, decoded_start_bit, start_bit_ber_threshold): | |
| i = i + shift_range_points | |
| continue | |
| # 寻找到了起始位置 | |
| if verbose: | |
| msg_bit = bit_array[len(start_bit):] | |
| msg_str = bin_util.binArray2HexStr(msg_bit) | |
| print(i, "解码信息:", msg_str) | |
| results.append(bit_array) | |
| i = i + num_point + shift_range_points | |
| support_count = len(results) | |
| if support_count == 0: | |
| mean_result = None | |
| first_result = None | |
| exist_prob = None | |
| else: | |
| mean_result = (np.array(results).mean(axis=0) >= 0.5).astype(int) | |
| exist_prob = (mean_result[0:len(start_bit)] == start_bit).mean() | |
| first_result = results[0] | |
| return support_count, exist_prob, mean_result, first_result | |
| def extract_watermark_v2(data, start_bit, shift_range, num_point, | |
| start_bit_ber_threshold, model, device, | |
| merge_type, | |
| shift_range_p=0.5, ): | |
| shift_range_points = int(shift_range * num_point * shift_range_p) | |
| i = 0 # 当前的指针位置 | |
| results = [] | |
| while True: | |
| start = i | |
| end = start + num_point | |
| trunck = data[start:end] | |
| if len(trunck) < num_point: | |
| break | |
| bit_array = decode_trunck(trunck, model, device) | |
| decoded_start_bit = bit_array[0:len(start_bit)] | |
| ber_start_bit = 1 - np.mean(start_bit == decoded_start_bit) | |
| if ber_start_bit > start_bit_ber_threshold: | |
| i = i + shift_range_points | |
| continue | |
| # 寻找到了起始位置 | |
| results.append({ | |
| "sim": 1 - ber_start_bit, | |
| "msg": bit_array, | |
| }) | |
| # 这里很重要,如果threshold设置的太大,那么就会跳过一些可能的点 | |
| # i = i + num_point + shift_range_points | |
| i = i + shift_range_points | |
| support_count = len(results) | |
| if support_count == 0: | |
| mean_result = None | |
| else: | |
| # 1.加权得到最终结果 | |
| if merge_type == "weighted": | |
| raise Exception("") | |
| elif merge_type == "best": | |
| # 相似度从大到小排序 | |
| best_val = sorted(results, key=lambda x: x["sim"], reverse=True)[0] | |
| if np.isclose(1.0, best_val["sim"]): | |
| # 那么对所有为1.0的进行求平均 | |
| results_1 = [i["msg"] for i in results if np.isclose(i["sim"], 1.0)] | |
| mean_result = (np.array(results_1).mean(axis=0) >= 0.5).astype(int) | |
| else: | |
| mean_result = best_val["msg"] | |
| else: | |
| raise Exception("") | |
| # assert merge_type == "mean" | |
| # mean_result = (np.array([i[-1] for i in results]).mean(axis=0) >= 0.5).astype(int) | |
| return support_count, mean_result, results | |