Spaces:
Runtime error
Runtime error
| import platform | |
| import subprocess | |
| #import sys | |
| #print("python = ", sys.version) | |
| # can be "Linux", "Darwin" | |
| if platform.system() == "Linux": | |
| # for some reason it says "pip not found" | |
| # and also "pip3 not found" | |
| # subprocess.run( | |
| # "pip install flash-attn --no-build-isolation", | |
| # | |
| # # hmm... this should be False, since we are in a CUDA environment, no? | |
| # env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, | |
| # | |
| # shell=True, | |
| # ) | |
| pass | |
| import gradio as gr | |
| from pathlib import Path | |
| import logging | |
| import mimetypes | |
| import shutil | |
| import os | |
| import traceback | |
| import asyncio | |
| import tempfile | |
| import zipfile | |
| from typing import Any, Optional, Dict, List, Union, Tuple | |
| from typing import AsyncGenerator | |
| from training_service import TrainingService | |
| from captioning_service import CaptioningService | |
| from splitting_service import SplittingService | |
| from import_service import ImportService | |
| from config import ( | |
| STORAGE_PATH, VIDEOS_TO_SPLIT_PATH, STAGING_PATH, | |
| TRAINING_PATH, LOG_FILE_PATH, TRAINING_VIDEOS_PATH, MODEL_PATH, OUTPUT_PATH, DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, | |
| DEFAULT_PROMPT_PREFIX, HF_API_TOKEN, ASK_USER_TO_DUPLICATE_SPACE, MODEL_TYPES, TRAINING_BUCKETS | |
| ) | |
| from utils import make_archive, count_media_files, format_media_title, is_image_file, is_video_file, validate_model_repo, format_time | |
| from finetrainers_utils import copy_files_to_training_dir, prepare_finetrainers_dataset | |
| from training_log_parser import TrainingLogParser | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| httpx_logger = logging.getLogger('httpx') | |
| httpx_logger.setLevel(logging.WARN) | |
| class VideoTrainerUI: | |
| def __init__(self): | |
| self.trainer = TrainingService() | |
| self.splitter = SplittingService() | |
| self.importer = ImportService() | |
| self.captioner = CaptioningService() | |
| self._should_stop_captioning = False | |
| self.log_parser = TrainingLogParser() | |
| def update_training_ui(self, training_state: Dict[str, Any]): | |
| """Update UI components based on training state""" | |
| updates = {} | |
| print("update_training_ui: training_state = ", training_state) | |
| # Update status box with high-level information | |
| status_text = [] | |
| if training_state["status"] != "idle": | |
| status_text.extend([ | |
| f"Status: {training_state['status']}", | |
| f"Progress: {training_state['progress']}", | |
| f"Step: {training_state['current_step']}/{training_state['total_steps']}", | |
| # Epoch information | |
| # there is an issue with how epoch is reported because we display: | |
| # Progress: 96.9%, Step: 872/900, Epoch: 12/50 | |
| # we should probably just show the steps | |
| #f"Epoch: {training_state['current_epoch']}/{training_state['total_epochs']}", | |
| f"Time elapsed: {training_state['elapsed']}", | |
| f"Estimated remaining: {training_state['remaining']}", | |
| "", | |
| f"Current loss: {training_state['step_loss']}", | |
| f"Learning rate: {training_state['learning_rate']}", | |
| f"Gradient norm: {training_state['grad_norm']}", | |
| f"Memory usage: {training_state['memory']}" | |
| ]) | |
| if training_state["error_message"]: | |
| status_text.append(f"\nError: {training_state['error_message']}") | |
| updates["status_box"] = "\n".join(status_text) | |
| # Update button states | |
| updates["start_btn"] = gr.Button( | |
| "Start training", | |
| interactive=(training_state["status"] in ["idle", "completed", "error", "stopped"]), | |
| variant="primary" if training_state["status"] == "idle" else "secondary" | |
| ) | |
| updates["stop_btn"] = gr.Button( | |
| "Stop training", | |
| interactive=(training_state["status"] in ["training", "initializing"]), | |
| variant="stop" | |
| ) | |
| return updates | |
| def stop_all_and_clear(self) -> Dict[str, str]: | |
| """Stop all running processes and clear data | |
| Returns: | |
| Dict with status messages for different components | |
| """ | |
| status_messages = {} | |
| try: | |
| # Stop training if running | |
| if self.trainer.is_training_running(): | |
| training_result = self.trainer.stop_training() | |
| status_messages["training"] = training_result["status"] | |
| # Stop captioning if running | |
| if self.captioner: | |
| self.captioner.stop_captioning() | |
| status_messages["captioning"] = "Captioning stopped" | |
| # Stop scene detection if running | |
| if self.splitter.is_processing(): | |
| self.splitter.processing = False | |
| status_messages["splitting"] = "Scene detection stopped" | |
| # Properly close logging before clearing log file | |
| if self.trainer.file_handler: | |
| self.trainer.file_handler.close() | |
| logger.removeHandler(self.trainer.file_handler) | |
| self.trainer.file_handler = None | |
| if LOG_FILE_PATH.exists(): | |
| LOG_FILE_PATH.unlink() | |
| # Clear all data directories | |
| for path in [VIDEOS_TO_SPLIT_PATH, STAGING_PATH, TRAINING_VIDEOS_PATH, TRAINING_PATH, | |
| MODEL_PATH, OUTPUT_PATH]: | |
| if path.exists(): | |
| try: | |
| shutil.rmtree(path) | |
| path.mkdir(parents=True, exist_ok=True) | |
| except Exception as e: | |
| status_messages[f"clear_{path.name}"] = f"Error clearing {path.name}: {str(e)}" | |
| else: | |
| status_messages[f"clear_{path.name}"] = f"Cleared {path.name}" | |
| # Reset any persistent state | |
| self._should_stop_captioning = True | |
| self.splitter.processing = False | |
| # Recreate logging setup | |
| self.trainer.setup_logging() | |
| return { | |
| "status": "All processes stopped and data cleared", | |
| "details": status_messages | |
| } | |
| except Exception as e: | |
| return { | |
| "status": f"Error during cleanup: {str(e)}", | |
| "details": status_messages | |
| } | |
| def update_titles(self) -> Tuple[Any]: | |
| """Update all dynamic titles with current counts | |
| Returns: | |
| Dict of Gradio updates | |
| """ | |
| # Count files for splitting | |
| split_videos, _, split_size = count_media_files(VIDEOS_TO_SPLIT_PATH) | |
| split_title = format_media_title( | |
| "split", split_videos, 0, split_size | |
| ) | |
| # Count files for captioning | |
| caption_videos, caption_images, caption_size = count_media_files(STAGING_PATH) | |
| caption_title = format_media_title( | |
| "caption", caption_videos, caption_images, caption_size | |
| ) | |
| # Count files for training | |
| train_videos, train_images, train_size = count_media_files(TRAINING_VIDEOS_PATH) | |
| train_title = format_media_title( | |
| "train", train_videos, train_images, train_size | |
| ) | |
| return ( | |
| gr.Markdown(value=split_title), | |
| gr.Markdown(value=caption_title), | |
| gr.Markdown(value=f"{train_title} available for training") | |
| ) | |
| def copy_files_to_training_dir(self, prompt_prefix: str): | |
| """Run auto-captioning process""" | |
| # Initialize captioner if not already done | |
| self._should_stop_captioning = False | |
| try: | |
| copy_files_to_training_dir(prompt_prefix) | |
| except Exception as e: | |
| traceback.print_exc() | |
| raise gr.Error(f"Error copying assets to training dir: {str(e)}") | |
| async def start_caption_generation(self, captioning_bot_instructions: str, prompt_prefix: str) -> AsyncGenerator[gr.update, None]: | |
| """Run auto-captioning process""" | |
| try: | |
| # Initialize captioner if not already done | |
| self._should_stop_captioning = False | |
| async for rows in self.captioner.start_caption_generation(captioning_bot_instructions, prompt_prefix): | |
| # Yield UI update | |
| yield gr.update( | |
| value=rows, | |
| headers=["name", "status"] | |
| ) | |
| # Final update after completion | |
| yield gr.update( | |
| value=self.list_training_files_to_caption(), | |
| headers=["name", "status"] | |
| ) | |
| except Exception as e: | |
| yield gr.update( | |
| value=[[str(e), "error"]], | |
| headers=["name", "status"] | |
| ) | |
| def list_training_files_to_caption(self) -> List[List[str]]: | |
| """List all clips and images - both pending and captioned""" | |
| files = [] | |
| already_listed: Dict[str, bool] = {} | |
| # Check files in STAGING_PATH | |
| for file in STAGING_PATH.glob("*.*"): | |
| if is_video_file(file) or is_image_file(file): | |
| txt_file = file.with_suffix('.txt') | |
| status = "captioned" if txt_file.exists() else "no caption" | |
| file_type = "video" if is_video_file(file) else "image" | |
| files.append([file.name, f"{status} ({file_type})", str(file)]) | |
| already_listed[str(file.name)] = True | |
| # Check files in TRAINING_VIDEOS_PATH | |
| for file in TRAINING_VIDEOS_PATH.glob("*.*"): | |
| if not str(file.name) in already_listed: | |
| if is_video_file(file) or is_image_file(file): | |
| txt_file = file.with_suffix('.txt') | |
| if txt_file.exists(): | |
| file_type = "video" if is_video_file(file) else "image" | |
| files.append([file.name, f"captioned ({file_type})", str(file)]) | |
| # Sort by filename | |
| files.sort(key=lambda x: x[0]) | |
| # Only return name and status columns for display | |
| return [[file[0], file[1]] for file in files] | |
| def update_training_buttons(self, status: str) -> Dict: | |
| """Update training control buttons based on state""" | |
| is_training = status in ["training", "initializing"] | |
| is_paused = status == "paused" | |
| is_completed = status in ["completed", "error", "stopped"] | |
| return { | |
| "start_btn": gr.Button( | |
| interactive=not is_training and not is_paused, | |
| variant="primary" if not is_training else "secondary", | |
| ), | |
| "stop_btn": gr.Button( | |
| interactive=is_training or is_paused, | |
| variant="stop", | |
| ), | |
| "pause_resume_btn": gr.Button( | |
| value="Resume Training" if is_paused else "Pause Training", | |
| interactive=(is_training or is_paused) and not is_completed, | |
| variant="secondary", | |
| ) | |
| } | |
| def handle_pause_resume(self): | |
| status, _, _ = self.get_latest_status_message_and_logs() | |
| if status == "paused": | |
| self.trainer.resume_training() | |
| else: | |
| self.trainer.pause_training() | |
| return self.get_latest_status_message_logs_and_button_labels() | |
| def handle_stop(self): | |
| self.trainer.stop_training() | |
| return self.get_latest_status_message_logs_and_button_labels() | |
| def handle_training_dataset_select(self, evt: gr.SelectData) -> Tuple[Optional[str], Optional[str], Optional[str]]: | |
| """Handle selection of both video clips and images""" | |
| try: | |
| if not evt: | |
| return [ | |
| gr.Image( | |
| interactive=False, | |
| visible=False | |
| ), | |
| gr.Video( | |
| interactive=False, | |
| visible=False | |
| ), | |
| gr.Textbox( | |
| visible=False | |
| ), | |
| "No file selected" | |
| ] | |
| file_name = evt.value | |
| if not file_name: | |
| return [ | |
| gr.Image( | |
| interactive=False, | |
| visible=False | |
| ), | |
| gr.Video( | |
| interactive=False, | |
| visible=False | |
| ), | |
| gr.Textbox( | |
| visible=False | |
| ), | |
| "No file selected" | |
| ] | |
| # Check both possible locations for the file | |
| possible_paths = [ | |
| STAGING_PATH / file_name, | |
| # note: we use to look into this dir for already-captioned clips, | |
| # but we don't do this anymore | |
| #TRAINING_VIDEOS_PATH / file_name | |
| ] | |
| # Find the first existing file path | |
| file_path = None | |
| for path in possible_paths: | |
| if path.exists(): | |
| file_path = path | |
| break | |
| if not file_path: | |
| return [ | |
| gr.Image( | |
| interactive=False, | |
| visible=False | |
| ), | |
| gr.Video( | |
| interactive=False, | |
| visible=False | |
| ), | |
| gr.Textbox( | |
| visible=False | |
| ), | |
| f"File not found: {file_name}" | |
| ] | |
| txt_path = file_path.with_suffix('.txt') | |
| caption = txt_path.read_text() if txt_path.exists() else "" | |
| # Handle video files | |
| if is_video_file(file_path): | |
| return [ | |
| gr.Image( | |
| interactive=False, | |
| visible=False | |
| ), | |
| gr.Video( | |
| label="Video Preview", | |
| interactive=False, | |
| visible=True, | |
| value=str(file_path) | |
| ), | |
| gr.Textbox( | |
| label="Caption", | |
| lines=6, | |
| interactive=True, | |
| visible=True, | |
| value=str(caption) | |
| ), | |
| None | |
| ] | |
| # Handle image files | |
| elif is_image_file(file_path): | |
| return [ | |
| gr.Image( | |
| label="Image Preview", | |
| interactive=False, | |
| visible=True, | |
| value=str(file_path) | |
| ), | |
| gr.Video( | |
| interactive=False, | |
| visible=False | |
| ), | |
| gr.Textbox( | |
| label="Caption", | |
| lines=6, | |
| interactive=True, | |
| visible=True, | |
| value=str(caption) | |
| ), | |
| None | |
| ] | |
| else: | |
| return [ | |
| gr.Image( | |
| interactive=False, | |
| visible=False | |
| ), | |
| gr.Video( | |
| interactive=False, | |
| visible=False | |
| ), | |
| gr.Textbox( | |
| interactive=False, | |
| visible=False | |
| ), | |
| f"Unsupported file type: {file_path.suffix}" | |
| ] | |
| except Exception as e: | |
| logger.error(f"Error handling selection: {str(e)}") | |
| return [ | |
| gr.Image( | |
| interactive=False, | |
| visible=False | |
| ), | |
| gr.Video( | |
| interactive=False, | |
| visible=False | |
| ), | |
| gr.Textbox( | |
| interactive=False, | |
| visible=False | |
| ), | |
| f"Error handling selection: {str(e)}" | |
| ] | |
| def save_caption_changes(self, preview_caption: str, preview_image: str, preview_video: str, prompt_prefix: str): | |
| """Save changes to caption""" | |
| try: | |
| # Add prefix if not already present | |
| if prompt_prefix and not preview_caption.startswith(prompt_prefix): | |
| full_caption = f"{prompt_prefix}{preview_caption}" | |
| else: | |
| full_caption = preview_caption | |
| path = Path(preview_video if preview_video else preview_image) | |
| if path.suffix == '.txt': | |
| self.trainer.update_file_caption(path.with_suffix(''), full_caption) | |
| else: | |
| self.trainer.update_file_caption(path, full_caption) | |
| return gr.update(value="Caption saved successfully!") | |
| except Exception as e: | |
| return gr.update(value=f"Error saving caption: {str(e)}") | |
| def get_model_info(self, model_type: str) -> str: | |
| """Get information about the selected model type""" | |
| if model_type == "hunyuan_video": | |
| return """### HunyuanVideo (LoRA) | |
| - Best for learning complex video generation patterns | |
| - Required VRAM: ~47GB minimum | |
| - Recommended batch size: 1-2 | |
| - Typical training time: 2-4 hours | |
| - Default resolution: 49x512x768 | |
| - Default LoRA rank: 128""" | |
| elif model_type == "ltx_video": | |
| return """### LTX-Video (LoRA) | |
| - Lightweight video model | |
| - Required VRAM: ~18GB minimum | |
| - Recommended batch size: 1-4 | |
| - Typical training time: 1-3 hours | |
| - Default resolution: 49x512x768 | |
| - Default LoRA rank: 128""" | |
| return "" | |
| def get_default_params(self, model_type: str) -> Dict[str, Any]: | |
| """Get default training parameters for model type""" | |
| if model_type == "hunyuan_video": | |
| return { | |
| "num_epochs": 70, | |
| "batch_size": 1, | |
| "learning_rate": 2e-5, | |
| "save_iterations": 500, | |
| "video_resolution_buckets": TRAINING_BUCKETS, | |
| "video_reshape_mode": "center", | |
| "caption_dropout_p": 0.05, | |
| "gradient_accumulation_steps": 1, | |
| "rank": 128, | |
| "lora_alpha": 128 | |
| } | |
| else: # ltx_video | |
| return { | |
| "num_epochs": 70, | |
| "batch_size": 1, | |
| "learning_rate": 3e-5, | |
| "save_iterations": 500, | |
| "video_resolution_buckets": TRAINING_BUCKETS, | |
| "video_reshape_mode": "center", | |
| "caption_dropout_p": 0.05, | |
| "gradient_accumulation_steps": 4, | |
| "rank": 128, | |
| "lora_alpha": 128 | |
| } | |
| def preview_file(self, selected_text: str) -> Dict: | |
| """Generate preview based on selected file | |
| Args: | |
| selected_text: Text of the selected item containing filename | |
| Returns: | |
| Dict with preview content for each preview component | |
| """ | |
| if not selected_text or "Caption:" in selected_text: | |
| return { | |
| "video": None, | |
| "image": None, | |
| "text": None | |
| } | |
| # Extract filename from the preview text (remove size info) | |
| filename = selected_text.split(" (")[0].strip() | |
| file_path = TRAINING_VIDEOS_PATH / filename | |
| if not file_path.exists(): | |
| return { | |
| "video": None, | |
| "image": None, | |
| "text": f"File not found: {filename}" | |
| } | |
| # Detect file type | |
| mime_type, _ = mimetypes.guess_type(str(file_path)) | |
| if not mime_type: | |
| return { | |
| "video": None, | |
| "image": None, | |
| "text": f"Unknown file type: {filename}" | |
| } | |
| # Return appropriate preview | |
| if mime_type.startswith('video/'): | |
| return { | |
| "video": str(file_path), | |
| "image": None, | |
| "text": None | |
| } | |
| elif mime_type.startswith('image/'): | |
| return { | |
| "video": None, | |
| "image": str(file_path), | |
| "text": None | |
| } | |
| elif mime_type.startswith('text/'): | |
| try: | |
| text_content = file_path.read_text() | |
| return { | |
| "video": None, | |
| "image": None, | |
| "text": text_content | |
| } | |
| except Exception as e: | |
| return { | |
| "video": None, | |
| "image": None, | |
| "text": f"Error reading file: {str(e)}" | |
| } | |
| else: | |
| return { | |
| "video": None, | |
| "image": None, | |
| "text": f"Unsupported file type: {mime_type}" | |
| } | |
| def list_unprocessed_videos(self) -> gr.Dataframe: | |
| """Update list of unprocessed videos""" | |
| videos = self.splitter.list_unprocessed_videos() | |
| # videos is already in [[name, status]] format from splitting_service | |
| return gr.Dataframe( | |
| headers=["name", "status"], | |
| value=videos, | |
| interactive=False | |
| ) | |
| async def start_scene_detection(self, enable_splitting: bool) -> str: | |
| """Start background scene detection process | |
| Args: | |
| enable_splitting: Whether to split videos into scenes | |
| """ | |
| if self.splitter.is_processing(): | |
| return "Scene detection already running" | |
| try: | |
| await self.splitter.start_processing(enable_splitting) | |
| return "Scene detection completed" | |
| except Exception as e: | |
| return f"Error during scene detection: {str(e)}" | |
| def get_latest_status_message_and_logs(self) -> Tuple[str, str, str]: | |
| state = self.trainer.get_status() | |
| logs = self.trainer.get_logs() | |
| # Parse new log lines | |
| if logs: | |
| last_state = None | |
| for line in logs.splitlines(): | |
| state_update = self.log_parser.parse_line(line) | |
| if state_update: | |
| last_state = state_update | |
| if last_state: | |
| ui_updates = self.update_training_ui(last_state) | |
| state["message"] = ui_updates.get("status_box", state["message"]) | |
| # Parse status for training state | |
| if "completed" in state["message"].lower(): | |
| state["status"] = "completed" | |
| return (state["status"], state["message"], logs) | |
| def get_latest_status_message_logs_and_button_labels(self) -> Tuple[str, str, Any, Any, Any]: | |
| status, message, logs = self.get_latest_status_message_and_logs() | |
| return ( | |
| message, | |
| logs, | |
| *self.update_training_buttons(status).values() | |
| ) | |
| def get_latest_button_labels(self) -> Tuple[Any, Any, Any]: | |
| status, message, logs = self.get_latest_status_message_and_logs() | |
| return self.update_training_buttons(status).values() | |
| def refresh_dataset(self): | |
| """Refresh all dynamic lists and training state""" | |
| video_list = self.splitter.list_unprocessed_videos() | |
| training_dataset = self.list_training_files_to_caption() | |
| return ( | |
| video_list, | |
| training_dataset | |
| ) | |
| def create_ui(self): | |
| """Create Gradio interface""" | |
| with gr.Blocks(title="🎥 Video Model Studio") as app: | |
| gr.Markdown("# 🎥 Video Model Studio") | |
| with gr.Tabs() as tabs: | |
| with gr.TabItem("1️⃣ Import", id="import_tab"): | |
| with gr.Row(): | |
| gr.Markdown("## Automatic splitting and captioning") | |
| with gr.Row(): | |
| enable_automatic_video_split = gr.Checkbox( | |
| label="Automatically split videos into smaller clips", | |
| info="Note: a clip is a single camera shot, usually a few seconds", | |
| value=True, | |
| visible=True | |
| ) | |
| enable_automatic_content_captioning = gr.Checkbox( | |
| label="Automatically caption photos and videos", | |
| info="Note: this uses LlaVA and takes some extra time to load and process", | |
| value=False, | |
| visible=True, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## Import video files") | |
| gr.Markdown("You can upload either:") | |
| gr.Markdown("- A single MP4 video file") | |
| gr.Markdown("- A ZIP archive containing multiple videos and optional caption files") | |
| gr.Markdown("For ZIP files: Create a folder containing videos (name is not important) and optional caption files with the same name (eg. `some_video.txt` for `some_video.mp4`)") | |
| with gr.Row(): | |
| files = gr.Files( | |
| label="Upload Images, Videos or ZIP", | |
| #file_count="multiple", | |
| file_types=[".jpg", ".jpeg", ".png", ".webp", ".webp", ".avif", ".heic", ".mp4", ".zip"], | |
| type="filepath" | |
| ) | |
| with gr.Column(scale=3): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## Import a YouTube video") | |
| gr.Markdown("You can also use a YouTube video as reference, by pasting its URL here:") | |
| with gr.Row(): | |
| youtube_url = gr.Textbox( | |
| label="Import YouTube Video", | |
| placeholder="https://www.youtube.com/watch?v=..." | |
| ) | |
| with gr.Row(): | |
| youtube_download_btn = gr.Button("Download YouTube Video", variant="secondary") | |
| with gr.Row(): | |
| import_status = gr.Textbox(label="Status", interactive=False) | |
| with gr.TabItem("2️⃣ Split", id="split_tab"): | |
| with gr.Row(): | |
| split_title = gr.Markdown("## Splitting of 0 videos (0 bytes)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| detect_btn = gr.Button("Split videos into single-camera shots", variant="primary") | |
| detect_status = gr.Textbox(label="Status", interactive=False) | |
| with gr.Column(): | |
| video_list = gr.Dataframe( | |
| headers=["name", "status"], | |
| label="Videos to split", | |
| interactive=False, | |
| wrap=True, | |
| #selection_mode="cell" # Enable cell selection | |
| ) | |
| with gr.TabItem("3️⃣ Caption"): | |
| with gr.Row(): | |
| caption_title = gr.Markdown("## Captioning of 0 files (0 bytes)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| custom_prompt_prefix = gr.Textbox( | |
| scale=3, | |
| label='Prefix to add to ALL captions (eg. "In the style of TOK, ")', | |
| placeholder="In the style of TOK, ", | |
| lines=2, | |
| value=DEFAULT_PROMPT_PREFIX | |
| ) | |
| captioning_bot_instructions = gr.Textbox( | |
| scale=6, | |
| label="System instructions for the automatic captioning model", | |
| placeholder="Please generate a full description of...", | |
| lines=5, | |
| value=DEFAULT_CAPTIONING_BOT_INSTRUCTIONS | |
| ) | |
| with gr.Row(): | |
| run_autocaption_btn = gr.Button( | |
| "Automatically fill missing captions", | |
| variant="primary" # Makes it green by default | |
| ) | |
| copy_files_to_training_dir_btn = gr.Button( | |
| "Copy assets to training directory", | |
| variant="primary" # Makes it green by default | |
| ) | |
| stop_autocaption_btn = gr.Button( | |
| "Stop Captioning", | |
| variant="stop", # Red when enabled | |
| interactive=False # Disabled by default | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| training_dataset = gr.Dataframe( | |
| headers=["name", "status"], | |
| interactive=False, | |
| wrap=True, | |
| value=self.list_training_files_to_caption(), | |
| row_count=10, # Optional: set a reasonable row count | |
| #selection_mode="cell" | |
| ) | |
| with gr.Column(): | |
| preview_video = gr.Video( | |
| label="Video Preview", | |
| interactive=False, | |
| visible=False | |
| ) | |
| preview_image = gr.Image( | |
| label="Image Preview", | |
| interactive=False, | |
| visible=False | |
| ) | |
| preview_caption = gr.Textbox( | |
| label="Caption", | |
| lines=6, | |
| interactive=True | |
| ) | |
| save_caption_btn = gr.Button("Save Caption") | |
| preview_status = gr.Textbox( | |
| label="Status", | |
| interactive=False, | |
| visible=True | |
| ) | |
| with gr.TabItem("4️⃣ Train"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| train_title = gr.Markdown("## 0 files available for training (0 bytes)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_type = gr.Dropdown( | |
| choices=list(MODEL_TYPES.keys()), | |
| label="Model Type", | |
| value=list(MODEL_TYPES.keys())[0] | |
| ) | |
| model_info = gr.Markdown( | |
| value=self.get_model_info(list(MODEL_TYPES.keys())[0]) | |
| ) | |
| with gr.Row(): | |
| lora_rank = gr.Dropdown( | |
| label="LoRA Rank", | |
| choices=["16", "32", "64", "128", "256"], | |
| value="128", | |
| type="value" | |
| ) | |
| lora_alpha = gr.Dropdown( | |
| label="LoRA Alpha", | |
| choices=["16", "32", "64", "128", "256"], | |
| value="128", | |
| type="value" | |
| ) | |
| with gr.Row(): | |
| num_epochs = gr.Number( | |
| label="Number of Epochs", | |
| value=70, | |
| minimum=1, | |
| precision=0 | |
| ) | |
| batch_size = gr.Number( | |
| label="Batch Size", | |
| value=1, | |
| minimum=1, | |
| precision=0 | |
| ) | |
| with gr.Row(): | |
| learning_rate = gr.Number( | |
| label="Learning Rate", | |
| value=2e-5, | |
| minimum=1e-7 | |
| ) | |
| save_iterations = gr.Number( | |
| label="Save checkpoint every N iterations", | |
| value=500, | |
| minimum=50, | |
| precision=0, | |
| info="Model will be saved periodically after these many steps" | |
| ) | |
| with gr.Column(): | |
| with gr.Row(): | |
| start_btn = gr.Button( | |
| "Start Training", | |
| variant="primary", | |
| interactive=not ASK_USER_TO_DUPLICATE_SPACE | |
| ) | |
| pause_resume_btn = gr.Button( | |
| "Resume Training", | |
| variant="secondary", | |
| interactive=False | |
| ) | |
| stop_btn = gr.Button( | |
| "Stop Training", | |
| variant="stop", | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| status_box = gr.Textbox( | |
| label="Training Status", | |
| interactive=False, | |
| lines=4 | |
| ) | |
| with gr.Accordion("See training logs"): | |
| log_box = gr.TextArea( | |
| label="Finetrainers output (see HF Space logs for more details)", | |
| interactive=False, | |
| lines=40, | |
| max_lines=200, | |
| autoscroll=True | |
| ) | |
| with gr.TabItem("5️⃣ Manage"): | |
| with gr.Column(): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## Publishing") | |
| gr.Markdown("You model can be pushed to Hugging Face (this will use HF_API_TOKEN)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| repo_id = gr.Textbox( | |
| label="HuggingFace Model Repository", | |
| placeholder="username/model-name", | |
| info="The repository will be created if it doesn't exist" | |
| ) | |
| gr.Checkbox(label="Check this to make your model public (ie. visible and downloadable by anyone)", info="You model is private by default"), | |
| global_stop_btn = gr.Button( | |
| "Push my model", | |
| #variant="stop" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## Storage management") | |
| with gr.Row(): | |
| download_dataset_btn = gr.DownloadButton( | |
| "Download dataset", | |
| variant="secondary", | |
| size="lg" | |
| ) | |
| download_model_btn = gr.DownloadButton( | |
| "Download model", | |
| variant="secondary", | |
| size="lg" | |
| ) | |
| with gr.Row(): | |
| global_stop_btn = gr.Button( | |
| "Stop everything and delete my data", | |
| variant="stop" | |
| ) | |
| global_status = gr.Textbox( | |
| label="Global Status", | |
| interactive=False, | |
| visible=False | |
| ) | |
| # Event handlers | |
| def update_model_info(model): | |
| params = self.get_default_params(MODEL_TYPES[model]) | |
| info = self.get_model_info(MODEL_TYPES[model]) | |
| return { | |
| model_info: info, | |
| num_epochs: params["num_epochs"], | |
| batch_size: params["batch_size"], | |
| learning_rate: params["learning_rate"], | |
| save_iterations: params["save_iterations"] | |
| } | |
| def validate_repo(repo_id: str) -> dict: | |
| validation = validate_model_repo(repo_id) | |
| if validation["error"]: | |
| return gr.update(value=repo_id, error=validation["error"]) | |
| return gr.update(value=repo_id, error=None) | |
| # Connect events | |
| model_type.change( | |
| fn=update_model_info, | |
| inputs=[model_type], | |
| outputs=[model_info, num_epochs, batch_size, learning_rate, save_iterations] | |
| ) | |
| async def on_import_success(enable_splitting, enable_automatic_content_captioning, prompt_prefix): | |
| videos = self.list_unprocessed_videos() | |
| # If scene detection isn't already running and there are videos to process, | |
| # and auto-splitting is enabled, start the detection | |
| if videos and not self.splitter.is_processing() and enable_splitting: | |
| await self.start_scene_detection(enable_splitting) | |
| msg = "Starting automatic scene detection..." | |
| else: | |
| # Just copy files without splitting if auto-split disabled | |
| for video_file in VIDEOS_TO_SPLIT_PATH.glob("*.mp4"): | |
| await self.splitter.process_video(video_file, enable_splitting=False) | |
| msg = "Copying videos without splitting..." | |
| copy_files_to_training_dir(prompt_prefix) | |
| # Start auto-captioning if enabled | |
| if enable_automatic_content_captioning: | |
| await self.start_caption_generation( | |
| DEFAULT_CAPTIONING_BOT_INSTRUCTIONS, | |
| prompt_prefix | |
| ) | |
| return { | |
| tabs: gr.Tabs(selected="split_tab"), | |
| video_list: videos, | |
| detect_status: msg | |
| } | |
| async def update_titles_after_import(enable_splitting, enable_automatic_content_captioning, prompt_prefix): | |
| """Handle post-import updates including titles""" | |
| import_result = await on_import_success(enable_splitting, enable_automatic_content_captioning, prompt_prefix) | |
| titles = self.update_titles() | |
| return (*import_result, *titles) | |
| files.upload( | |
| fn=lambda x: self.importer.process_uploaded_files(x), | |
| inputs=[files], | |
| outputs=[import_status] | |
| ).success( | |
| fn=update_titles_after_import, | |
| inputs=[enable_automatic_video_split, enable_automatic_content_captioning, custom_prompt_prefix], | |
| outputs=[ | |
| tabs, video_list, detect_status, | |
| split_title, caption_title, train_title | |
| ] | |
| ) | |
| youtube_download_btn.click( | |
| fn=self.importer.download_youtube_video, | |
| inputs=[youtube_url], | |
| outputs=[import_status] | |
| ).success( | |
| fn=on_import_success, | |
| inputs=[enable_automatic_video_split, enable_automatic_content_captioning, custom_prompt_prefix], | |
| outputs=[tabs, video_list, detect_status] | |
| ) | |
| # Scene detection events | |
| detect_btn.click( | |
| fn=self.start_scene_detection, | |
| inputs=[enable_automatic_video_split], | |
| outputs=[detect_status] | |
| ) | |
| # Update button states based on captioning status | |
| def update_button_states(is_running): | |
| return { | |
| run_autocaption_btn: gr.Button( | |
| interactive=not is_running, | |
| variant="secondary" if is_running else "primary", | |
| ), | |
| stop_autocaption_btn: gr.Button( | |
| interactive=is_running, | |
| variant="secondary", | |
| ), | |
| } | |
| run_autocaption_btn.click( | |
| fn=self.start_caption_generation, | |
| inputs=[captioning_bot_instructions, custom_prompt_prefix], | |
| outputs=[training_dataset], | |
| ).then( | |
| fn=lambda: update_button_states(True), | |
| outputs=[run_autocaption_btn, stop_autocaption_btn] | |
| ) | |
| copy_files_to_training_dir_btn.click( | |
| fn=self.copy_files_to_training_dir, | |
| inputs=[custom_prompt_prefix] | |
| ) | |
| stop_autocaption_btn.click( | |
| fn=lambda: (self.captioner.stop_captioning() if self.captioner else None, update_button_states(False)), | |
| outputs=[run_autocaption_btn, stop_autocaption_btn] | |
| ) | |
| training_dataset.select( | |
| fn=self.handle_training_dataset_select, | |
| outputs=[preview_image, preview_video, preview_caption, preview_status] | |
| ) | |
| save_caption_btn.click( | |
| fn=self.save_caption_changes, | |
| inputs=[preview_caption, preview_image, preview_video, custom_prompt_prefix], | |
| outputs=[preview_status] | |
| ).success( | |
| fn=self.list_training_files_to_caption, | |
| outputs=[training_dataset] | |
| ) | |
| # Training control events | |
| start_btn.click( | |
| fn=lambda model_type, *args: ( | |
| self.log_parser.reset(), | |
| self.trainer.start_training( | |
| MODEL_TYPES[model_type], | |
| *args | |
| ) | |
| ), | |
| inputs=[ | |
| model_type, | |
| lora_rank, | |
| lora_alpha, | |
| num_epochs, | |
| batch_size, | |
| learning_rate, | |
| save_iterations, | |
| repo_id | |
| ], | |
| outputs=[status_box, log_box] | |
| ).success( | |
| fn=self.get_latest_status_message_logs_and_button_labels, | |
| outputs=[status_box, log_box, start_btn, stop_btn, pause_resume_btn] | |
| ) | |
| pause_resume_btn.click( | |
| fn=self.handle_pause_resume, | |
| outputs=[status_box, log_box, start_btn, stop_btn, pause_resume_btn] | |
| ) | |
| stop_btn.click( | |
| fn=self.handle_stop, | |
| outputs=[status_box, log_box, start_btn, stop_btn, pause_resume_btn] | |
| ) | |
| def handle_global_stop(): | |
| result = self.stop_all_and_clear() | |
| # Update all relevant UI components | |
| status = result["status"] | |
| details = "\n".join(f"{k}: {v}" for k, v in result["details"].items()) | |
| full_status = f"{status}\n\nDetails:\n{details}" | |
| # Get fresh lists after cleanup | |
| videos = self.splitter.list_unprocessed_videos() | |
| clips = self.list_training_files_to_caption() | |
| return { | |
| global_status: gr.update(value=full_status, visible=True), | |
| video_list: videos, | |
| training_dataset: clips, | |
| status_box: "Training stopped and data cleared", | |
| log_box: "", | |
| detect_status: "Scene detection stopped", | |
| import_status: "All data cleared", | |
| preview_status: "Captioning stopped" | |
| } | |
| download_dataset_btn.click( | |
| fn=self.trainer.create_training_dataset_zip, | |
| outputs=[download_dataset_btn] | |
| ) | |
| download_model_btn.click( | |
| fn=self.trainer.get_model_output_safetensors, | |
| outputs=[download_model_btn] | |
| ) | |
| global_stop_btn.click( | |
| fn=handle_global_stop, | |
| outputs=[ | |
| global_status, | |
| video_list, | |
| training_dataset, | |
| status_box, | |
| log_box, | |
| detect_status, | |
| import_status, | |
| preview_status | |
| ] | |
| ) | |
| # Auto-refresh timers | |
| app.load( | |
| fn=lambda: ( | |
| self.refresh_dataset() | |
| ), | |
| outputs=[ | |
| video_list, training_dataset | |
| ] | |
| ) | |
| timer = gr.Timer(value=1) | |
| timer.tick( | |
| fn=lambda: ( | |
| self.get_latest_status_message_logs_and_button_labels() | |
| ), | |
| outputs=[ | |
| status_box, | |
| log_box, | |
| start_btn, | |
| stop_btn, | |
| pause_resume_btn | |
| ] | |
| ) | |
| timer = gr.Timer(value=5) | |
| timer.tick( | |
| fn=lambda: ( | |
| self.refresh_dataset() | |
| ), | |
| outputs=[ | |
| video_list, training_dataset | |
| ] | |
| ) | |
| timer = gr.Timer(value=6) | |
| timer.tick( | |
| fn=lambda: self.update_titles(), | |
| outputs=[ | |
| split_title, caption_title, train_title | |
| ] | |
| ) | |
| return app | |
| def create_app(): | |
| if ASK_USER_TO_DUPLICATE_SPACE: | |
| with gr.Blocks() as app: | |
| gr.Markdown("""# Finetrainers UI | |
| This Hugging Face space needs to be duplicated to your own billing account to work. | |
| Click the 'Duplicate Space' button at the top of the page to create your own copy. | |
| It is recommended to use a Nvidia L40S and a persistent storage space. | |
| To avoid overpaying for your space, you can configure the auto-sleep settings to fit your personal budget.""") | |
| return app | |
| ui = VideoTrainerUI() | |
| return ui.create_ui() | |
| if __name__ == "__main__": | |
| app = create_app() | |
| allowed_paths = [ | |
| str(STORAGE_PATH), # Base storage | |
| str(VIDEOS_TO_SPLIT_PATH), | |
| str(STAGING_PATH), | |
| str(TRAINING_PATH), | |
| str(TRAINING_VIDEOS_PATH), | |
| str(MODEL_PATH), | |
| str(OUTPUT_PATH) | |
| ] | |
| app.queue(default_concurrency_limit=1).launch( | |
| server_name="0.0.0.0", | |
| allowed_paths=allowed_paths | |
| ) |