Spaces:
Runtime error
Runtime error
update loss
Browse files- app.py +3 -3
- inference.py +2 -1
- modules/__pycache__/loss.cpython-311.pyc +0 -0
- modules/loss.py +1 -0
app.py
CHANGED
|
@@ -243,9 +243,9 @@ with gr.Blocks() as demo:
|
|
| 243 |
|
| 244 |
def update_clap_options(loss_function):
|
| 245 |
if loss_function == "CLAPFeatureLoss":
|
| 246 |
-
return gr.update(visible=
|
| 247 |
else:
|
| 248 |
-
return gr.update(visible=
|
| 249 |
|
| 250 |
loss_function.change(
|
| 251 |
update_clap_options,
|
|
@@ -261,7 +261,7 @@ with gr.Blocks() as demo:
|
|
| 261 |
inputs=[clap_target_type],
|
| 262 |
outputs=[clap_text_prompt]
|
| 263 |
)
|
| 264 |
-
|
| 265 |
ito_button = gr.Button("Perform ITO")
|
| 266 |
|
| 267 |
with gr.Row():
|
|
|
|
| 243 |
|
| 244 |
def update_clap_options(loss_function):
|
| 245 |
if loss_function == "CLAPFeatureLoss":
|
| 246 |
+
return gr.update(visible=False), gr.update(visible=True)
|
| 247 |
else:
|
| 248 |
+
return gr.update(visible=True), gr.update(visible=False)
|
| 249 |
|
| 250 |
loss_function.change(
|
| 251 |
update_clap_options,
|
|
|
|
| 261 |
inputs=[clap_target_type],
|
| 262 |
outputs=[clap_text_prompt]
|
| 263 |
)
|
| 264 |
+
|
| 265 |
ito_button = gr.Button("Perform ITO")
|
| 266 |
|
| 267 |
with gr.Row():
|
inference.py
CHANGED
|
@@ -91,13 +91,14 @@ class MasteringStyleTransfer:
|
|
| 91 |
# Compute loss
|
| 92 |
if ito_config['loss_function'] == 'AudioFeatureLoss':
|
| 93 |
losses = af_loss(output_audio, reference_tensor)
|
|
|
|
| 94 |
elif ito_config['loss_function'] == 'CLAPFeatureLoss':
|
| 95 |
if ito_config['clap_target_type'] == 'Audio':
|
| 96 |
target = reference_tensor
|
| 97 |
else:
|
| 98 |
target = ito_config['clap_text_prompt']
|
| 99 |
losses = self.clap_loss(output_audio, target, self.args.sample_rate)
|
| 100 |
-
|
| 101 |
|
| 102 |
if total_loss < min_loss:
|
| 103 |
min_loss = total_loss.item()
|
|
|
|
| 91 |
# Compute loss
|
| 92 |
if ito_config['loss_function'] == 'AudioFeatureLoss':
|
| 93 |
losses = af_loss(output_audio, reference_tensor)
|
| 94 |
+
total_loss = sum(losses.values())
|
| 95 |
elif ito_config['loss_function'] == 'CLAPFeatureLoss':
|
| 96 |
if ito_config['clap_target_type'] == 'Audio':
|
| 97 |
target = reference_tensor
|
| 98 |
else:
|
| 99 |
target = ito_config['clap_text_prompt']
|
| 100 |
losses = self.clap_loss(output_audio, target, self.args.sample_rate)
|
| 101 |
+
total_loss = losses
|
| 102 |
|
| 103 |
if total_loss < min_loss:
|
| 104 |
min_loss = total_loss.item()
|
modules/__pycache__/loss.cpython-311.pyc
CHANGED
|
Binary files a/modules/__pycache__/loss.cpython-311.pyc and b/modules/__pycache__/loss.cpython-311.pyc differ
|
|
|
modules/loss.py
CHANGED
|
@@ -520,3 +520,4 @@ if __name__ == "__main__":
|
|
| 520 |
loss = clap_loss(input_audio, target_text, sample_rate)
|
| 521 |
print(loss)
|
| 522 |
|
|
|
|
|
|
| 520 |
loss = clap_loss(input_audio, target_text, sample_rate)
|
| 521 |
print(loss)
|
| 522 |
|
| 523 |
+
print(loss.item())
|