Julian Bilcke
commited on
Commit
·
fc0385d
1
Parent(s):
ece1c33
fix for the custom prompt prefix
Browse files- vms/ui/project/services/importing/file_upload.py +4 -4
- vms/ui/project/services/importing/hub_dataset.py +5 -3
- vms/ui/project/services/importing/import_service.py +8 -6
- vms/ui/project/services/training.py +5 -10
- vms/ui/project/tabs/import_tab/hub_tab.py +6 -4
- vms/ui/project/tabs/import_tab/upload_tab.py +1 -1
- vms/ui/project/tabs/train_tab.py +9 -1
vms/ui/project/services/importing/file_upload.py
CHANGED
|
@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
|
|
| 22 |
class FileUploadHandler:
|
| 23 |
"""Handles processing of uploaded files"""
|
| 24 |
|
| 25 |
-
def process_uploaded_files(self, file_paths: List[str], enable_splitting: bool) -> str:
|
| 26 |
"""Process uploaded file (ZIP, TAR, MP4, or image)
|
| 27 |
|
| 28 |
Args:
|
|
@@ -48,7 +48,7 @@ class FileUploadHandler:
|
|
| 48 |
file_ext = file_path.suffix.lower()
|
| 49 |
|
| 50 |
if file_ext == '.zip':
|
| 51 |
-
return self.process_zip_file(file_path, enable_splitting)
|
| 52 |
elif file_ext == '.tar':
|
| 53 |
return self.process_tar_file(file_path, enable_splitting)
|
| 54 |
elif file_ext == '.mp4' or file_ext == '.webm':
|
|
@@ -63,7 +63,7 @@ class FileUploadHandler:
|
|
| 63 |
logger.error(f"Error processing file {file_path}: {str(e)}", exc_info=True)
|
| 64 |
raise gr.Error(f"Error processing file: {str(e)}")
|
| 65 |
|
| 66 |
-
def process_zip_file(self, file_path: Path, enable_splitting: bool) -> str:
|
| 67 |
"""Process uploaded ZIP file containing media files or WebDataset tar files
|
| 68 |
|
| 69 |
Args:
|
|
@@ -138,7 +138,7 @@ class FileUploadHandler:
|
|
| 138 |
logger.info(f"Copied caption file for {file}")
|
| 139 |
elif is_image_file(file_path):
|
| 140 |
caption = txt_path.read_text()
|
| 141 |
-
caption = add_prefix_to_caption(caption, DEFAULT_PROMPT_PREFIX)
|
| 142 |
target_path.with_suffix('.txt').write_text(caption)
|
| 143 |
logger.info(f"Processed caption for {file}")
|
| 144 |
|
|
|
|
| 22 |
class FileUploadHandler:
|
| 23 |
"""Handles processing of uploaded files"""
|
| 24 |
|
| 25 |
+
def process_uploaded_files(self, file_paths: List[str], enable_splitting: bool, custom_prompt_prefix: str = None) -> str:
|
| 26 |
"""Process uploaded file (ZIP, TAR, MP4, or image)
|
| 27 |
|
| 28 |
Args:
|
|
|
|
| 48 |
file_ext = file_path.suffix.lower()
|
| 49 |
|
| 50 |
if file_ext == '.zip':
|
| 51 |
+
return self.process_zip_file(file_path, enable_splitting, custom_prompt_prefix)
|
| 52 |
elif file_ext == '.tar':
|
| 53 |
return self.process_tar_file(file_path, enable_splitting)
|
| 54 |
elif file_ext == '.mp4' or file_ext == '.webm':
|
|
|
|
| 63 |
logger.error(f"Error processing file {file_path}: {str(e)}", exc_info=True)
|
| 64 |
raise gr.Error(f"Error processing file: {str(e)}")
|
| 65 |
|
| 66 |
+
def process_zip_file(self, file_path: Path, enable_splitting: bool, custom_prompt_prefix: str = None) -> str:
|
| 67 |
"""Process uploaded ZIP file containing media files or WebDataset tar files
|
| 68 |
|
| 69 |
Args:
|
|
|
|
| 138 |
logger.info(f"Copied caption file for {file}")
|
| 139 |
elif is_image_file(file_path):
|
| 140 |
caption = txt_path.read_text()
|
| 141 |
+
caption = add_prefix_to_caption(caption, custom_prompt_prefix or DEFAULT_PROMPT_PREFIX)
|
| 142 |
target_path.with_suffix('.txt').write_text(caption)
|
| 143 |
logger.info(f"Processed caption for {file}")
|
| 144 |
|
vms/ui/project/services/importing/hub_dataset.py
CHANGED
|
@@ -169,7 +169,8 @@ class HubDatasetBrowser:
|
|
| 169 |
dataset_id: str,
|
| 170 |
file_type: str,
|
| 171 |
enable_splitting: bool,
|
| 172 |
-
progress_callback: Optional[Callable] = None
|
|
|
|
| 173 |
) -> str:
|
| 174 |
"""Download all files of a specific type from the dataset
|
| 175 |
|
|
@@ -329,7 +330,8 @@ class HubDatasetBrowser:
|
|
| 329 |
self,
|
| 330 |
dataset_id: str,
|
| 331 |
enable_splitting: bool,
|
| 332 |
-
progress_callback: Optional[Callable] = None
|
|
|
|
| 333 |
) -> Tuple[str, str]:
|
| 334 |
"""Download a dataset and process its video/image content
|
| 335 |
|
|
@@ -555,7 +557,7 @@ class HubDatasetBrowser:
|
|
| 555 |
txt_path = file_path.with_suffix('.txt')
|
| 556 |
if txt_path.exists():
|
| 557 |
caption = txt_path.read_text()
|
| 558 |
-
caption = add_prefix_to_caption(caption, DEFAULT_PROMPT_PREFIX)
|
| 559 |
target_path.with_suffix('.txt').write_text(caption)
|
| 560 |
logger.info(f"Processed caption for {file_path.name}")
|
| 561 |
|
|
|
|
| 169 |
dataset_id: str,
|
| 170 |
file_type: str,
|
| 171 |
enable_splitting: bool,
|
| 172 |
+
progress_callback: Optional[Callable] = None,
|
| 173 |
+
custom_prompt_prefix: str = None
|
| 174 |
) -> str:
|
| 175 |
"""Download all files of a specific type from the dataset
|
| 176 |
|
|
|
|
| 330 |
self,
|
| 331 |
dataset_id: str,
|
| 332 |
enable_splitting: bool,
|
| 333 |
+
progress_callback: Optional[Callable] = None,
|
| 334 |
+
custom_prompt_prefix: str = None
|
| 335 |
) -> Tuple[str, str]:
|
| 336 |
"""Download a dataset and process its video/image content
|
| 337 |
|
|
|
|
| 557 |
txt_path = file_path.with_suffix('.txt')
|
| 558 |
if txt_path.exists():
|
| 559 |
caption = txt_path.read_text()
|
| 560 |
+
caption = add_prefix_to_caption(caption, custom_prompt_prefix or DEFAULT_PROMPT_PREFIX)
|
| 561 |
target_path.with_suffix('.txt').write_text(caption)
|
| 562 |
logger.info(f"Processed caption for {file_path.name}")
|
| 563 |
|
vms/ui/project/services/importing/import_service.py
CHANGED
|
@@ -28,7 +28,7 @@ class ImportingService:
|
|
| 28 |
self.youtube_handler = YouTubeDownloader()
|
| 29 |
self.hub_browser = HubDatasetBrowser(self.hf_api)
|
| 30 |
|
| 31 |
-
def process_uploaded_files(self, file_paths: List[str], enable_splitting: bool) -> str:
|
| 32 |
"""Process uploaded file (ZIP, TAR, MP4, or image)
|
| 33 |
|
| 34 |
Args:
|
|
@@ -45,7 +45,7 @@ class ImportingService:
|
|
| 45 |
|
| 46 |
print(f"process_uploaded_files(..., enable_splitting = {enable_splitting:})")
|
| 47 |
print(f"process_uploaded_files: calling self.file_handler.process_uploaded_files")
|
| 48 |
-
return self.file_handler.process_uploaded_files(file_paths, enable_splitting)
|
| 49 |
|
| 50 |
def download_youtube_video(self, url: str, enable_splitting: bool, progress=None) -> str:
|
| 51 |
"""Download a video from YouTube
|
|
@@ -86,7 +86,8 @@ class ImportingService:
|
|
| 86 |
self,
|
| 87 |
dataset_id: str,
|
| 88 |
enable_splitting: bool,
|
| 89 |
-
progress_callback: Optional[Callable] = None
|
|
|
|
| 90 |
) -> Tuple[str, str]:
|
| 91 |
"""Download a dataset and process its video/image content
|
| 92 |
|
|
@@ -98,14 +99,15 @@ class ImportingService:
|
|
| 98 |
Returns:
|
| 99 |
Tuple of (loading_msg, status_msg)
|
| 100 |
"""
|
| 101 |
-
return await self.hub_browser.download_dataset(dataset_id, enable_splitting, progress_callback)
|
| 102 |
|
| 103 |
async def download_file_group(
|
| 104 |
self,
|
| 105 |
dataset_id: str,
|
| 106 |
file_type: str,
|
| 107 |
enable_splitting: bool,
|
| 108 |
-
progress_callback: Optional[Callable] = None
|
|
|
|
| 109 |
) -> str:
|
| 110 |
"""Download a group of files (videos or WebDatasets)
|
| 111 |
|
|
@@ -118,4 +120,4 @@ class ImportingService:
|
|
| 118 |
Returns:
|
| 119 |
Status message
|
| 120 |
"""
|
| 121 |
-
return await self.hub_browser.download_file_group(dataset_id, file_type, enable_splitting, progress_callback)
|
|
|
|
| 28 |
self.youtube_handler = YouTubeDownloader()
|
| 29 |
self.hub_browser = HubDatasetBrowser(self.hf_api)
|
| 30 |
|
| 31 |
+
def process_uploaded_files(self, file_paths: List[str], enable_splitting: bool, custom_prompt_prefix: str = None) -> str:
|
| 32 |
"""Process uploaded file (ZIP, TAR, MP4, or image)
|
| 33 |
|
| 34 |
Args:
|
|
|
|
| 45 |
|
| 46 |
print(f"process_uploaded_files(..., enable_splitting = {enable_splitting:})")
|
| 47 |
print(f"process_uploaded_files: calling self.file_handler.process_uploaded_files")
|
| 48 |
+
return self.file_handler.process_uploaded_files(file_paths, enable_splitting, custom_prompt_prefix)
|
| 49 |
|
| 50 |
def download_youtube_video(self, url: str, enable_splitting: bool, progress=None) -> str:
|
| 51 |
"""Download a video from YouTube
|
|
|
|
| 86 |
self,
|
| 87 |
dataset_id: str,
|
| 88 |
enable_splitting: bool,
|
| 89 |
+
progress_callback: Optional[Callable] = None,
|
| 90 |
+
custom_prompt_prefix: str = None
|
| 91 |
) -> Tuple[str, str]:
|
| 92 |
"""Download a dataset and process its video/image content
|
| 93 |
|
|
|
|
| 99 |
Returns:
|
| 100 |
Tuple of (loading_msg, status_msg)
|
| 101 |
"""
|
| 102 |
+
return await self.hub_browser.download_dataset(dataset_id, enable_splitting, progress_callback, custom_prompt_prefix)
|
| 103 |
|
| 104 |
async def download_file_group(
|
| 105 |
self,
|
| 106 |
dataset_id: str,
|
| 107 |
file_type: str,
|
| 108 |
enable_splitting: bool,
|
| 109 |
+
progress_callback: Optional[Callable] = None,
|
| 110 |
+
custom_prompt_prefix: str = None
|
| 111 |
) -> str:
|
| 112 |
"""Download a group of files (videos or WebDatasets)
|
| 113 |
|
|
|
|
| 120 |
Returns:
|
| 121 |
Status message
|
| 122 |
"""
|
| 123 |
+
return await self.hub_browser.download_file_group(dataset_id, file_type, enable_splitting, progress_callback, custom_prompt_prefix)
|
vms/ui/project/services/training.py
CHANGED
|
@@ -579,6 +579,7 @@ class TrainingService:
|
|
| 579 |
precomputation_items: int = DEFAULT_PRECOMPUTATION_ITEMS,
|
| 580 |
lr_warmup_steps: int = DEFAULT_NB_LR_WARMUP_STEPS,
|
| 581 |
progress: Optional[gr.Progress] = None,
|
|
|
|
| 582 |
) -> Tuple[str, str]:
|
| 583 |
"""Start training with finetrainers"""
|
| 584 |
|
|
@@ -669,16 +670,10 @@ class TrainingService:
|
|
| 669 |
else:
|
| 670 |
flow_weighting_scheme = "logit_normal"
|
| 671 |
|
| 672 |
-
#
|
| 673 |
-
|
| 674 |
-
if
|
| 675 |
-
|
| 676 |
-
if hasattr(self.app.tabs['caption_tab'], 'components') and 'custom_prompt_prefix' in self.app.tabs['caption_tab'].components:
|
| 677 |
-
# Get the value and clean it
|
| 678 |
-
prefix = self.app.tabs['caption_tab'].components['custom_prompt_prefix'].value
|
| 679 |
-
if prefix:
|
| 680 |
-
# Clean the prefix - remove trailing comma, space or comma+space
|
| 681 |
-
custom_prompt_prefix = prefix.rstrip(', ')
|
| 682 |
|
| 683 |
# Create a proper dataset configuration JSON file
|
| 684 |
dataset_config_file = self.app.output_path / "dataset_config.json"
|
|
|
|
| 579 |
precomputation_items: int = DEFAULT_PRECOMPUTATION_ITEMS,
|
| 580 |
lr_warmup_steps: int = DEFAULT_NB_LR_WARMUP_STEPS,
|
| 581 |
progress: Optional[gr.Progress] = None,
|
| 582 |
+
custom_prompt_prefix: Optional[str] = None,
|
| 583 |
) -> Tuple[str, str]:
|
| 584 |
"""Start training with finetrainers"""
|
| 585 |
|
|
|
|
| 670 |
else:
|
| 671 |
flow_weighting_scheme = "logit_normal"
|
| 672 |
|
| 673 |
+
# Use the custom prompt prefix passed as parameter
|
| 674 |
+
# Clean the prefix - remove trailing comma, space or comma+space
|
| 675 |
+
if custom_prompt_prefix:
|
| 676 |
+
custom_prompt_prefix = custom_prompt_prefix.rstrip(', ')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 677 |
|
| 678 |
# Create a proper dataset configuration JSON file
|
| 679 |
dataset_config_file = self.app.output_path / "dataset_config.json"
|
vms/ui/project/tabs/import_tab/hub_tab.py
CHANGED
|
@@ -267,7 +267,7 @@ class HubTab(BaseTab):
|
|
| 267 |
"" # status_output
|
| 268 |
)
|
| 269 |
|
| 270 |
-
async def _download_with_progress(self, dataset_id, file_type, enable_splitting, progress_callback):
|
| 271 |
"""Wrapper for download_file_group that integrates with progress tracking"""
|
| 272 |
try:
|
| 273 |
# Set up the progress callback adapter
|
|
@@ -289,7 +289,8 @@ class HubTab(BaseTab):
|
|
| 289 |
dataset_id,
|
| 290 |
file_type,
|
| 291 |
enable_splitting,
|
| 292 |
-
progress_callback=progress_adapter
|
|
|
|
| 293 |
)
|
| 294 |
|
| 295 |
return result
|
|
@@ -298,7 +299,7 @@ class HubTab(BaseTab):
|
|
| 298 |
logger.error(f"Error in download with progress: {str(e)}", exc_info=True)
|
| 299 |
return f"Error: {str(e)}"
|
| 300 |
|
| 301 |
-
def download_file_group(self, dataset_id: str, enable_splitting: bool, file_type: str, progress=gr.Progress()) -> Tuple:
|
| 302 |
"""Handle download of a group of files (videos or WebDatasets) with progress tracking"""
|
| 303 |
try:
|
| 304 |
if not dataset_id:
|
|
@@ -323,7 +324,8 @@ class HubTab(BaseTab):
|
|
| 323 |
dataset_id,
|
| 324 |
file_type,
|
| 325 |
enable_splitting,
|
| 326 |
-
progress
|
|
|
|
| 327 |
))
|
| 328 |
|
| 329 |
# When download is complete, update the UI
|
|
|
|
| 267 |
"" # status_output
|
| 268 |
)
|
| 269 |
|
| 270 |
+
async def _download_with_progress(self, dataset_id, file_type, enable_splitting, progress_callback, custom_prompt_prefix=None):
|
| 271 |
"""Wrapper for download_file_group that integrates with progress tracking"""
|
| 272 |
try:
|
| 273 |
# Set up the progress callback adapter
|
|
|
|
| 289 |
dataset_id,
|
| 290 |
file_type,
|
| 291 |
enable_splitting,
|
| 292 |
+
progress_callback=progress_adapter,
|
| 293 |
+
custom_prompt_prefix=custom_prompt_prefix
|
| 294 |
)
|
| 295 |
|
| 296 |
return result
|
|
|
|
| 299 |
logger.error(f"Error in download with progress: {str(e)}", exc_info=True)
|
| 300 |
return f"Error: {str(e)}"
|
| 301 |
|
| 302 |
+
def download_file_group(self, dataset_id: str, enable_splitting: bool, file_type: str, custom_prompt_prefix: str = None, progress=gr.Progress()) -> Tuple:
|
| 303 |
"""Handle download of a group of files (videos or WebDatasets) with progress tracking"""
|
| 304 |
try:
|
| 305 |
if not dataset_id:
|
|
|
|
| 324 |
dataset_id,
|
| 325 |
file_type,
|
| 326 |
enable_splitting,
|
| 327 |
+
progress,
|
| 328 |
+
custom_prompt_prefix
|
| 329 |
))
|
| 330 |
|
| 331 |
# When download is complete, update the UI
|
vms/ui/project/tabs/import_tab/upload_tab.py
CHANGED
|
@@ -65,7 +65,7 @@ class UploadTab(BaseTab):
|
|
| 65 |
# File upload event with enable_splitting parameter
|
| 66 |
upload_event = self.components["files"].upload(
|
| 67 |
fn=self.app.importing.process_uploaded_files,
|
| 68 |
-
inputs=[self.components["files"], self.components["enable_automatic_video_split"]],
|
| 69 |
outputs=[self.components["import_status"]]
|
| 70 |
).success(
|
| 71 |
fn=self.app.tabs["import_tab"].on_import_success,
|
|
|
|
| 65 |
# File upload event with enable_splitting parameter
|
| 66 |
upload_event = self.components["files"].upload(
|
| 67 |
fn=self.app.importing.process_uploaded_files,
|
| 68 |
+
inputs=[self.components["files"], self.components["enable_automatic_video_split"], self.app.tabs["caption_tab"].components["custom_prompt_prefix"]],
|
| 69 |
outputs=[self.components["import_status"]]
|
| 70 |
).success(
|
| 71 |
fn=self.app.tabs["import_tab"].on_import_success,
|
vms/ui/project/tabs/train_tab.py
CHANGED
|
@@ -906,6 +906,13 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 906 |
precomputation_items = int(self.components["precomputation_items"].value)
|
| 907 |
lr_warmup_steps = int(self.components["lr_warmup_steps"].value)
|
| 908 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 909 |
# Start training (it will automatically use the checkpoint if provided)
|
| 910 |
try:
|
| 911 |
return self.app.training.start_training(
|
|
@@ -924,7 +931,8 @@ For image-to-video tasks, 'index' (usually with index 0) is most common as it co
|
|
| 924 |
precomputation_items=precomputation_items,
|
| 925 |
lr_warmup_steps=lr_warmup_steps,
|
| 926 |
progress=progress,
|
| 927 |
-
pretrained_lora_path=pretrained_lora_path
|
|
|
|
| 928 |
)
|
| 929 |
except Exception as e:
|
| 930 |
logger.exception("Error starting training")
|
|
|
|
| 906 |
precomputation_items = int(self.components["precomputation_items"].value)
|
| 907 |
lr_warmup_steps = int(self.components["lr_warmup_steps"].value)
|
| 908 |
|
| 909 |
+
# Get custom prompt prefix from caption tab
|
| 910 |
+
custom_prompt_prefix = None
|
| 911 |
+
if hasattr(self.app, 'tabs') and 'caption_tab' in self.app.tabs:
|
| 912 |
+
caption_tab = self.app.tabs['caption_tab']
|
| 913 |
+
if hasattr(caption_tab, 'components') and 'custom_prompt_prefix' in caption_tab.components:
|
| 914 |
+
custom_prompt_prefix = caption_tab.components['custom_prompt_prefix'].value
|
| 915 |
+
|
| 916 |
# Start training (it will automatically use the checkpoint if provided)
|
| 917 |
try:
|
| 918 |
return self.app.training.start_training(
|
|
|
|
| 931 |
precomputation_items=precomputation_items,
|
| 932 |
lr_warmup_steps=lr_warmup_steps,
|
| 933 |
progress=progress,
|
| 934 |
+
pretrained_lora_path=pretrained_lora_path,
|
| 935 |
+
custom_prompt_prefix=custom_prompt_prefix
|
| 936 |
)
|
| 937 |
except Exception as e:
|
| 938 |
logger.exception("Error starting training")
|