Eueuiaa commited on
Commit
42dccc7
·
verified ·
1 Parent(s): 8fd9bdd

Update api/ltx_server_refactored.py

Browse files
Files changed (1) hide show
  1. api/ltx_server_refactored.py +208 -6
api/ltx_server_refactored.py CHANGED
@@ -33,6 +33,7 @@ import shutil
33
  import contextlib
34
  import time
35
  import traceback
 
36
  from einops import rearrange
37
  import torch.nn.functional as F
38
  from managers.vae_manager import vae_manager_singleton
@@ -101,6 +102,9 @@ class VideoService:
101
  def __init__(self):
102
  t0 = time.perf_counter()
103
  print("[DEBUG] Inicializando VideoService...")
 
 
 
104
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
105
  self.config = self._load_config()
106
  self.pipeline, self.latent_upsampler = self._load_models()
@@ -139,7 +143,21 @@ class VideoService:
139
  self._log_gpu_memory("Após finalize")
140
  except Exception as e:
141
  print(f"[DEBUG] Log GPU pós-finalize falhou: {e}")
142
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  def _load_models(self):
144
  t0 = time.perf_counter()
145
  LTX_REPO = "Lightricks/LTX-Video"
@@ -245,7 +263,7 @@ class VideoService:
245
  conditioning_items.append(ConditioningItem(tensor, safe_frame, float(weight)))
246
  return conditioning_items
247
 
248
- def generate_low(self, prompt, negative_prompt, height, width, duration, guidance_scale, seed, conditioning_items=None):
249
  used_seed = random.randint(0, 2**32 - 1) if seed is None else int(seed)
250
  seed_everething(used_seed)
251
  FPS = 24.0
@@ -282,7 +300,191 @@ class VideoService:
282
  torch.cuda.empty_cache()
283
  torch.cuda.ipc_collect()
284
  self.finalize(keep_paths=[])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  def generate_upscale_denoise(self, latents_path, prompt, negative_prompt, guidance_scale, seed):
287
  used_seed = random.randint(0, 2**32 - 1) if seed is None else int(seed)
288
  seed_everething(used_seed)
@@ -330,8 +532,6 @@ class VideoService:
330
  video_path = self._save_and_log_video(pixel_tensor, "refined_video", 24.0, temp_dir, results_dir, used_seed)
331
  return video_path, tensor_path
332
 
333
-
334
-
335
  def encode_mp4(self, latents_path: str, fps: int = 24):
336
  latents = torch.load(latents_path)
337
  seed = random.randint(0, 99999)
@@ -362,6 +562,8 @@ class VideoService:
362
 
363
 
364
  # --- INSTANCIAÇÃO DO SERVIÇO ---
365
- print("Criando instância do VideoService. O carregamento do modelo começará agora...")
366
  video_generation_service = VideoService()
367
- print("Instância do VideoService pronta para uso.")
 
 
 
33
  import contextlib
34
  import time
35
  import traceback
36
+ from api.gpu_manager import gpu_manager
37
  from einops import rearrange
38
  import torch.nn.functional as F
39
  from managers.vae_manager import vae_manager_singleton
 
102
  def __init__(self):
103
  t0 = time.perf_counter()
104
  print("[DEBUG] Inicializando VideoService...")
105
+ self.device = gpu_manager.get_ltx_device()
106
+ print(f"[DEBUG] LTX foi alocado para o dispositivo: {self.device}")
107
+
108
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
109
  self.config = self._load_config()
110
  self.pipeline, self.latent_upsampler = self._load_models()
 
143
  self._log_gpu_memory("Após finalize")
144
  except Exception as e:
145
  print(f"[DEBUG] Log GPU pós-finalize falhou: {e}")
146
+
147
+ def move_to_device(self, device):
148
+ """Move os modelos do pipeline para o dispositivo especificado."""
149
+ print(f"[LTX] Movendo modelos para {device}...")
150
+ self.pipeline.to(device)
151
+ if self.latent_upsampler:
152
+ self.latent_upsampler.to(device)
153
+ self.device = device
154
+
155
+ def move_to_cpu(self):
156
+ """Move os modelos para a CPU para liberar VRAM."""
157
+ self.move_to_device(torch.device("cpu"))
158
+ if torch.cuda.is_available():
159
+ torch.cuda.empty_cache()
160
+
161
  def _load_models(self):
162
  t0 = time.perf_counter()
163
  LTX_REPO = "Lightricks/LTX-Video"
 
263
  conditioning_items.append(ConditioningItem(tensor, safe_frame, float(weight)))
264
  return conditioning_items
265
 
266
+ def generate_low_old(self, prompt, negative_prompt, height, width, duration, guidance_scale, seed, conditioning_items=None):
267
  used_seed = random.randint(0, 2**32 - 1) if seed is None else int(seed)
268
  seed_everething(used_seed)
269
  FPS = 24.0
 
300
  torch.cuda.empty_cache()
301
  torch.cuda.ipc_collect()
302
  self.finalize(keep_paths=[])
303
+
304
+ def _generate_single_chunk_low(self, prompt, negative_prompt, height, width, num_frames, guidance_scale, seed, initial_latent_condition=None, image_conditions=None, ltx_configs_override=None):
305
+ """
306
+ [NÓ DE GERAÇÃO]
307
+ Gera um ÚNICO chunk de latentes brutos. Esta é a unidade de trabalho fundamental.
308
+ """
309
+ # (Esta função auxiliar permanece a mesma da nossa última versão, com a lógica de override)
310
+ print("\n" + "-"*20 + " INÍCIO: _generate_single_chunk_low " + "-"*20)
311
+ height_padded = ((height - 1) // 8 + 1) * 8
312
+ width_padded = ((width - 1) // 8 + 1) * 8
313
+ generator = torch.Generator(device=self.device).manual_seed(seed)
314
+
315
+ downscale_factor = self.config.get("downscale_factor", 0.6666666)
316
+ vae_scale_factor = self.pipeline.vae_scale_factor
317
+
318
+ x_width = int(width_padded * downscale_factor)
319
+ downscaled_width = x_width - (x_width % vae_scale_factor)
320
+ x_height = int(height_padded * downscale_factor)
321
+ downscaled_height = x_height - (x_height % vae_scale_factor)
322
+
323
+ all_conditions = []
324
+ if image_conditions: all_conditions.extend(image_conditions)
325
+ if initial_latent_condition: all_conditions.append(initial_latent_condition)
326
+
327
+ first_pass_config = self.config.get("first_pass", {}).copy()
328
+
329
+ if ltx_configs_override:
330
+ print("[DEBUG] Sobrepondo configurações do LTX com valores da UI...")
331
+ if "first_pass_num_inference_steps" in ltx_configs_override:
332
+ first_pass_config["num_inference_steps"] = ltx_configs_override["first_pass_num_inference_steps"]
333
+ if "first_pass_guidance_scale" in ltx_configs_override:
334
+ max_val = max(first_pass_config.get("guidance_scale", [1]))
335
+ new_max_val = ltx_configs_override["first_pass_guidance_scale"]
336
+ first_pass_config["guidance_scale"] = [new_max_val if x==max_val else x for x in first_pass_config["guidance_scale"]]
337
+
338
+ first_pass_kwargs = {
339
+ "prompt": prompt, "negative_prompt": negative_prompt, "height": downscaled_height, "width": downscaled_width,
340
+ "num_frames": num_frames, "frame_rate": 24, "generator": generator, "output_type": "latent",
341
+ "conditioning_items": all_conditions if all_conditions else None,
342
+ **first_pass_config
343
+ }
344
+ # Removido guidance_scale daqui pois agora está dentro do first_pass_config
345
+ if "guidance_scale" in first_pass_kwargs:
346
+ del first_pass_kwargs['guidance_scale']
347
+
348
+ with torch.autocast(device_type="cuda", dtype=self.runtime_autocast_dtype, enabled=self.device.type == 'cuda'):
349
+ latents_bruto = self.pipeline(**first_pass_kwargs).images
350
+ log_tensor_info(latents_bruto, f"Latente Bruto Gerado para: '{prompt[:40]}...'")
351
+
352
+ print("-" * 20 + " FIM: _generate_single_chunk_low " + "-"*20)
353
+ return latents_bruto
354
+
355
+ def generate_narrative_low(self, prompt: str, negative_prompt, height, width, duration, guidance_scale, seed, initial_image_conditions=None, overlap_frames: int = 8, ltx_configs_override: dict = None):
356
+ """
357
+ [ORQUESTRADOR NARRATIVO]
358
+ Gera um vídeo em múltiplos chunks sequenciais a partir de um prompt com várias linhas.
359
+ """
360
+ print("\n" + "="*80)
361
+ print("====== INICIANDO GERAÇÃO NARRATIVA EM CHUNKS (LOW-RES) ======")
362
+ print("="*80)
363
+
364
+ used_seed = random.randint(0, 2**32 - 1) if seed is None else int(seed)
365
+ seed_everething(used_seed)
366
+ FPS = 24.0
367
+
368
+ prompt_list = [p.strip() for p in prompt.splitlines() if p.strip()]
369
+ num_chunks = len(prompt_list)
370
+ if num_chunks == 0: raise ValueError("O prompt está vazio ou não contém linhas válidas.")
371
+
372
+ total_actual_frames = max(9, int(round((round(duration * FPS) - 1) / 8.0) * 8 + 1))
373
+
374
+ if num_chunks > 1:
375
+ total_blocks = (total_actual_frames - 1) // 8
376
+ blocks_per_chunk = total_blocks // num_chunks
377
+ blocks_last_chunk = total_blocks - (blocks_per_chunk * (num_chunks - 1))
378
+ frames_per_chunk = blocks_per_chunk * 8 + 1
379
+ frames_per_chunk_last = blocks_last_chunk * 8 + 1
380
+ else:
381
+ frames_per_chunk = total_actual_frames
382
+ frames_per_chunk_last = total_actual_frames
383
+
384
+ frames_per_chunk = max(9, frames_per_chunk)
385
+ frames_per_chunk_last = max(9, frames_per_chunk_last)
386
+
387
+ poda_latents_num = overlap_frames // self.pipeline.video_scale_factor if self.pipeline.video_scale_factor > 0 else 0
388
+
389
+ latentes_chunk_video = []
390
+ condition_item_latent_overlap = None
391
+ temp_dir = tempfile.mkdtemp(prefix="ltxv_narrative_"); self._register_tmp_dir(temp_dir)
392
+ results_dir = "/app/output"; os.makedirs(results_dir, exist_ok=True)
393
+
394
+ for i, chunk_prompt in enumerate(prompt_list):
395
+ print(f"\n--- Gerando Chunk Narrativo {i+1}/{num_chunks}: '{chunk_prompt}' ---")
396
+
397
+ current_image_conditions = []
398
+ if initial_image_conditions:
399
+ cond_item_original = initial_image_conditions[0]
400
+ if i == 0:
401
+ current_image_conditions.append(cond_item_original)
402
+ else:
403
+ cond_item_fraco = ConditioningItem(
404
+ media_item=cond_item_original.media_item, media_frame_number=0, conditioning_strength=0.1
405
+ )
406
+ current_image_conditions.append(cond_item_fraco)
407
+
408
+ num_frames_para_gerar = frames_per_chunk_last if i == num_chunks - 1 else frames_per_chunk
409
+ if i > 0 and poda_latents_num > 0:
410
+ num_frames_para_gerar += overlap_frames
411
 
412
+ latentes_bruto = self._generate_single_chunk_low(
413
+ prompt=chunk_prompt, negative_prompt=negative_prompt, height=height, width=width,
414
+ num_frames=num_frames_para_gerar, guidance_scale=guidance_scale, seed=used_seed + i,
415
+ initial_latent_condition=condition_item_latent_overlap, image_conditions=current_image_conditions,
416
+ ltx_configs_override=ltx_configs_override
417
+ )
418
+
419
+ if i > 0 and poda_latents_num > 0:
420
+ latentes_bruto = latentes_bruto[:, :, poda_latents_num:, :, :]
421
+
422
+ latentes_podado = latentes_bruto.clone().detach()
423
+ if i < num_chunks - 1 and poda_latents_num > 0:
424
+ latentes_podado = latentes_bruto[:, :, :-poda_latents_num, :, :].clone()
425
+ overlap_latents = latentes_bruto[:, :, -poda_latents_num:, :, :].clone()
426
+ condition_item_latent_overlap = ConditioningItem(
427
+ media_item=overlap_latents, media_frame_number=0, conditioning_strength=1.0
428
+ )
429
+ latentes_chunk_video.append(latentes_podado)
430
+
431
+ print("\n--- Finalizando Narrativa: Concatenando chunks ---")
432
+ final_latents = torch.cat(latentes_chunk_video, dim=2)
433
+ log_tensor_info(final_latents, "Tensor de Latentes Final Concatenado")
434
+
435
+ tensor_path = os.path.join(results_dir, f"latents_narrative_{used_seed}.pt")
436
+ torch.save(final_latents.cpu(), tensor_path)
437
+
438
+ with torch.autocast(device_type="cuda", dtype=self.runtime_autocast_dtype, enabled=self.device.type == 'cuda'):
439
+ pixel_tensor = vae_manager_singleton.decode(final_latents, decode_timestep=float(self.config.get("decode_timestep", 0.05)))
440
+ video_path = self._save_and_log_video(pixel_tensor, "narrative_video", FPS, temp_dir, results_dir, used_seed)
441
+
442
+ self.finalize(keep_paths=[video_path, tensor_path])
443
+ return video_path, tensor_path, used_seed
444
+
445
+ def generate_single_low(self, prompt: str, negative_prompt, height, width, duration, guidance_scale, seed, initial_image_conditions=None, ltx_configs_override: dict = None):
446
+ """
447
+ [ORQUESTRADOR SIMPLES]
448
+ Gera um vídeo completo em um único chunk. Ideal para prompts simples e curtos.
449
+ """
450
+ print("\n" + "="*80)
451
+ print("====== INICIANDO GERAÇÃO SIMPLES EM CHUNK ÚNICO (LOW-RES) ======")
452
+ print("="*80)
453
+
454
+ used_seed = random.randint(0, 2**32 - 1) if seed is None else int(seed)
455
+ seed_everething(used_seed)
456
+ FPS = 24.0
457
+
458
+ total_actual_frames = max(9, int(round((round(duration * FPS) - 1) / 8.0) * 8 + 1))
459
+
460
+ temp_dir = tempfile.mkdtemp(prefix="ltxv_single_"); self._register_tmp_dir(temp_dir)
461
+ results_dir = "/app/output"; os.makedirs(results_dir, exist_ok=True)
462
+
463
+ # Chama a função de geração de chunk único para fazer todo o trabalho
464
+ final_latents = self._generate_single_chunk_low(
465
+ prompt=prompt,
466
+ negative_prompt=negative_prompt,
467
+ height=height, width=width,
468
+ num_frames=total_actual_frames,
469
+ guidance_scale=guidance_scale,
470
+ seed=used_seed,
471
+ image_conditions=initial_image_conditions,
472
+ ltx_configs_override=ltx_configs_override
473
+ )
474
+
475
+ print("\n--- Finalizando Geração Simples: Salvando e decodificando ---")
476
+ log_tensor_info(final_latents, "Tensor de Latentes Final")
477
+
478
+ tensor_path = os.path.join(results_dir, f"latents_single_{used_seed}.pt")
479
+ torch.save(final_latents.cpu(), tensor_path)
480
+
481
+ with torch.autocast(device_type="cuda", dtype=self.runtime_autocast_dtype, enabled=self.device.type == 'cuda'):
482
+ pixel_tensor = vae_manager_singleton.decode(final_latents, decode_timestep=float(self.config.get("decode_timestep", 0.05)))
483
+ video_path = self._save_and_log_video(pixel_tensor, "single_video", FPS, temp_dir, results_dir, used_seed)
484
+
485
+ self.finalize(keep_paths=[video_path, tensor_path])
486
+ return video_path, tensor_path, used_seed
487
+
488
  def generate_upscale_denoise(self, latents_path, prompt, negative_prompt, guidance_scale, seed):
489
  used_seed = random.randint(0, 2**32 - 1) if seed is None else int(seed)
490
  seed_everething(used_seed)
 
532
  video_path = self._save_and_log_video(pixel_tensor, "refined_video", 24.0, temp_dir, results_dir, used_seed)
533
  return video_path, tensor_path
534
 
 
 
535
  def encode_mp4(self, latents_path: str, fps: int = 24):
536
  latents = torch.load(latents_path)
537
  seed = random.randint(0, 99999)
 
562
 
563
 
564
  # --- INSTANCIAÇÃO DO SERVIÇO ---
565
+ print("Criando instância do VideoService...")
566
  video_generation_service = VideoService()
567
+ print("Instância do VideoService pronta.")
568
+ self.device = gpu_manager.get_ltx_device()
569
+ print(f"[DEBUG] LTX foi alocado para o dispositivo: {self.device}")