Julian Bilcke
commited on
Commit
Β·
19db9a3
1
Parent(s):
6475195
workaround for Finetrainers
Browse files
finetrainers/data/dataset.py
CHANGED
|
@@ -970,9 +970,59 @@ def _preprocess_image(image: PIL.Image.Image) -> torch.Tensor:
|
|
| 970 |
image = image.permute(2, 0, 1).contiguous() / 127.5 - 1.0
|
| 971 |
return image
|
| 972 |
|
| 973 |
-
|
| 974 |
-
|
| 975 |
-
|
| 976 |
-
|
| 977 |
-
|
| 978 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 970 |
image = image.permute(2, 0, 1).contiguous() / 127.5 - 1.0
|
| 971 |
return image
|
| 972 |
|
| 973 |
+
def _preprocess_video(video) -> torch.Tensor:
|
| 974 |
+
import torch
|
| 975 |
+
import numpy as np
|
| 976 |
+
|
| 977 |
+
# For decord VideoReader
|
| 978 |
+
if hasattr(video, 'get_batch') and 'decord' in str(type(video)):
|
| 979 |
+
video = video.get_batch(list(range(len(video))))
|
| 980 |
+
video = video.permute(0, 3, 1, 2).contiguous() / 127.5 - 1.0
|
| 981 |
+
return video
|
| 982 |
+
|
| 983 |
+
# For torchvision VideoReader
|
| 984 |
+
elif 'torchvision.io.video_reader' in str(type(video)):
|
| 985 |
+
# Use the correct iteration pattern for torchvision.io.VideoReader
|
| 986 |
+
frames = []
|
| 987 |
+
try:
|
| 988 |
+
# First seek to the beginning
|
| 989 |
+
video.seek(0)
|
| 990 |
+
|
| 991 |
+
# Then collect frames by iterating
|
| 992 |
+
for _ in range(30): # Try to get a reasonable number of frames
|
| 993 |
+
try:
|
| 994 |
+
frame_dict = next(video)
|
| 995 |
+
frame = frame_dict["data"] # Extract the tensor data from the dict
|
| 996 |
+
frames.append(frame)
|
| 997 |
+
except StopIteration:
|
| 998 |
+
break
|
| 999 |
+
except Exception as e:
|
| 1000 |
+
print(f"Error iterating VideoReader: {e}")
|
| 1001 |
+
|
| 1002 |
+
if frames:
|
| 1003 |
+
# In torchvision.io.VideoReader, frames are already in [C, H, W] format
|
| 1004 |
+
# We need to stack and convert to [B, C, H, W]
|
| 1005 |
+
stacked_frames = torch.stack(frames)
|
| 1006 |
+
# Normalize to [-1, 1]
|
| 1007 |
+
stacked_frames = stacked_frames.float() / 127.5 - 1.0
|
| 1008 |
+
return stacked_frames
|
| 1009 |
+
|
| 1010 |
+
# If we couldn't get frames, create a dummy tensor
|
| 1011 |
+
print("Failed to get frames, creating dummy tensor")
|
| 1012 |
+
return torch.zeros(16, 3, 512, 768).float()
|
| 1013 |
+
|
| 1014 |
+
# For list of PIL images
|
| 1015 |
+
elif isinstance(video, list) and len(video) > 0 and hasattr(video[0], 'convert'):
|
| 1016 |
+
frames = []
|
| 1017 |
+
for img in video:
|
| 1018 |
+
img_tensor = torch.from_numpy(np.array(img.convert("RGB"))).float()
|
| 1019 |
+
frames.append(img_tensor)
|
| 1020 |
+
|
| 1021 |
+
video = torch.stack(frames)
|
| 1022 |
+
video = video.permute(0, 3, 1, 2).contiguous() / 127.5 - 1.0
|
| 1023 |
+
return video
|
| 1024 |
+
|
| 1025 |
+
# Unknown type
|
| 1026 |
+
else:
|
| 1027 |
+
print(f"Unknown video type: {type(video)}")
|
| 1028 |
+
return torch.zeros(16, 3, 512, 768).float()
|
finetrainers/trainer/sft_trainer/trainer.py
CHANGED
|
@@ -325,8 +325,21 @@ class SFTTrainer:
|
|
| 325 |
resume_from_checkpoint = self.args.resume_from_checkpoint
|
| 326 |
if resume_from_checkpoint == "latest":
|
| 327 |
resume_from_checkpoint = -1
|
|
|
|
|
|
|
|
|
|
| 328 |
if resume_from_checkpoint is not None:
|
| 329 |
-
self.checkpointer.load(resume_from_checkpoint)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
|
| 331 |
def _train(self) -> None:
|
| 332 |
logger.info("Starting training")
|
|
|
|
| 325 |
resume_from_checkpoint = self.args.resume_from_checkpoint
|
| 326 |
if resume_from_checkpoint == "latest":
|
| 327 |
resume_from_checkpoint = -1
|
| 328 |
+
|
| 329 |
+
# Store the load result
|
| 330 |
+
load_successful = False
|
| 331 |
if resume_from_checkpoint is not None:
|
| 332 |
+
load_successful = self.checkpointer.load(resume_from_checkpoint)
|
| 333 |
+
|
| 334 |
+
# If loading succeeded and we have a specific checkpoint path
|
| 335 |
+
if load_successful and isinstance(resume_from_checkpoint, str) and resume_from_checkpoint != "latest":
|
| 336 |
+
try:
|
| 337 |
+
step = int(resume_from_checkpoint.split("_")[-1])
|
| 338 |
+
self.state.train_state.step = step
|
| 339 |
+
logger.info(f"Explicitly setting training step to {step} based on checkpoint path")
|
| 340 |
+
except (ValueError, IndexError):
|
| 341 |
+
logger.warning(f"Could not parse step number from checkpoint path: {resume_from_checkpoint}")
|
| 342 |
+
|
| 343 |
|
| 344 |
def _train(self) -> None:
|
| 345 |
logger.info("Starting training")
|
vms/ui/app_ui.py
CHANGED
|
@@ -146,7 +146,7 @@ class AppUI:
|
|
| 146 |
# Sidebar for navigation
|
| 147 |
with gr.Sidebar(position="left", open=True):
|
| 148 |
gr.Markdown("# ποΈ Video Model Studio")
|
| 149 |
-
self.components["current_project_btn"] = gr.Button("π
|
| 150 |
self.components["system_monitoring_btn"] = gr.Button("π‘οΈ System Monitoring")
|
| 151 |
|
| 152 |
# Main content area with tabs
|
|
@@ -156,7 +156,7 @@ class AppUI:
|
|
| 156 |
self.main_tabs = main_tabs
|
| 157 |
|
| 158 |
# Project View Tab
|
| 159 |
-
with gr.Tab("π
|
| 160 |
# Create project tabs
|
| 161 |
with gr.Tabs() as project_tabs:
|
| 162 |
# Store reference to project tabs component
|
|
@@ -551,20 +551,20 @@ class AppUI:
|
|
| 551 |
if is_training:
|
| 552 |
# Active training detected
|
| 553 |
start_btn_props = {"interactive": False, "variant": "secondary", "value": "π Start new training"}
|
| 554 |
-
resume_btn_props = {"interactive": False, "variant": "secondary", "value": "
|
| 555 |
stop_btn_props = {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"}
|
| 556 |
delete_btn_props = {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"}
|
| 557 |
else:
|
| 558 |
# No active training
|
| 559 |
start_btn_props = {"interactive": True, "variant": "primary", "value": "π Start new training"}
|
| 560 |
-
resume_btn_props = {"interactive": has_checkpoints, "variant": "primary", "value": "
|
| 561 |
stop_btn_props = {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"}
|
| 562 |
delete_btn_props = {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"}
|
| 563 |
else:
|
| 564 |
# Use button states from recovery, adding the new resume button
|
| 565 |
start_btn_props = ui_updates.get("start_btn", {"interactive": True, "variant": "primary", "value": "π Start new training"})
|
| 566 |
resume_btn_props = {"interactive": has_checkpoints and not self.training.is_training_running(),
|
| 567 |
-
"variant": "primary", "value": "
|
| 568 |
stop_btn_props = ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"})
|
| 569 |
delete_btn_props = ui_updates.get("delete_checkpoints_btn", {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"})
|
| 570 |
|
|
|
|
| 146 |
# Sidebar for navigation
|
| 147 |
with gr.Sidebar(position="left", open=True):
|
| 148 |
gr.Markdown("# ποΈ Video Model Studio")
|
| 149 |
+
self.components["current_project_btn"] = gr.Button("π New Project", variant="primary")
|
| 150 |
self.components["system_monitoring_btn"] = gr.Button("π‘οΈ System Monitoring")
|
| 151 |
|
| 152 |
# Main content area with tabs
|
|
|
|
| 156 |
self.main_tabs = main_tabs
|
| 157 |
|
| 158 |
# Project View Tab
|
| 159 |
+
with gr.Tab("π New Project", id=0) as project_view:
|
| 160 |
# Create project tabs
|
| 161 |
with gr.Tabs() as project_tabs:
|
| 162 |
# Store reference to project tabs component
|
|
|
|
| 551 |
if is_training:
|
| 552 |
# Active training detected
|
| 553 |
start_btn_props = {"interactive": False, "variant": "secondary", "value": "π Start new training"}
|
| 554 |
+
resume_btn_props = {"interactive": False, "variant": "secondary", "value": "πΈ Start from latest checkpoint"}
|
| 555 |
stop_btn_props = {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"}
|
| 556 |
delete_btn_props = {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"}
|
| 557 |
else:
|
| 558 |
# No active training
|
| 559 |
start_btn_props = {"interactive": True, "variant": "primary", "value": "π Start new training"}
|
| 560 |
+
resume_btn_props = {"interactive": has_checkpoints, "variant": "primary", "value": "πΈ Start from latest checkpoint"}
|
| 561 |
stop_btn_props = {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"}
|
| 562 |
delete_btn_props = {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"}
|
| 563 |
else:
|
| 564 |
# Use button states from recovery, adding the new resume button
|
| 565 |
start_btn_props = ui_updates.get("start_btn", {"interactive": True, "variant": "primary", "value": "π Start new training"})
|
| 566 |
resume_btn_props = {"interactive": has_checkpoints and not self.training.is_training_running(),
|
| 567 |
+
"variant": "primary", "value": "πΈ Start from latest checkpoint"}
|
| 568 |
stop_btn_props = ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"})
|
| 569 |
delete_btn_props = ui_updates.get("delete_checkpoints_btn", {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"})
|
| 570 |
|
vms/ui/project/tabs/train_tab.py
CHANGED
|
@@ -187,8 +187,8 @@ class TrainTab(BaseTab):
|
|
| 187 |
# Add description of the training buttons
|
| 188 |
self.components["training_buttons_info"] = gr.Markdown("""
|
| 189 |
## βοΈ Train your model on your dataset
|
| 190 |
-
-
|
| 191 |
-
-
|
| 192 |
""")
|
| 193 |
|
| 194 |
with gr.Row():
|
|
@@ -204,7 +204,7 @@ class TrainTab(BaseTab):
|
|
| 204 |
|
| 205 |
# Add new button for continuing from checkpoint
|
| 206 |
self.components["resume_btn"] = gr.Button(
|
| 207 |
-
"
|
| 208 |
variant="primary",
|
| 209 |
interactive=has_checkpoints and not ASK_USER_TO_DUPLICATE_SPACE
|
| 210 |
)
|
|
@@ -972,7 +972,7 @@ class TrainTab(BaseTab):
|
|
| 972 |
)
|
| 973 |
|
| 974 |
resume_btn = gr.Button(
|
| 975 |
-
value="Start from latest checkpoint",
|
| 976 |
interactive=has_checkpoints and not is_training,
|
| 977 |
variant="primary" if not is_training else "secondary"
|
| 978 |
)
|
|
|
|
| 187 |
# Add description of the training buttons
|
| 188 |
self.components["training_buttons_info"] = gr.Markdown("""
|
| 189 |
## βοΈ Train your model on your dataset
|
| 190 |
+
- **π Start new training**: Begins training from scratch (clears previous checkpoints)
|
| 191 |
+
- **πΈ Start from latest checkpoint**: Continues training from the most recent checkpoint
|
| 192 |
""")
|
| 193 |
|
| 194 |
with gr.Row():
|
|
|
|
| 204 |
|
| 205 |
# Add new button for continuing from checkpoint
|
| 206 |
self.components["resume_btn"] = gr.Button(
|
| 207 |
+
"πΈ Start from latest checkpoint",
|
| 208 |
variant="primary",
|
| 209 |
interactive=has_checkpoints and not ASK_USER_TO_DUPLICATE_SPACE
|
| 210 |
)
|
|
|
|
| 972 |
)
|
| 973 |
|
| 974 |
resume_btn = gr.Button(
|
| 975 |
+
value="πΈ Start from latest checkpoint",
|
| 976 |
interactive=has_checkpoints and not is_training,
|
| 977 |
variant="primary" if not is_training else "secondary"
|
| 978 |
)
|