Eueuiaa commited on
Commit
007c224
·
verified ·
1 Parent(s): 8f0c470

Update api/ltx_server_refactored_complete.py

Browse files
Files changed (1) hide show
  1. api/ltx_server_refactored_complete.py +128 -61
api/ltx_server_refactored_complete.py CHANGED
@@ -126,104 +126,171 @@ class VideoService:
126
  """
127
 
128
  def __init__(self):
129
- """Initializes the service, loads models, and configures the environment."""
130
  t0 = time.perf_counter()
131
- logging.info("Initializing VideoService...")
132
- RESULTS_DIR.mkdir(parents=True, exist_ok=True)
133
-
134
- self.config = self._load_config(DEFAULT_CONFIG_FILE)
135
- self._tmp_dirs = set()
136
 
137
- self.pipeline, self.latent_upsampler = self._load_models_on_cpu()
138
-
139
  target_device = gpu_manager.get_ltx_device()
140
- self.device = torch.device("cpu") # Default device
141
- self.move_to_device(target_device)
142
 
 
 
 
 
 
 
 
 
143
  self._apply_precision_policy()
144
  vae_manager_singleton.attach_pipeline(
145
  self.pipeline,
146
- device=self.device,
147
  autocast_dtype=self.runtime_autocast_dtype
148
  )
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
- logging.info(f"VideoService ready. Startup time: {time.perf_counter()-t0:.2f}s")
 
 
 
 
151
 
 
152
  # ==========================================================================
153
  # --- LIFECYCLE & MODEL MANAGEMENT ---
154
  # ==========================================================================
155
 
156
- def _load_config(self, config_path: Path) -> Dict:
157
- """Loads the YAML configuration file."""
158
- logging.info(f"Loading config from: {config_path}")
159
  with open(config_path, "r") as file:
160
  return yaml.safe_load(file)
161
 
162
- def _load_models_on_cpu(self) -> Tuple[LTXMultiScalePipeline, Optional[torch.nn.Module]]:
163
- """Downloads and loads the pipeline and upsampler checkpoints onto the CPU."""
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  t0 = time.perf_counter()
165
-
166
- logging.info("Downloading main checkpoint...")
167
  distilled_model_path = hf_hub_download(
168
- repo_id=LTX_REPO_ID,
169
  filename=self.config["checkpoint_path"],
 
 
170
  token=os.getenv("HF_TOKEN"),
171
  )
172
  self.config["checkpoint_path"] = distilled_model_path
 
 
 
 
 
 
 
 
 
 
 
 
173
 
 
174
  pipeline = create_ltx_video_pipeline(
175
  ckpt_path=self.config["checkpoint_path"],
176
  precision=self.config["precision"],
177
- device="cpu", # Load on CPU first
178
- # Pass other config values directly
179
- **{k: v for k, v in self.config.items() if k in create_ltx_video_pipeline.__code__.co_varnames}
 
 
 
180
  )
181
-
 
182
  latent_upsampler = None
183
  if self.config.get("spatial_upscaler_model_path"):
184
- logging.info("Downloading spatial upscaler checkpoint...")
185
- spatial_upscaler_path = hf_hub_download(
186
- repo_id=LTX_REPO_ID,
187
- filename=self.config["spatial_upscaler_model_path"],
188
- token=os.getenv("HF_TOKEN")
189
- )
190
- self.config["spatial_upscaler_model_path"] = spatial_upscaler_path
191
  latent_upsampler = create_latent_upsampler(self.config["spatial_upscaler_model_path"], device="cpu")
192
-
193
- logging.info(f"Models loaded on CPU in {time.perf_counter()-t0:.2f}s")
194
  return pipeline, latent_upsampler
195
 
196
- def move_to_device(self, device_str: str):
197
- """Moves all relevant models to the specified device (e.g., 'cuda:0' or 'cpu')."""
198
- target_device = torch.device(device_str)
199
- if self.device == target_device:
200
- logging.info(f"Models are already on the target device: {device_str}")
201
- return
202
-
203
- logging.info(f"Moving models to {device_str}...")
204
- self.device = target_device
205
- self.pipeline.to(self.device)
206
- if self.latent_upsampler:
207
- self.latent_upsampler.to(self.device)
208
-
209
- if device_str == "cpu" and torch.cuda.is_available():
210
- torch.cuda.empty_cache()
211
-
212
- logging.info(f"Models successfully moved to {self.device}.")
213
 
214
- def finalize(self, keep_paths: Optional[List[str]] = None):
215
- """Cleans up GPU memory and temporary directories."""
216
- logging.debug("Finalizing resources...")
217
- gc.collect()
218
- if torch.cuda.is_available():
 
 
 
 
 
 
 
 
 
 
219
  torch.cuda.empty_cache()
220
- try:
221
- torch.cuda.ipc_collect()
222
- except Exception:
223
- pass
224
-
225
- # Optional: Clean up temporary directories if needed (logic can be added here)
 
 
226
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  # ==========================================================================
229
  # --- PUBLIC ORCHESTRATORS ---
 
126
  """
127
 
128
  def __init__(self):
 
129
  t0 = time.perf_counter()
130
+ print("[DEBUG] Inicializando VideoService...")
 
 
 
 
131
 
132
+ # 1. Obter o dispositivo alvo a partir do gerenciador
133
+ # Não definimos `self.device` ainda, apenas guardamos o alvo.
134
  target_device = gpu_manager.get_ltx_device()
135
+ print(f"[DEBUG] LTX foi alocado para o dispositivo: {target_device}")
 
136
 
137
+ # 2. Carregar a configuração e os modelos (na CPU, como a função _load_models faz)
138
+ self.config = self._load_config()
139
+ self.pipeline, self.latent_upsampler = self._load_models()
140
+
141
+ # 3. Mover os modelos para o dispositivo alvo e definir `self.device`
142
+ self.move_to_device(target_device) # Usando a função que já criamos!
143
+
144
+ # 4. Configurar o resto dos componentes com o dispositivo correto
145
  self._apply_precision_policy()
146
  vae_manager_singleton.attach_pipeline(
147
  self.pipeline,
148
+ device=self.device, # Agora `self.device` está correto
149
  autocast_dtype=self.runtime_autocast_dtype
150
  )
151
+ self._tmp_dirs = set()
152
+ print(f"[DEBUG] VideoService pronto. boot_time={time.perf_counter()-t0:.3f}s")
153
+
154
+ # A função move_to_device que criamos antes é essencial aqui
155
+ def move_to_device(self, device):
156
+ """Move os modelos do pipeline para o dispositivo especificado."""
157
+ print(f"[LTX] Movendo modelos para {device}...")
158
+ self.device = torch.device(device) # Garante que é um objeto torch.device
159
+ self.pipeline.to(self.device)
160
+ if self.latent_upsampler:
161
+ self.latent_upsampler.to(self.device)
162
+ print(f"[LTX] Modelos agora estão em {self.device}.")
163
 
164
+ def move_to_cpu(self):
165
+ """Move os modelos para a CPU para liberar VRAM."""
166
+ self.move_to_device(torch.device("cpu"))
167
+ if torch.cuda.is_available():
168
+ torch.cuda.empty_cache()
169
 
170
+
171
  # ==========================================================================
172
  # --- LIFECYCLE & MODEL MANAGEMENT ---
173
  # ==========================================================================
174
 
175
+ def _load_config(self):
176
+ base = LTX_VIDEO_REPO_DIR / "configs"
177
+ config_path = base / "ltxv-13b-0.9.8-distilled-fp8.yaml"
178
  with open(config_path, "r") as file:
179
  return yaml.safe_load(file)
180
 
181
+ def finalize(self, keep_paths=None, extra_paths=None, clear_gpu=True):
182
+ print("[DEBUG] Finalize: iniciando limpeza...")
183
+ keep = set(keep_paths or []); extras = set(extra_paths or [])
184
+ gc.collect()
185
+ try:
186
+ if clear_gpu and torch.cuda.is_available():
187
+ torch.cuda.empty_cache()
188
+ try:
189
+ torch.cuda.ipc_collect()
190
+ except Exception:
191
+ pass
192
+ except Exception as e:
193
+ print(f"[DEBUG] Finalize: limpeza GPU falhou: {e}")
194
+
195
+ def _load_models(self):
196
  t0 = time.perf_counter()
197
+ LTX_REPO = "Lightricks/LTX-Video"
198
+ print("[DEBUG] Baixando checkpoint principal...")
199
  distilled_model_path = hf_hub_download(
200
+ repo_id=LTX_REPO,
201
  filename=self.config["checkpoint_path"],
202
+ local_dir=os.getenv("HF_HOME"),
203
+ cache_dir=os.getenv("HF_HOME_CACHE"),
204
  token=os.getenv("HF_TOKEN"),
205
  )
206
  self.config["checkpoint_path"] = distilled_model_path
207
+ print(f"[DEBUG] Checkpoint em: {distilled_model_path}")
208
+
209
+ print("[DEBUG] Baixando upscaler espacial...")
210
+ spatial_upscaler_path = hf_hub_download(
211
+ repo_id=LTX_REPO,
212
+ filename=self.config["spatial_upscaler_model_path"],
213
+ local_dir=os.getenv("HF_HOME"),
214
+ cache_dir=os.getenv("HF_HOME_CACHE"),
215
+ token=os.getenv("HF_TOKEN")
216
+ )
217
+ self.config["spatial_upscaler_model_path"] = spatial_upscaler_path
218
+ print(f"[DEBUG] Upscaler em: {spatial_upscaler_path}")
219
 
220
+ print("[DEBUG] Construindo pipeline...")
221
  pipeline = create_ltx_video_pipeline(
222
  ckpt_path=self.config["checkpoint_path"],
223
  precision=self.config["precision"],
224
+ text_encoder_model_name_or_path=self.config["text_encoder_model_name_or_path"],
225
+ sampler=self.config["sampler"],
226
+ device="cpu",
227
+ enhance_prompt=False,
228
+ prompt_enhancer_image_caption_model_name_or_path=self.config["prompt_enhancer_image_caption_model_name_or_path"],
229
+ prompt_enhancer_llm_model_name_or_path=self.config["prompt_enhancer_llm_model_name_or_path"],
230
  )
231
+ print("[DEBUG] Pipeline pronto.")
232
+
233
  latent_upsampler = None
234
  if self.config.get("spatial_upscaler_model_path"):
235
+ print("[DEBUG] Construindo latent_upsampler...")
 
 
 
 
 
 
236
  latent_upsampler = create_latent_upsampler(self.config["spatial_upscaler_model_path"], device="cpu")
237
+ print("[DEBUG] Upsampler pronto.")
238
+ print(f"[DEBUG] _load_models() tempo total={time.perf_counter()-t0:.3f}s")
239
  return pipeline, latent_upsampler
240
 
241
+ def _apply_precision_policy(self):
242
+ prec = str(self.config.get("precision", "")).lower()
243
+ self.runtime_autocast_dtype = torch.float32
244
+ if prec in ["float8_e4m3fn", "bfloat16"]:
245
+ self.runtime_autocast_dtype = torch.bfloat16
246
+ elif prec == "mixed_precision":
247
+ self.runtime_autocast_dtype = torch.float16
 
 
 
 
 
 
 
 
 
 
248
 
249
+ def _register_tmp_dir(self, d: str):
250
+ if d and os.path.isdir(d):
251
+ self._tmp_dirs.add(d); print(f"[DEBUG] Registrado tmp dir: {d}")
252
+
253
+ @torch.no_grad()
254
+ def _upsample_latents_internal(self, latents: torch.Tensor) -> torch.Tensor:
255
+ try:
256
+ if not self.latent_upsampler:
257
+ raise ValueError("Latent Upsampler não está carregado.")
258
+ latents_unnormalized = un_normalize_latents(latents, self.pipeline.vae, vae_per_channel_normalize=True)
259
+ upsampled_latents = self.latent_upsampler(latents_unnormalized)
260
+ return normalize_latents(upsampled_latents, self.pipeline.vae, vae_per_channel_normalize=True)
261
+ except Exception as e:
262
+ pass
263
+ finally:
264
  torch.cuda.empty_cache()
265
+ torch.cuda.ipc_collect()
266
+ self.finalize(keep_paths=[])
267
+
268
+ def _prepare_conditioning_tensor(self, filepath, height, width, padding_values):
269
+ tensor = load_image_to_tensor_with_resize_and_crop(filepath, height, width)
270
+ tensor = torch.nn.functional.pad(tensor, padding_values)
271
+ log_tensor_info(tensor, f"_prepare_conditioning_tensor")
272
+ return tensor.to(self.device, dtype=self.runtime_autocast_dtype)
273
 
274
+
275
+ def _save_and_log_video(self, pixel_tensor, base_filename, fps, temp_dir, results_dir, used_seed, progress_callback=None):
276
+ output_path = os.path.join(temp_dir, f"{base_filename}_.mp4")
277
+ video_encode_tool_singleton.save_video_from_tensor(
278
+ pixel_tensor, output_path, fps=fps, progress_callback=progress_callback
279
+ )
280
+ final_path = os.path.join(results_dir, f"{base_filename}_.mp4")
281
+ shutil.move(output_path, final_path)
282
+ print(f"[DEBUG] Vídeo salvo em: {final_path}")
283
+ return final_path
284
+
285
+ def _load_tensor(self, caminho):
286
+ # Se já é um tensor, retorna diretamente
287
+ if isinstance(caminho, torch.Tensor):
288
+ return caminho
289
+ # Se é bytes, carrega do buffer
290
+ if isinstance(caminho, (bytes, bytearray)):
291
+ return torch.load(io.BytesIO(caminho))
292
+ # Caso contrário, assume que é um caminho de arquivo
293
+ return torch.load(caminho
294
 
295
  # ==========================================================================
296
  # --- PUBLIC ORCHESTRATORS ---