EuuIia commited on
Commit
33de423
·
verified ·
1 Parent(s): c14605c

Update video_service.py

Browse files
Files changed (1) hide show
  1. video_service.py +80 -10
video_service.py CHANGED
@@ -16,9 +16,9 @@ import sys
16
  import subprocess
17
  import gc
18
  import shutil
 
19
 
20
  # --- 2. GERENCIAMENTO DE DEPENDÊNCIAS E SETUP ---
21
-
22
  def _query_gpu_processes_via_nvml(device_index: int) -> List[Dict]:
23
  try:
24
  import psutil
@@ -156,11 +156,16 @@ class VideoService:
156
  self._tmp_dirs = set()
157
  self._tmp_files = set()
158
  self._last_outputs = []
 
159
  self.pipeline, self.latent_upsampler = self._load_models()
160
  print(f"Movendo modelos para o dispositivo de inferência: {self.device}")
161
  self.pipeline.to(self.device)
162
  if self.latent_upsampler:
163
  self.latent_upsampler.to(self.device)
 
 
 
 
164
  if self.device == "cuda":
165
  torch.cuda.empty_cache()
166
  self._log_gpu_memory("Após carregar modelos")
@@ -212,6 +217,7 @@ class VideoService:
212
  keep = set(keep_paths or [])
213
  extras = set(extra_paths or [])
214
 
 
215
  for f in list(self._tmp_files | extras):
216
  try:
217
  if f not in keep and os.path.isfile(f):
@@ -221,6 +227,7 @@ class VideoService:
221
  finally:
222
  self._tmp_files.discard(f)
223
 
 
224
  for d in list(self._tmp_dirs):
225
  try:
226
  if d not in keep and os.path.isdir(d):
@@ -230,6 +237,7 @@ class VideoService:
230
  finally:
231
  self._tmp_dirs.discard(d)
232
 
 
233
  gc.collect()
234
  try:
235
  if clear_gpu and torch.cuda.is_available():
@@ -241,19 +249,33 @@ class VideoService:
241
  except Exception:
242
  pass
243
 
 
244
  try:
245
  self._log_gpu_memory("Após finalize")
246
  except Exception:
247
  pass
248
 
249
  def _load_config(self):
250
- config_file_path = LTX_VIDEO_REPO_DIR / "configs" / "ltxv-13b-0.9.8-distilled-fp8.yaml"
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  with open(config_file_path, "r") as file:
252
  return yaml.safe_load(file)
253
 
254
  def _load_models(self):
255
-
256
  LTX_REPO = "Lightricks/LTX-Video"
 
257
  distilled_model_path = hf_hub_download(
258
  repo_id=LTX_REPO,
259
  filename=self.config["checkpoint_path"],
@@ -289,9 +311,47 @@ class VideoService:
289
 
290
  return pipeline, latent_upsampler
291
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  def _prepare_conditioning_tensor(self, filepath, height, width, padding_values):
293
  tensor = load_image_to_tensor_with_resize_and_crop(filepath, height, width)
294
  tensor = torch.nn.functional.pad(tensor, padding_values)
 
 
295
  return tensor.to(self.device)
296
 
297
  def generate(
@@ -407,7 +467,12 @@ class VideoService:
407
  "second_pass": second_pass_args,
408
  }
409
  )
410
- result_tensor = multi_scale_pipeline(**multi_scale_call_kwargs).images
 
 
 
 
 
411
  log_tensor_info(result_tensor, "Resultado da Etapa 2 (Saída do Pipeline Multi-Scale)")
412
  else:
413
  single_pass_kwargs = call_kwargs.copy()
@@ -424,10 +489,14 @@ class VideoService:
424
  single_pass_kwargs["timesteps"] = [0.7]
425
  print("[INFO] Modo video-to-video (etapa única): definindo timesteps (força) para [0.7]")
426
  else:
427
- single_pass_kwargs["timesteps"] = first_pass_config.get("timesteps")
428
 
429
  print("\n[INFO] Executando pipeline de etapa única...")
430
- result_tensor = self.pipeline(**single_pass_kwargs).images
 
 
 
 
431
 
432
  pad_left, pad_right, pad_top, pad_bottom = padding_values
433
  slice_h_end = -pad_bottom if pad_bottom > 0 else None
@@ -437,17 +506,16 @@ class VideoService:
437
 
438
  video_np = (result_tensor[0].permute(1, 2, 3, 0).cpu().float().numpy() * 255).astype(np.uint8)
439
 
 
440
  temp_dir = tempfile.mkdtemp(prefix="ltxv_")
441
  self._register_tmp_dir(temp_dir)
442
- results_dir = "/app/output"
443
  os.makedirs(results_dir, exist_ok=True)
444
 
445
  final_output_path = None
446
  output_video_path = os.path.join(temp_dir, f"output_{used_seed}.mp4")
447
  try:
448
- with imageio.get_writer(
449
- output_video_path, fps=call_kwargs["frame_rate"], codec="libx264", quality=8
450
- ) as writer:
451
  total_frames = len(video_np)
452
  for i, frame in enumerate(video_np):
453
  writer.append_data(frame)
@@ -465,6 +533,7 @@ class VideoService:
465
  self._log_gpu_memory("Fim da Geração")
466
  return final_output_path, used_seed
467
  finally:
 
468
  try:
469
  del result_tensor
470
  except Exception:
@@ -489,6 +558,7 @@ class VideoService:
489
  except Exception:
490
  pass
491
 
 
492
  try:
493
  self.finalize(keep_paths=[final_output_path] if final_output_path else [])
494
  except Exception:
 
16
  import subprocess
17
  import gc
18
  import shutil
19
+ import contextlib
20
 
21
  # --- 2. GERENCIAMENTO DE DEPENDÊNCIAS E SETUP ---
 
22
  def _query_gpu_processes_via_nvml(device_index: int) -> List[Dict]:
23
  try:
24
  import psutil
 
156
  self._tmp_dirs = set()
157
  self._tmp_files = set()
158
  self._last_outputs = []
159
+
160
  self.pipeline, self.latent_upsampler = self._load_models()
161
  print(f"Movendo modelos para o dispositivo de inferência: {self.device}")
162
  self.pipeline.to(self.device)
163
  if self.latent_upsampler:
164
  self.latent_upsampler.to(self.device)
165
+
166
+ # Política de precisão (inclui promoção FP8->BF16 e dtype de autocast)
167
+ self._apply_precision_policy()
168
+
169
  if self.device == "cuda":
170
  torch.cuda.empty_cache()
171
  self._log_gpu_memory("Após carregar modelos")
 
217
  keep = set(keep_paths or [])
218
  extras = set(extra_paths or [])
219
 
220
+ # Remoção de arquivos
221
  for f in list(self._tmp_files | extras):
222
  try:
223
  if f not in keep and os.path.isfile(f):
 
227
  finally:
228
  self._tmp_files.discard(f)
229
 
230
+ # Remoção de diretórios
231
  for d in list(self._tmp_dirs):
232
  try:
233
  if d not in keep and os.path.isdir(d):
 
237
  finally:
238
  self._tmp_dirs.discard(d)
239
 
240
+ # Coleta de GC e limpeza de VRAM
241
  gc.collect()
242
  try:
243
  if clear_gpu and torch.cuda.is_available():
 
249
  except Exception:
250
  pass
251
 
252
+ # Log opcional pós-limpeza
253
  try:
254
  self._log_gpu_memory("Após finalize")
255
  except Exception:
256
  pass
257
 
258
  def _load_config(self):
259
+ # Prioriza configs FP8 se presentes, mantendo compatibilidade
260
+ base = LTX_VIDEO_REPO_DIR / "configs"
261
+ candidates = [
262
+ base / "ltxv-13b-0.9.8-dev-fp8.yaml",
263
+ base / "ltxv-13b-0.9.8-distilled-fp8.yaml",
264
+ base / "ltxv-13b-0.9.8-dev-fp8.yaml.txt",
265
+ base / "ltxv-13b-0.9.8-distilled.yaml", # fallback não-FP8
266
+ ]
267
+ for cfg in candidates:
268
+ if cfg.exists():
269
+ with open(cfg, "r") as file:
270
+ return yaml.safe_load(file)
271
+ # Fallback rígido para caminho clássico se nada acima existir
272
+ config_file_path = base / "ltxv-13b-0.9.8-distilled.yaml"
273
  with open(config_file_path, "r") as file:
274
  return yaml.safe_load(file)
275
 
276
  def _load_models(self):
 
277
  LTX_REPO = "Lightricks/LTX-Video"
278
+
279
  distilled_model_path = hf_hub_download(
280
  repo_id=LTX_REPO,
281
  filename=self.config["checkpoint_path"],
 
311
 
312
  return pipeline, latent_upsampler
313
 
314
+ # Precisão: promove FP8->BF16 e define dtype de autocast
315
+ def _promote_fp8_weights_to_bf16(self, module):
316
+ f8 = getattr(torch, "float8_e4m3fn", None)
317
+ if f8 is None:
318
+ return
319
+ for _, p in module.named_parameters(recurse=True):
320
+ try:
321
+ if p.dtype == f8:
322
+ with torch.no_grad():
323
+ p.data = p.data.to(torch.bfloat16)
324
+ except Exception:
325
+ pass
326
+ for _, b in module.named_buffers(recurse=True):
327
+ try:
328
+ if hasattr(b, "dtype") and b.dtype == f8:
329
+ b.data = b.data.to(torch.bfloat16)
330
+ except Exception:
331
+ pass
332
+
333
+ def _apply_precision_policy(self):
334
+ prec = str(self.config.get("precision", "")).lower()
335
+ self.runtime_autocast_dtype = torch.float32
336
+ if prec == "float8_e4m3fn":
337
+ # FP8 experimental: promove pesos para BF16 e padroniza autocast em BF16
338
+ if hasattr(torch, "float8_e4m3fn"):
339
+ self._promote_fp8_weights_to_bf16(self.pipeline)
340
+ if self.latent_upsampler:
341
+ self._promote_fp8_weights_to_bf16(self.latent_upsampler)
342
+ self.runtime_autocast_dtype = torch.bfloat16
343
+ elif prec == "bfloat16":
344
+ self.runtime_autocast_dtype = torch.bfloat16
345
+ elif prec == "mixed_precision":
346
+ self.runtime_autocast_dtype = torch.float16
347
+ else:
348
+ self.runtime_autocast_dtype = torch.float32
349
+
350
  def _prepare_conditioning_tensor(self, filepath, height, width, padding_values):
351
  tensor = load_image_to_tensor_with_resize_and_crop(filepath, height, width)
352
  tensor = torch.nn.functional.pad(tensor, padding_values)
353
+ if self.device == "cuda":
354
+ return tensor.to(self.device, dtype=self.runtime_autocast_dtype)
355
  return tensor.to(self.device)
356
 
357
  def generate(
 
467
  "second_pass": second_pass_args,
468
  }
469
  )
470
+
471
+ ctx = contextlib.nullcontext()
472
+ if self.device == "cuda":
473
+ ctx = torch.autocast(device_type="cuda", dtype=self.runtime_autocast_dtype)
474
+ with ctx:
475
+ result_tensor = multi_scale_pipeline(**multi_scale_call_kwargs).images
476
  log_tensor_info(result_tensor, "Resultado da Etapa 2 (Saída do Pipeline Multi-Scale)")
477
  else:
478
  single_pass_kwargs = call_kwargs.copy()
 
489
  single_pass_kwargs["timesteps"] = [0.7]
490
  print("[INFO] Modo video-to-video (etapa única): definindo timesteps (força) para [0.7]")
491
  else:
492
+ single_pass_kwargs["timesteps"] = first_pass_config.get("guidance_timesteps") or first_pass_config.get("timesteps")
493
 
494
  print("\n[INFO] Executando pipeline de etapa única...")
495
+ ctx = contextlib.nullcontext()
496
+ if self.device == "cuda":
497
+ ctx = torch.autocast(device_type="cuda", dtype=self.runtime_autocast_dtype)
498
+ with ctx:
499
+ result_tensor = self.pipeline(**single_pass_kwargs).images
500
 
501
  pad_left, pad_right, pad_top, pad_bottom = padding_values
502
  slice_h_end = -pad_bottom if pad_bottom > 0 else None
 
506
 
507
  video_np = (result_tensor[0].permute(1, 2, 3, 0).cpu().float().numpy() * 255).astype(np.uint8)
508
 
509
+ # Staging seguro em tmp e move para diretório persistente
510
  temp_dir = tempfile.mkdtemp(prefix="ltxv_")
511
  self._register_tmp_dir(temp_dir)
512
+ results_dir = "/data/results"
513
  os.makedirs(results_dir, exist_ok=True)
514
 
515
  final_output_path = None
516
  output_video_path = os.path.join(temp_dir, f"output_{used_seed}.mp4")
517
  try:
518
+ with imageio.get_writer(output_video_path, fps=call_kwargs["frame_rate"], codec="libx264", quality=8) as writer:
 
 
519
  total_frames = len(video_np)
520
  for i, frame in enumerate(video_np):
521
  writer.append_data(frame)
 
533
  self._log_gpu_memory("Fim da Geração")
534
  return final_output_path, used_seed
535
  finally:
536
+ # Libera tensores/objetos grandes antes de limpar VRAM
537
  try:
538
  del result_tensor
539
  except Exception:
 
558
  except Exception:
559
  pass
560
 
561
+ # Limpeza de temporários preservando o vídeo final
562
  try:
563
  self.finalize(keep_paths=[final_output_path] if final_output_path else [])
564
  except Exception: