Spaces:
Runtime error
Runtime error
XXXXRT666
commited on
Commit
Β·
7bdf3c3
1
Parent(s):
d2e713a
Fix
Browse files- AR/models/structs.py +1 -1
- AR/models/t2s_model_flash_attn.py +21 -15
- inference_webui.py +1 -0
- pre-requirements.txt +2 -1
- requirements.txt +1 -2
AR/models/structs.py
CHANGED
|
@@ -68,7 +68,7 @@ class T2SSession:
|
|
| 68 |
self.xy_dec_ = torch.rand((bsz, 1, decoder.embedding_dim)).to(dtype)
|
| 69 |
|
| 70 |
# EOS
|
| 71 |
-
self.completed = [False] * len(self.x)
|
| 72 |
self.y_results: List[Tensor] = [None] * len(self.x) # type: ignore
|
| 73 |
|
| 74 |
self.xy_pos = decoder.embed(self.x, self.y, self.bert_feature)
|
|
|
|
| 68 |
self.xy_dec_ = torch.rand((bsz, 1, decoder.embedding_dim)).to(dtype)
|
| 69 |
|
| 70 |
# EOS
|
| 71 |
+
self.completed = torch.Tensor([False] * len(self.x)).bool().to(device)
|
| 72 |
self.y_results: List[Tensor] = [None] * len(self.x) # type: ignore
|
| 73 |
|
| 74 |
self.xy_pos = decoder.embed(self.x, self.y, self.bert_feature)
|
AR/models/t2s_model_flash_attn.py
CHANGED
|
@@ -245,7 +245,6 @@ class CUDAGraphRunner:
|
|
| 245 |
**kwds,
|
| 246 |
)
|
| 247 |
|
| 248 |
-
torch_profiler.start()
|
| 249 |
with torch_profiler.record("AR"):
|
| 250 |
if session.graph:
|
| 251 |
session.xy_pos_.copy_(session.xy_pos)
|
|
@@ -275,22 +274,28 @@ class CUDAGraphRunner:
|
|
| 275 |
top_p=request.top_p,
|
| 276 |
repetition_penalty=request.repetition_penalty,
|
| 277 |
temperature=request.temperature,
|
| 278 |
-
use_cuda_graph=
|
| 279 |
idx=idx,
|
| 280 |
)
|
| 281 |
|
| 282 |
session.y = torch.cat([session.y, samples], dim=1)
|
| 283 |
|
| 284 |
with torch_profiler.record("EOS"):
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
if session.y.size(1) == 0:
|
| 295 |
session.y = torch.cat([session.y, torch.zeros_like(samples)], dim=1)
|
| 296 |
tqdm.write("Bad Zero Prediction")
|
|
@@ -306,7 +311,7 @@ class CUDAGraphRunner:
|
|
| 306 |
and (session.y.size(1) - session.y_len) > request.early_stop_num
|
| 307 |
):
|
| 308 |
for i in range(bsz):
|
| 309 |
-
if not session.completed[i]:
|
| 310 |
session.y_results[i] = session.y[i, session.y_len :]
|
| 311 |
session.completed[i] = True
|
| 312 |
break
|
|
@@ -316,10 +321,11 @@ class CUDAGraphRunner:
|
|
| 316 |
session.xy_pos = decoder.ar_audio_position.forward(session.input_pos - session.x_lens, y_emb)
|
| 317 |
|
| 318 |
if idx == 2:
|
|
|
|
| 319 |
t1 = time.perf_counter()
|
| 320 |
|
| 321 |
-
if idx == 51:
|
| 322 |
-
|
| 323 |
|
| 324 |
match session.device.type:
|
| 325 |
case "cuda":
|
|
@@ -331,7 +337,7 @@ class CUDAGraphRunner:
|
|
| 331 |
case "mtia":
|
| 332 |
torch.mtia.empty_cache()
|
| 333 |
gc.collect()
|
| 334 |
-
|
| 335 |
return session.y_results[: request.valid_length]
|
| 336 |
|
| 337 |
def generate(self, request: T2SRequest):
|
|
|
|
| 245 |
**kwds,
|
| 246 |
)
|
| 247 |
|
|
|
|
| 248 |
with torch_profiler.record("AR"):
|
| 249 |
if session.graph:
|
| 250 |
session.xy_pos_.copy_(session.xy_pos)
|
|
|
|
| 274 |
top_p=request.top_p,
|
| 275 |
repetition_penalty=request.repetition_penalty,
|
| 276 |
temperature=request.temperature,
|
| 277 |
+
use_cuda_graph=request.use_cuda_graph,
|
| 278 |
idx=idx,
|
| 279 |
)
|
| 280 |
|
| 281 |
session.y = torch.cat([session.y, samples], dim=1)
|
| 282 |
|
| 283 |
with torch_profiler.record("EOS"):
|
| 284 |
+
argmax_token = torch.argmax(logits, dim=-1)
|
| 285 |
+
sample_token = samples.squeeze(1)
|
| 286 |
+
EOS_mask = (argmax_token == decoder.EOS) | (sample_token == decoder.EOS)
|
| 287 |
+
with torch_profiler.record("EOS1"):
|
| 288 |
+
newly_done_mask = EOS_mask & (~session.completed)
|
| 289 |
+
with torch_profiler.record("EOS2"):
|
| 290 |
+
newly_done_indices = newly_done_mask.nonzero()
|
| 291 |
+
with torch_profiler.record("EOS3"):
|
| 292 |
+
if newly_done_indices.numel() > 0:
|
| 293 |
+
session.y_results[newly_done_indices[0]] = session.y[
|
| 294 |
+
newly_done_indices[0], session.y_len : -1
|
| 295 |
+
].squeeze(0)
|
| 296 |
+
session.completed[newly_done_indices] = True
|
| 297 |
+
with torch_profiler.record("EOS4"):
|
| 298 |
+
if torch.all(session.completed).item():
|
| 299 |
if session.y.size(1) == 0:
|
| 300 |
session.y = torch.cat([session.y, torch.zeros_like(samples)], dim=1)
|
| 301 |
tqdm.write("Bad Zero Prediction")
|
|
|
|
| 311 |
and (session.y.size(1) - session.y_len) > request.early_stop_num
|
| 312 |
):
|
| 313 |
for i in range(bsz):
|
| 314 |
+
if not session.completed[i].item():
|
| 315 |
session.y_results[i] = session.y[i, session.y_len :]
|
| 316 |
session.completed[i] = True
|
| 317 |
break
|
|
|
|
| 321 |
session.xy_pos = decoder.ar_audio_position.forward(session.input_pos - session.x_lens, y_emb)
|
| 322 |
|
| 323 |
if idx == 2:
|
| 324 |
+
torch_profiler.start()
|
| 325 |
t1 = time.perf_counter()
|
| 326 |
|
| 327 |
+
# if idx == 51:
|
| 328 |
+
# torch_profiler.end()
|
| 329 |
|
| 330 |
match session.device.type:
|
| 331 |
case "cuda":
|
|
|
|
| 337 |
case "mtia":
|
| 338 |
torch.mtia.empty_cache()
|
| 339 |
gc.collect()
|
| 340 |
+
torch_profiler.end()
|
| 341 |
return session.y_results[: request.valid_length]
|
| 342 |
|
| 343 |
def generate(self, request: T2SRequest):
|
inference_webui.py
CHANGED
|
@@ -836,4 +836,5 @@ if __name__ == "__main__":
|
|
| 836 |
server_name="0.0.0.0",
|
| 837 |
inbrowser=True,
|
| 838 |
show_api=False,
|
|
|
|
| 839 |
)
|
|
|
|
| 836 |
server_name="0.0.0.0",
|
| 837 |
inbrowser=True,
|
| 838 |
show_api=False,
|
| 839 |
+
server_port=1111,
|
| 840 |
)
|
pre-requirements.txt
CHANGED
|
@@ -1 +1,2 @@
|
|
| 1 |
-
torch==2.5.1
|
|
|
|
|
|
| 1 |
+
torch==2.5.1
|
| 2 |
+
torchaudio
|
requirements.txt
CHANGED
|
@@ -3,7 +3,6 @@ scipy>=1.11.3
|
|
| 3 |
tensorboard==2.15.1
|
| 4 |
librosa==0.9.2
|
| 5 |
numba==0.56.4
|
| 6 |
-
torchaudio
|
| 7 |
pytorch-lightning>=2.4
|
| 8 |
gradio==4.44.1
|
| 9 |
gradio_client==1.3.0
|
|
@@ -36,4 +35,4 @@ nltk==3.8.1
|
|
| 36 |
fast_langdetect==0.3.1
|
| 37 |
split_lang==2.1.0
|
| 38 |
ToJyutping==3.2.0
|
| 39 |
-
https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.
|
|
|
|
| 3 |
tensorboard==2.15.1
|
| 4 |
librosa==0.9.2
|
| 5 |
numba==0.56.4
|
|
|
|
| 6 |
pytorch-lightning>=2.4
|
| 7 |
gradio==4.44.1
|
| 8 |
gradio_client==1.3.0
|
|
|
|
| 35 |
fast_langdetect==0.3.1
|
| 36 |
split_lang==2.1.0
|
| 37 |
ToJyutping==3.2.0
|
| 38 |
+
https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|