Spaces:
Running
Running
modify loss
Browse files- __pycache__/inference.cpython-311.pyc +0 -0
- inference.py +4 -4
- modules/__pycache__/loss.cpython-311.pyc +0 -0
- modules/loss.py +26 -102
__pycache__/inference.cpython-311.pyc
CHANGED
|
Binary files a/__pycache__/inference.cpython-311.pyc and b/__pycache__/inference.cpython-311.pyc differ
|
|
|
inference.py
CHANGED
|
@@ -68,7 +68,7 @@ class MasteringStyleTransfer:
|
|
| 68 |
return output_audio, predicted_params
|
| 69 |
|
| 70 |
def inference_time_optimization(self, input_tensor, reference_tensor, ito_config, initial_reference_feature):
|
| 71 |
-
fit_embedding = torch.nn.Parameter(initial_reference_feature)
|
| 72 |
optimizer = getattr(torch.optim, ito_config['optimizer'])([fit_embedding], lr=ito_config['learning_rate'])
|
| 73 |
|
| 74 |
min_loss = float('inf')
|
|
@@ -97,9 +97,7 @@ class MasteringStyleTransfer:
|
|
| 97 |
target = reference_tensor
|
| 98 |
else:
|
| 99 |
target = ito_config['clap_text_prompt']
|
| 100 |
-
print(f'ito_config clap_distance_fn: {ito_config["clap_distance_fn"]}')
|
| 101 |
total_loss = self.clap_loss(output_audio, target, self.args.sample_rate, distance_fn=ito_config['clap_distance_fn'])
|
| 102 |
-
print(f'total_loss: {total_loss}')
|
| 103 |
|
| 104 |
if total_loss < min_loss:
|
| 105 |
min_loss = total_loss.item()
|
|
@@ -122,6 +120,9 @@ class MasteringStyleTransfer:
|
|
| 122 |
total_loss.backward()
|
| 123 |
optimizer.step()
|
| 124 |
|
|
|
|
|
|
|
|
|
|
| 125 |
return all_results, min_loss_step
|
| 126 |
|
| 127 |
def preprocess_audio(self, audio, target_sample_rate=44100, normalize=False):
|
|
@@ -290,7 +291,6 @@ class MasteringStyleTransfer:
|
|
| 290 |
|
| 291 |
return "\n".join(output)
|
| 292 |
|
| 293 |
-
|
| 294 |
def reload_weights(model, ckpt_path, device):
|
| 295 |
checkpoint = torch.load(ckpt_path, map_location=device)
|
| 296 |
|
|
|
|
| 68 |
return output_audio, predicted_params
|
| 69 |
|
| 70 |
def inference_time_optimization(self, input_tensor, reference_tensor, ito_config, initial_reference_feature):
|
| 71 |
+
fit_embedding = torch.nn.Parameter(initial_reference_feature, requires_grad=True)
|
| 72 |
optimizer = getattr(torch.optim, ito_config['optimizer'])([fit_embedding], lr=ito_config['learning_rate'])
|
| 73 |
|
| 74 |
min_loss = float('inf')
|
|
|
|
| 97 |
target = reference_tensor
|
| 98 |
else:
|
| 99 |
target = ito_config['clap_text_prompt']
|
|
|
|
| 100 |
total_loss = self.clap_loss(output_audio, target, self.args.sample_rate, distance_fn=ito_config['clap_distance_fn'])
|
|
|
|
| 101 |
|
| 102 |
if total_loss < min_loss:
|
| 103 |
min_loss = total_loss.item()
|
|
|
|
| 120 |
total_loss.backward()
|
| 121 |
optimizer.step()
|
| 122 |
|
| 123 |
+
gc.collect()
|
| 124 |
+
torch.cuda.empty_cache()
|
| 125 |
+
|
| 126 |
return all_results, min_loss_step
|
| 127 |
|
| 128 |
def preprocess_audio(self, audio, target_sample_rate=44100, normalize=False):
|
|
|
|
| 291 |
|
| 292 |
return "\n".join(output)
|
| 293 |
|
|
|
|
| 294 |
def reload_weights(model, ckpt_path, device):
|
| 295 |
checkpoint = torch.load(ckpt_path, map_location=device)
|
| 296 |
|
modules/__pycache__/loss.cpython-311.pyc
CHANGED
|
Binary files a/modules/__pycache__/loss.cpython-311.pyc and b/modules/__pycache__/loss.cpython-311.pyc differ
|
|
|
modules/loss.py
CHANGED
|
@@ -185,35 +185,26 @@ class CLAPFeatureLoss(nn.Module):
|
|
| 185 |
self.target_sample_rate = 48000 # CLAP expects 48kHz audio
|
| 186 |
self.model = laion_clap.CLAP_Module(enable_fusion=False)
|
| 187 |
self.model.load_ckpt() # download the default pretrained checkpoint
|
| 188 |
-
|
| 189 |
-
# Freeze the CLAP model parameters
|
| 190 |
-
for param in self.model.parameters():
|
| 191 |
-
param.requires_grad = False
|
| 192 |
|
| 193 |
-
def forward(self, input_audio, target, sample_rate, distance_fn='
|
| 194 |
# Process input audio
|
| 195 |
-
|
| 196 |
-
input_audio = self.preprocess_audio(input_audio, sample_rate)
|
| 197 |
-
|
| 198 |
-
with torch.enable_grad():
|
| 199 |
-
input_embed = self.model.get_audio_embedding_from_data(x=input_audio, use_tensor=True)
|
| 200 |
|
| 201 |
# Process target (audio or text)
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
else:
|
| 209 |
-
raise ValueError("Target must be either audio tensor or text (string or list of strings)")
|
| 210 |
|
| 211 |
# Compute loss using the specified distance function
|
| 212 |
loss = self.compute_distance(input_embed, target_embed, distance_fn)
|
| 213 |
|
| 214 |
return loss
|
| 215 |
|
| 216 |
-
def
|
| 217 |
# Ensure input is in the correct shape (N, C, T)
|
| 218 |
if audio.dim() == 2:
|
| 219 |
audio = audio.unsqueeze(1)
|
|
@@ -221,15 +212,22 @@ class CLAPFeatureLoss(nn.Module):
|
|
| 221 |
# Convert to mono if stereo
|
| 222 |
if audio.shape[1] > 1:
|
| 223 |
audio = audio.mean(dim=1, keepdim=True)
|
| 224 |
-
|
| 225 |
# Resample if necessary
|
| 226 |
if sample_rate != self.target_sample_rate:
|
| 227 |
audio = self.resample(audio, sample_rate)
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
|
| 234 |
def compute_distance(self, x, y, distance_fn):
|
| 235 |
if distance_fn == 'mse':
|
|
@@ -241,86 +239,12 @@ class CLAPFeatureLoss(nn.Module):
|
|
| 241 |
else:
|
| 242 |
raise ValueError(f"Unsupported distance function: {distance_fn}")
|
| 243 |
|
| 244 |
-
def
|
| 245 |
-
audio = audio.squeeze(1) # Remove channel dimension
|
| 246 |
-
audio = torch.clamp(audio, -1.0, 1.0)
|
| 247 |
-
audio = (audio * 32767.0).to(torch.int16).to(torch.float32) / 32767.0
|
| 248 |
-
return audio
|
| 249 |
-
|
| 250 |
-
def resample(self, audio, orig_sample_rate):
|
| 251 |
resampler = torchaudio.transforms.Resample(
|
| 252 |
-
orig_freq=
|
| 253 |
).to(audio.device)
|
| 254 |
return resampler(audio)
|
| 255 |
-
|
| 256 |
-
# def forward(self, input_audio, target, sample_rate, distance_fn='cosine'):
|
| 257 |
-
# # Process input audio
|
| 258 |
-
# input_embed = self.process_audio(input_audio, sample_rate)
|
| 259 |
-
|
| 260 |
-
# # Process target (audio or text)
|
| 261 |
-
# if isinstance(target, torch.Tensor):
|
| 262 |
-
# target_embed = self.process_audio(target, sample_rate)
|
| 263 |
-
# elif isinstance(target, str) or (isinstance(target, list) and isinstance(target[0], str)):
|
| 264 |
-
# target_embed = self.process_text(target)
|
| 265 |
-
# else:
|
| 266 |
-
# raise ValueError("Target must be either audio tensor or text (string or list of strings)")
|
| 267 |
-
|
| 268 |
-
# # Compute loss using the specified distance function
|
| 269 |
-
# loss = self.compute_distance(input_embed, target_embed, distance_fn)
|
| 270 |
-
|
| 271 |
-
# return loss
|
| 272 |
-
|
| 273 |
-
# def process_audio(self, audio, sample_rate):
|
| 274 |
-
# # Ensure input is in the correct shape (N, C, T)
|
| 275 |
-
# if audio.dim() == 2:
|
| 276 |
-
# audio = audio.unsqueeze(1)
|
| 277 |
-
|
| 278 |
-
# # Convert to mono if stereo
|
| 279 |
-
# if audio.shape[1] > 1:
|
| 280 |
-
# audio = audio.mean(dim=1, keepdim=True)
|
| 281 |
-
|
| 282 |
-
# # Resample if necessary
|
| 283 |
-
# if sample_rate != self.target_sample_rate:
|
| 284 |
-
# audio = self.resample(audio, sample_rate)
|
| 285 |
-
|
| 286 |
-
# # Quantize audio data
|
| 287 |
-
# audio = self.quantize(audio)
|
| 288 |
-
|
| 289 |
-
# # Get CLAP embeddings
|
| 290 |
-
# with torch.no_grad():
|
| 291 |
-
# embed = self.model.get_audio_embedding_from_data(x=audio, use_tensor=True)
|
| 292 |
-
# return embed
|
| 293 |
-
|
| 294 |
-
# def process_text(self, text):
|
| 295 |
-
# # Get CLAP embeddings for text
|
| 296 |
-
# # ensure input is a list of strings
|
| 297 |
-
# if not isinstance(text, list):
|
| 298 |
-
# text = [text]
|
| 299 |
-
# with torch.no_grad():
|
| 300 |
-
# embed = self.model.get_text_embedding(text, use_tensor=True)
|
| 301 |
-
# return embed
|
| 302 |
-
|
| 303 |
-
# def compute_distance(self, x, y, distance_fn):
|
| 304 |
-
# if distance_fn == 'mse':
|
| 305 |
-
# return F.mse_loss(x, y)
|
| 306 |
-
# elif distance_fn == 'l1':
|
| 307 |
-
# return F.l1_loss(x, y)
|
| 308 |
-
# elif distance_fn == 'cosine':
|
| 309 |
-
# return 1 - F.cosine_similarity(x, y).mean()
|
| 310 |
-
# else:
|
| 311 |
-
# raise ValueError(f"Unsupported distance function: {distance_fn}")
|
| 312 |
-
|
| 313 |
-
# def quantize(self, audio):
|
| 314 |
-
# audio = audio.squeeze(1) # Remove channel dimension
|
| 315 |
-
# audio = torch.clamp(audio, -1.0, 1.0)
|
| 316 |
-
# audio = (audio * 32767.0).to(torch.int16).to(torch.float32) / 32767.0
|
| 317 |
-
# return audio
|
| 318 |
-
|
| 319 |
-
# def resample(self, audio, input_sample_rate):
|
| 320 |
-
# resampler = torchaudio.transforms.Resample(
|
| 321 |
-
# orig_freq=input_sample_rate, new_freq=self.target_sample_rate
|
| 322 |
-
# ).to(audio.device)
|
| 323 |
-
# return resampler(audio)
|
| 324 |
|
| 325 |
|
| 326 |
"""
|
|
|
|
| 185 |
self.target_sample_rate = 48000 # CLAP expects 48kHz audio
|
| 186 |
self.model = laion_clap.CLAP_Module(enable_fusion=False)
|
| 187 |
self.model.load_ckpt() # download the default pretrained checkpoint
|
| 188 |
+
self.model.eval()
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
+
def forward(self, input_audio, target, sample_rate, distance_fn='cosine'):
|
| 191 |
# Process input audio
|
| 192 |
+
input_embed = self.process_audio(input_audio, sample_rate)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
# Process target (audio or text)
|
| 195 |
+
if isinstance(target, torch.Tensor):
|
| 196 |
+
target_embed = self.process_audio(target, sample_rate)
|
| 197 |
+
elif isinstance(target, str) or (isinstance(target, list) and isinstance(target[0], str)):
|
| 198 |
+
target_embed = self.process_text(target)
|
| 199 |
+
else:
|
| 200 |
+
raise ValueError("Target must be either audio tensor or text (string or list of strings)")
|
|
|
|
|
|
|
| 201 |
|
| 202 |
# Compute loss using the specified distance function
|
| 203 |
loss = self.compute_distance(input_embed, target_embed, distance_fn)
|
| 204 |
|
| 205 |
return loss
|
| 206 |
|
| 207 |
+
def process_audio(self, audio, sample_rate):
|
| 208 |
# Ensure input is in the correct shape (N, C, T)
|
| 209 |
if audio.dim() == 2:
|
| 210 |
audio = audio.unsqueeze(1)
|
|
|
|
| 212 |
# Convert to mono if stereo
|
| 213 |
if audio.shape[1] > 1:
|
| 214 |
audio = audio.mean(dim=1, keepdim=True)
|
|
|
|
| 215 |
# Resample if necessary
|
| 216 |
if sample_rate != self.target_sample_rate:
|
| 217 |
audio = self.resample(audio, sample_rate)
|
| 218 |
+
audio = audio.squeeze(1)
|
| 219 |
+
|
| 220 |
+
# Get CLAP embeddings
|
| 221 |
+
embed = self.model.get_audio_embedding_from_data(x=audio, use_tensor=True)
|
| 222 |
+
return embed
|
| 223 |
+
|
| 224 |
+
def process_text(self, text):
|
| 225 |
+
# Get CLAP embeddings for text
|
| 226 |
+
# ensure input is a list of strings
|
| 227 |
+
if not isinstance(text, list):
|
| 228 |
+
text = [text]
|
| 229 |
+
embed = self.model.get_text_embedding(text, use_tensor=True)
|
| 230 |
+
return embed
|
| 231 |
|
| 232 |
def compute_distance(self, x, y, distance_fn):
|
| 233 |
if distance_fn == 'mse':
|
|
|
|
| 239 |
else:
|
| 240 |
raise ValueError(f"Unsupported distance function: {distance_fn}")
|
| 241 |
|
| 242 |
+
def resample(self, audio, input_sample_rate):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
resampler = torchaudio.transforms.Resample(
|
| 244 |
+
orig_freq=input_sample_rate, new_freq=self.target_sample_rate
|
| 245 |
).to(audio.device)
|
| 246 |
return resampler(audio)
|
| 247 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
|
| 249 |
|
| 250 |
"""
|