Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| from typing import TypedDict, Dict, Optional, Tuple | |
| from typing_extensions import override | |
| from PIL import Image | |
| from enum import Enum | |
| from abc import ABC | |
| from tqdm import tqdm | |
| from typing import TYPE_CHECKING | |
| if TYPE_CHECKING: | |
| from comfy_execution.graph import DynamicPrompt | |
| from protocol import BinaryEventTypes | |
| from comfy_api import feature_flags | |
| PreviewImageTuple = Tuple[str, Image.Image, Optional[int]] | |
| class NodeState(Enum): | |
| Pending = "pending" | |
| Running = "running" | |
| Finished = "finished" | |
| Error = "error" | |
| class NodeProgressState(TypedDict): | |
| """ | |
| A class to represent the state of a node's progress. | |
| """ | |
| state: NodeState | |
| value: float | |
| max: float | |
| class ProgressHandler(ABC): | |
| """ | |
| Abstract base class for progress handlers. | |
| Progress handlers receive progress updates and display them in various ways. | |
| """ | |
| def __init__(self, name: str): | |
| self.name = name | |
| self.enabled = True | |
| def set_registry(self, registry: "ProgressRegistry"): | |
| pass | |
| def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): | |
| """Called when a node starts processing""" | |
| pass | |
| def update_handler( | |
| self, | |
| node_id: str, | |
| value: float, | |
| max_value: float, | |
| state: NodeProgressState, | |
| prompt_id: str, | |
| image: PreviewImageTuple | None = None, | |
| ): | |
| """Called when a node's progress is updated""" | |
| pass | |
| def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): | |
| """Called when a node finishes processing""" | |
| pass | |
| def reset(self): | |
| """Called when the progress registry is reset""" | |
| pass | |
| def enable(self): | |
| """Enable this handler""" | |
| self.enabled = True | |
| def disable(self): | |
| """Disable this handler""" | |
| self.enabled = False | |
| class CLIProgressHandler(ProgressHandler): | |
| """ | |
| Handler that displays progress using tqdm progress bars in the CLI. | |
| """ | |
| def __init__(self): | |
| super().__init__("cli") | |
| self.progress_bars: Dict[str, tqdm] = {} | |
| def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): | |
| # Create a new tqdm progress bar | |
| if node_id not in self.progress_bars: | |
| self.progress_bars[node_id] = tqdm( | |
| total=state["max"], | |
| desc=f"Node {node_id}", | |
| unit="steps", | |
| leave=True, | |
| position=len(self.progress_bars), | |
| ) | |
| def update_handler( | |
| self, | |
| node_id: str, | |
| value: float, | |
| max_value: float, | |
| state: NodeProgressState, | |
| prompt_id: str, | |
| image: PreviewImageTuple | None = None, | |
| ): | |
| # Handle case where start_handler wasn't called | |
| if node_id not in self.progress_bars: | |
| self.progress_bars[node_id] = tqdm( | |
| total=max_value, | |
| desc=f"Node {node_id}", | |
| unit="steps", | |
| leave=True, | |
| position=len(self.progress_bars), | |
| ) | |
| self.progress_bars[node_id].update(value) | |
| else: | |
| # Update existing progress bar | |
| if max_value != self.progress_bars[node_id].total: | |
| self.progress_bars[node_id].total = max_value | |
| # Calculate the update amount (difference from current position) | |
| current_position = self.progress_bars[node_id].n | |
| update_amount = value - current_position | |
| if update_amount > 0: | |
| self.progress_bars[node_id].update(update_amount) | |
| def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): | |
| # Complete and close the progress bar if it exists | |
| if node_id in self.progress_bars: | |
| # Ensure the bar shows 100% completion | |
| remaining = state["max"] - self.progress_bars[node_id].n | |
| if remaining > 0: | |
| self.progress_bars[node_id].update(remaining) | |
| self.progress_bars[node_id].close() | |
| del self.progress_bars[node_id] | |
| def reset(self): | |
| # Close all progress bars | |
| for bar in self.progress_bars.values(): | |
| bar.close() | |
| self.progress_bars.clear() | |
| class WebUIProgressHandler(ProgressHandler): | |
| """ | |
| Handler that sends progress updates to the WebUI via WebSockets. | |
| """ | |
| def __init__(self, server_instance): | |
| super().__init__("webui") | |
| self.server_instance = server_instance | |
| def set_registry(self, registry: "ProgressRegistry"): | |
| self.registry = registry | |
| def _send_progress_state(self, prompt_id: str, nodes: Dict[str, NodeProgressState]): | |
| """Send the current progress state to the client""" | |
| if self.server_instance is None: | |
| return | |
| # Only send info for non-pending nodes | |
| active_nodes = { | |
| node_id: { | |
| "value": state["value"], | |
| "max": state["max"], | |
| "state": state["state"].value, | |
| "node_id": node_id, | |
| "prompt_id": prompt_id, | |
| "display_node_id": self.registry.dynprompt.get_display_node_id(node_id), | |
| "parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id), | |
| "real_node_id": self.registry.dynprompt.get_real_node_id(node_id), | |
| } | |
| for node_id, state in nodes.items() | |
| if state["state"] != NodeState.Pending | |
| } | |
| # Send a combined progress_state message with all node states | |
| self.server_instance.send_sync( | |
| "progress_state", {"prompt_id": prompt_id, "nodes": active_nodes} | |
| ) | |
| def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): | |
| # Send progress state of all nodes | |
| if self.registry: | |
| self._send_progress_state(prompt_id, self.registry.nodes) | |
| def update_handler( | |
| self, | |
| node_id: str, | |
| value: float, | |
| max_value: float, | |
| state: NodeProgressState, | |
| prompt_id: str, | |
| image: PreviewImageTuple | None = None, | |
| ): | |
| # Send progress state of all nodes | |
| if self.registry: | |
| self._send_progress_state(prompt_id, self.registry.nodes) | |
| if image: | |
| # Only send new format if client supports it | |
| if feature_flags.supports_feature( | |
| self.server_instance.sockets_metadata, | |
| self.server_instance.client_id, | |
| "supports_preview_metadata", | |
| ): | |
| metadata = { | |
| "node_id": node_id, | |
| "prompt_id": prompt_id, | |
| "display_node_id": self.registry.dynprompt.get_display_node_id( | |
| node_id | |
| ), | |
| "parent_node_id": self.registry.dynprompt.get_parent_node_id( | |
| node_id | |
| ), | |
| "real_node_id": self.registry.dynprompt.get_real_node_id(node_id), | |
| } | |
| self.server_instance.send_sync( | |
| BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, | |
| (image, metadata), | |
| self.server_instance.client_id, | |
| ) | |
| def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): | |
| # Send progress state of all nodes | |
| if self.registry: | |
| self._send_progress_state(prompt_id, self.registry.nodes) | |
| class ProgressRegistry: | |
| """ | |
| Registry that maintains node progress state and notifies registered handlers. | |
| """ | |
| def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt"): | |
| self.prompt_id = prompt_id | |
| self.dynprompt = dynprompt | |
| self.nodes: Dict[str, NodeProgressState] = {} | |
| self.handlers: Dict[str, ProgressHandler] = {} | |
| def register_handler(self, handler: ProgressHandler) -> None: | |
| """Register a progress handler""" | |
| self.handlers[handler.name] = handler | |
| def unregister_handler(self, handler_name: str) -> None: | |
| """Unregister a progress handler""" | |
| if handler_name in self.handlers: | |
| # Allow handler to clean up resources | |
| self.handlers[handler_name].reset() | |
| del self.handlers[handler_name] | |
| def enable_handler(self, handler_name: str) -> None: | |
| """Enable a progress handler""" | |
| if handler_name in self.handlers: | |
| self.handlers[handler_name].enable() | |
| def disable_handler(self, handler_name: str) -> None: | |
| """Disable a progress handler""" | |
| if handler_name in self.handlers: | |
| self.handlers[handler_name].disable() | |
| def ensure_entry(self, node_id: str) -> NodeProgressState: | |
| """Ensure a node entry exists""" | |
| if node_id not in self.nodes: | |
| self.nodes[node_id] = NodeProgressState( | |
| state=NodeState.Pending, value=0, max=1 | |
| ) | |
| return self.nodes[node_id] | |
| def start_progress(self, node_id: str) -> None: | |
| """Start progress tracking for a node""" | |
| entry = self.ensure_entry(node_id) | |
| entry["state"] = NodeState.Running | |
| entry["value"] = 0.0 | |
| entry["max"] = 1.0 | |
| # Notify all enabled handlers | |
| for handler in self.handlers.values(): | |
| if handler.enabled: | |
| handler.start_handler(node_id, entry, self.prompt_id) | |
| def update_progress( | |
| self, node_id: str, value: float, max_value: float, image: PreviewImageTuple | None = None | |
| ) -> None: | |
| """Update progress for a node""" | |
| entry = self.ensure_entry(node_id) | |
| entry["state"] = NodeState.Running | |
| entry["value"] = value | |
| entry["max"] = max_value | |
| # Notify all enabled handlers | |
| for handler in self.handlers.values(): | |
| if handler.enabled: | |
| handler.update_handler( | |
| node_id, value, max_value, entry, self.prompt_id, image | |
| ) | |
| def finish_progress(self, node_id: str) -> None: | |
| """Finish progress tracking for a node""" | |
| entry = self.ensure_entry(node_id) | |
| entry["state"] = NodeState.Finished | |
| entry["value"] = entry["max"] | |
| # Notify all enabled handlers | |
| for handler in self.handlers.values(): | |
| if handler.enabled: | |
| handler.finish_handler(node_id, entry, self.prompt_id) | |
| def reset_handlers(self) -> None: | |
| """Reset all handlers""" | |
| for handler in self.handlers.values(): | |
| handler.reset() | |
| # Global registry instance | |
| global_progress_registry: ProgressRegistry | None = None | |
| def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None: | |
| global global_progress_registry | |
| # Reset existing handlers if registry exists | |
| if global_progress_registry is not None: | |
| global_progress_registry.reset_handlers() | |
| # Create new registry | |
| global_progress_registry = ProgressRegistry(prompt_id, dynprompt) | |
| def add_progress_handler(handler: ProgressHandler) -> None: | |
| registry = get_progress_state() | |
| handler.set_registry(registry) | |
| registry.register_handler(handler) | |
| def get_progress_state() -> ProgressRegistry: | |
| global global_progress_registry | |
| if global_progress_registry is None: | |
| from comfy_execution.graph import DynamicPrompt | |
| global_progress_registry = ProgressRegistry( | |
| prompt_id="", dynprompt=DynamicPrompt({}) | |
| ) | |
| return global_progress_registry | |