|
|
""" |
|
|
Models service for Video Model Studio |
|
|
|
|
|
Handles the model history tracking and management |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import json |
|
|
from pathlib import Path |
|
|
from typing import Dict, Any, List, Optional, Tuple |
|
|
from dataclasses import dataclass |
|
|
from datetime import datetime |
|
|
|
|
|
from vms.config import ( |
|
|
STORAGE_PATH, MODEL_TYPES, TRAINING_TYPES |
|
|
) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
logger.setLevel(logging.INFO) |
|
|
|
|
|
@dataclass |
|
|
class Model: |
|
|
"""Class for tracking model metadata""" |
|
|
id: str |
|
|
status: str |
|
|
model_type: str |
|
|
model_display_name: str |
|
|
created_at: datetime |
|
|
updated_at: datetime |
|
|
training_progress: Optional[float] = 0.0 |
|
|
current_step: Optional[int] = 0 |
|
|
total_steps: Optional[int] = 0 |
|
|
|
|
|
@classmethod |
|
|
def from_dir(cls, model_dir: Path) -> 'Model': |
|
|
"""Create a Model instance from a directory""" |
|
|
model_id = model_dir.name |
|
|
|
|
|
|
|
|
status = 'draft' |
|
|
model_type = '' |
|
|
model_display_name = '' |
|
|
created_at = datetime.fromtimestamp(model_dir.stat().st_ctime) |
|
|
updated_at = datetime.fromtimestamp(model_dir.stat().st_mtime) |
|
|
training_progress = 0.0 |
|
|
current_step = 0 |
|
|
total_steps = 0 |
|
|
|
|
|
|
|
|
ui_state_file = model_dir / "output" / "ui_state.json" |
|
|
if ui_state_file.exists(): |
|
|
try: |
|
|
with open(ui_state_file, 'r') as f: |
|
|
ui_state = json.load(f) |
|
|
status = ui_state.get('project_status', 'draft') |
|
|
|
|
|
|
|
|
model_type_value = ui_state.get('model_type', '') |
|
|
|
|
|
|
|
|
display_name_found = False |
|
|
for display_name, internal_name in MODEL_TYPES.items(): |
|
|
if display_name == model_type_value: |
|
|
model_type = internal_name |
|
|
model_display_name = display_name |
|
|
display_name_found = True |
|
|
break |
|
|
|
|
|
|
|
|
if not display_name_found: |
|
|
for display_name, internal_name in MODEL_TYPES.items(): |
|
|
if internal_name == model_type_value: |
|
|
model_type = internal_name |
|
|
model_display_name = display_name |
|
|
break |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading UI state for model {model_id}: {str(e)}") |
|
|
|
|
|
|
|
|
status_file = model_dir / "output" / "status.json" |
|
|
if status_file.exists(): |
|
|
try: |
|
|
with open(status_file, 'r') as f: |
|
|
status_data = json.load(f) |
|
|
if status_data.get('status') == 'training': |
|
|
status = 'training' |
|
|
current_step = status_data.get('step', 0) |
|
|
total_steps = status_data.get('total_steps', 0) |
|
|
if total_steps > 0: |
|
|
training_progress = (current_step / total_steps) * 100 |
|
|
elif status_data.get('status') == 'completed': |
|
|
status = 'trained' |
|
|
training_progress = 100.0 |
|
|
elif status_data.get('status') == 'error': |
|
|
status = 'error' |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading status for model {model_id}: {str(e)}") |
|
|
|
|
|
|
|
|
pid_file = model_dir / "output" / "training.pid" |
|
|
if pid_file.exists(): |
|
|
status = 'training' |
|
|
|
|
|
|
|
|
model_weights = model_dir / "output" / "pytorch_lora_weights.safetensors" |
|
|
if model_weights.exists() and status != 'training': |
|
|
status = 'trained' |
|
|
training_progress = 100.0 |
|
|
|
|
|
return cls( |
|
|
id=model_id, |
|
|
status=status, |
|
|
model_type=model_type, |
|
|
model_display_name=model_display_name, |
|
|
created_at=created_at, |
|
|
updated_at=updated_at, |
|
|
training_progress=training_progress, |
|
|
current_step=current_step, |
|
|
total_steps=total_steps |
|
|
) |
|
|
|
|
|
class ModelsService: |
|
|
"""Service for tracking and managing model history""" |
|
|
|
|
|
def __init__(self, app_state=None): |
|
|
"""Initialize the models service |
|
|
|
|
|
Args: |
|
|
app_state: Reference to main application state |
|
|
""" |
|
|
self.app = app_state |
|
|
|
|
|
def get_all_models(self) -> List[Model]: |
|
|
"""Get a list of all models |
|
|
|
|
|
Returns: |
|
|
List of Model objects |
|
|
""" |
|
|
models_dir = STORAGE_PATH / "models" |
|
|
if not models_dir.exists(): |
|
|
return [] |
|
|
|
|
|
models = [] |
|
|
|
|
|
for model_dir in models_dir.iterdir(): |
|
|
if not model_dir.is_dir(): |
|
|
continue |
|
|
|
|
|
try: |
|
|
model = Model.from_dir(model_dir) |
|
|
models.append(model) |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading model from {model_dir}: {str(e)}") |
|
|
|
|
|
|
|
|
return sorted(models, key=lambda m: m.updated_at, reverse=True) |
|
|
|
|
|
def get_draft_models(self) -> List[Model]: |
|
|
"""Get a list of draft models |
|
|
|
|
|
Returns: |
|
|
List of Model objects with 'draft' status |
|
|
""" |
|
|
return [m for m in self.get_all_models() if m.status == 'draft'] |
|
|
|
|
|
def get_training_models(self) -> List[Model]: |
|
|
"""Get a list of models currently in training |
|
|
|
|
|
Returns: |
|
|
List of Model objects with 'training' status |
|
|
""" |
|
|
return [m for m in self.get_all_models() if m.status == 'training'] |
|
|
|
|
|
def get_trained_models(self) -> List[Model]: |
|
|
"""Get a list of completed trained models |
|
|
|
|
|
Returns: |
|
|
List of Model objects with 'trained' status |
|
|
""" |
|
|
return [m for m in self.get_all_models() if m.status == 'trained'] |
|
|
|
|
|
def delete_model(self, model_id: str) -> bool: |
|
|
"""Delete a model by ID |
|
|
|
|
|
Args: |
|
|
model_id: The model ID to delete |
|
|
|
|
|
Returns: |
|
|
True if deletion was successful |
|
|
""" |
|
|
if not model_id: |
|
|
return False |
|
|
|
|
|
model_dir = STORAGE_PATH / "models" / model_id |
|
|
if not model_dir.exists(): |
|
|
return False |
|
|
|
|
|
try: |
|
|
import shutil |
|
|
shutil.rmtree(model_dir) |
|
|
return True |
|
|
except Exception as e: |
|
|
logger.error(f"Error deleting model {model_id}: {str(e)}") |
|
|
return False |