|
|
""" |
|
|
API Client Framework for api.comfy.org. |
|
|
|
|
|
This module provides a flexible framework for making API requests from ComfyUI nodes. |
|
|
It supports both synchronous and asynchronous API operations with proper type validation. |
|
|
|
|
|
Key Components: |
|
|
-------------- |
|
|
1. ApiClient - Handles HTTP requests with authentication and error handling |
|
|
2. ApiEndpoint - Defines a single HTTP endpoint with its request/response models |
|
|
3. ApiOperation - Executes a single synchronous API operation |
|
|
|
|
|
Usage Examples: |
|
|
-------------- |
|
|
|
|
|
# Example 1: Synchronous API Operation |
|
|
# ------------------------------------ |
|
|
# For a simple API call that returns the result immediately: |
|
|
|
|
|
# 1. Create the API client |
|
|
api_client = ApiClient( |
|
|
base_url="https://api.example.com", |
|
|
auth_token="your_auth_token_here", |
|
|
comfy_api_key="your_comfy_api_key_here", |
|
|
timeout=30.0, |
|
|
verify_ssl=True |
|
|
) |
|
|
|
|
|
# 2. Define the endpoint |
|
|
user_info_endpoint = ApiEndpoint( |
|
|
path="/v1/users/me", |
|
|
method=HttpMethod.GET, |
|
|
request_model=EmptyRequest, # No request body needed |
|
|
response_model=UserProfile, # Pydantic model for the response |
|
|
query_params=None |
|
|
) |
|
|
|
|
|
# 3. Create the request object |
|
|
request = EmptyRequest() |
|
|
|
|
|
# 4. Create and execute the operation |
|
|
operation = ApiOperation( |
|
|
endpoint=user_info_endpoint, |
|
|
request=request |
|
|
) |
|
|
user_profile = await operation.execute(client=api_client) # Returns immediately with the result |
|
|
|
|
|
|
|
|
# Example 2: Asynchronous API Operation with Polling |
|
|
# ------------------------------------------------- |
|
|
# For an API that starts a task and requires polling for completion: |
|
|
|
|
|
# 1. Define the endpoints (initial request and polling) |
|
|
generate_image_endpoint = ApiEndpoint( |
|
|
path="/v1/images/generate", |
|
|
method=HttpMethod.POST, |
|
|
request_model=ImageGenerationRequest, |
|
|
response_model=TaskCreatedResponse, |
|
|
query_params=None |
|
|
) |
|
|
|
|
|
check_task_endpoint = ApiEndpoint( |
|
|
path="/v1/tasks/{task_id}", |
|
|
method=HttpMethod.GET, |
|
|
request_model=EmptyRequest, |
|
|
response_model=ImageGenerationResult, |
|
|
query_params=None |
|
|
) |
|
|
|
|
|
# 2. Create the request object |
|
|
request = ImageGenerationRequest( |
|
|
prompt="a beautiful sunset over mountains", |
|
|
width=1024, |
|
|
height=1024, |
|
|
num_images=1 |
|
|
) |
|
|
|
|
|
# 3. Create and execute the polling operation |
|
|
operation = PollingOperation( |
|
|
initial_endpoint=generate_image_endpoint, |
|
|
initial_request=request, |
|
|
poll_endpoint=check_task_endpoint, |
|
|
task_id_field="task_id", |
|
|
status_field="status", |
|
|
completed_statuses=["completed"], |
|
|
failed_statuses=["failed", "error"] |
|
|
) |
|
|
|
|
|
# This will make the initial request and then poll until completion |
|
|
result = await operation.execute(client=api_client) # Returns the final ImageGenerationResult when done |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
import aiohttp |
|
|
import asyncio |
|
|
import logging |
|
|
import io |
|
|
import socket |
|
|
from aiohttp.client_exceptions import ClientError, ClientResponseError |
|
|
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple |
|
|
from enum import Enum |
|
|
import json |
|
|
from urllib.parse import urljoin, urlparse |
|
|
from pydantic import BaseModel, Field |
|
|
import uuid |
|
|
|
|
|
from server import PromptServer |
|
|
from comfy.cli_args import args |
|
|
from comfy import utils |
|
|
from . import request_logger |
|
|
|
|
|
T = TypeVar("T", bound=BaseModel) |
|
|
R = TypeVar("R", bound=BaseModel) |
|
|
P = TypeVar("P", bound=BaseModel) |
|
|
|
|
|
PROGRESS_BAR_MAX = 100 |
|
|
|
|
|
|
|
|
class NetworkError(Exception): |
|
|
"""Base exception for network-related errors with diagnostic information.""" |
|
|
pass |
|
|
|
|
|
|
|
|
class LocalNetworkError(NetworkError): |
|
|
"""Exception raised when local network connectivity issues are detected.""" |
|
|
pass |
|
|
|
|
|
|
|
|
class ApiServerError(NetworkError): |
|
|
"""Exception raised when the API server is unreachable but internet is working.""" |
|
|
pass |
|
|
|
|
|
|
|
|
class EmptyRequest(BaseModel): |
|
|
"""Base class for empty request bodies. |
|
|
For GET requests, fields will be sent as query parameters.""" |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
class UploadRequest(BaseModel): |
|
|
file_name: str = Field(..., description="Filename to upload") |
|
|
content_type: Optional[str] = Field( |
|
|
None, |
|
|
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.", |
|
|
) |
|
|
|
|
|
|
|
|
class UploadResponse(BaseModel): |
|
|
download_url: str = Field(..., description="URL to GET uploaded file") |
|
|
upload_url: str = Field(..., description="URL to PUT file to upload") |
|
|
|
|
|
|
|
|
class HttpMethod(str, Enum): |
|
|
GET = "GET" |
|
|
POST = "POST" |
|
|
PUT = "PUT" |
|
|
DELETE = "DELETE" |
|
|
PATCH = "PATCH" |
|
|
|
|
|
|
|
|
class ApiClient: |
|
|
""" |
|
|
Client for making HTTP requests to an API with authentication, error handling, and retry logic. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
base_url: str, |
|
|
auth_token: Optional[str] = None, |
|
|
comfy_api_key: Optional[str] = None, |
|
|
timeout: float = 3600.0, |
|
|
verify_ssl: bool = True, |
|
|
max_retries: int = 3, |
|
|
retry_delay: float = 1.0, |
|
|
retry_backoff_factor: float = 2.0, |
|
|
retry_status_codes: Optional[Tuple[int, ...]] = None, |
|
|
session: Optional[aiohttp.ClientSession] = None, |
|
|
): |
|
|
self.base_url = base_url |
|
|
self.auth_token = auth_token |
|
|
self.comfy_api_key = comfy_api_key |
|
|
self.timeout = timeout |
|
|
self.verify_ssl = verify_ssl |
|
|
self.max_retries = max_retries |
|
|
self.retry_delay = retry_delay |
|
|
self.retry_backoff_factor = retry_backoff_factor |
|
|
|
|
|
|
|
|
self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504) |
|
|
self._session: Optional[aiohttp.ClientSession] = session |
|
|
self._owns_session = session is None |
|
|
|
|
|
@staticmethod |
|
|
def _generate_operation_id(path: str) -> str: |
|
|
"""Generates a unique operation ID for logging.""" |
|
|
return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}" |
|
|
|
|
|
@staticmethod |
|
|
def _create_json_payload_args( |
|
|
data: Optional[Dict[str, Any]] = None, |
|
|
headers: Optional[Dict[str, str]] = None, |
|
|
) -> Dict[str, Any]: |
|
|
return { |
|
|
"json": data, |
|
|
"headers": headers, |
|
|
} |
|
|
|
|
|
def _create_form_data_args( |
|
|
self, |
|
|
data: Dict[str, Any] | None, |
|
|
files: Dict[str, Any] | None, |
|
|
headers: Optional[Dict[str, str]] = None, |
|
|
multipart_parser: Callable | None = None, |
|
|
) -> Dict[str, Any]: |
|
|
if headers and "Content-Type" in headers: |
|
|
del headers["Content-Type"] |
|
|
|
|
|
if multipart_parser and data: |
|
|
data = multipart_parser(data) |
|
|
|
|
|
form = aiohttp.FormData(default_to_multipart=True) |
|
|
if data: |
|
|
for k, v in data.items(): |
|
|
if v is None: |
|
|
continue |
|
|
|
|
|
form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v) |
|
|
|
|
|
if files: |
|
|
file_iter = files if isinstance(files, list) else files.items() |
|
|
for field_name, file_obj in file_iter: |
|
|
if file_obj is None: |
|
|
continue |
|
|
|
|
|
if isinstance(file_obj, tuple): |
|
|
filename, file_value, content_type = self._unpack_tuple(file_obj) |
|
|
else: |
|
|
file_value = file_obj |
|
|
filename = getattr(file_obj, "name", field_name) |
|
|
content_type = "application/octet-stream" |
|
|
|
|
|
form.add_field( |
|
|
name=field_name, |
|
|
value=file_value, |
|
|
filename=filename, |
|
|
content_type=content_type, |
|
|
) |
|
|
return {"data": form, "headers": headers or {}} |
|
|
|
|
|
@staticmethod |
|
|
def _create_urlencoded_form_data_args( |
|
|
data: Dict[str, Any], |
|
|
headers: Optional[Dict[str, str]] = None, |
|
|
) -> Dict[str, Any]: |
|
|
headers = headers or {} |
|
|
headers["Content-Type"] = "application/x-www-form-urlencoded" |
|
|
return { |
|
|
"data": data, |
|
|
"headers": headers, |
|
|
} |
|
|
|
|
|
def get_headers(self) -> Dict[str, str]: |
|
|
"""Get headers for API requests, including authentication if available""" |
|
|
headers = {"Content-Type": "application/json", "Accept": "application/json"} |
|
|
|
|
|
if self.auth_token: |
|
|
headers["Authorization"] = f"Bearer {self.auth_token}" |
|
|
elif self.comfy_api_key: |
|
|
headers["X-API-KEY"] = self.comfy_api_key |
|
|
|
|
|
return headers |
|
|
|
|
|
async def _check_connectivity(self, target_url: str) -> Dict[str, bool]: |
|
|
""" |
|
|
Check connectivity to determine if network issues are local or server-related. |
|
|
|
|
|
Args: |
|
|
target_url: URL to check connectivity to |
|
|
|
|
|
Returns: |
|
|
Dictionary with connectivity status details |
|
|
""" |
|
|
results = { |
|
|
"internet_accessible": False, |
|
|
"api_accessible": False, |
|
|
"is_local_issue": False, |
|
|
"is_api_issue": False, |
|
|
} |
|
|
timeout = aiohttp.ClientTimeout(total=5.0) |
|
|
async with aiohttp.ClientSession(timeout=timeout) as session: |
|
|
try: |
|
|
async with session.get("https://www.google.com", ssl=self.verify_ssl) as resp: |
|
|
results["internet_accessible"] = resp.status < 500 |
|
|
except (ClientError, asyncio.TimeoutError, socket.gaierror): |
|
|
results["is_local_issue"] = True |
|
|
return results |
|
|
|
|
|
|
|
|
parsed = urlparse(target_url) |
|
|
health_url = f"{parsed.scheme}://{parsed.netloc}/health" |
|
|
try: |
|
|
async with session.get(health_url, ssl=self.verify_ssl) as resp: |
|
|
results["api_accessible"] = resp.status < 500 |
|
|
except ClientError: |
|
|
pass |
|
|
|
|
|
results["is_api_issue"] = results["internet_accessible"] and not results["api_accessible"] |
|
|
return results |
|
|
|
|
|
async def request( |
|
|
self, |
|
|
method: str, |
|
|
path: str, |
|
|
params: Optional[Dict[str, Any]] = None, |
|
|
data: Optional[Dict[str, Any]] = None, |
|
|
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None, |
|
|
headers: Optional[Dict[str, str]] = None, |
|
|
content_type: str = "application/json", |
|
|
multipart_parser: Callable | None = None, |
|
|
retry_count: int = 0, |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Make an HTTP request to the API with automatic retries for transient errors. |
|
|
|
|
|
Args: |
|
|
method: HTTP method (GET, POST, etc.) |
|
|
path: API endpoint path (will be joined with base_url) |
|
|
params: Query parameters |
|
|
data: body data |
|
|
files: Files to upload |
|
|
headers: Additional headers |
|
|
content_type: Content type of the request. Defaults to application/json. |
|
|
retry_count: Internal parameter for tracking retries, do not set manually |
|
|
|
|
|
Returns: |
|
|
Parsed JSON response |
|
|
|
|
|
Raises: |
|
|
LocalNetworkError: If local network connectivity issues are detected |
|
|
ApiServerError: If the API server is unreachable but internet is working |
|
|
Exception: For other request failures |
|
|
""" |
|
|
|
|
|
|
|
|
relative_path = path.lstrip("/") |
|
|
url = urljoin(self.base_url, relative_path) |
|
|
self._check_auth(self.auth_token, self.comfy_api_key) |
|
|
|
|
|
request_headers = self.get_headers() |
|
|
if headers: |
|
|
request_headers.update(headers) |
|
|
if files: |
|
|
request_headers.pop("Content-Type", None) |
|
|
if params: |
|
|
params = {k: v for k, v in params.items() if v is not None} |
|
|
|
|
|
logging.debug(f"[DEBUG] Request Headers: {request_headers}") |
|
|
logging.debug(f"[DEBUG] Files: {files}") |
|
|
logging.debug(f"[DEBUG] Params: {params}") |
|
|
logging.debug(f"[DEBUG] Data: {data}") |
|
|
|
|
|
if content_type == "application/x-www-form-urlencoded": |
|
|
payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers) |
|
|
elif content_type == "multipart/form-data": |
|
|
payload_args = self._create_form_data_args(data, files, request_headers, multipart_parser) |
|
|
else: |
|
|
payload_args = self._create_json_payload_args(data, request_headers) |
|
|
|
|
|
operation_id = self._generate_operation_id(path) |
|
|
request_logger.log_request_response( |
|
|
operation_id=operation_id, |
|
|
request_method=method, |
|
|
request_url=url, |
|
|
request_headers=request_headers, |
|
|
request_params=params, |
|
|
request_data=data if content_type == "application/json" else "[form-data or other]", |
|
|
) |
|
|
|
|
|
session = await self._get_session() |
|
|
try: |
|
|
async with session.request( |
|
|
method, |
|
|
url, |
|
|
params=params, |
|
|
ssl=self.verify_ssl, |
|
|
**payload_args, |
|
|
) as resp: |
|
|
if resp.status >= 400: |
|
|
try: |
|
|
error_data = await resp.json() |
|
|
except (aiohttp.ContentTypeError, json.JSONDecodeError): |
|
|
error_data = await resp.text() |
|
|
|
|
|
return await self._handle_http_error( |
|
|
ClientResponseError(resp.request_info, resp.history, status=resp.status, message=error_data), |
|
|
operation_id, |
|
|
method, |
|
|
url, |
|
|
params, |
|
|
data, |
|
|
files, |
|
|
headers, |
|
|
content_type, |
|
|
multipart_parser, |
|
|
retry_count=retry_count, |
|
|
response_content=error_data, |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
payload = await resp.json() |
|
|
response_content_to_log = payload |
|
|
except (aiohttp.ContentTypeError, json.JSONDecodeError): |
|
|
payload = {} |
|
|
response_content_to_log = await resp.text() |
|
|
|
|
|
request_logger.log_request_response( |
|
|
operation_id=operation_id, |
|
|
request_method=method, |
|
|
request_url=url, |
|
|
response_status_code=resp.status, |
|
|
response_headers=dict(resp.headers), |
|
|
response_content=response_content_to_log, |
|
|
) |
|
|
return payload |
|
|
|
|
|
except (ClientError, asyncio.TimeoutError, socket.gaierror) as e: |
|
|
|
|
|
if retry_count < self.max_retries: |
|
|
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) |
|
|
logging.warning("Connection error. Retrying in %.2fs (%s/%s): %s", delay, retry_count + 1, |
|
|
self.max_retries, str(e)) |
|
|
await asyncio.sleep(delay) |
|
|
return await self.request( |
|
|
method, |
|
|
path, |
|
|
params=params, |
|
|
data=data, |
|
|
files=files, |
|
|
headers=headers, |
|
|
content_type=content_type, |
|
|
multipart_parser=multipart_parser, |
|
|
retry_count=retry_count + 1, |
|
|
) |
|
|
|
|
|
connectivity = await self._check_connectivity(self.base_url) |
|
|
if connectivity["is_local_issue"]: |
|
|
raise LocalNetworkError( |
|
|
"Unable to connect to the API server due to local network issues. " |
|
|
"Please check your internet connection and try again." |
|
|
) from e |
|
|
raise ApiServerError( |
|
|
f"The API server at {self.base_url} is currently unreachable. " |
|
|
f"The service may be experiencing issues. Please try again later." |
|
|
) from e |
|
|
|
|
|
@staticmethod |
|
|
def _check_auth(auth_token, comfy_api_key): |
|
|
"""Verify that an auth token is present or comfy_api_key is present""" |
|
|
if auth_token is None and comfy_api_key is None: |
|
|
raise Exception("Unauthorized: Please login first to use this node.") |
|
|
return auth_token or comfy_api_key |
|
|
|
|
|
@staticmethod |
|
|
async def upload_file( |
|
|
upload_url: str, |
|
|
file: io.BytesIO | str, |
|
|
content_type: str | None = None, |
|
|
max_retries: int = 3, |
|
|
retry_delay: float = 1.0, |
|
|
retry_backoff_factor: float = 2.0, |
|
|
) -> aiohttp.ClientResponse: |
|
|
"""Upload a file to the API with retry logic. |
|
|
|
|
|
Args: |
|
|
upload_url: The URL to upload to |
|
|
file: Either a file path string, BytesIO object, or tuple of (file_path, filename) |
|
|
content_type: Optional mime type to set for the upload |
|
|
max_retries: Maximum number of retry attempts |
|
|
retry_delay: Initial delay between retries in seconds |
|
|
retry_backoff_factor: Multiplier for the delay after each retry |
|
|
""" |
|
|
headers: Dict[str, str] = {} |
|
|
skip_auto_headers: set[str] = set() |
|
|
if content_type: |
|
|
headers["Content-Type"] = content_type |
|
|
else: |
|
|
|
|
|
skip_auto_headers.add("Content-Type") |
|
|
|
|
|
|
|
|
if isinstance(file, io.BytesIO): |
|
|
file.seek(0) |
|
|
data = file.read() |
|
|
elif isinstance(file, str): |
|
|
with open(file, "rb") as f: |
|
|
data = f.read() |
|
|
else: |
|
|
raise ValueError("File must be BytesIO or str path") |
|
|
|
|
|
operation_id = f"upload_{upload_url.split('/')[-1]}_{uuid.uuid4().hex[:8]}" |
|
|
request_logger.log_request_response( |
|
|
operation_id=operation_id, |
|
|
request_method="PUT", |
|
|
request_url=upload_url, |
|
|
request_headers=headers, |
|
|
request_data=f"[File data {len(data)} bytes]", |
|
|
) |
|
|
|
|
|
delay = retry_delay |
|
|
for attempt in range(max_retries + 1): |
|
|
try: |
|
|
timeout = aiohttp.ClientTimeout(total=None) |
|
|
async with aiohttp.ClientSession(timeout=timeout) as session: |
|
|
async with session.put( |
|
|
upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers, |
|
|
) as resp: |
|
|
resp.raise_for_status() |
|
|
request_logger.log_request_response( |
|
|
operation_id=operation_id, |
|
|
request_method="PUT", |
|
|
request_url=upload_url, |
|
|
response_status_code=resp.status, |
|
|
response_headers=dict(resp.headers), |
|
|
response_content="File uploaded successfully.", |
|
|
) |
|
|
return resp |
|
|
except (ClientError, asyncio.TimeoutError) as e: |
|
|
request_logger.log_request_response( |
|
|
operation_id=operation_id, |
|
|
request_method="PUT", |
|
|
request_url=upload_url, |
|
|
response_status_code=e.status if hasattr(e, "status") else None, |
|
|
response_headers=dict(e.headers) if getattr(e, "headers") else None, |
|
|
response_content=None, |
|
|
error_message=f"{type(e).__name__}: {str(e)}", |
|
|
) |
|
|
if attempt < max_retries: |
|
|
logging.warning( |
|
|
"Upload failed (%s/%s). Retrying in %.2fs. %s", attempt + 1, max_retries, delay, str(e) |
|
|
) |
|
|
await asyncio.sleep(delay) |
|
|
delay *= retry_backoff_factor |
|
|
else: |
|
|
raise NetworkError(f"Failed to upload file after {max_retries + 1} attempts: {e}") from e |
|
|
|
|
|
async def _handle_http_error( |
|
|
self, |
|
|
exc: ClientResponseError, |
|
|
operation_id: str, |
|
|
*req_meta, |
|
|
retry_count: int, |
|
|
response_content: dict | str = "", |
|
|
) -> Dict[str, Any]: |
|
|
status_code = exc.status |
|
|
if status_code == 401: |
|
|
user_friendly = "Unauthorized: Please login first to use this node." |
|
|
elif status_code == 402: |
|
|
user_friendly = "Payment Required: Please add credits to your account to use this node." |
|
|
elif status_code == 409: |
|
|
user_friendly = "There is a problem with your account. Please contact [email protected]." |
|
|
elif status_code == 429: |
|
|
user_friendly = "Rate Limit Exceeded: Please try again later." |
|
|
else: |
|
|
if isinstance(response_content, dict): |
|
|
if "error" in response_content and "message" in response_content["error"]: |
|
|
user_friendly = f"API Error: {response_content['error']['message']}" |
|
|
if "type" in response_content["error"]: |
|
|
user_friendly += f" (Type: {response_content['error']['type']})" |
|
|
else: |
|
|
user_friendly = f"API Error: {json.dumps(response_content)}" |
|
|
else: |
|
|
if len(response_content) < 200: |
|
|
user_friendly = f"API Error (raw): {response_content}" |
|
|
else: |
|
|
user_friendly = f"API Error (raw, status {response_content})" |
|
|
|
|
|
request_logger.log_request_response( |
|
|
operation_id=operation_id, |
|
|
request_method=req_meta[0], |
|
|
request_url=req_meta[1], |
|
|
response_status_code=exc.status, |
|
|
response_headers=dict(req_meta[5]) if req_meta[5] else None, |
|
|
response_content=response_content, |
|
|
error_message=f"HTTP Error {exc.status}", |
|
|
) |
|
|
|
|
|
logging.debug(f"[DEBUG] API Error: {user_friendly} (Status: {status_code})") |
|
|
if response_content: |
|
|
logging.debug(f"[DEBUG] Response content: {response_content}") |
|
|
|
|
|
|
|
|
if status_code in self.retry_status_codes and retry_count < self.max_retries: |
|
|
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) |
|
|
logging.warning( |
|
|
"HTTP error %s. Retrying in %.2fs (%s/%s)", |
|
|
status_code, |
|
|
delay, |
|
|
retry_count + 1, |
|
|
self.max_retries, |
|
|
) |
|
|
await asyncio.sleep(delay) |
|
|
return await self.request( |
|
|
req_meta[0], |
|
|
req_meta[1].replace(self.base_url, ""), |
|
|
params=req_meta[2], |
|
|
data=req_meta[3], |
|
|
files=req_meta[4], |
|
|
headers=req_meta[5], |
|
|
content_type=req_meta[6], |
|
|
multipart_parser=req_meta[7], |
|
|
retry_count=retry_count + 1, |
|
|
) |
|
|
|
|
|
raise Exception(user_friendly) from exc |
|
|
|
|
|
@staticmethod |
|
|
def _unpack_tuple(t): |
|
|
"""Helper to normalise (filename, file, content_type) tuples.""" |
|
|
if len(t) == 3: |
|
|
return t |
|
|
elif len(t) == 2: |
|
|
return t[0], t[1], "application/octet-stream" |
|
|
else: |
|
|
raise ValueError("files tuple must be (filename, file[, content_type])") |
|
|
|
|
|
async def _get_session(self) -> aiohttp.ClientSession: |
|
|
if self._session is None or self._session.closed: |
|
|
timeout = aiohttp.ClientTimeout(total=self.timeout) |
|
|
self._session = aiohttp.ClientSession(timeout=timeout) |
|
|
self._owns_session = True |
|
|
return self._session |
|
|
|
|
|
async def close(self) -> None: |
|
|
if self._owns_session and self._session and not self._session.closed: |
|
|
await self._session.close() |
|
|
|
|
|
async def __aenter__(self) -> "ApiClient": |
|
|
"""Allow usage as async‑context‑manager – ensures clean teardown""" |
|
|
return self |
|
|
|
|
|
async def __aexit__(self, exc_type, exc, tb): |
|
|
await self.close() |
|
|
|
|
|
|
|
|
class ApiEndpoint(Generic[T, R]): |
|
|
"""Defines an API endpoint with its request and response types""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
path: str, |
|
|
method: HttpMethod, |
|
|
request_model: Type[T], |
|
|
response_model: Type[R], |
|
|
query_params: Optional[Dict[str, Any]] = None, |
|
|
): |
|
|
"""Initialize an API endpoint definition. |
|
|
|
|
|
Args: |
|
|
path: The URL path for this endpoint, can include placeholders like {id} |
|
|
method: The HTTP method to use (GET, POST, etc.) |
|
|
request_model: Pydantic model class that defines the structure and validation rules for API requests to this endpoint |
|
|
response_model: Pydantic model class that defines the structure and validation rules for API responses from this endpoint |
|
|
query_params: Optional dictionary of query parameters to include in the request |
|
|
""" |
|
|
self.path = path |
|
|
self.method = method |
|
|
self.request_model = request_model |
|
|
self.response_model = response_model |
|
|
self.query_params = query_params or {} |
|
|
|
|
|
|
|
|
class SynchronousOperation(Generic[T, R]): |
|
|
"""Represents a single synchronous API operation.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
endpoint: ApiEndpoint[T, R], |
|
|
request: T, |
|
|
files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None, |
|
|
api_base: str | None = None, |
|
|
auth_token: Optional[str] = None, |
|
|
comfy_api_key: Optional[str] = None, |
|
|
auth_kwargs: Optional[Dict[str, str]] = None, |
|
|
timeout: float = 604800.0, |
|
|
verify_ssl: bool = True, |
|
|
content_type: str = "application/json", |
|
|
multipart_parser: Callable | None = None, |
|
|
max_retries: int = 3, |
|
|
retry_delay: float = 1.0, |
|
|
retry_backoff_factor: float = 2.0, |
|
|
) -> None: |
|
|
self.endpoint = endpoint |
|
|
self.request = request |
|
|
self.files = files |
|
|
self.api_base: str = api_base or args.comfy_api_base |
|
|
self.auth_token = auth_token |
|
|
self.comfy_api_key = comfy_api_key |
|
|
if auth_kwargs is not None: |
|
|
self.auth_token = auth_kwargs.get("auth_token", self.auth_token) |
|
|
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key) |
|
|
self.timeout = timeout |
|
|
self.verify_ssl = verify_ssl |
|
|
self.content_type = content_type |
|
|
self.multipart_parser = multipart_parser |
|
|
self.max_retries = max_retries |
|
|
self.retry_delay = retry_delay |
|
|
self.retry_backoff_factor = retry_backoff_factor |
|
|
|
|
|
async def execute(self, client: Optional[ApiClient] = None) -> R: |
|
|
owns_client = client is None |
|
|
if owns_client: |
|
|
client = ApiClient( |
|
|
base_url=self.api_base, |
|
|
auth_token=self.auth_token, |
|
|
comfy_api_key=self.comfy_api_key, |
|
|
timeout=self.timeout, |
|
|
verify_ssl=self.verify_ssl, |
|
|
max_retries=self.max_retries, |
|
|
retry_delay=self.retry_delay, |
|
|
retry_backoff_factor=self.retry_backoff_factor, |
|
|
) |
|
|
|
|
|
try: |
|
|
request_dict: Optional[Dict[str, Any]] |
|
|
if isinstance(self.request, EmptyRequest): |
|
|
request_dict = None |
|
|
else: |
|
|
request_dict = self.request.model_dump(exclude_none=True) |
|
|
for k, v in list(request_dict.items()): |
|
|
if isinstance(v, Enum): |
|
|
request_dict[k] = v.value |
|
|
|
|
|
logging.debug( |
|
|
f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}" |
|
|
) |
|
|
logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}") |
|
|
logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}") |
|
|
|
|
|
response_json = await client.request( |
|
|
self.endpoint.method.value, |
|
|
self.endpoint.path, |
|
|
params=self.endpoint.query_params, |
|
|
data=request_dict, |
|
|
files=self.files, |
|
|
content_type=self.content_type, |
|
|
multipart_parser=self.multipart_parser, |
|
|
) |
|
|
|
|
|
logging.debug("=" * 50) |
|
|
logging.debug("[DEBUG] RESPONSE DETAILS:") |
|
|
logging.debug("[DEBUG] Status Code: 200 (Success)") |
|
|
logging.debug(f"[DEBUG] Response Body: {json.dumps(response_json, indent=2)}") |
|
|
logging.debug("=" * 50) |
|
|
|
|
|
parsed_response = self.endpoint.response_model.model_validate(response_json) |
|
|
logging.debug(f"[DEBUG] Parsed Response: {parsed_response}") |
|
|
return parsed_response |
|
|
finally: |
|
|
if owns_client: |
|
|
await client.close() |
|
|
|
|
|
|
|
|
class TaskStatus(str, Enum): |
|
|
"""Enum for task status values""" |
|
|
|
|
|
COMPLETED = "completed" |
|
|
FAILED = "failed" |
|
|
PENDING = "pending" |
|
|
|
|
|
|
|
|
class PollingOperation(Generic[T, R]): |
|
|
"""Represents an asynchronous API operation that requires polling for completion.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
poll_endpoint: ApiEndpoint[EmptyRequest, R], |
|
|
completed_statuses: list[str], |
|
|
failed_statuses: list[str], |
|
|
status_extractor: Callable[[R], str], |
|
|
progress_extractor: Callable[[R], float] | None = None, |
|
|
result_url_extractor: Callable[[R], str] | None = None, |
|
|
request: Optional[T] = None, |
|
|
api_base: str | None = None, |
|
|
auth_token: Optional[str] = None, |
|
|
comfy_api_key: Optional[str] = None, |
|
|
auth_kwargs: Optional[Dict[str, str]] = None, |
|
|
poll_interval: float = 5.0, |
|
|
max_poll_attempts: int = 120, |
|
|
max_retries: int = 3, |
|
|
retry_delay: float = 1.0, |
|
|
retry_backoff_factor: float = 2.0, |
|
|
estimated_duration: Optional[float] = None, |
|
|
node_id: Optional[str] = None, |
|
|
) -> None: |
|
|
self.poll_endpoint = poll_endpoint |
|
|
self.request = request |
|
|
self.api_base: str = api_base or args.comfy_api_base |
|
|
self.auth_token = auth_token |
|
|
self.comfy_api_key = comfy_api_key |
|
|
if auth_kwargs is not None: |
|
|
self.auth_token = auth_kwargs.get("auth_token", self.auth_token) |
|
|
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key) |
|
|
self.poll_interval = poll_interval |
|
|
self.max_poll_attempts = max_poll_attempts |
|
|
self.max_retries = max_retries |
|
|
self.retry_delay = retry_delay |
|
|
self.retry_backoff_factor = retry_backoff_factor |
|
|
self.estimated_duration = estimated_duration |
|
|
self.status_extractor = status_extractor or (lambda x: getattr(x, "status", None)) |
|
|
self.progress_extractor = progress_extractor |
|
|
self.result_url_extractor = result_url_extractor |
|
|
self.node_id = node_id |
|
|
self.completed_statuses = completed_statuses |
|
|
self.failed_statuses = failed_statuses |
|
|
self.final_response: Optional[R] = None |
|
|
|
|
|
async def execute(self, client: Optional[ApiClient] = None) -> R: |
|
|
owns_client = client is None |
|
|
if owns_client: |
|
|
client = ApiClient( |
|
|
base_url=self.api_base, |
|
|
auth_token=self.auth_token, |
|
|
comfy_api_key=self.comfy_api_key, |
|
|
max_retries=self.max_retries, |
|
|
retry_delay=self.retry_delay, |
|
|
retry_backoff_factor=self.retry_backoff_factor, |
|
|
) |
|
|
try: |
|
|
return await self._poll_until_complete(client) |
|
|
finally: |
|
|
if owns_client: |
|
|
await client.close() |
|
|
|
|
|
def _display_text_on_node(self, text: str): |
|
|
if not self.node_id: |
|
|
return |
|
|
PromptServer.instance.send_progress_text(text, self.node_id) |
|
|
|
|
|
def _display_time_progress_on_node(self, time_completed: int | float): |
|
|
if not self.node_id: |
|
|
return |
|
|
if self.estimated_duration is not None: |
|
|
remaining = max(0, int(self.estimated_duration) - time_completed) |
|
|
message = f"Task in progress: {time_completed}s (~{remaining}s remaining)" |
|
|
else: |
|
|
message = f"Task in progress: {time_completed}s" |
|
|
self._display_text_on_node(message) |
|
|
|
|
|
def _check_task_status(self, response: R) -> TaskStatus: |
|
|
try: |
|
|
status = self.status_extractor(response) |
|
|
if status in self.completed_statuses: |
|
|
return TaskStatus.COMPLETED |
|
|
if status in self.failed_statuses: |
|
|
return TaskStatus.FAILED |
|
|
return TaskStatus.PENDING |
|
|
except Exception as e: |
|
|
logging.error("Error extracting status: %s", e) |
|
|
return TaskStatus.PENDING |
|
|
|
|
|
async def _poll_until_complete(self, client: ApiClient) -> R: |
|
|
"""Poll until the task is complete""" |
|
|
consecutive_errors = 0 |
|
|
max_consecutive_errors = min(5, self.max_retries * 2) |
|
|
|
|
|
if self.progress_extractor: |
|
|
progress = utils.ProgressBar(PROGRESS_BAR_MAX) |
|
|
|
|
|
status = TaskStatus.PENDING |
|
|
for poll_count in range(1, self.max_poll_attempts + 1): |
|
|
try: |
|
|
logging.debug(f"[DEBUG] Polling attempt #{poll_count}") |
|
|
|
|
|
request_dict = ( |
|
|
None if self.request is None else self.request.model_dump(exclude_none=True) |
|
|
) |
|
|
|
|
|
if poll_count == 1: |
|
|
logging.debug( |
|
|
f"[DEBUG] Poll Request: {self.poll_endpoint.method.value} {self.poll_endpoint.path}" |
|
|
) |
|
|
logging.debug( |
|
|
f"[DEBUG] Poll Request Data: {json.dumps(request_dict, indent=2) if request_dict else 'None'}" |
|
|
) |
|
|
|
|
|
|
|
|
resp = await client.request( |
|
|
self.poll_endpoint.method.value, |
|
|
self.poll_endpoint.path, |
|
|
params=self.poll_endpoint.query_params, |
|
|
data=request_dict, |
|
|
) |
|
|
consecutive_errors = 0 |
|
|
response_obj: R = self.poll_endpoint.response_model.model_validate(resp) |
|
|
|
|
|
|
|
|
status = self._check_task_status(response_obj) |
|
|
logging.debug(f"[DEBUG] Task Status: {status}") |
|
|
|
|
|
|
|
|
if self.progress_extractor: |
|
|
new_progress = self.progress_extractor(response_obj) |
|
|
if new_progress is not None: |
|
|
progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX) |
|
|
|
|
|
if status == TaskStatus.COMPLETED: |
|
|
message = "Task completed successfully" |
|
|
if self.result_url_extractor: |
|
|
result_url = self.result_url_extractor(response_obj) |
|
|
if result_url: |
|
|
message = f"Result URL: {result_url}" |
|
|
logging.debug(f"[DEBUG] {message}") |
|
|
self._display_text_on_node(message) |
|
|
self.final_response = response_obj |
|
|
if self.progress_extractor: |
|
|
progress.update(100) |
|
|
return self.final_response |
|
|
if status == TaskStatus.FAILED: |
|
|
message = f"Task failed: {json.dumps(resp)}" |
|
|
logging.error(f"[DEBUG] {message}") |
|
|
raise Exception(message) |
|
|
logging.debug("[DEBUG] Task still pending, continuing to poll...") |
|
|
|
|
|
for i in range(int(self.poll_interval)): |
|
|
self._display_time_progress_on_node((poll_count - 1) * self.poll_interval + i) |
|
|
await asyncio.sleep(1) |
|
|
|
|
|
except (LocalNetworkError, ApiServerError, NetworkError) as e: |
|
|
consecutive_errors += 1 |
|
|
if consecutive_errors >= max_consecutive_errors: |
|
|
raise Exception( |
|
|
f"Polling aborted after {consecutive_errors} network errors: {str(e)}" |
|
|
) from e |
|
|
logging.warning("Network error (%s/%s): %s", consecutive_errors, max_consecutive_errors, str(e)) |
|
|
await asyncio.sleep(self.poll_interval) |
|
|
except Exception as e: |
|
|
|
|
|
consecutive_errors += 1 |
|
|
if consecutive_errors >= max_consecutive_errors or status == TaskStatus.FAILED: |
|
|
raise Exception( |
|
|
f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}" |
|
|
) from e |
|
|
|
|
|
logging.error(f"[DEBUG] Polling error: {str(e)}") |
|
|
logging.warning( |
|
|
f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. " |
|
|
f"Will retry in {self.poll_interval} seconds." |
|
|
) |
|
|
await asyncio.sleep(self.poll_interval) |
|
|
|
|
|
|
|
|
raise Exception( |
|
|
f"Polling timed out after {self.max_poll_attempts} attempts (" f"{self.max_poll_attempts * self.poll_interval} seconds). " |
|
|
"The operation may still be running on the server but is taking longer than expected." |
|
|
) |
|
|
|