Commit
·
745eaaf
1
Parent(s):
17d970d
download result video & log file in zip
Browse files- app.py +30 -7
- cosmos_transfer1/utils/log.py +15 -2
app.py
CHANGED
|
@@ -1,6 +1,9 @@
|
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
|
|
|
| 3 |
import time
|
|
|
|
| 4 |
from typing import List, Tuple
|
| 5 |
|
| 6 |
import gradio as gr
|
|
@@ -11,6 +14,8 @@ from gpu_info import watch_gpu_memory
|
|
| 11 |
PWD = os.path.dirname(__file__)
|
| 12 |
CHECKPOINTS_PATH = "/data/checkpoints"
|
| 13 |
# CHECKPOINTS_PATH = os.path.join(PWD, "checkpoints")
|
|
|
|
|
|
|
| 14 |
|
| 15 |
try:
|
| 16 |
import os
|
|
@@ -30,8 +35,8 @@ except Exception as e:
|
|
| 30 |
# download checkpoints
|
| 31 |
from download_checkpoints import main as download_checkpoints
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
|
| 36 |
|
| 37 |
from test_environment import main as check_environment
|
|
@@ -271,6 +276,18 @@ def inference(cfg, control_inputs, chunking) -> Tuple[List[str], List[str]]:
|
|
| 271 |
return video_paths, prompt_paths
|
| 272 |
|
| 273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
@spaces.GPU()
|
| 275 |
def generate_video(
|
| 276 |
rgb_video_path,
|
|
@@ -283,6 +300,10 @@ def generate_video(
|
|
| 283 |
chunking=False,
|
| 284 |
progress=gr.Progress(track_tqdm=True),
|
| 285 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
if randomize_seed:
|
| 287 |
actual_seed = random.randint(0, 1000000)
|
| 288 |
else:
|
|
@@ -290,8 +311,8 @@ def generate_video(
|
|
| 290 |
|
| 291 |
log.info(f"actual_seed: {actual_seed}")
|
| 292 |
|
| 293 |
-
if not os.path.isfile(rgb_video_path):
|
| 294 |
-
log.warning(f"File {rgb_video_path} does not exist")
|
| 295 |
rgb_video_path = ""
|
| 296 |
|
| 297 |
# add timer to calculate the generation time
|
|
@@ -315,7 +336,7 @@ def generate_video(
|
|
| 315 |
)
|
| 316 |
|
| 317 |
# watch gpu memory
|
| 318 |
-
watcher = watch_gpu_memory(10)
|
| 319 |
|
| 320 |
# start inference
|
| 321 |
videos, prompts = inference(args, control_inputs, chunking)
|
|
@@ -328,7 +349,9 @@ def generate_video(
|
|
| 328 |
watcher.cancel()
|
| 329 |
|
| 330 |
video = videos[0]
|
| 331 |
-
|
|
|
|
|
|
|
| 332 |
|
| 333 |
|
| 334 |
# Define the Gradio Blocks interface
|
|
@@ -369,7 +392,7 @@ with gr.Blocks() as demo:
|
|
| 369 |
|
| 370 |
with gr.Column():
|
| 371 |
output_video = gr.Video(label="Generated Video", format="mp4")
|
| 372 |
-
output_file = gr.File(label="Download
|
| 373 |
|
| 374 |
generate_button.click(
|
| 375 |
fn=generate_video,
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
import os
|
| 3 |
import sys
|
| 4 |
+
import tempfile
|
| 5 |
import time
|
| 6 |
+
import zipfile
|
| 7 |
from typing import List, Tuple
|
| 8 |
|
| 9 |
import gradio as gr
|
|
|
|
| 14 |
PWD = os.path.dirname(__file__)
|
| 15 |
CHECKPOINTS_PATH = "/data/checkpoints"
|
| 16 |
# CHECKPOINTS_PATH = os.path.join(PWD, "checkpoints")
|
| 17 |
+
LOG_DIR = os.path.join(PWD, "logs")
|
| 18 |
+
os.makedirs(LOG_DIR, exist_ok=True)
|
| 19 |
|
| 20 |
try:
|
| 21 |
import os
|
|
|
|
| 35 |
# download checkpoints
|
| 36 |
from download_checkpoints import main as download_checkpoints
|
| 37 |
|
| 38 |
+
os.makedirs(CHECKPOINTS_PATH, exist_ok=True)
|
| 39 |
+
download_checkpoints(hf_token="", output_dir=CHECKPOINTS_PATH, model="7b_av")
|
| 40 |
|
| 41 |
|
| 42 |
from test_environment import main as check_environment
|
|
|
|
| 276 |
return video_paths, prompt_paths
|
| 277 |
|
| 278 |
|
| 279 |
+
def create_zip_for_download(filename, files_to_zip):
|
| 280 |
+
temp_dir = tempfile.mkdtemp()
|
| 281 |
+
zip_path = os.path.join(temp_dir, f"{os.path.splitext(filename)[0]}.zip")
|
| 282 |
+
|
| 283 |
+
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
|
| 284 |
+
for file_path in files_to_zip:
|
| 285 |
+
arcname = os.path.basename(file_path)
|
| 286 |
+
zipf.write(file_path, arcname)
|
| 287 |
+
|
| 288 |
+
return zip_path
|
| 289 |
+
|
| 290 |
+
|
| 291 |
@spaces.GPU()
|
| 292 |
def generate_video(
|
| 293 |
rgb_video_path,
|
|
|
|
| 300 |
chunking=False,
|
| 301 |
progress=gr.Progress(track_tqdm=True),
|
| 302 |
):
|
| 303 |
+
_dt = datetime.datetime.now(tz=datetime.timezone(datetime.timedelta(hours=8))).strftime("%Y-%m-%d_%H.%M.%S")
|
| 304 |
+
logfile_path = os.path.join(LOG_DIR, f"{_dt}.log")
|
| 305 |
+
log_handler = log.init_dev_loguru_file(logfile_path)
|
| 306 |
+
|
| 307 |
if randomize_seed:
|
| 308 |
actual_seed = random.randint(0, 1000000)
|
| 309 |
else:
|
|
|
|
| 311 |
|
| 312 |
log.info(f"actual_seed: {actual_seed}")
|
| 313 |
|
| 314 |
+
if rgb_video_path is None or not os.path.isfile(rgb_video_path):
|
| 315 |
+
log.warning(f"File `{rgb_video_path}` does not exist")
|
| 316 |
rgb_video_path = ""
|
| 317 |
|
| 318 |
# add timer to calculate the generation time
|
|
|
|
| 336 |
)
|
| 337 |
|
| 338 |
# watch gpu memory
|
| 339 |
+
watcher = watch_gpu_memory(10, lambda x: log.debug(f"GPU memory usage: {x} (MiB)"))
|
| 340 |
|
| 341 |
# start inference
|
| 342 |
videos, prompts = inference(args, control_inputs, chunking)
|
|
|
|
| 349 |
watcher.cancel()
|
| 350 |
|
| 351 |
video = videos[0]
|
| 352 |
+
|
| 353 |
+
log.logger.remove(log_handler)
|
| 354 |
+
return video, create_zip_for_download(filename=logfile_path, files_to_zip=[video, logfile_path]), actual_seed
|
| 355 |
|
| 356 |
|
| 357 |
# Define the Gradio Blocks interface
|
|
|
|
| 392 |
|
| 393 |
with gr.Column():
|
| 394 |
output_video = gr.Video(label="Generated Video", format="mp4")
|
| 395 |
+
output_file = gr.File(label="Download Results")
|
| 396 |
|
| 397 |
generate_button.click(
|
| 398 |
fn=generate_video,
|
cosmos_transfer1/utils/log.py
CHANGED
|
@@ -76,10 +76,10 @@ def get_machine_format() -> str:
|
|
| 76 |
return machine_format
|
| 77 |
|
| 78 |
|
| 79 |
-
def init_loguru_file(path: str) ->
|
| 80 |
machine_format = get_machine_format()
|
| 81 |
message_format = get_message_format()
|
| 82 |
-
logger.add(
|
| 83 |
path,
|
| 84 |
encoding="utf8",
|
| 85 |
level=LEVEL,
|
|
@@ -89,6 +89,19 @@ def init_loguru_file(path: str) -> None:
|
|
| 89 |
enqueue=True,
|
| 90 |
)
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
def get_message_format() -> str:
|
| 94 |
message_format = "<level>{level}</level>|<cyan>{extra[relative_path]}:{line}:{function}</cyan>] {message}"
|
|
|
|
| 76 |
return machine_format
|
| 77 |
|
| 78 |
|
| 79 |
+
def init_loguru_file(path: str) -> int:
|
| 80 |
machine_format = get_machine_format()
|
| 81 |
message_format = get_message_format()
|
| 82 |
+
return logger.add(
|
| 83 |
path,
|
| 84 |
encoding="utf8",
|
| 85 |
level=LEVEL,
|
|
|
|
| 89 |
enqueue=True,
|
| 90 |
)
|
| 91 |
|
| 92 |
+
def init_dev_loguru_file(path: str) -> int:
|
| 93 |
+
machine_format = get_machine_format()
|
| 94 |
+
message_format = get_message_format()
|
| 95 |
+
return logger.add(
|
| 96 |
+
path,
|
| 97 |
+
encoding="utf8",
|
| 98 |
+
level="DEBUG",
|
| 99 |
+
format="[<green>{time:MM-DD HH:mm:ss}</green>|" f"{machine_format}" f"{message_format}",
|
| 100 |
+
rotation="100 MB",
|
| 101 |
+
filter=lambda result: _rank0_only_filter(result) or not RANK0_ONLY,
|
| 102 |
+
enqueue=True,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
|
| 106 |
def get_message_format() -> str:
|
| 107 |
message_format = "<level>{level}</level>|<cyan>{extra[relative_path]}:{line}:{function}</cyan>] {message}"
|