Spaces:
Runtime error
Runtime error
Fix prompt_tokens.shape[-1] issue with max_prompt_length
Browse files- app.py +1 -1
- audiocraft/models/genmodel.py +16 -10
- audiocraft/models/lm.py +1 -1
- audiocraft/models/musicgen.py +2 -2
app.py
CHANGED
|
@@ -281,7 +281,7 @@ def predict(model, text, melody_filepath, duration, dimension, topk, topp, tempe
|
|
| 281 |
cfg_coef=cfg_coef,
|
| 282 |
duration=segment_duration,
|
| 283 |
two_step_cfg=False,
|
| 284 |
-
extend_stride=
|
| 285 |
rep_penalty=0.5,
|
| 286 |
cfg_coef_beta=None, # double CFG is only useful for text-and-style conditioning
|
| 287 |
)
|
|
|
|
| 281 |
cfg_coef=cfg_coef,
|
| 282 |
duration=segment_duration,
|
| 283 |
two_step_cfg=False,
|
| 284 |
+
extend_stride=2,
|
| 285 |
rep_penalty=0.5,
|
| 286 |
cfg_coef_beta=None, # double CFG is only useful for text-and-style conditioning
|
| 287 |
)
|
audiocraft/models/genmodel.py
CHANGED
|
@@ -16,6 +16,7 @@ import typing as tp
|
|
| 16 |
|
| 17 |
import omegaconf
|
| 18 |
import torch
|
|
|
|
| 19 |
|
| 20 |
from .encodec import CompressionModel
|
| 21 |
from .lm import LMModel
|
|
@@ -191,11 +192,11 @@ class BaseGenModel(ABC):
|
|
| 191 |
return self.generate_audio(tokens)
|
| 192 |
|
| 193 |
def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
|
| 194 |
-
prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
|
| 195 |
"""Generate discrete audio tokens given audio prompt and/or conditions.
|
| 196 |
|
| 197 |
Args:
|
| 198 |
-
attributes (list of ConditioningAttributes): Conditions used for generation (
|
| 199 |
prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
|
| 200 |
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
| 201 |
Returns:
|
|
@@ -207,20 +208,24 @@ class BaseGenModel(ABC):
|
|
| 207 |
|
| 208 |
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
| 209 |
generated_tokens += current_gen_offset
|
|
|
|
|
|
|
| 210 |
if self._progress_callback is not None:
|
| 211 |
# Note that total_gen_len might be quite wrong depending on the
|
| 212 |
# codebook pattern used, but with delay it is almost accurate.
|
| 213 |
-
self._progress_callback(generated_tokens, tokens_to_generate)
|
| 214 |
-
|
| 215 |
-
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
if prompt_tokens is not None:
|
| 218 |
-
|
| 219 |
-
|
| 220 |
|
| 221 |
-
callback = None
|
| 222 |
-
|
| 223 |
-
callback = _progress_callback
|
| 224 |
|
| 225 |
if self.duration <= self.max_duration:
|
| 226 |
# generate by sampling from LM, simple case.
|
|
@@ -240,6 +245,7 @@ class BaseGenModel(ABC):
|
|
| 240 |
prompt_length = prompt_tokens.shape[-1]
|
| 241 |
|
| 242 |
stride_tokens = int(self.frame_rate * self.extend_stride)
|
|
|
|
| 243 |
while current_gen_offset + prompt_length < total_gen_len:
|
| 244 |
time_offset = current_gen_offset / self.frame_rate
|
| 245 |
chunk_duration = min(self.duration - time_offset, self.max_duration)
|
|
|
|
| 16 |
|
| 17 |
import omegaconf
|
| 18 |
import torch
|
| 19 |
+
import gradio as gr
|
| 20 |
|
| 21 |
from .encodec import CompressionModel
|
| 22 |
from .lm import LMModel
|
|
|
|
| 192 |
return self.generate_audio(tokens)
|
| 193 |
|
| 194 |
def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
|
| 195 |
+
prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False, progress_callback: gr.Progress = None) -> torch.Tensor:
|
| 196 |
"""Generate discrete audio tokens given audio prompt and/or conditions.
|
| 197 |
|
| 198 |
Args:
|
| 199 |
+
attributes (list of ConditioningAttributes): Conditions used for generation (text/melody).
|
| 200 |
prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
|
| 201 |
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
| 202 |
Returns:
|
|
|
|
| 208 |
|
| 209 |
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
| 210 |
generated_tokens += current_gen_offset
|
| 211 |
+
generated_tokens /= ((tokens_to_generate) / self.duration)
|
| 212 |
+
tokens_to_generate /= ((tokens_to_generate) / self.duration)
|
| 213 |
if self._progress_callback is not None:
|
| 214 |
# Note that total_gen_len might be quite wrong depending on the
|
| 215 |
# codebook pattern used, but with delay it is almost accurate.
|
| 216 |
+
self._progress_callback((generated_tokens / tokens_to_generate), f"Generated {generated_tokens: 6.2f}/{tokens_to_generate: 6.2f} seconds")
|
| 217 |
+
if progress_callback is not None:
|
| 218 |
+
# Update Gradio progress bar
|
| 219 |
+
progress_callback((generated_tokens / tokens_to_generate), f"Generated {generated_tokens: 6.2f}/{tokens_to_generate: 6.2f} seconds")
|
| 220 |
+
if progress:
|
| 221 |
+
print(f'{generated_tokens: 6.2f} / {tokens_to_generate: 6.2f}', end='\r')
|
| 222 |
|
| 223 |
if prompt_tokens is not None:
|
| 224 |
+
if prompt_tokens.shape[-1] > max_prompt_len:
|
| 225 |
+
prompt_tokens = prompt_tokens[..., :max_prompt_len]
|
| 226 |
|
| 227 |
+
# callback = None
|
| 228 |
+
callback = _progress_callback
|
|
|
|
| 229 |
|
| 230 |
if self.duration <= self.max_duration:
|
| 231 |
# generate by sampling from LM, simple case.
|
|
|
|
| 245 |
prompt_length = prompt_tokens.shape[-1]
|
| 246 |
|
| 247 |
stride_tokens = int(self.frame_rate * self.extend_stride)
|
| 248 |
+
|
| 249 |
while current_gen_offset + prompt_length < total_gen_len:
|
| 250 |
time_offset = current_gen_offset / self.frame_rate
|
| 251 |
chunk_duration = min(self.duration - time_offset, self.max_duration)
|
audiocraft/models/lm.py
CHANGED
|
@@ -517,7 +517,7 @@ class LMModel(StreamingModule):
|
|
| 517 |
B, K, T = prompt.shape
|
| 518 |
start_offset = T
|
| 519 |
print(f"start_offset: {start_offset} | max_gen_len: {max_gen_len}")
|
| 520 |
-
assert start_offset
|
| 521 |
|
| 522 |
pattern = self.pattern_provider.get_pattern(max_gen_len)
|
| 523 |
# this token is used as default value for codes that are not generated yet
|
|
|
|
| 517 |
B, K, T = prompt.shape
|
| 518 |
start_offset = T
|
| 519 |
print(f"start_offset: {start_offset} | max_gen_len: {max_gen_len}")
|
| 520 |
+
assert start_offset <= max_gen_len
|
| 521 |
|
| 522 |
pattern = self.pattern_provider.get_pattern(max_gen_len)
|
| 523 |
# this token is used as default value for codes that are not generated yet
|
audiocraft/models/musicgen.py
CHANGED
|
@@ -453,8 +453,8 @@ class MusicGen:
|
|
| 453 |
print(f'{generated_tokens: 6.2f} / {tokens_to_generate: 6.2f}', end='\r')
|
| 454 |
|
| 455 |
if prompt_tokens is not None:
|
| 456 |
-
|
| 457 |
-
|
| 458 |
|
| 459 |
# callback = None
|
| 460 |
callback = _progress_callback
|
|
|
|
| 453 |
print(f'{generated_tokens: 6.2f} / {tokens_to_generate: 6.2f}', end='\r')
|
| 454 |
|
| 455 |
if prompt_tokens is not None:
|
| 456 |
+
if prompt_tokens.shape[-1] > max_prompt_len:
|
| 457 |
+
prompt_tokens = prompt_tokens[..., :max_prompt_len]
|
| 458 |
|
| 459 |
# callback = None
|
| 460 |
callback = _progress_callback
|