Test / api /seedvr_server.py
EuuIia's picture
Update api/seedvr_server.py
55327ec verified
raw
history blame
5.25 kB
import os
import shutil
import subprocess
import sys
import time
import mimetypes
from pathlib import Path
from typing import List, Optional, Tuple
from huggingface_hub import hf_hub_download
class SeedVRServer:
def __init__(self, **kwargs):
self.SEEDVR_ROOT = Path(os.getenv("SEEDVR_ROOT", "/data/SeedVR"))
# Apontamos para o nosso diretório de checkpoints customizado
self.CKPTS_ROOT = Path("/data/seedvr_models_fp16")
self.OUTPUT_ROOT = Path(os.getenv("OUTPUT_ROOT", "/app/outputs"))
self.INPUT_ROOT = Path(os.getenv("INPUT_ROOT", "/app/inputs"))
self.HF_HOME_CACHE = Path(os.getenv("HF_HOME", "/data/.cache/huggingface"))
self.REPO_URL = os.getenv("SEEDVR_GIT_URL", "https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler")
self.NUM_GPUS_TOTAL = int(os.getenv("NUM_GPUS", "4"))
print("🚀 SeedVRServer (FP16) inicializando e preparando o ambiente...")
for p in [self.SEEDVR_ROOT.parent, self.CKPTS_ROOT, self.OUTPUT_ROOT, self.INPUT_ROOT, self.HF_HOME_CACHE]:
p.mkdir(parents=True, exist_ok=True)
self.setup_dependencies()
print("✅ SeedVRServer (FP16) pronto.")
def setup_dependencies(self):
self._ensure_repo()
# O monkey patch agora é feito pelo start_seedvr.sh, não mais aqui.
self._ensure_model()
def _ensure_repo(self) -> None:
if not (self.SEEDVR_ROOT / ".git").exists():
print(f"[SeedVRServer] Clonando repositório para {self.SEEDVR_ROOT}...")
subprocess.run(["git", "clone", "--depth", "1", self.REPO_URL, str(self.SEEDVR_ROOT)], check=True)
else:
print("[SeedVRServer] Repositório SeedVR já existe.")
def _ensure_model(self) -> None:
"""Baixa os arquivos de modelo FP16 otimizados e suas dependências."""
print(f"[SeedVRServer] Verificando checkpoints (FP16) em {self.CKPTS_ROOT}...")
model_files = {
"seedvr2_ema_3b_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses", "ema_vae_fp16.safetensors": "MonsterMMORPG/SeedVR2_SECourses",
"pos_emb.pt": "ByteDance-Seed/SeedVR2-3B", "neg_emb.pt": "ByteDance-Seed/SeedVR2-3B"
}
for filename, repo_id in model_files.items():
if not (self.CKPTS_ROOT / filename).exists():
print(f"Baixando {filename} de {repo_id}...")
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=str(self.CKPTS_ROOT), cache_dir=str(self.HF_HOME_CACHE), token=os.getenv("HF_TOKEN"))
print("[SeedVRServer] Checkpoints (FP16) estão no local correto.")
def _prepare_job(self, input_file: str) -> Tuple[Path, Path]:
ts = f"{int(time.time())}_{os.urandom(4).hex()}"
job_input_dir = self.INPUT_ROOT / f"job_{ts}"
out_dir = self.OUTPUT_ROOT / f"run_{ts}"
job_input_dir.mkdir(parents=True, exist_ok=True)
out_dir.mkdir(parents=True, exist_ok=True)
shutil.copy2(input_file, job_input_dir / Path(input_file).name)
return job_input_dir, out_dir
def run_inference(self, filepath: str, *, seed: int, resh: int, resw: int, spsize: int, fps: Optional[float] = None):
script = self.SEEDVR_ROOT / "inference_cli.py"
job_input_dir, outdir = self._prepare_job(filepath)
mediatype, _ = mimetypes.guess_type(filepath)
is_image = mediatype and mediatype.startswith("image")
effective_nproc = 1 if is_image else self.NUM_GPUS_TOTAL
effective_spsize = 1 if is_image else spsize
output_filename = f"result_{Path(filepath).stem}.mp4" if not is_image else f"{Path(filepath).stem}_upscaled"
output_filepath = outdir / output_filename
cmd = [
"torchrun", "--standalone", "--nnodes=1",
f"--nproc-per-node={effective_nproc}",
str(script),
"--video_path", str(filepath),
"--output", str(output_filepath),
"--model_dir", str(self.CKPTS_ROOT),
"--seed", str(seed),
"--cuda_device", "0",
"--resolution", str(resh),
"--batch_size", str(effective_spsize),
"--model", "seedvr2_ema_3b_fp16.safetensors",
"--preserve_vram",
"--debug",
"--output_format", "video" if not is_image else "png",
]
print("SeedVRServer Comando:", " ".join(cmd))
try:
subprocess.run(cmd, cwd=str(self.SEEDVR_ROOT), check=True, env=os.environ.copy(), stdout=sys.stdout, stderr=sys.stderr)
# Constrói a tupla de retorno de forma determinística
if is_image:
# CLI salva PNGs em diretório args.output (tratado como diretório quando outputformat=png)
image_dir = output_filepath if output_filepath.suffix == "" else output_filepath.with_suffix("")
return str(image_dir), None, outdir
else:
# CLI salva vídeo exatamente em output_filepath
return None, str(output_filepath), outdir
except Exception as e:
print(f"[UI ERROR] A inferência falhou: {e}")
return None, None, None