File size: 3,823 Bytes
ed88963
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#!/usr/bin/env python3
"""
SeedVR UI (Gradio) — CLI torchrun

- Upload único: vídeo (.mp4) ou imagem (.png/.jpg/.jpeg/.webp).
- Parâmetros: seed, res_h, res_w, sp_size.
- Executa via torchrun com NUM_GPUS (do ambiente).
- Exibe vídeo se a entrada for vídeo; imagem se for imagem.
"""

import os
import mimetypes
from pathlib import Path
from typing import Optional

import gradio as gr

from services.seed_server import SeedVRServer

# Instância única do servidor (clona repo, baixa modelo, cria symlink)
server = SeedVRServer()

# Paths padrão (para allowed_paths e debug)
OUTPUT_ROOT = Path(os.getenv("OUTPUT_ROOT", "/app/outputs"))
CKPTS_ROOT = Path(os.getenv("CKPTS_ROOT", "/app/ckpts/SeedVR2-3B"))

def _is_video(path: str) -> bool:
    mime, _ = mimetypes.guess_type(path)
    return (mime or "").startswith("video") or str(path).lower().endswith(".mp4")

def _is_image(path: str) -> bool:
    mime, _ = mimetypes.guess_type(path)
    if mime and mime.startswith("image"):
        return True
    return str(path).lower().endswith((".png", ".jpg", ".jpeg", ".webp"))

def ui_infer(
    input_path: Optional[str],
    seed: int,
    res_h: int,
    res_w: int,
    sp_size: int,
):
    if not input_path or not Path(input_path).exists():
        gr.Warning("Arquivo de entrada ausente ou inválido.")
        return None, None

    is_vid = _is_video(input_path)
    is_img = _is_image(input_path)
    if not (is_vid or is_img):
        gr.Warning("Tipo de arquivo não suportado. Envie .mp4, .png, .jpg, .jpeg ou .webp.")
        return None, None

    try:
        video_out, image_out, _ = server.run_inference(
            file_path=input_path,
            seed=int(seed),
            res_h=int(res_h),
            res_w=int(res_w),
            sp_size=int(sp_size),
        )
    except Exception as e:
        gr.Warning(f"Erro na inferência: {e}")
        return None, None

    if is_vid:
        if video_out and Path(video_out).exists():
            return None, video_out
        if image_out and Path(image_out).exists():
            return image_out, None
        gr.Warning("Nenhum resultado encontrado.")
        return None, None
    else:
        if image_out and Path(image_out).exists():
            return image_out, None
        if video_out and Path(video_out).exists():
            return None, video_out
        gr.Warning("Nenhum resultado encontrado.")
        return None, None

with gr.Blocks(title="SeedVR (CLI torchrun)") as demo:
    gr.Markdown(
        "\n".join([
            "# SeedVR — Restauração (CLI torchrun)",
            "- Envie um vídeo (.mp4) ou uma imagem (.png/.jpg/.jpeg/.webp).",
            "- A execução utiliza torchrun com múltiplas GPUs.",
        ])
    )

    with gr.Row():
        inp = gr.File(label="Entrada (vídeo .mp4 ou imagem)", type="filepath")

    with gr.Row():
        seed = gr.Number(label="Seed", value=int(os.getenv("SEED", "42")), precision=0)
        res_h = gr.Number(label="Altura (res_h)", value=int(os.getenv("RES_H", "720")), precision=0)
        res_w = gr.Number(label="Largura (res_w)", value=int(os.getenv("RES_W", "1280")), precision=0)
        sp_size = gr.Number(label="sp_size", value=int(os.getenv("SP_SIZE", "4")), precision=0)

    run = gr.Button("Restaurar", variant="primary")

    out_image = gr.Image(label="Resultado (imagem)")
    out_video = gr.Video(label="Resultado (vídeo)")

    run.click(
        ui_infer,
        inputs=[inp, seed, res_h, res_w, sp_size],
        outputs=[out_image, out_video],
    )

if __name__ == "__main__":
    demo.launch(
        server_name=os.getenv("GRADIO_SERVER_NAME", "0.0.0.0"),
        server_port=int(os.getenv("GRADIO_SERVER_PORT", os.getenv("PORT", "7860"))),
        allowed_paths=[str(OUTPUT_ROOT), str(CKPTS_ROOT)],
        show_error=True,
    )