abidlabs's picture
abidlabs HF Staff
Upload folder using huggingface_hub
04dca9c verified
import threading
import time
import warnings
from datetime import datetime, timezone
import huggingface_hub
from gradio_client import Client, handle_file
from trackio import utils
from trackio.histogram import Histogram
from trackio.media import TrackioMedia
from trackio.sqlite_storage import SQLiteStorage
from trackio.table import Table
from trackio.typehints import LogEntry, UploadEntry
BATCH_SEND_INTERVAL = 0.5
class Run:
def __init__(
self,
url: str,
project: str,
client: Client | None,
name: str | None = None,
group: str | None = None,
config: dict | None = None,
space_id: str | None = None,
):
self.url = url
self.project = project
self._client_lock = threading.Lock()
self._client_thread = None
self._client = client
self._space_id = space_id
self.name = name or utils.generate_readable_name(
SQLiteStorage.get_runs(project), space_id
)
self.group = group
self.config = utils.to_json_safe(config or {})
if isinstance(self.config, dict):
for key in self.config:
if key.startswith("_"):
raise ValueError(
f"Config key '{key}' is reserved (keys starting with '_' are reserved for internal use)"
)
self.config["_Username"] = self._get_username()
self.config["_Created"] = datetime.now(timezone.utc).isoformat()
self.config["_Group"] = self.group
self._queued_logs: list[LogEntry] = []
self._queued_uploads: list[UploadEntry] = []
self._stop_flag = threading.Event()
self._config_logged = False
self._client_thread = threading.Thread(target=self._init_client_background)
self._client_thread.daemon = True
self._client_thread.start()
def _get_username(self) -> str | None:
"""Get the current HuggingFace username if logged in, otherwise None."""
try:
who = huggingface_hub.whoami()
return who["name"] if who else None
except Exception:
return None
def _batch_sender(self):
"""Send batched logs every BATCH_SEND_INTERVAL."""
while not self._stop_flag.is_set() or len(self._queued_logs) > 0:
if not self._stop_flag.is_set():
time.sleep(BATCH_SEND_INTERVAL)
with self._client_lock:
if self._client is None:
return
if self._queued_logs:
logs_to_send = self._queued_logs.copy()
self._queued_logs.clear()
self._client.predict(
api_name="/bulk_log",
logs=logs_to_send,
hf_token=huggingface_hub.utils.get_token(),
)
if self._queued_uploads:
uploads_to_send = self._queued_uploads.copy()
self._queued_uploads.clear()
self._client.predict(
api_name="/bulk_upload_media",
uploads=uploads_to_send,
hf_token=huggingface_hub.utils.get_token(),
)
def _init_client_background(self):
if self._client is None:
fib = utils.fibo()
for sleep_coefficient in fib:
try:
client = Client(self.url, verbose=False)
with self._client_lock:
self._client = client
break
except Exception:
pass
if sleep_coefficient is not None:
time.sleep(0.1 * sleep_coefficient)
self._batch_sender()
def _queue_upload(self, file_path, step: int | None):
"""Queue a media file for upload to space."""
upload_entry: UploadEntry = {
"project": self.project,
"run": self.name,
"step": step,
"uploaded_file": handle_file(file_path),
}
with self._client_lock:
self._queued_uploads.append(upload_entry)
def _process_media(self, value: TrackioMedia, step: int | None) -> dict:
"""
Serialize media in metrics and upload to space if needed.
"""
value._save(self.project, self.name, step)
if self._space_id:
self._queue_upload(value._get_absolute_file_path(), step)
return value._to_dict()
def _scan_and_queue_media_uploads(self, table_dict: dict, step: int | None):
"""
Scan a serialized table for media objects and queue them for upload to space.
"""
if not self._space_id:
return
table_data = table_dict.get("_value", [])
for row in table_data:
for value in row.values():
if isinstance(value, dict) and value.get("_type") in [
"trackio.image",
"trackio.video",
"trackio.audio",
]:
file_path = value.get("file_path")
if file_path:
from trackio.utils import MEDIA_DIR
absolute_path = MEDIA_DIR / file_path
self._queue_upload(absolute_path, step)
elif isinstance(value, list):
for item in value:
if isinstance(item, dict) and item.get("_type") in [
"trackio.image",
"trackio.video",
"trackio.audio",
]:
file_path = item.get("file_path")
if file_path:
from trackio.utils import MEDIA_DIR
absolute_path = MEDIA_DIR / file_path
self._queue_upload(absolute_path, step)
def log(self, metrics: dict, step: int | None = None):
renamed_keys = []
new_metrics = {}
for k, v in metrics.items():
if k in utils.RESERVED_KEYS or k.startswith("__"):
new_key = f"__{k}"
renamed_keys.append(k)
new_metrics[new_key] = v
else:
new_metrics[k] = v
if renamed_keys:
warnings.warn(f"Reserved keys renamed: {renamed_keys} → '__{{key}}'")
metrics = new_metrics
for key, value in metrics.items():
if isinstance(value, Table):
metrics[key] = value._to_dict(
project=self.project, run=self.name, step=step
)
self._scan_and_queue_media_uploads(metrics[key], step)
elif isinstance(value, Histogram):
metrics[key] = value._to_dict()
elif isinstance(value, TrackioMedia):
metrics[key] = self._process_media(value, step)
metrics = utils.serialize_values(metrics)
config_to_log = None
if not self._config_logged and self.config:
config_to_log = utils.to_json_safe(self.config)
self._config_logged = True
log_entry: LogEntry = {
"project": self.project,
"run": self.name,
"metrics": metrics,
"step": step,
"config": config_to_log,
}
with self._client_lock:
self._queued_logs.append(log_entry)
def finish(self):
"""Cleanup when run is finished."""
self._stop_flag.set()
time.sleep(2 * BATCH_SEND_INTERVAL)
if self._client_thread is not None:
print("* Run finished. Uploading logs to Trackio (please wait...)")
self._client_thread.join()