Spaces:
Runtime error
Runtime error
modify app
Browse files- inference.py +36 -17
inference.py
CHANGED
|
@@ -114,24 +114,43 @@ class MasteringStyleTransfer:
|
|
| 114 |
|
| 115 |
return min_loss_output, min_loss_params, min_loss_embedding, min_loss_step + 1
|
| 116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
def process_audio(self, input_audio, reference_audio, ito_reference_audio, params, perform_ito, log_ito=False):
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
for audio in [input_audio, reference_audio, ito_reference_audio]
|
| 122 |
-
]
|
| 123 |
-
|
| 124 |
-
input_tensor = torch.FloatTensor(input_audio).unsqueeze(0).to(self.device)
|
| 125 |
-
reference_tensor = torch.FloatTensor(reference_audio).unsqueeze(0).to(self.device)
|
| 126 |
-
ito_reference_tensor = torch.FloatTensor(ito_reference_audio).unsqueeze(0).to(self.device)
|
| 127 |
-
|
| 128 |
-
#resample to 44.1kHz if necessary
|
| 129 |
-
if input_audio[0] != self.args.sample_rate:
|
| 130 |
-
input_tensor = convert_audio(input_tensor, input_audio[0], self.args.sample_rate, 2)
|
| 131 |
-
if reference_audio[0] != self.args.sample_rate:
|
| 132 |
-
reference_tensor = convert_audio(reference_tensor, reference_audio[0], self.args.sample_rate, 2)
|
| 133 |
-
if ito_reference_audio[0] != self.args.sample_rate:
|
| 134 |
-
ito_reference_tensor = convert_audio(ito_reference_tensor, ito_reference_audio[0], self.args.sample_rate, 2)
|
| 135 |
|
| 136 |
reference_feature = self.get_reference_embedding(reference_tensor)
|
| 137 |
|
|
|
|
| 114 |
|
| 115 |
return min_loss_output, min_loss_params, min_loss_embedding, min_loss_step + 1
|
| 116 |
|
| 117 |
+
def preprocess_audio(self, audio, target_sample_rate=44100):
|
| 118 |
+
sample_rate, data = audio
|
| 119 |
+
|
| 120 |
+
# Normalize audio to -1 to 1 range
|
| 121 |
+
if data.dtype == np.int16:
|
| 122 |
+
data = data.astype(np.float32) / 32768.0
|
| 123 |
+
elif data.dtype == np.float32:
|
| 124 |
+
data = np.clip(data, -1.0, 1.0)
|
| 125 |
+
else:
|
| 126 |
+
raise ValueError(f"Unsupported audio data type: {data.dtype}")
|
| 127 |
+
|
| 128 |
+
# Ensure stereo channels
|
| 129 |
+
if data.ndim == 1:
|
| 130 |
+
data = np.stack([data, data])
|
| 131 |
+
elif data.ndim == 2:
|
| 132 |
+
if data.shape[0] == 2:
|
| 133 |
+
pass # Already in correct shape
|
| 134 |
+
elif data.shape[1] == 2:
|
| 135 |
+
data = data.T
|
| 136 |
+
else:
|
| 137 |
+
data = np.stack([data[:, 0], data[:, 0]]) # Duplicate mono channel
|
| 138 |
+
else:
|
| 139 |
+
raise ValueError(f"Unsupported audio shape: {data.shape}")
|
| 140 |
+
|
| 141 |
+
# Convert to torch tensor
|
| 142 |
+
data_tensor = torch.FloatTensor(data).unsqueeze(0)
|
| 143 |
+
|
| 144 |
+
# Resample if necessary
|
| 145 |
+
if sample_rate != target_sample_rate:
|
| 146 |
+
data_tensor = julius.resample_frac(data_tensor, sample_rate, target_sample_rate)
|
| 147 |
+
|
| 148 |
+
return data_tensor.to(self.device)
|
| 149 |
+
|
| 150 |
def process_audio(self, input_audio, reference_audio, ito_reference_audio, params, perform_ito, log_ito=False):
|
| 151 |
+
input_tensor = self.preprocess_audio(input_audio, self.args.sample_rate)
|
| 152 |
+
reference_tensor = self.preprocess_audio(reference_audio, self.args.sample_rate)
|
| 153 |
+
ito_reference_tensor = self.preprocess_audio(ito_reference_audio, self.args.sample_rate)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
reference_feature = self.get_reference_embedding(reference_tensor)
|
| 156 |
|