Spaces:
Runtime error
Runtime error
| from utils import silent_util | |
| import torch | |
| import numpy as np | |
| from utils import bin_util | |
| fix_pattern = [1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, | |
| 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, | |
| 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, | |
| 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, | |
| 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0] | |
| def create_parcel_message(len_start_bit, num_bit, wm_text, verbose=False): | |
| # 2.起始bit | |
| # start_bit = np.array([0] * len_start_bit) | |
| start_bit = fix_pattern[0:len_start_bit] | |
| error_prob = 2 ** len_start_bit / 10000 | |
| # todo:考虑threshold的时候的错误率呢? | |
| if verbose: | |
| print("起始bit长度:%d,错误率:%.1f万" % (len(start_bit), error_prob)) | |
| # 3.信息内容 | |
| length_msg = num_bit - len(start_bit) | |
| if wm_text: | |
| msg_arr = bin_util.hexStr2BinArray(wm_text) | |
| else: | |
| msg_arr = np.random.choice([0, 1], size=length_msg) | |
| # 4.封装信息 | |
| watermark = np.concatenate([start_bit, msg_arr]) | |
| assert len(watermark) == num_bit | |
| return start_bit, msg_arr, watermark | |
| import time | |
| def add_watermark(bir_array, data, num_point, shift_range, device, model, silence_check=False): | |
| t1 = time.time() | |
| # 1.获得区块大小 | |
| chunk_size = num_point + int(num_point * shift_range) | |
| output_chunks = [] | |
| idx_trunck = -1 | |
| for i in range(0, len(data), chunk_size): | |
| idx_trunck += 1 | |
| current_chunk = data[i:i + chunk_size].copy() | |
| # 最后一块,长度不足 | |
| if len(current_chunk) < chunk_size: | |
| output_chunks.append(current_chunk) | |
| break | |
| # 处理区块: [水印区|间隔区] | |
| current_chunk_cover_area = current_chunk[0:num_point] | |
| current_chunk_shift_area = current_chunk[num_point:] | |
| current_chunk_cover_area_wmd = encode_trunck_with_silence_check(silence_check, | |
| idx_trunck, | |
| current_chunk_cover_area, bir_array, | |
| device, model) | |
| output = np.concatenate([current_chunk_cover_area_wmd, current_chunk_shift_area]) | |
| assert output.shape == current_chunk.shape | |
| output_chunks.append(output) | |
| assert len(output_chunks) > 0 | |
| reconstructed_array = np.concatenate(output_chunks) | |
| time_cost = time.time() - t1 | |
| return data, reconstructed_array, time_cost | |
| def encode_trunck_with_silence_check(silence_check, trunck_idx, trunck, wm, device, model): | |
| # 1.判断是否是静音,通过判断子段是否静音来处理 | |
| if silence_check and silent_util.is_silent(trunck): | |
| print("跳过静音区块:", trunck_idx) | |
| return trunck | |
| # 2.加入水印 | |
| trnck_wmd = encode_trunck(trunck, wm, device, model) | |
| return trnck_wmd | |
| def encode_trunck(trunck, wm, device, model): | |
| with torch.no_grad(): | |
| signal = torch.FloatTensor(trunck).to(device)[None] | |
| message = torch.FloatTensor(np.array(wm)).to(device)[None] | |
| signal_wmd_tensor = model.encode(signal, message) | |
| signal_wmd = signal_wmd_tensor.detach().cpu().numpy().squeeze() | |
| return signal_wmd | |