Spaces:
Runtime error
Runtime error
| import pdb | |
| import torch.optim | |
| import torch.nn as nn | |
| from models.hinet import Hinet | |
| # from utils.attacks import attack_layer, mp3_attack_v2, butterworth_attack | |
| import numpy as np | |
| import random | |
| device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | |
| class Model(nn.Module): | |
| def __init__(self, num_point, num_bit, n_fft, hop_length, use_recover_layer, num_layers): | |
| super(Model, self).__init__() | |
| self.hinet = Hinet(num_layers=num_layers) | |
| self.watermark_fc = torch.nn.Linear(num_bit, num_point) | |
| self.watermark_fc_back = torch.nn.Linear(num_point, num_bit) | |
| self.n_fft = n_fft | |
| self.hop_length = hop_length | |
| self.dropout1 = torch.nn.Dropout() | |
| self.identity = torch.nn.Identity() | |
| self.recover_layer = SameSizeConv2d(2, 2) | |
| self.use_recover_layer = use_recover_layer | |
| def stft(self, data): | |
| window = torch.hann_window(self.n_fft).to(data.device) | |
| tmp = torch.stft(data, n_fft=self.n_fft, hop_length=self.hop_length, window=window, return_complex=False) | |
| # [1, 501, 41, 2] | |
| return tmp | |
| def istft(self, signal_wmd_fft): | |
| window = torch.hann_window(self.n_fft).to(signal_wmd_fft.device) | |
| # Changed in version 2.0: Real datatype inputs are no longer supported. Input must now have a complex datatype, as returned by stft(..., return_complex=True). | |
| return torch.istft(signal_wmd_fft, n_fft=self.n_fft, hop_length=self.hop_length, window=window, | |
| return_complex=False) | |
| def encode(self, signal, message, need_fft=False): | |
| # 1.信号执行fft | |
| signal_fft = self.stft(signal) | |
| # import pdb | |
| # pdb.set_trace() | |
| # (batch,freq_bins,time_frames,2) | |
| # 2.Message执行fft | |
| message_expand = self.watermark_fc(message) | |
| message_fft = self.stft(message_expand) | |
| # 3.encode | |
| signal_wmd_fft, msg_remain = self.enc_dec(signal_fft, message_fft, rev=False) | |
| # (batch,freq_bins,time_frames,2) | |
| signal_wmd = self.istft(signal_wmd_fft) | |
| if need_fft: | |
| return signal_wmd, signal_fft, message_fft | |
| return signal_wmd | |
| def decode(self, signal): | |
| signal_fft = self.stft(signal) | |
| if self.use_recover_layer: | |
| signal_fft = self.recover_layer(signal_fft) | |
| watermark_fft = signal_fft | |
| # watermark_fft = torch.randn(signal_fft.shape).cuda() | |
| _, message_restored_fft = self.enc_dec(signal_fft, watermark_fft, rev=True) | |
| message_restored_expanded = self.istft(message_restored_fft) | |
| message_restored_float = self.watermark_fc_back(message_restored_expanded).clamp(-1, 1) | |
| return message_restored_float | |
| def enc_dec(self, signal, watermark, rev): | |
| signal = signal.permute(0, 3, 2, 1) | |
| # [4, 2, 41, 501] | |
| watermark = watermark.permute(0, 3, 2, 1) | |
| # pdb.set_trace() | |
| signal2, watermark2 = self.hinet(signal, watermark, rev) | |
| return signal2.permute(0, 3, 2, 1), watermark2.permute(0, 3, 2, 1) | |
| class SameSizeConv2d(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(SameSizeConv2d, self).__init__() | |
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) | |
| def forward(self, x): | |
| # (batch,501,41,2] | |
| x1 = x.permute(0, 3, 1, 2) | |
| # (batch,2,501,41] | |
| x2 = self.conv(x1) | |
| # (batch,2,501,41] | |
| x3 = x2.permute(0, 2, 3, 1) | |
| # (batch,501,41,2] | |
| return x3 | |