diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..df573b28e8a840f0092cd0e97609ace2c73500a4 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+assets/trackio_logo_old.png filter=lfs diff=lfs merge=lfs -text
diff --git a/CHANGELOG.md b/CHANGELOG.md
new file mode 100644
index 0000000000000000000000000000000000000000..418c15791bf32416e59af408e082f0e310c347d9
--- /dev/null
+++ b/CHANGELOG.md
@@ -0,0 +1,51 @@
+# trackio
+
+## 0.8.0
+
+### Features
+
+- [#331](https://github.com/gradio-app/trackio/pull/331) [`2c02d0f`](https://github.com/gradio-app/trackio/commit/2c02d0fd0a5824160528782402bb0dd4083396d5) - Truncate table string values that are greater than 250 characters (configuirable via env variable). Thanks @abidlabs!
+- [#324](https://github.com/gradio-app/trackio/pull/324) [`50b2122`](https://github.com/gradio-app/trackio/commit/50b2122e7965ac82a72e6cb3b7d048bc10a2a6b1) - Add log y-axis functionality to UI. Thanks @abidlabs!
+- [#326](https://github.com/gradio-app/trackio/pull/326) [`61dc1f4`](https://github.com/gradio-app/trackio/commit/61dc1f40af2f545f8e70395ddf0dbb8aee6b60d5) - Fix: improve table rendering for metrics in Trackio Dashboard. Thanks @vigneshwaran!
+- [#328](https://github.com/gradio-app/trackio/pull/328) [`6857cbb`](https://github.com/gradio-app/trackio/commit/6857cbbe557a59a4642f210ec42566d108294e63) - Support trackio.Table with trackio.Image columns. Thanks @abidlabs!
+
+## 0.7.0
+
+### Features
+
+- [#277](https://github.com/gradio-app/trackio/pull/277) [`db35601`](https://github.com/gradio-app/trackio/commit/db35601b9c023423c4654c9909b8ab73e58737de) - fix: make grouped runs view reflect live updates. Thanks @Saba9!
+- [#320](https://github.com/gradio-app/trackio/pull/320) [`24ae739`](https://github.com/gradio-app/trackio/commit/24ae73969b09fb3126acd2f91647cdfbf8cf72a1) - Add additional query parms for xmin, xmax, and smoothing. Thanks @abidlabs!
+- [#270](https://github.com/gradio-app/trackio/pull/270) [`cd1dfc3`](https://github.com/gradio-app/trackio/commit/cd1dfc3dc641b4499ac6d4a1b066fa8e2b52c57b) - feature: add support for logging audio. Thanks @Saba9!
+
+## 0.6.0
+
+### Features
+
+- [#309](https://github.com/gradio-app/trackio/pull/309) [`1df2353`](https://github.com/gradio-app/trackio/commit/1df23534d6c01938c8db9c0f584ffa23e8d6021d) - Add histogram support with wandb-compatible API. Thanks @abidlabs!
+- [#315](https://github.com/gradio-app/trackio/pull/315) [`76ba060`](https://github.com/gradio-app/trackio/commit/76ba06055dc43ca8f03b79f3e72d761949bd19a8) - Add guards to avoid silent fails. Thanks @Xmaster6y!
+- [#313](https://github.com/gradio-app/trackio/pull/313) [`a606b3e`](https://github.com/gradio-app/trackio/commit/a606b3e1c5edf3d4cf9f31bd50605226a5a1c5d0) - No longer prevent certain keys from being used. Instead, dunderify them to prevent collisions with internal usage. Thanks @abidlabs!
+- [#317](https://github.com/gradio-app/trackio/pull/317) [`27370a5`](https://github.com/gradio-app/trackio/commit/27370a595d0dbdf7eebbe7159d2ba778f039da44) - quick fixes for trackio.histogram. Thanks @abidlabs!
+- [#312](https://github.com/gradio-app/trackio/pull/312) [`aa0f3bf`](https://github.com/gradio-app/trackio/commit/aa0f3bf372e7a0dd592a38af699c998363830eeb) - Fix video logging by adding TRACKIO_DIR to allowed_paths. Thanks @abidlabs!
+
+## 0.5.3
+
+### Features
+
+- [#300](https://github.com/gradio-app/trackio/pull/300) [`5e4cacf`](https://github.com/gradio-app/trackio/commit/5e4cacf2e7ce527b4ce60de3a5bc05d2c02c77fb) - Adds more environment variables to allow customization of Trackio dashboard. Thanks @abidlabs!
+
+## 0.5.2
+
+### Features
+
+- [#293](https://github.com/gradio-app/trackio/pull/293) [`64afc28`](https://github.com/gradio-app/trackio/commit/64afc28d3ea1dfd821472dc6bf0b8ed35a9b74be) - Ensures that the TRACKIO_DIR environment variable is respected. Thanks @abidlabs!
+- [#287](https://github.com/gradio-app/trackio/pull/287) [`cd3e929`](https://github.com/gradio-app/trackio/commit/cd3e9294320949e6b8b829239069a43d5d7ff4c1) - fix(sqlite): unify .sqlite extension, allow export when DBs exist, clean WAL sidecars on import. Thanks @vaibhav-research!
+
+### Fixes
+
+- [#291](https://github.com/gradio-app/trackio/pull/291) [`3b5adc3`](https://github.com/gradio-app/trackio/commit/3b5adc3d1f452dbab7a714d235f4974782f93730) - Fix the wheel build. Thanks @pngwn!
+
+## 0.5.1
+
+### Fixes
+
+- [#278](https://github.com/gradio-app/trackio/pull/278) [`314c054`](https://github.com/gradio-app/trackio/commit/314c05438007ddfea3383e06fd19143e27468e2d) - Fix row orientation of metrics plots. Thanks @abidlabs!
\ No newline at end of file
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e615a5186e4538874e2efc9051b9381a426f2d8f
--- /dev/null
+++ b/__init__.py
@@ -0,0 +1,339 @@
+import hashlib
+import json
+import logging
+import os
+import warnings
+import webbrowser
+from pathlib import Path
+from typing import Any
+
+from gradio.blocks import BUILT_IN_THEMES
+from gradio.themes import Default as DefaultTheme
+from gradio.themes import ThemeClass
+from gradio_client import Client
+from huggingface_hub import SpaceStorage
+
+from trackio import context_vars, deploy, utils
+from trackio.histogram import Histogram
+from trackio.imports import import_csv, import_tf_events
+from trackio.media import TrackioAudio, TrackioImage, TrackioVideo
+from trackio.run import Run
+from trackio.sqlite_storage import SQLiteStorage
+from trackio.table import Table
+from trackio.ui.main import demo
+from trackio.utils import TRACKIO_DIR, TRACKIO_LOGO_DIR
+
+logging.getLogger("httpx").setLevel(logging.WARNING)
+
+warnings.filterwarnings(
+ "ignore",
+ message="Empty session being created. Install gradio\\[oauth\\]",
+ category=UserWarning,
+ module="gradio.helpers",
+)
+
+__version__ = json.loads(Path(__file__).parent.joinpath("package.json").read_text())[
+ "version"
+]
+
+__all__ = [
+ "init",
+ "log",
+ "finish",
+ "show",
+ "import_csv",
+ "import_tf_events",
+ "Image",
+ "Video",
+ "Audio",
+ "Table",
+ "Histogram",
+]
+
+Image = TrackioImage
+Video = TrackioVideo
+Audio = TrackioAudio
+
+
+config = {}
+
+DEFAULT_THEME = "default"
+
+
+def init(
+ project: str,
+ name: str | None = None,
+ group: str | None = None,
+ space_id: str | None = None,
+ space_storage: SpaceStorage | None = None,
+ dataset_id: str | None = None,
+ config: dict | None = None,
+ resume: str = "never",
+ settings: Any = None,
+ private: bool | None = None,
+ embed: bool = True,
+) -> Run:
+ """
+ Creates a new Trackio project and returns a [`Run`] object.
+
+ Args:
+ project (`str`):
+ The name of the project (can be an existing project to continue tracking or
+ a new project to start tracking from scratch).
+ name (`str`, *optional*):
+ The name of the run (if not provided, a default name will be generated).
+ group (`str`, *optional*):
+ The name of the group which this run belongs to in order to help organize
+ related runs together. You can toggle the entire group's visibilitiy in the
+ dashboard.
+ space_id (`str`, *optional*):
+ If provided, the project will be logged to a Hugging Face Space instead of
+ a local directory. Should be a complete Space name like
+ `"username/reponame"` or `"orgname/reponame"`, or just `"reponame"` in which
+ case the Space will be created in the currently-logged-in Hugging Face
+ user's namespace. If the Space does not exist, it will be created. If the
+ Space already exists, the project will be logged to it.
+ space_storage ([`~huggingface_hub.SpaceStorage`], *optional*):
+ Choice of persistent storage tier.
+ dataset_id (`str`, *optional*):
+ If a `space_id` is provided, a persistent Hugging Face Dataset will be
+ created and the metrics will be synced to it every 5 minutes. Specify a
+ Dataset with name like `"username/datasetname"` or `"orgname/datasetname"`,
+ or `"datasetname"` (uses currently-logged-in Hugging Face user's namespace),
+ or `None` (uses the same name as the Space but with the `"_dataset"`
+ suffix). If the Dataset does not exist, it will be created. If the Dataset
+ already exists, the project will be appended to it.
+ config (`dict`, *optional*):
+ A dictionary of configuration options. Provided for compatibility with
+ `wandb.init()`.
+ resume (`str`, *optional*, defaults to `"never"`):
+ Controls how to handle resuming a run. Can be one of:
+
+ - `"must"`: Must resume the run with the given name, raises error if run
+ doesn't exist
+ - `"allow"`: Resume the run if it exists, otherwise create a new run
+ - `"never"`: Never resume a run, always create a new one
+ private (`bool`, *optional*):
+ Whether to make the Space private. If None (default), the repo will be
+ public unless the organization's default is private. This value is ignored
+ if the repo already exists.
+ settings (`Any`, *optional*):
+ Not used. Provided for compatibility with `wandb.init()`.
+ embed (`bool`, *optional*, defaults to `True`):
+ If running inside a jupyter/Colab notebook, whether the dashboard should
+ automatically be embedded in the cell when trackio.init() is called.
+
+ Returns:
+ `Run`: A [`Run`] object that can be used to log metrics and finish the run.
+ """
+ if settings is not None:
+ warnings.warn(
+ "* Warning: settings is not used. Provided for compatibility with wandb.init(). Please create an issue at: https://github.com/gradio-app/trackio/issues if you need a specific feature implemented."
+ )
+
+ if space_id is None and dataset_id is not None:
+ raise ValueError("Must provide a `space_id` when `dataset_id` is provided.")
+ space_id, dataset_id = utils.preprocess_space_and_dataset_ids(space_id, dataset_id)
+ url = context_vars.current_server.get()
+ share_url = context_vars.current_share_server.get()
+
+ if url is None:
+ if space_id is None:
+ _, url, share_url = demo.launch(
+ show_api=False,
+ inline=False,
+ quiet=True,
+ prevent_thread_lock=True,
+ show_error=True,
+ favicon_path=TRACKIO_LOGO_DIR / "trackio_logo_light.png",
+ allowed_paths=[TRACKIO_LOGO_DIR, TRACKIO_DIR],
+ )
+ else:
+ url = space_id
+ share_url = None
+ context_vars.current_server.set(url)
+ context_vars.current_share_server.set(share_url)
+ if (
+ context_vars.current_project.get() is None
+ or context_vars.current_project.get() != project
+ ):
+ print(f"* Trackio project initialized: {project}")
+
+ if dataset_id is not None:
+ os.environ["TRACKIO_DATASET_ID"] = dataset_id
+ print(
+ f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}"
+ )
+ if space_id is None:
+ print(f"* Trackio metrics logged to: {TRACKIO_DIR}")
+ if utils.is_in_notebook() and embed:
+ base_url = share_url + "/" if share_url else url
+ full_url = utils.get_full_url(
+ base_url, project=project, write_token=demo.write_token
+ )
+ utils.embed_url_in_notebook(full_url)
+ else:
+ utils.print_dashboard_instructions(project)
+ else:
+ deploy.create_space_if_not_exists(
+ space_id, space_storage, dataset_id, private
+ )
+ user_name, space_name = space_id.split("/")
+ space_url = deploy.SPACE_HOST_URL.format(
+ user_name=user_name, space_name=space_name
+ )
+ print(f"* View dashboard by going to: {space_url}")
+ if utils.is_in_notebook() and embed:
+ utils.embed_url_in_notebook(space_url)
+ context_vars.current_project.set(project)
+
+ client = None
+ if not space_id:
+ client = Client(url, verbose=False)
+
+ if resume == "must":
+ if name is None:
+ raise ValueError("Must provide a run name when resume='must'")
+ if name not in SQLiteStorage.get_runs(project):
+ raise ValueError(f"Run '{name}' does not exist in project '{project}'")
+ resumed = True
+ elif resume == "allow":
+ resumed = name is not None and name in SQLiteStorage.get_runs(project)
+ elif resume == "never":
+ if name is not None and name in SQLiteStorage.get_runs(project):
+ warnings.warn(
+ f"* Warning: resume='never' but a run '{name}' already exists in "
+ f"project '{project}'. Generating a new name and instead. If you want "
+ "to resume this run, call init() with resume='must' or resume='allow'."
+ )
+ name = None
+ resumed = False
+ else:
+ raise ValueError("resume must be one of: 'must', 'allow', or 'never'")
+
+ run = Run(
+ url=url,
+ project=project,
+ client=client,
+ name=name,
+ group=group,
+ config=config,
+ space_id=space_id,
+ )
+
+ if resumed:
+ print(f"* Resumed existing run: {run.name}")
+ else:
+ print(f"* Created new run: {run.name}")
+
+ context_vars.current_run.set(run)
+ globals()["config"] = run.config
+ return run
+
+
+def log(metrics: dict, step: int | None = None) -> None:
+ """
+ Logs metrics to the current run.
+
+ Args:
+ metrics (`dict`):
+ A dictionary of metrics to log.
+ step (`int`, *optional*):
+ The step number. If not provided, the step will be incremented
+ automatically.
+ """
+ run = context_vars.current_run.get()
+ if run is None:
+ raise RuntimeError("Call trackio.init() before trackio.log().")
+ run.log(
+ metrics=metrics,
+ step=step,
+ )
+
+
+def finish():
+ """
+ Finishes the current run.
+ """
+ run = context_vars.current_run.get()
+ if run is None:
+ raise RuntimeError("Call trackio.init() before trackio.finish().")
+ run.finish()
+
+
+def show(
+ project: str | None = None,
+ theme: str | ThemeClass | None = None,
+ mcp_server: bool | None = None,
+):
+ """
+ Launches the Trackio dashboard.
+
+ Args:
+ project (`str`, *optional*):
+ The name of the project whose runs to show. If not provided, all projects
+ will be shown and the user can select one.
+ theme (`str` or `ThemeClass`, *optional*):
+ A Gradio Theme to use for the dashboard instead of the default Gradio theme,
+ can be a built-in theme (e.g. `'soft'`, `'citrus'`), a theme from the Hub
+ (e.g. `"gstaff/xkcd"`), or a custom Theme class. If not provided, the
+ `TRACKIO_THEME` environment variable will be used, or if that is not set, the
+ default Gradio theme will be used.
+ mcp_server (`bool`, *optional*):
+ If `True`, the Trackio dashboard will be set up as an MCP server and certain
+ functions will be added as MCP tools. If `None` (default behavior), then the
+ `GRADIO_MCP_SERVER` environment variable will be used to determine if the
+ MCP server should be enabled (which is `"True"` on Hugging Face Spaces).
+ """
+ theme = theme or os.environ.get("TRACKIO_THEME", DEFAULT_THEME)
+
+ if theme != DEFAULT_THEME:
+ # TODO: It's a little hacky to reproduce this theme-setting logic from Gradio Blocks,
+ # but in Gradio 6.0, the theme will be set in `launch()` instead, which means that we
+ # will be able to remove this code.
+ if isinstance(theme, str):
+ if theme.lower() in BUILT_IN_THEMES:
+ theme = BUILT_IN_THEMES[theme.lower()]
+ else:
+ try:
+ theme = ThemeClass.from_hub(theme)
+ except Exception as e:
+ warnings.warn(f"Cannot load {theme}. Caught Exception: {str(e)}")
+ theme = DefaultTheme()
+ if not isinstance(theme, ThemeClass):
+ warnings.warn("Theme should be a class loaded from gradio.themes")
+ theme = DefaultTheme()
+ demo.theme: ThemeClass = theme
+ demo.theme_css = theme._get_theme_css()
+ demo.stylesheets = theme._stylesheets
+ theme_hasher = hashlib.sha256()
+ theme_hasher.update(demo.theme_css.encode("utf-8"))
+ demo.theme_hash = theme_hasher.hexdigest()
+
+ _mcp_server = (
+ mcp_server
+ if mcp_server is not None
+ else os.environ.get("GRADIO_MCP_SERVER", "False") == "True"
+ )
+
+ _, url, share_url = demo.launch(
+ show_api=_mcp_server,
+ quiet=True,
+ inline=False,
+ prevent_thread_lock=True,
+ favicon_path=TRACKIO_LOGO_DIR / "trackio_logo_light.png",
+ allowed_paths=[TRACKIO_LOGO_DIR, TRACKIO_DIR],
+ mcp_server=_mcp_server,
+ )
+
+ base_url = share_url + "/" if share_url else url
+ full_url = utils.get_full_url(
+ base_url, project=project, write_token=demo.write_token
+ )
+
+ if not utils.is_in_notebook():
+ print(f"* Trackio UI launched at: {full_url}")
+ webbrowser.open(full_url)
+ utils.block_main_thread_until_keyboard_interrupt()
+ else:
+ utils.embed_url_in_notebook(full_url)
diff --git a/__pycache__/__init__.cpython-311.pyc b/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..529e64086605bfb97b9cce58dc12034ea0b48ed7
Binary files /dev/null and b/__pycache__/__init__.cpython-311.pyc differ
diff --git a/__pycache__/__init__.cpython-312.pyc b/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4c25288eb96ee30ecd02aa4285989f91370f24bc
Binary files /dev/null and b/__pycache__/__init__.cpython-312.pyc differ
diff --git a/__pycache__/cli.cpython-311.pyc b/__pycache__/cli.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..87c265e8071c07da6c8b491cb201618e001744c0
Binary files /dev/null and b/__pycache__/cli.cpython-311.pyc differ
diff --git a/__pycache__/cli.cpython-312.pyc b/__pycache__/cli.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b3db8b06e9857734b9728fb97a77d5baf15223a4
Binary files /dev/null and b/__pycache__/cli.cpython-312.pyc differ
diff --git a/__pycache__/commit_scheduler.cpython-311.pyc b/__pycache__/commit_scheduler.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3bab6c98a40929b877dc234c2132feceae74cdcc
Binary files /dev/null and b/__pycache__/commit_scheduler.cpython-311.pyc differ
diff --git a/__pycache__/commit_scheduler.cpython-312.pyc b/__pycache__/commit_scheduler.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..53fb6dd9ef11f226b39023364781b7c790e0d7b4
Binary files /dev/null and b/__pycache__/commit_scheduler.cpython-312.pyc differ
diff --git a/__pycache__/context_vars.cpython-311.pyc b/__pycache__/context_vars.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1383e6d61245e80f439af53fb93ff3b328e92b68
Binary files /dev/null and b/__pycache__/context_vars.cpython-311.pyc differ
diff --git a/__pycache__/context_vars.cpython-312.pyc b/__pycache__/context_vars.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6d096ca6bcb420748ed9c71b699d03e4af8bd875
Binary files /dev/null and b/__pycache__/context_vars.cpython-312.pyc differ
diff --git a/__pycache__/deploy.cpython-311.pyc b/__pycache__/deploy.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1b1aca8da888cf07458ecbc8f1608478d53f8747
Binary files /dev/null and b/__pycache__/deploy.cpython-311.pyc differ
diff --git a/__pycache__/deploy.cpython-312.pyc b/__pycache__/deploy.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ef14a407b566bd4691da238f4018152fa78d5f8c
Binary files /dev/null and b/__pycache__/deploy.cpython-312.pyc differ
diff --git a/__pycache__/dummy_commit_scheduler.cpython-311.pyc b/__pycache__/dummy_commit_scheduler.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7c30e3bc154a33be8ee59f621dffa9e606475e0d
Binary files /dev/null and b/__pycache__/dummy_commit_scheduler.cpython-311.pyc differ
diff --git a/__pycache__/dummy_commit_scheduler.cpython-312.pyc b/__pycache__/dummy_commit_scheduler.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a4aa376d45d48127b31466259f83119330430052
Binary files /dev/null and b/__pycache__/dummy_commit_scheduler.cpython-312.pyc differ
diff --git a/__pycache__/file_storage.cpython-311.pyc b/__pycache__/file_storage.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..05e753159bc5fb7ef31f6ea8fc63b462a9715a5b
Binary files /dev/null and b/__pycache__/file_storage.cpython-311.pyc differ
diff --git a/__pycache__/file_storage.cpython-312.pyc b/__pycache__/file_storage.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..54d02ecd9d5af1a0f023480a376b6716171e8f12
Binary files /dev/null and b/__pycache__/file_storage.cpython-312.pyc differ
diff --git a/__pycache__/histogram.cpython-311.pyc b/__pycache__/histogram.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..536e72c5c8f35c22648948db5c8ced0b42f3c764
Binary files /dev/null and b/__pycache__/histogram.cpython-311.pyc differ
diff --git a/__pycache__/histogram.cpython-312.pyc b/__pycache__/histogram.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f0589a07249ecd86543357ebc4675784d3299387
Binary files /dev/null and b/__pycache__/histogram.cpython-312.pyc differ
diff --git a/__pycache__/imports.cpython-311.pyc b/__pycache__/imports.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..173a28843c870068f2da2432a42db3a82e4a6687
Binary files /dev/null and b/__pycache__/imports.cpython-311.pyc differ
diff --git a/__pycache__/imports.cpython-312.pyc b/__pycache__/imports.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b9a5a7ca1d63f76d133b3146085a501e9beff498
Binary files /dev/null and b/__pycache__/imports.cpython-312.pyc differ
diff --git a/__pycache__/media.cpython-311.pyc b/__pycache__/media.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..102519ab6a8a02155762bdd89a85403131771248
Binary files /dev/null and b/__pycache__/media.cpython-311.pyc differ
diff --git a/__pycache__/media.cpython-312.pyc b/__pycache__/media.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c81fcb8b160fec6b48b0c99f8178f071167fbdd1
Binary files /dev/null and b/__pycache__/media.cpython-312.pyc differ
diff --git a/__pycache__/run.cpython-311.pyc b/__pycache__/run.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e1ae20de7ec65821aa88f456f82fde92a349ed46
Binary files /dev/null and b/__pycache__/run.cpython-311.pyc differ
diff --git a/__pycache__/run.cpython-312.pyc b/__pycache__/run.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..661c25c01f92371bf045472f33ba60c8e5df8eb8
Binary files /dev/null and b/__pycache__/run.cpython-312.pyc differ
diff --git a/__pycache__/sqlite_storage.cpython-311.pyc b/__pycache__/sqlite_storage.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..309285b71ff66c471e72ae8499b25b37d5a4c6e0
Binary files /dev/null and b/__pycache__/sqlite_storage.cpython-311.pyc differ
diff --git a/__pycache__/sqlite_storage.cpython-312.pyc b/__pycache__/sqlite_storage.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9382e7b53dfe448d75b0a146366841c236e089f2
Binary files /dev/null and b/__pycache__/sqlite_storage.cpython-312.pyc differ
diff --git a/__pycache__/table.cpython-311.pyc b/__pycache__/table.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5051cb80fe2f2a3a60e04dfcdcaa13c77ea577c5
Binary files /dev/null and b/__pycache__/table.cpython-311.pyc differ
diff --git a/__pycache__/table.cpython-312.pyc b/__pycache__/table.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f3c8d1ed43cef667eb729d94056cd584d090ef20
Binary files /dev/null and b/__pycache__/table.cpython-312.pyc differ
diff --git a/__pycache__/typehints.cpython-311.pyc b/__pycache__/typehints.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..81f9ec9aff1620e18f1c2e432557880f3df83394
Binary files /dev/null and b/__pycache__/typehints.cpython-311.pyc differ
diff --git a/__pycache__/typehints.cpython-312.pyc b/__pycache__/typehints.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..edb0d8e5632c8e797a4c6dfdcf49dfc7fcd774d1
Binary files /dev/null and b/__pycache__/typehints.cpython-312.pyc differ
diff --git a/__pycache__/utils.cpython-311.pyc b/__pycache__/utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..422d7751bc33c25d8458067fe1935a6be493166d
Binary files /dev/null and b/__pycache__/utils.cpython-311.pyc differ
diff --git a/__pycache__/utils.cpython-312.pyc b/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..134daa183d816686e8725c84864be680a34506d4
Binary files /dev/null and b/__pycache__/utils.cpython-312.pyc differ
diff --git a/__pycache__/video_writer.cpython-311.pyc b/__pycache__/video_writer.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..035e2e9973cd8a1e329d08f333e08bbe3a43478c
Binary files /dev/null and b/__pycache__/video_writer.cpython-311.pyc differ
diff --git a/__pycache__/video_writer.cpython-312.pyc b/__pycache__/video_writer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d974430899c00641f54d0f23f006f2043abbc302
Binary files /dev/null and b/__pycache__/video_writer.cpython-312.pyc differ
diff --git a/assets/trackio_logo_dark.png b/assets/trackio_logo_dark.png
new file mode 100644
index 0000000000000000000000000000000000000000..5c5c08b2387d23599f177477ef7482ff6a601df3
Binary files /dev/null and b/assets/trackio_logo_dark.png differ
diff --git a/assets/trackio_logo_light.png b/assets/trackio_logo_light.png
new file mode 100644
index 0000000000000000000000000000000000000000..b3438262c61989e6c6d16df4801a8935136115b3
Binary files /dev/null and b/assets/trackio_logo_light.png differ
diff --git a/assets/trackio_logo_old.png b/assets/trackio_logo_old.png
new file mode 100644
index 0000000000000000000000000000000000000000..48a07d40b83e23c9cc9dc0cb6544a3c6ad62b57f
--- /dev/null
+++ b/assets/trackio_logo_old.png
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3922c4d1e465270ad4d8abb12023f3beed5d9f7f338528a4c0ac21dcf358a1c8
+size 487101
diff --git a/assets/trackio_logo_type_dark.png b/assets/trackio_logo_type_dark.png
new file mode 100644
index 0000000000000000000000000000000000000000..6f80a3191e514a8a0beaa6ab2011e5baf8df5eda
Binary files /dev/null and b/assets/trackio_logo_type_dark.png differ
diff --git a/assets/trackio_logo_type_dark_transparent.png b/assets/trackio_logo_type_dark_transparent.png
new file mode 100644
index 0000000000000000000000000000000000000000..95b2c1f3499c502a81f2ec1094c0e09f827fb1fa
Binary files /dev/null and b/assets/trackio_logo_type_dark_transparent.png differ
diff --git a/assets/trackio_logo_type_light.png b/assets/trackio_logo_type_light.png
new file mode 100644
index 0000000000000000000000000000000000000000..f07866d245ea897b9aba417b29403f7f91cc8bbd
Binary files /dev/null and b/assets/trackio_logo_type_light.png differ
diff --git a/assets/trackio_logo_type_light_transparent.png b/assets/trackio_logo_type_light_transparent.png
new file mode 100644
index 0000000000000000000000000000000000000000..a20b4d5e64c61c91546577645310593fe3493508
Binary files /dev/null and b/assets/trackio_logo_type_light_transparent.png differ
diff --git a/cli.py b/cli.py
new file mode 100644
index 0000000000000000000000000000000000000000..d358ad74d05e5ea259697ac75ba34b41c7fa8ebf
--- /dev/null
+++ b/cli.py
@@ -0,0 +1,37 @@
+import argparse
+
+from trackio import show
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Trackio CLI")
+ subparsers = parser.add_subparsers(dest="command")
+
+ ui_parser = subparsers.add_parser(
+ "show", help="Show the Trackio dashboard UI for a project"
+ )
+ ui_parser.add_argument(
+ "--project", required=False, help="Project name to show in the dashboard"
+ )
+ ui_parser.add_argument(
+ "--theme",
+ required=False,
+ default="default",
+ help="A Gradio Theme to use for the dashboard instead of the default, can be a built-in theme (e.g. 'soft', 'citrus'), or a theme from the Hub (e.g. 'gstaff/xkcd').",
+ )
+ ui_parser.add_argument(
+ "--mcp-server",
+ action="store_true",
+ help="Enable MCP server functionality. The Trackio dashboard will be set up as an MCP server and certain functions will be exposed as MCP tools.",
+ )
+
+ args = parser.parse_args()
+
+ if args.command == "show":
+ show(args.project, args.theme, args.mcp_server)
+ else:
+ parser.print_help()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/commit_scheduler.py b/commit_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a904f9b1ff940718f0f137fc812b54f750ff589
--- /dev/null
+++ b/commit_scheduler.py
@@ -0,0 +1,391 @@
+# Originally copied from https://github.com/huggingface/huggingface_hub/blob/d0a948fc2a32ed6e557042a95ef3e4af97ec4a7c/src/huggingface_hub/_commit_scheduler.py
+
+import atexit
+import logging
+import os
+import time
+from concurrent.futures import Future
+from dataclasses import dataclass
+from io import SEEK_END, SEEK_SET, BytesIO
+from pathlib import Path
+from threading import Lock, Thread
+from typing import Callable, Dict, List, Union
+
+from huggingface_hub.hf_api import (
+ DEFAULT_IGNORE_PATTERNS,
+ CommitInfo,
+ CommitOperationAdd,
+ HfApi,
+)
+from huggingface_hub.utils import filter_repo_objects
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass(frozen=True)
+class _FileToUpload:
+ """Temporary dataclass to store info about files to upload. Not meant to be used directly."""
+
+ local_path: Path
+ path_in_repo: str
+ size_limit: int
+ last_modified: float
+
+
+class CommitScheduler:
+ """
+ Scheduler to upload a local folder to the Hub at regular intervals (e.g. push to hub every 5 minutes).
+
+ The recommended way to use the scheduler is to use it as a context manager. This ensures that the scheduler is
+ properly stopped and the last commit is triggered when the script ends. The scheduler can also be stopped manually
+ with the `stop` method. Checkout the [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#scheduled-uploads)
+ to learn more about how to use it.
+
+ Args:
+ repo_id (`str`):
+ The id of the repo to commit to.
+ folder_path (`str` or `Path`):
+ Path to the local folder to upload regularly.
+ every (`int` or `float`, *optional*):
+ The number of minutes between each commit. Defaults to 5 minutes.
+ path_in_repo (`str`, *optional*):
+ Relative path of the directory in the repo, for example: `"checkpoints/"`. Defaults to the root folder
+ of the repository.
+ repo_type (`str`, *optional*):
+ The type of the repo to commit to. Defaults to `model`.
+ revision (`str`, *optional*):
+ The revision of the repo to commit to. Defaults to `main`.
+ private (`bool`, *optional*):
+ Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
+ token (`str`, *optional*):
+ The token to use to commit to the repo. Defaults to the token saved on the machine.
+ allow_patterns (`List[str]` or `str`, *optional*):
+ If provided, only files matching at least one pattern are uploaded.
+ ignore_patterns (`List[str]` or `str`, *optional*):
+ If provided, files matching any of the patterns are not uploaded.
+ squash_history (`bool`, *optional*):
+ Whether to squash the history of the repo after each commit. Defaults to `False`. Squashing commits is
+ useful to avoid degraded performances on the repo when it grows too large.
+ hf_api (`HfApi`, *optional*):
+ The [`HfApi`] client to use to commit to the Hub. Can be set with custom settings (user agent, token,...).
+ on_before_commit (`Callable[[], None]`, *optional*):
+ If specified, a function that will be called before the CommitScheduler lists files to create a commit.
+
+ Example:
+ ```py
+ >>> from pathlib import Path
+ >>> from huggingface_hub import CommitScheduler
+
+ # Scheduler uploads every 10 minutes
+ >>> csv_path = Path("watched_folder/data.csv")
+ >>> CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path=csv_path.parent, every=10)
+
+ >>> with csv_path.open("a") as f:
+ ... f.write("first line")
+
+ # Some time later (...)
+ >>> with csv_path.open("a") as f:
+ ... f.write("second line")
+ ```
+
+ Example using a context manager:
+ ```py
+ >>> from pathlib import Path
+ >>> from huggingface_hub import CommitScheduler
+
+ >>> with CommitScheduler(repo_id="test_scheduler", repo_type="dataset", folder_path="watched_folder", every=10) as scheduler:
+ ... csv_path = Path("watched_folder/data.csv")
+ ... with csv_path.open("a") as f:
+ ... f.write("first line")
+ ... (...)
+ ... with csv_path.open("a") as f:
+ ... f.write("second line")
+
+ # Scheduler is now stopped and last commit have been triggered
+ ```
+ """
+
+ def __init__(
+ self,
+ *,
+ repo_id: str,
+ folder_path: Union[str, Path],
+ every: Union[int, float] = 5,
+ path_in_repo: str | None = None,
+ repo_type: str | None = None,
+ revision: str | None = None,
+ private: bool | None = None,
+ token: str | None = None,
+ allow_patterns: list[str] | str | None = None,
+ ignore_patterns: list[str] | str | None = None,
+ squash_history: bool = False,
+ hf_api: HfApi | None = None,
+ on_before_commit: Callable[[], None] | None = None,
+ ) -> None:
+ self.api = hf_api or HfApi(token=token)
+ self.on_before_commit = on_before_commit
+
+ # Folder
+ self.folder_path = Path(folder_path).expanduser().resolve()
+ self.path_in_repo = path_in_repo or ""
+ self.allow_patterns = allow_patterns
+
+ if ignore_patterns is None:
+ ignore_patterns = []
+ elif isinstance(ignore_patterns, str):
+ ignore_patterns = [ignore_patterns]
+ self.ignore_patterns = ignore_patterns + DEFAULT_IGNORE_PATTERNS
+
+ if self.folder_path.is_file():
+ raise ValueError(
+ f"'folder_path' must be a directory, not a file: '{self.folder_path}'."
+ )
+ self.folder_path.mkdir(parents=True, exist_ok=True)
+
+ # Repository
+ repo_url = self.api.create_repo(
+ repo_id=repo_id, private=private, repo_type=repo_type, exist_ok=True
+ )
+ self.repo_id = repo_url.repo_id
+ self.repo_type = repo_type
+ self.revision = revision
+ self.token = token
+
+ self.last_uploaded: Dict[Path, float] = {}
+ self.last_push_time: float | None = None
+
+ if not every > 0:
+ raise ValueError(f"'every' must be a positive integer, not '{every}'.")
+ self.lock = Lock()
+ self.every = every
+ self.squash_history = squash_history
+
+ logger.info(
+ f"Scheduled job to push '{self.folder_path}' to '{self.repo_id}' every {self.every} minutes."
+ )
+ self._scheduler_thread = Thread(target=self._run_scheduler, daemon=True)
+ self._scheduler_thread.start()
+ atexit.register(self._push_to_hub)
+
+ self.__stopped = False
+
+ def stop(self) -> None:
+ """Stop the scheduler.
+
+ A stopped scheduler cannot be restarted. Mostly for tests purposes.
+ """
+ self.__stopped = True
+
+ def __enter__(self) -> "CommitScheduler":
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback) -> None:
+ # Upload last changes before exiting
+ self.trigger().result()
+ self.stop()
+ return
+
+ def _run_scheduler(self) -> None:
+ """Dumb thread waiting between each scheduled push to Hub."""
+ while True:
+ self.last_future = self.trigger()
+ time.sleep(self.every * 60)
+ if self.__stopped:
+ break
+
+ def trigger(self) -> Future:
+ """Trigger a `push_to_hub` and return a future.
+
+ This method is automatically called every `every` minutes. You can also call it manually to trigger a commit
+ immediately, without waiting for the next scheduled commit.
+ """
+ return self.api.run_as_future(self._push_to_hub)
+
+ def _push_to_hub(self) -> CommitInfo | None:
+ if self.__stopped: # If stopped, already scheduled commits are ignored
+ return None
+
+ logger.info("(Background) scheduled commit triggered.")
+ try:
+ value = self.push_to_hub()
+ if self.squash_history:
+ logger.info("(Background) squashing repo history.")
+ self.api.super_squash_history(
+ repo_id=self.repo_id, repo_type=self.repo_type, branch=self.revision
+ )
+ return value
+ except Exception as e:
+ logger.error(
+ f"Error while pushing to Hub: {e}"
+ ) # Depending on the setup, error might be silenced
+ raise
+
+ def push_to_hub(self) -> CommitInfo | None:
+ """
+ Push folder to the Hub and return the commit info.
+
+
+
+ This method is not meant to be called directly. It is run in the background by the scheduler, respecting a
+ queue mechanism to avoid concurrent commits. Making a direct call to the method might lead to concurrency
+ issues.
+
+
+
+ The default behavior of `push_to_hub` is to assume an append-only folder. It lists all files in the folder and
+ uploads only changed files. If no changes are found, the method returns without committing anything. If you want
+ to change this behavior, you can inherit from [`CommitScheduler`] and override this method. This can be useful
+ for example to compress data together in a single file before committing. For more details and examples, check
+ out our [integration guide](https://huggingface.co/docs/huggingface_hub/main/en/guides/upload#scheduled-uploads).
+ """
+ # Check files to upload (with lock)
+ with self.lock:
+ if self.on_before_commit is not None:
+ self.on_before_commit()
+
+ logger.debug("Listing files to upload for scheduled commit.")
+
+ # List files from folder (taken from `_prepare_upload_folder_additions`)
+ relpath_to_abspath = {
+ path.relative_to(self.folder_path).as_posix(): path
+ for path in sorted(
+ self.folder_path.glob("**/*")
+ ) # sorted to be deterministic
+ if path.is_file()
+ }
+ prefix = f"{self.path_in_repo.strip('/')}/" if self.path_in_repo else ""
+
+ # Filter with pattern + filter out unchanged files + retrieve current file size
+ files_to_upload: List[_FileToUpload] = []
+ for relpath in filter_repo_objects(
+ relpath_to_abspath.keys(),
+ allow_patterns=self.allow_patterns,
+ ignore_patterns=self.ignore_patterns,
+ ):
+ local_path = relpath_to_abspath[relpath]
+ stat = local_path.stat()
+ if (
+ self.last_uploaded.get(local_path) is None
+ or self.last_uploaded[local_path] != stat.st_mtime
+ ):
+ files_to_upload.append(
+ _FileToUpload(
+ local_path=local_path,
+ path_in_repo=prefix + relpath,
+ size_limit=stat.st_size,
+ last_modified=stat.st_mtime,
+ )
+ )
+
+ # Return if nothing to upload
+ if len(files_to_upload) == 0:
+ logger.debug("Dropping schedule commit: no changed file to upload.")
+ return None
+
+ # Convert `_FileToUpload` as `CommitOperationAdd` (=> compute file shas + limit to file size)
+ logger.debug("Removing unchanged files since previous scheduled commit.")
+ add_operations = [
+ CommitOperationAdd(
+ # TODO: Cap the file to its current size, even if the user append data to it while a scheduled commit is happening
+ # (requires an upstream fix for XET-535: `hf_xet` should support `BinaryIO` for upload)
+ path_or_fileobj=file_to_upload.local_path,
+ path_in_repo=file_to_upload.path_in_repo,
+ )
+ for file_to_upload in files_to_upload
+ ]
+
+ # Upload files (append mode expected - no need for lock)
+ logger.debug("Uploading files for scheduled commit.")
+ commit_info = self.api.create_commit(
+ repo_id=self.repo_id,
+ repo_type=self.repo_type,
+ operations=add_operations,
+ commit_message="Scheduled Commit",
+ revision=self.revision,
+ )
+
+ for file in files_to_upload:
+ self.last_uploaded[file.local_path] = file.last_modified
+
+ self.last_push_time = time.time()
+
+ return commit_info
+
+
+class PartialFileIO(BytesIO):
+ """A file-like object that reads only the first part of a file.
+
+ Useful to upload a file to the Hub when the user might still be appending data to it. Only the first part of the
+ file is uploaded (i.e. the part that was available when the filesystem was first scanned).
+
+ In practice, only used internally by the CommitScheduler to regularly push a folder to the Hub with minimal
+ disturbance for the user. The object is passed to `CommitOperationAdd`.
+
+ Only supports `read`, `tell` and `seek` methods.
+
+ Args:
+ file_path (`str` or `Path`):
+ Path to the file to read.
+ size_limit (`int`):
+ The maximum number of bytes to read from the file. If the file is larger than this, only the first part
+ will be read (and uploaded).
+ """
+
+ def __init__(self, file_path: Union[str, Path], size_limit: int) -> None:
+ self._file_path = Path(file_path)
+ self._file = self._file_path.open("rb")
+ self._size_limit = min(size_limit, os.fstat(self._file.fileno()).st_size)
+
+ def __del__(self) -> None:
+ self._file.close()
+ return super().__del__()
+
+ def __repr__(self) -> str:
+ return (
+ f""
+ )
+
+ def __len__(self) -> int:
+ return self._size_limit
+
+ def __getattribute__(self, name: str):
+ if name.startswith("_") or name in (
+ "read",
+ "tell",
+ "seek",
+ ): # only 3 public methods supported
+ return super().__getattribute__(name)
+ raise NotImplementedError(f"PartialFileIO does not support '{name}'.")
+
+ def tell(self) -> int:
+ """Return the current file position."""
+ return self._file.tell()
+
+ def seek(self, __offset: int, __whence: int = SEEK_SET) -> int:
+ """Change the stream position to the given offset.
+
+ Behavior is the same as a regular file, except that the position is capped to the size limit.
+ """
+ if __whence == SEEK_END:
+ # SEEK_END => set from the truncated end
+ __offset = len(self) + __offset
+ __whence = SEEK_SET
+
+ pos = self._file.seek(__offset, __whence)
+ if pos > self._size_limit:
+ return self._file.seek(self._size_limit)
+ return pos
+
+ def read(self, __size: int | None = -1) -> bytes:
+ """Read at most `__size` bytes from the file.
+
+ Behavior is the same as a regular file, except that it is capped to the size limit.
+ """
+ current = self._file.tell()
+ if __size is None or __size < 0:
+ # Read until file limit
+ truncated_size = self._size_limit - current
+ else:
+ # Read until file limit or __size
+ truncated_size = min(__size, self._size_limit - current)
+ return self._file.read(truncated_size)
diff --git a/context_vars.py b/context_vars.py
new file mode 100644
index 0000000000000000000000000000000000000000..5670ac7fee3b6ce31816f67390bbb42bec045f3a
--- /dev/null
+++ b/context_vars.py
@@ -0,0 +1,18 @@
+import contextvars
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from trackio.run import Run
+
+current_run: contextvars.ContextVar["Run | None"] = contextvars.ContextVar(
+ "current_run", default=None
+)
+current_project: contextvars.ContextVar[str | None] = contextvars.ContextVar(
+ "current_project", default=None
+)
+current_server: contextvars.ContextVar[str | None] = contextvars.ContextVar(
+ "current_server", default=None
+)
+current_share_server: contextvars.ContextVar[str | None] = contextvars.ContextVar(
+ "current_share_server", default=None
+)
diff --git a/deploy.py b/deploy.py
new file mode 100644
index 0000000000000000000000000000000000000000..29d7b6e037cd87dd7fc3c670c641b87f5ba92a25
--- /dev/null
+++ b/deploy.py
@@ -0,0 +1,258 @@
+import importlib.metadata
+import io
+import os
+import time
+from importlib.resources import files
+from pathlib import Path
+
+import gradio
+import huggingface_hub
+from gradio_client import Client, handle_file
+from httpx import ReadTimeout
+from huggingface_hub.errors import RepositoryNotFoundError
+from requests import HTTPError
+
+import trackio
+from trackio.sqlite_storage import SQLiteStorage
+
+SPACE_HOST_URL = "https://{user_name}-{space_name}.hf.space/"
+SPACE_URL = "https://huggingface.co/spaces/{space_id}"
+
+
+def _is_trackio_installed_from_source() -> bool:
+ """Check if trackio is installed from source/editable install vs PyPI."""
+ try:
+ trackio_file = trackio.__file__
+ if "site-packages" not in trackio_file:
+ return True
+
+ dist = importlib.metadata.distribution("trackio")
+ if dist.files:
+ files = list(dist.files)
+ has_pth = any(".pth" in str(f) for f in files)
+ if has_pth:
+ return True
+
+ return False
+ except (
+ AttributeError,
+ importlib.metadata.PackageNotFoundError,
+ importlib.metadata.MetadataError,
+ ValueError,
+ TypeError,
+ ):
+ return True
+
+
+def deploy_as_space(
+ space_id: str,
+ space_storage: huggingface_hub.SpaceStorage | None = None,
+ dataset_id: str | None = None,
+ private: bool | None = None,
+):
+ if (
+ os.getenv("SYSTEM") == "spaces"
+ ): # in case a repo with this function is uploaded to spaces
+ return
+
+ trackio_path = files("trackio")
+
+ hf_api = huggingface_hub.HfApi()
+
+ try:
+ huggingface_hub.create_repo(
+ space_id,
+ private=private,
+ space_sdk="gradio",
+ space_storage=space_storage,
+ repo_type="space",
+ exist_ok=True,
+ )
+ except HTTPError as e:
+ if e.response.status_code in [401, 403]: # unauthorized or forbidden
+ print("Need 'write' access token to create a Spaces repo.")
+ huggingface_hub.login(add_to_git_credential=False)
+ huggingface_hub.create_repo(
+ space_id,
+ private=private,
+ space_sdk="gradio",
+ space_storage=space_storage,
+ repo_type="space",
+ exist_ok=True,
+ )
+ else:
+ raise ValueError(f"Failed to create Space: {e}")
+
+ with open(Path(trackio_path, "README.md"), "r") as f:
+ readme_content = f.read()
+ readme_content = readme_content.replace("{GRADIO_VERSION}", gradio.__version__)
+ readme_buffer = io.BytesIO(readme_content.encode("utf-8"))
+ hf_api.upload_file(
+ path_or_fileobj=readme_buffer,
+ path_in_repo="README.md",
+ repo_id=space_id,
+ repo_type="space",
+ )
+
+ # We can assume pandas, gradio, and huggingface-hub are already installed in a Gradio Space.
+ # Make sure necessary dependencies are installed by creating a requirements.txt.
+ is_source_install = _is_trackio_installed_from_source()
+
+ if is_source_install:
+ requirements_content = """pyarrow>=21.0
+plotly>=6.0.0,<7.0.0"""
+ else:
+ requirements_content = f"""pyarrow>=21.0
+trackio=={trackio.__version__}
+plotly>=6.0.0,<7.0.0"""
+
+ requirements_buffer = io.BytesIO(requirements_content.encode("utf-8"))
+ hf_api.upload_file(
+ path_or_fileobj=requirements_buffer,
+ path_in_repo="requirements.txt",
+ repo_id=space_id,
+ repo_type="space",
+ )
+
+ huggingface_hub.utils.disable_progress_bars()
+
+ if is_source_install:
+ hf_api.upload_folder(
+ repo_id=space_id,
+ repo_type="space",
+ folder_path=trackio_path,
+ ignore_patterns=["README.md"],
+ )
+ else:
+ app_file_content = """import trackio
+trackio.show()"""
+ app_file_buffer = io.BytesIO(app_file_content.encode("utf-8"))
+ hf_api.upload_file(
+ path_or_fileobj=app_file_buffer,
+ path_in_repo="ui/main.py",
+ repo_id=space_id,
+ repo_type="space",
+ )
+
+ if hf_token := huggingface_hub.utils.get_token():
+ huggingface_hub.add_space_secret(space_id, "HF_TOKEN", hf_token)
+ if dataset_id is not None:
+ huggingface_hub.add_space_variable(space_id, "TRACKIO_DATASET_ID", dataset_id)
+
+ if logo_light_url := os.environ.get("TRACKIO_LOGO_LIGHT_URL"):
+ huggingface_hub.add_space_variable(
+ space_id, "TRACKIO_LOGO_LIGHT_URL", logo_light_url
+ )
+ if logo_dark_url := os.environ.get("TRACKIO_LOGO_DARK_URL"):
+ huggingface_hub.add_space_variable(
+ space_id, "TRACKIO_LOGO_DARK_URL", logo_dark_url
+ )
+
+ if plot_order := os.environ.get("TRACKIO_PLOT_ORDER"):
+ huggingface_hub.add_space_variable(space_id, "TRACKIO_PLOT_ORDER", plot_order)
+
+ if theme := os.environ.get("TRACKIO_THEME"):
+ huggingface_hub.add_space_variable(space_id, "TRACKIO_THEME", theme)
+
+
+def create_space_if_not_exists(
+ space_id: str,
+ space_storage: huggingface_hub.SpaceStorage | None = None,
+ dataset_id: str | None = None,
+ private: bool | None = None,
+) -> None:
+ """
+ Creates a new Hugging Face Space if it does not exist. If a dataset_id is provided, it will be added as a space variable.
+
+ Args:
+ space_id: The ID of the Space to create.
+ dataset_id: The ID of the Dataset to add to the Space.
+ private: Whether to make the Space private. If None (default), the repo will be
+ public unless the organization's default is private. This value is ignored if
+ the repo already exists.
+ """
+ if "/" not in space_id:
+ raise ValueError(
+ f"Invalid space ID: {space_id}. Must be in the format: username/reponame or orgname/reponame."
+ )
+ if dataset_id is not None and "/" not in dataset_id:
+ raise ValueError(
+ f"Invalid dataset ID: {dataset_id}. Must be in the format: username/datasetname or orgname/datasetname."
+ )
+ try:
+ huggingface_hub.repo_info(space_id, repo_type="space")
+ print(f"* Found existing space: {SPACE_URL.format(space_id=space_id)}")
+ if dataset_id is not None:
+ huggingface_hub.add_space_variable(
+ space_id, "TRACKIO_DATASET_ID", dataset_id
+ )
+ if logo_light_url := os.environ.get("TRACKIO_LOGO_LIGHT_URL"):
+ huggingface_hub.add_space_variable(
+ space_id, "TRACKIO_LOGO_LIGHT_URL", logo_light_url
+ )
+ if logo_dark_url := os.environ.get("TRACKIO_LOGO_DARK_URL"):
+ huggingface_hub.add_space_variable(
+ space_id, "TRACKIO_LOGO_DARK_URL", logo_dark_url
+ )
+
+ if plot_order := os.environ.get("TRACKIO_PLOT_ORDER"):
+ huggingface_hub.add_space_variable(
+ space_id, "TRACKIO_PLOT_ORDER", plot_order
+ )
+
+ if theme := os.environ.get("TRACKIO_THEME"):
+ huggingface_hub.add_space_variable(space_id, "TRACKIO_THEME", theme)
+ return
+ except RepositoryNotFoundError:
+ pass
+ except HTTPError as e:
+ if e.response.status_code in [401, 403]: # unauthorized or forbidden
+ print("Need 'write' access token to create a Spaces repo.")
+ huggingface_hub.login(add_to_git_credential=False)
+ huggingface_hub.add_space_variable(
+ space_id, "TRACKIO_DATASET_ID", dataset_id
+ )
+ else:
+ raise ValueError(f"Failed to create Space: {e}")
+
+ print(f"* Creating new space: {SPACE_URL.format(space_id=space_id)}")
+ deploy_as_space(space_id, space_storage, dataset_id, private)
+
+
+def wait_until_space_exists(
+ space_id: str,
+) -> None:
+ """
+ Blocks the current thread until the space exists.
+ May raise a TimeoutError if this takes quite a while.
+
+ Args:
+ space_id: The ID of the Space to wait for.
+ """
+ delay = 1
+ for _ in range(10):
+ try:
+ Client(space_id, verbose=False)
+ return
+ except (ReadTimeout, ValueError):
+ time.sleep(delay)
+ delay = min(delay * 2, 30)
+ raise TimeoutError("Waiting for space to exist took longer than expected")
+
+
+def upload_db_to_space(project: str, space_id: str) -> None:
+ """
+ Uploads the database of a local Trackio project to a Hugging Face Space.
+
+ Args:
+ project: The name of the project to upload.
+ space_id: The ID of the Space to upload to.
+ """
+ db_path = SQLiteStorage.get_project_db_path(project)
+ client = Client(space_id, verbose=False)
+ client.predict(
+ api_name="/upload_db_to_space",
+ project=project,
+ uploaded_db=handle_file(db_path),
+ hf_token=huggingface_hub.utils.get_token(),
+ )
diff --git a/dummy_commit_scheduler.py b/dummy_commit_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f5015e1479a175081080ef8908966979c1de179
--- /dev/null
+++ b/dummy_commit_scheduler.py
@@ -0,0 +1,12 @@
+# A dummy object to fit the interface of huggingface_hub's CommitScheduler
+class DummyCommitSchedulerLock:
+ def __enter__(self):
+ return None
+
+ def __exit__(self, exception_type, exception_value, exception_traceback):
+ pass
+
+
+class DummyCommitScheduler:
+ def __init__(self):
+ self.lock = DummyCommitSchedulerLock()
diff --git a/histogram.py b/histogram.py
new file mode 100644
index 0000000000000000000000000000000000000000..13922f4aabbd441086d8153e2e2b3721da014110
--- /dev/null
+++ b/histogram.py
@@ -0,0 +1,68 @@
+from typing import Any
+
+import numpy as np
+
+
+class Histogram:
+ """
+ Histogram data type for Trackio, compatible with wandb.Histogram.
+
+ Example:
+ ```python
+ import trackio
+ import numpy as np
+
+ # Create histogram from sequence
+ data = np.random.randn(1000)
+ trackio.log({"distribution": trackio.Histogram(data)})
+
+ # Create histogram from numpy histogram
+ hist, bins = np.histogram(data, bins=30)
+ trackio.log({"distribution": trackio.Histogram(np_histogram=(hist, bins))})
+
+ # Specify custom number of bins
+ trackio.log({"distribution": trackio.Histogram(data, num_bins=50)})
+ ```
+
+ Args:
+ sequence: Optional sequence of values to create histogram from
+ np_histogram: Optional pre-computed numpy histogram (hist, bins) tuple
+ num_bins: Number of bins for the histogram (default 64, max 512)
+ """
+
+ TYPE = "trackio.histogram"
+
+ def __init__(
+ self,
+ sequence: Any = None,
+ np_histogram: tuple | None = None,
+ num_bins: int = 64,
+ ):
+ if sequence is None and np_histogram is None:
+ raise ValueError("Must provide either sequence or np_histogram")
+
+ if sequence is not None and np_histogram is not None:
+ raise ValueError("Cannot provide both sequence and np_histogram")
+
+ num_bins = min(num_bins, 512)
+
+ if np_histogram is not None:
+ self.histogram, self.bins = np_histogram
+ self.histogram = np.asarray(self.histogram)
+ self.bins = np.asarray(self.bins)
+ else:
+ data = np.asarray(sequence).flatten()
+ data = data[np.isfinite(data)]
+ if len(data) == 0:
+ self.histogram = np.array([])
+ self.bins = np.array([])
+ else:
+ self.histogram, self.bins = np.histogram(data, bins=num_bins)
+
+ def _to_dict(self) -> dict:
+ """Convert histogram to dictionary for storage."""
+ return {
+ "_type": self.TYPE,
+ "bins": self.bins.tolist(),
+ "values": self.histogram.tolist(),
+ }
diff --git a/imports.py b/imports.py
new file mode 100644
index 0000000000000000000000000000000000000000..3dcc21e157919e974d7b0249292e287010e50074
--- /dev/null
+++ b/imports.py
@@ -0,0 +1,302 @@
+import os
+from pathlib import Path
+
+import pandas as pd
+
+from trackio import deploy, utils
+from trackio.sqlite_storage import SQLiteStorage
+
+
+def import_csv(
+ csv_path: str | Path,
+ project: str,
+ name: str | None = None,
+ space_id: str | None = None,
+ dataset_id: str | None = None,
+ private: bool | None = None,
+) -> None:
+ """
+ Imports a CSV file into a Trackio project. The CSV file must contain a `"step"`
+ column, may optionally contain a `"timestamp"` column, and any other columns will be
+ treated as metrics. It should also include a header row with the column names.
+
+ TODO: call init() and return a Run object so that the user can continue to log metrics to it.
+
+ Args:
+ csv_path (`str` or `Path`):
+ The str or Path to the CSV file to import.
+ project (`str`):
+ The name of the project to import the CSV file into. Must not be an existing
+ project.
+ name (`str`, *optional*):
+ The name of the Run to import the CSV file into. If not provided, a default
+ name will be generated.
+ name (`str`, *optional*):
+ The name of the run (if not provided, a default name will be generated).
+ space_id (`str`, *optional*):
+ If provided, the project will be logged to a Hugging Face Space instead of a
+ local directory. Should be a complete Space name like `"username/reponame"`
+ or `"orgname/reponame"`, or just `"reponame"` in which case the Space will
+ be created in the currently-logged-in Hugging Face user's namespace. If the
+ Space does not exist, it will be created. If the Space already exists, the
+ project will be logged to it.
+ dataset_id (`str`, *optional*):
+ If provided, a persistent Hugging Face Dataset will be created and the
+ metrics will be synced to it every 5 minutes. Should be a complete Dataset
+ name like `"username/datasetname"` or `"orgname/datasetname"`, or just
+ `"datasetname"` in which case the Dataset will be created in the
+ currently-logged-in Hugging Face user's namespace. If the Dataset does not
+ exist, it will be created. If the Dataset already exists, the project will
+ be appended to it. If not provided, the metrics will be logged to a local
+ SQLite database, unless a `space_id` is provided, in which case a Dataset
+ will be automatically created with the same name as the Space but with the
+ `"_dataset"` suffix.
+ private (`bool`, *optional*):
+ Whether to make the Space private. If None (default), the repo will be
+ public unless the organization's default is private. This value is ignored
+ if the repo already exists.
+ """
+ if SQLiteStorage.get_runs(project):
+ raise ValueError(
+ f"Project '{project}' already exists. Cannot import CSV into existing project."
+ )
+
+ csv_path = Path(csv_path)
+ if not csv_path.exists():
+ raise FileNotFoundError(f"CSV file not found: {csv_path}")
+
+ df = pd.read_csv(csv_path)
+ if df.empty:
+ raise ValueError("CSV file is empty")
+
+ column_mapping = utils.simplify_column_names(df.columns.tolist())
+ df = df.rename(columns=column_mapping)
+
+ step_column = None
+ for col in df.columns:
+ if col.lower() == "step":
+ step_column = col
+ break
+
+ if step_column is None:
+ raise ValueError("CSV file must contain a 'step' or 'Step' column")
+
+ if name is None:
+ name = csv_path.stem
+
+ metrics_list = []
+ steps = []
+ timestamps = []
+
+ numeric_columns = []
+ for column in df.columns:
+ if column == step_column:
+ continue
+ if column == "timestamp":
+ continue
+
+ try:
+ pd.to_numeric(df[column], errors="raise")
+ numeric_columns.append(column)
+ except (ValueError, TypeError):
+ continue
+
+ for _, row in df.iterrows():
+ metrics = {}
+ for column in numeric_columns:
+ value = row[column]
+ if bool(pd.notna(value)):
+ metrics[column] = float(value)
+
+ if metrics:
+ metrics_list.append(metrics)
+ steps.append(int(row[step_column]))
+
+ if "timestamp" in df.columns and bool(pd.notna(row["timestamp"])):
+ timestamps.append(str(row["timestamp"]))
+ else:
+ timestamps.append("")
+
+ if metrics_list:
+ SQLiteStorage.bulk_log(
+ project=project,
+ run=name,
+ metrics_list=metrics_list,
+ steps=steps,
+ timestamps=timestamps,
+ )
+
+ print(
+ f"* Imported {len(metrics_list)} rows from {csv_path} into project '{project}' as run '{name}'"
+ )
+ print(f"* Metrics found: {', '.join(metrics_list[0].keys())}")
+
+ space_id, dataset_id = utils.preprocess_space_and_dataset_ids(space_id, dataset_id)
+ if dataset_id is not None:
+ os.environ["TRACKIO_DATASET_ID"] = dataset_id
+ print(f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}")
+
+ if space_id is None:
+ utils.print_dashboard_instructions(project)
+ else:
+ deploy.create_space_if_not_exists(
+ space_id=space_id, dataset_id=dataset_id, private=private
+ )
+ deploy.wait_until_space_exists(space_id=space_id)
+ deploy.upload_db_to_space(project=project, space_id=space_id)
+ print(
+ f"* View dashboard by going to: {deploy.SPACE_URL.format(space_id=space_id)}"
+ )
+
+
+def import_tf_events(
+ log_dir: str | Path,
+ project: str,
+ name: str | None = None,
+ space_id: str | None = None,
+ dataset_id: str | None = None,
+ private: bool | None = None,
+) -> None:
+ """
+ Imports TensorFlow Events files from a directory into a Trackio project. Each
+ subdirectory in the log directory will be imported as a separate run.
+
+ Args:
+ log_dir (`str` or `Path`):
+ The str or Path to the directory containing TensorFlow Events files.
+ project (`str`):
+ The name of the project to import the TensorFlow Events files into. Must not
+ be an existing project.
+ name (`str`, *optional*):
+ The name prefix for runs (if not provided, will use directory names). Each
+ subdirectory will create a separate run.
+ space_id (`str`, *optional*):
+ If provided, the project will be logged to a Hugging Face Space instead of a
+ local directory. Should be a complete Space name like `"username/reponame"`
+ or `"orgname/reponame"`, or just `"reponame"` in which case the Space will
+ be created in the currently-logged-in Hugging Face user's namespace. If the
+ Space does not exist, it will be created. If the Space already exists, the
+ project will be logged to it.
+ dataset_id (`str`, *optional*):
+ If provided, a persistent Hugging Face Dataset will be created and the
+ metrics will be synced to it every 5 minutes. Should be a complete Dataset
+ name like `"username/datasetname"` or `"orgname/datasetname"`, or just
+ `"datasetname"` in which case the Dataset will be created in the
+ currently-logged-in Hugging Face user's namespace. If the Dataset does not
+ exist, it will be created. If the Dataset already exists, the project will
+ be appended to it. If not provided, the metrics will be logged to a local
+ SQLite database, unless a `space_id` is provided, in which case a Dataset
+ will be automatically created with the same name as the Space but with the
+ `"_dataset"` suffix.
+ private (`bool`, *optional*):
+ Whether to make the Space private. If None (default), the repo will be
+ public unless the organization's default is private. This value is ignored
+ if the repo already exists.
+ """
+ try:
+ from tbparse import SummaryReader
+ except ImportError:
+ raise ImportError(
+ "The `tbparse` package is not installed but is required for `import_tf_events`. Please install trackio with the `tensorboard` extra: `pip install trackio[tensorboard]`."
+ )
+
+ if SQLiteStorage.get_runs(project):
+ raise ValueError(
+ f"Project '{project}' already exists. Cannot import TF events into existing project."
+ )
+
+ path = Path(log_dir)
+ if not path.exists():
+ raise FileNotFoundError(f"TF events directory not found: {path}")
+
+ # Use tbparse to read all tfevents files in the directory structure
+ reader = SummaryReader(str(path), extra_columns={"dir_name"})
+ df = reader.scalars
+
+ if df.empty:
+ raise ValueError(f"No TensorFlow events data found in {path}")
+
+ total_imported = 0
+ imported_runs = []
+
+ # Group by dir_name to create separate runs
+ for dir_name, group_df in df.groupby("dir_name"):
+ try:
+ # Determine run name based on directory name
+ if dir_name == "":
+ run_name = "main" # For files in the root directory
+ else:
+ run_name = dir_name # Use directory name
+
+ if name:
+ run_name = f"{name}_{run_name}"
+
+ if group_df.empty:
+ print(f"* Skipping directory {dir_name}: no scalar data found")
+ continue
+
+ metrics_list = []
+ steps = []
+ timestamps = []
+
+ for _, row in group_df.iterrows():
+ # Convert row values to appropriate types
+ tag = str(row["tag"])
+ value = float(row["value"])
+ step = int(row["step"])
+
+ metrics = {tag: value}
+ metrics_list.append(metrics)
+ steps.append(step)
+
+ # Use wall_time if present, else fallback
+ if "wall_time" in group_df.columns and not bool(
+ pd.isna(row["wall_time"])
+ ):
+ timestamps.append(str(row["wall_time"]))
+ else:
+ timestamps.append("")
+
+ if metrics_list:
+ SQLiteStorage.bulk_log(
+ project=project,
+ run=str(run_name),
+ metrics_list=metrics_list,
+ steps=steps,
+ timestamps=timestamps,
+ )
+
+ total_imported += len(metrics_list)
+ imported_runs.append(run_name)
+
+ print(
+ f"* Imported {len(metrics_list)} scalar events from directory '{dir_name}' as run '{run_name}'"
+ )
+ print(f"* Metrics in this run: {', '.join(set(group_df['tag']))}")
+
+ except Exception as e:
+ print(f"* Error processing directory {dir_name}: {e}")
+ continue
+
+ if not imported_runs:
+ raise ValueError("No valid TensorFlow events data could be imported")
+
+ print(f"* Total imported events: {total_imported}")
+ print(f"* Created runs: {', '.join(imported_runs)}")
+
+ space_id, dataset_id = utils.preprocess_space_and_dataset_ids(space_id, dataset_id)
+ if dataset_id is not None:
+ os.environ["TRACKIO_DATASET_ID"] = dataset_id
+ print(f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}")
+
+ if space_id is None:
+ utils.print_dashboard_instructions(project)
+ else:
+ deploy.create_space_if_not_exists(
+ space_id, dataset_id=dataset_id, private=private
+ )
+ deploy.wait_until_space_exists(space_id)
+ deploy.upload_db_to_space(project, space_id)
+ print(
+ f"* View dashboard by going to: {deploy.SPACE_URL.format(space_id=space_id)}"
+ )
diff --git a/media/__init__.py b/media/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b8458238898a1428866681eb11245b038d05239
--- /dev/null
+++ b/media/__init__.py
@@ -0,0 +1,34 @@
+"""
+Media module for Trackio.
+
+This module contains all media-related functionality including:
+- TrackioImage, TrackioVideo, TrackioAudio classes
+- Video writing utilities
+- Audio conversion utilities
+"""
+
+try:
+ from trackio.media.audio_writer import write_audio
+ from trackio.media.file_storage import FileStorage
+ from trackio.media.media import (
+ TrackioAudio,
+ TrackioImage,
+ TrackioMedia,
+ TrackioVideo,
+ )
+ from trackio.media.video_writer import write_video
+except ImportError:
+ from media.audio_writer import write_audio
+ from media.file_storage import FileStorage
+ from media.media import TrackioAudio, TrackioImage, TrackioMedia, TrackioVideo
+ from media.video_writer import write_video
+
+__all__ = [
+ "TrackioMedia",
+ "TrackioImage",
+ "TrackioVideo",
+ "TrackioAudio",
+ "FileStorage",
+ "write_video",
+ "write_audio",
+]
diff --git a/media/__pycache__/__init__.cpython-311.pyc b/media/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3d37bc91634ad6928d18b2f7ce88cf9ec99e4b44
Binary files /dev/null and b/media/__pycache__/__init__.cpython-311.pyc differ
diff --git a/media/__pycache__/__init__.cpython-312.pyc b/media/__pycache__/__init__.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8a5da9245d399fb58520b79f05e52a9d8804c238
Binary files /dev/null and b/media/__pycache__/__init__.cpython-312.pyc differ
diff --git a/media/__pycache__/audio_writer.cpython-311.pyc b/media/__pycache__/audio_writer.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..06444b0ce7456f0a19f6277385dd735afe21277d
Binary files /dev/null and b/media/__pycache__/audio_writer.cpython-311.pyc differ
diff --git a/media/__pycache__/audio_writer.cpython-312.pyc b/media/__pycache__/audio_writer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f14bd19f41636aaffaae94d6638f5c220ef9dc75
Binary files /dev/null and b/media/__pycache__/audio_writer.cpython-312.pyc differ
diff --git a/media/__pycache__/file_storage.cpython-311.pyc b/media/__pycache__/file_storage.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1ed456aa1d5a7706835316954ab5d7c27c7488e9
Binary files /dev/null and b/media/__pycache__/file_storage.cpython-311.pyc differ
diff --git a/media/__pycache__/file_storage.cpython-312.pyc b/media/__pycache__/file_storage.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8948c253b2c0fd368a7efafcd0e768f231a15dff
Binary files /dev/null and b/media/__pycache__/file_storage.cpython-312.pyc differ
diff --git a/media/__pycache__/media.cpython-311.pyc b/media/__pycache__/media.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4d2de5eb74a9ae715383df311f5c5c89eb65b8a5
Binary files /dev/null and b/media/__pycache__/media.cpython-311.pyc differ
diff --git a/media/__pycache__/media.cpython-312.pyc b/media/__pycache__/media.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7c243195ecdc053afb6f1a8dd092e684ea602196
Binary files /dev/null and b/media/__pycache__/media.cpython-312.pyc differ
diff --git a/media/__pycache__/utils.cpython-311.pyc b/media/__pycache__/utils.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a03a774f610d38a373dc93c897bc6034a804b87
Binary files /dev/null and b/media/__pycache__/utils.cpython-311.pyc differ
diff --git a/media/__pycache__/utils.cpython-312.pyc b/media/__pycache__/utils.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e63adfe93f72a58992c6c11d2aafbdbe42791996
Binary files /dev/null and b/media/__pycache__/utils.cpython-312.pyc differ
diff --git a/media/__pycache__/video_writer.cpython-311.pyc b/media/__pycache__/video_writer.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ab537b60ec96a5d0b38dfc1b7e62ac6bb5d98612
Binary files /dev/null and b/media/__pycache__/video_writer.cpython-311.pyc differ
diff --git a/media/__pycache__/video_writer.cpython-312.pyc b/media/__pycache__/video_writer.cpython-312.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bfd2c220e9775ea1ad8a61e712a6869b33988c22
Binary files /dev/null and b/media/__pycache__/video_writer.cpython-312.pyc differ
diff --git a/media/audio_writer.py b/media/audio_writer.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e45b5444e47f98a21ca6209046acfcc29854315
--- /dev/null
+++ b/media/audio_writer.py
@@ -0,0 +1,128 @@
+import warnings
+from pathlib import Path
+from typing import Literal
+
+import numpy as np
+
+try: # absolute imports when installed
+ from trackio.media.utils import check_ffmpeg_installed, check_path
+except ImportError: # relative imports for local execution on Spaces
+ from media.utils import check_ffmpeg_installed, check_path
+
+# Try to import pydub, but make it optional
+try:
+ from pydub import AudioSegment
+
+ PYDUB_AVAILABLE = True
+except ImportError:
+ PYDUB_AVAILABLE = False
+ AudioSegment = None
+
+SUPPORTED_FORMATS = ["wav", "mp3"]
+AudioFormatType = Literal["wav", "mp3"]
+
+
+def ensure_int16_pcm(data: np.ndarray) -> np.ndarray:
+ """
+ Convert input audio array to contiguous int16 PCM.
+ Peak normalization is applied to floating inputs.
+ """
+ arr = np.asarray(data)
+ if arr.ndim not in (1, 2):
+ raise ValueError("Audio data must be 1D (mono) or 2D ([samples, channels])")
+
+ if arr.dtype != np.int16:
+ warnings.warn(
+ f"Converting {arr.dtype} audio to int16 PCM; pass int16 to avoid conversion.",
+ stacklevel=2,
+ )
+
+ arr = np.nan_to_num(arr, copy=False)
+
+ # Floating types: normalize to peak 1.0, then scale to int16
+ if np.issubdtype(arr.dtype, np.floating):
+ max_abs = float(np.max(np.abs(arr))) if arr.size else 0.0
+ if max_abs > 0.0:
+ arr = arr / max_abs
+ out = (arr * 32767.0).clip(-32768, 32767).astype(np.int16, copy=False)
+ return np.ascontiguousarray(out)
+
+ converters: dict[np.dtype, callable] = {
+ np.dtype(np.int16): lambda a: a,
+ np.dtype(np.int32): lambda a: (
+ (a.astype(np.int32) // 65536).astype(np.int16, copy=False)
+ ),
+ np.dtype(np.uint16): lambda a: (
+ (a.astype(np.int32) - 32768).astype(np.int16, copy=False)
+ ),
+ np.dtype(np.uint8): lambda a: (
+ (a.astype(np.int32) * 257 - 32768).astype(np.int16, copy=False)
+ ),
+ np.dtype(np.int8): lambda a: (
+ (a.astype(np.int32) * 256).astype(np.int16, copy=False)
+ ),
+ }
+
+ conv = converters.get(arr.dtype)
+ if conv is not None:
+ out = conv(arr)
+ return np.ascontiguousarray(out)
+ raise TypeError(f"Unsupported audio dtype: {arr.dtype}")
+
+
+def write_audio(
+ data: np.ndarray,
+ sample_rate: int,
+ filename: str | Path,
+ format: AudioFormatType = "wav",
+) -> None:
+ if not isinstance(sample_rate, int) or sample_rate <= 0:
+ raise ValueError(f"Invalid sample_rate: {sample_rate}")
+ if format not in SUPPORTED_FORMATS:
+ raise ValueError(
+ f"Unsupported format: {format}. Supported: {SUPPORTED_FORMATS}"
+ )
+
+ check_path(filename)
+
+ pcm = ensure_int16_pcm(data)
+
+ # If pydub is missing, allow WAV fallback, otherwise require pydub
+ if not PYDUB_AVAILABLE:
+ if format == "wav":
+ write_wav_simple(filename, pcm, sample_rate)
+ return
+ raise ImportError(
+ "pydub is required for non-WAV formats. Install with: pip install pydub"
+ )
+
+ if format != "wav":
+ check_ffmpeg_installed()
+
+ channels = 1 if pcm.ndim == 1 else pcm.shape[1]
+ audio = AudioSegment(
+ pcm.tobytes(),
+ frame_rate=sample_rate,
+ sample_width=2, # int16
+ channels=channels,
+ )
+
+ file = audio.export(str(filename), format=format)
+ file.close()
+
+
+def write_wav_simple(
+ file_path: str | Path, data: np.ndarray, sample_rate: int = 44100
+) -> None:
+ """Fallback for basic WAV export when pydub is not available."""
+ import wave
+
+ pcm = ensure_int16_pcm(data)
+ if pcm.ndim > 2:
+ raise ValueError("Audio data must be 1D (mono) or 2D (stereo)")
+
+ with wave.open(str(file_path), "wb") as wav_file:
+ wav_file.setnchannels(1 if pcm.ndim == 1 else pcm.shape[1])
+ wav_file.setsampwidth(2) # 16-bit = 2 bytes
+ wav_file.setframerate(sample_rate)
+ wav_file.writeframes(pcm.tobytes())
diff --git a/media/file_storage.py b/media/file_storage.py
new file mode 100644
index 0000000000000000000000000000000000000000..fed947c0d8a5efde11b42c1c1f0cae05edb79f0a
--- /dev/null
+++ b/media/file_storage.py
@@ -0,0 +1,37 @@
+from pathlib import Path
+
+try: # absolute imports when installed
+ from trackio.utils import MEDIA_DIR
+except ImportError: # relative imports for local execution on Spaces
+ from utils import MEDIA_DIR
+
+
+class FileStorage:
+ @staticmethod
+ def get_project_media_path(
+ project: str,
+ run: str | None = None,
+ step: int | None = None,
+ filename: str | None = None,
+ ) -> Path:
+ if filename is not None and step is None:
+ raise ValueError("filename requires step")
+ if step is not None and run is None:
+ raise ValueError("step requires run")
+
+ path = MEDIA_DIR / project
+ if run:
+ path /= run
+ if step is not None:
+ path /= str(step)
+ if filename:
+ path /= filename
+ return path
+
+ @staticmethod
+ def init_project_media_path(
+ project: str, run: str | None = None, step: int | None = None
+ ) -> Path:
+ path = FileStorage.get_project_media_path(project, run, step)
+ path.mkdir(parents=True, exist_ok=True)
+ return path
diff --git a/media/media.py b/media/media.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fb7a8523b65c808b5333da12a99682910e5ae1d
--- /dev/null
+++ b/media/media.py
@@ -0,0 +1,378 @@
+import os
+import shutil
+import uuid
+from abc import ABC, abstractmethod
+from pathlib import Path
+from typing import Literal
+
+import numpy as np
+from PIL import Image as PILImage
+
+try: # absolute imports when installed
+ from trackio.media.audio_writer import AudioFormatType, write_audio
+ from trackio.media.file_storage import FileStorage
+ from trackio.media.video_writer import write_video
+ from trackio.utils import MEDIA_DIR
+except ImportError: # relative imports for local execution on Spaces
+ from media.audio_writer import AudioFormatType, write_audio
+ from media.file_storage import FileStorage
+ from media.video_writer import write_video
+ from utils import MEDIA_DIR
+
+
+class TrackioMedia(ABC):
+ """
+ Abstract base class for Trackio media objects
+ Provides shared functionality for file handling and serialization.
+ """
+
+ TYPE: str
+
+ def __init_subclass__(cls, **kwargs):
+ """Ensure subclasses define the TYPE attribute."""
+ super().__init_subclass__(**kwargs)
+ if not hasattr(cls, "TYPE") or cls.TYPE is None:
+ raise TypeError(f"Class {cls.__name__} must define TYPE attribute")
+
+ def __init__(self, value, caption: str | None = None):
+ self.caption = caption
+ self._value = value
+ self._file_path: Path | None = None
+
+ if isinstance(self._value, str | Path):
+ if not os.path.isfile(self._value):
+ raise ValueError(f"File not found: {self._value}")
+
+ def _file_extension(self) -> str:
+ if self._file_path:
+ return self._file_path.suffix[1:].lower()
+ if isinstance(self._value, str | Path):
+ path = Path(self._value)
+ return path.suffix[1:].lower()
+ if hasattr(self, "_format") and self._format:
+ return self._format
+ return "unknown"
+
+ def _get_relative_file_path(self) -> Path | None:
+ return self._file_path
+
+ def _get_absolute_file_path(self) -> Path | None:
+ if self._file_path:
+ return MEDIA_DIR / self._file_path
+ return None
+
+ def _save(self, project: str, run: str, step: int = 0):
+ if self._file_path:
+ return
+
+ media_dir = FileStorage.init_project_media_path(project, run, step)
+ filename = f"{uuid.uuid4()}.{self._file_extension()}"
+ file_path = media_dir / filename
+
+ self._save_media(file_path)
+ self._file_path = file_path.relative_to(MEDIA_DIR)
+
+ @abstractmethod
+ def _save_media(self, file_path: Path):
+ """
+ Performs the actual media saving logic.
+ """
+ pass
+
+ def _to_dict(self) -> dict:
+ if not self._file_path:
+ raise ValueError("Media must be saved to file before serialization")
+ return {
+ "_type": self.TYPE,
+ "file_path": str(self._get_relative_file_path()),
+ "caption": self.caption,
+ }
+
+
+TrackioImageSourceType = str | Path | np.ndarray | PILImage.Image
+
+
+class TrackioImage(TrackioMedia):
+ """
+ Initializes an Image object.
+
+ Example:
+ ```python
+ import trackio
+ import numpy as np
+ from PIL import Image
+
+ # Create an image from numpy array
+ image_data = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
+ image = trackio.Image(image_data, caption="Random image")
+ trackio.log({"my_image": image})
+
+ # Create an image from PIL Image
+ pil_image = Image.new('RGB', (100, 100), color='red')
+ image = trackio.Image(pil_image, caption="Red square")
+ trackio.log({"red_image": image})
+
+ # Create an image from file path
+ image = trackio.Image("path/to/image.jpg", caption="Photo from file")
+ trackio.log({"file_image": image})
+ ```
+
+ Args:
+ value (`str`, `Path`, `numpy.ndarray`, or `PIL.Image`, *optional*):
+ A path to an image, a PIL Image, or a numpy array of shape (height, width, channels).
+ If numpy array, should be of type `np.uint8` with RGB values in the range `[0, 255]`.
+ caption (`str`, *optional*):
+ A string caption for the image.
+ """
+
+ TYPE = "trackio.image"
+
+ def __init__(self, value: TrackioImageSourceType, caption: str | None = None):
+ super().__init__(value, caption)
+ self._format: str | None = None
+
+ if not isinstance(self._value, TrackioImageSourceType):
+ raise ValueError(
+ f"Invalid value type, expected {TrackioImageSourceType}, got {type(self._value)}"
+ )
+ if isinstance(self._value, np.ndarray) and self._value.dtype != np.uint8:
+ raise ValueError(
+ f"Invalid value dtype, expected np.uint8, got {self._value.dtype}"
+ )
+ if (
+ isinstance(self._value, np.ndarray | PILImage.Image)
+ and self._format is None
+ ):
+ self._format = "png"
+
+ def _as_pil(self) -> PILImage.Image | None:
+ try:
+ if isinstance(self._value, np.ndarray):
+ arr = np.asarray(self._value).astype("uint8")
+ return PILImage.fromarray(arr).convert("RGBA")
+ if isinstance(self._value, PILImage.Image):
+ return self._value.convert("RGBA")
+ except Exception as e:
+ raise ValueError(f"Failed to process image data: {self._value}") from e
+ return None
+
+ def _save_media(self, file_path: Path):
+ if pil := self._as_pil():
+ pil.save(file_path, format=self._format)
+ elif isinstance(self._value, str | Path):
+ if os.path.isfile(self._value):
+ shutil.copy(self._value, file_path)
+ else:
+ raise ValueError(f"File not found: {self._value}")
+
+
+TrackioVideoSourceType = str | Path | np.ndarray
+TrackioVideoFormatType = Literal["gif", "mp4", "webm"]
+
+
+class TrackioVideo(TrackioMedia):
+ """
+ Initializes a Video object.
+
+ Example:
+ ```python
+ import trackio
+ import numpy as np
+
+ # Create a simple video from numpy array
+ frames = np.random.randint(0, 255, (10, 3, 64, 64), dtype=np.uint8)
+ video = trackio.Video(frames, caption="Random video", fps=30)
+
+ # Create a batch of videos
+ batch_frames = np.random.randint(0, 255, (3, 10, 3, 64, 64), dtype=np.uint8)
+ batch_video = trackio.Video(batch_frames, caption="Batch of videos", fps=15)
+
+ # Create video from file path
+ video = trackio.Video("path/to/video.mp4", caption="Video from file")
+ ```
+
+ Args:
+ value (`str`, `Path`, or `numpy.ndarray`, *optional*):
+ A path to a video file, or a numpy array.
+ If numpy array, should be of type `np.uint8` with RGB values in the range `[0, 255]`.
+ It is expected to have shape of either (frames, channels, height, width) or (batch, frames, channels, height, width).
+ For the latter, the videos will be tiled into a grid.
+ caption (`str`, *optional*):
+ A string caption for the video.
+ fps (`int`, *optional*):
+ Frames per second for the video. Only used when value is an ndarray. Default is `24`.
+ format (`Literal["gif", "mp4", "webm"]`, *optional*):
+ Video format ("gif", "mp4", or "webm"). Only used when value is an ndarray. Default is "gif".
+ """
+
+ TYPE = "trackio.video"
+
+ def __init__(
+ self,
+ value: TrackioVideoSourceType,
+ caption: str | None = None,
+ fps: int | None = None,
+ format: TrackioVideoFormatType | None = None,
+ ):
+ super().__init__(value, caption)
+
+ if not isinstance(self._value, TrackioVideoSourceType):
+ raise ValueError(
+ f"Invalid value type, expected {TrackioVideoSourceType}, got {type(self._value)}"
+ )
+ if isinstance(self._value, np.ndarray):
+ if self._value.dtype != np.uint8:
+ raise ValueError(
+ f"Invalid value dtype, expected np.uint8, got {self._value.dtype}"
+ )
+ if format is None:
+ format = "gif"
+ if fps is None:
+ fps = 24
+ self._fps = fps
+ self._format = format
+
+ @property
+ def _codec(self) -> str:
+ match self._format:
+ case "gif":
+ return "gif"
+ case "mp4":
+ return "h264"
+ case "webm":
+ return "vp9"
+ case _:
+ raise ValueError(f"Unsupported format: {self._format}")
+
+ def _save_media(self, file_path: Path):
+ if isinstance(self._value, np.ndarray):
+ video = TrackioVideo._process_ndarray(self._value)
+ write_video(file_path, video, fps=self._fps, codec=self._codec)
+ elif isinstance(self._value, str | Path):
+ if os.path.isfile(self._value):
+ shutil.copy(self._value, file_path)
+ else:
+ raise ValueError(f"File not found: {self._value}")
+
+ @staticmethod
+ def _process_ndarray(value: np.ndarray) -> np.ndarray:
+ # Verify value is either 4D (single video) or 5D array (batched videos).
+ # Expected format: (frames, channels, height, width) or (batch, frames, channels, height, width)
+ if value.ndim < 4:
+ raise ValueError(
+ "Video requires at least 4 dimensions (frames, channels, height, width)"
+ )
+ if value.ndim > 5:
+ raise ValueError(
+ "Videos can have at most 5 dimensions (batch, frames, channels, height, width)"
+ )
+ if value.ndim == 4:
+ # Reshape to 5D with single batch: (1, frames, channels, height, width)
+ value = value[np.newaxis, ...]
+
+ value = TrackioVideo._tile_batched_videos(value)
+ return value
+
+ @staticmethod
+ def _tile_batched_videos(video: np.ndarray) -> np.ndarray:
+ """
+ Tiles a batch of videos into a grid of videos.
+
+ Input format: (batch, frames, channels, height, width) - original FCHW format
+ Output format: (frames, total_height, total_width, channels)
+ """
+ batch_size, frames, channels, height, width = video.shape
+
+ next_pow2 = 1 << (batch_size - 1).bit_length()
+ if batch_size != next_pow2:
+ pad_len = next_pow2 - batch_size
+ pad_shape = (pad_len, frames, channels, height, width)
+ padding = np.zeros(pad_shape, dtype=video.dtype)
+ video = np.concatenate((video, padding), axis=0)
+ batch_size = next_pow2
+
+ n_rows = 1 << ((batch_size.bit_length() - 1) // 2)
+ n_cols = batch_size // n_rows
+
+ # Reshape to grid layout: (n_rows, n_cols, frames, channels, height, width)
+ video = video.reshape(n_rows, n_cols, frames, channels, height, width)
+
+ # Rearrange dimensions to (frames, total_height, total_width, channels)
+ video = video.transpose(2, 0, 4, 1, 5, 3)
+ video = video.reshape(frames, n_rows * height, n_cols * width, channels)
+ return video
+
+
+TrackioAudioSourceType = str | Path | np.ndarray
+
+
+class TrackioAudio(TrackioMedia):
+ """
+ Initializes an Audio object.
+
+ Example:
+ ```python
+ import trackio
+ import numpy as np
+
+ # Generate a 1-second 440 Hz sine wave (mono)
+ sr = 16000
+ t = np.linspace(0, 1, sr, endpoint=False)
+ wave = 0.2 * np.sin(2 * np.pi * 440 * t)
+ audio = trackio.Audio(wave, caption="A4 sine", sample_rate=sr, format="wav")
+ trackio.log({"tone": audio})
+
+ # Stereo from numpy array (shape: samples, 2)
+ stereo = np.stack([wave, wave], axis=1)
+ audio = trackio.Audio(stereo, caption="Stereo", sample_rate=sr, format="mp3")
+ trackio.log({"stereo": audio})
+
+ # From an existing file
+ audio = trackio.Audio("path/to/audio.wav", caption="From file")
+ trackio.log({"file_audio": audio})
+ ```
+
+ Args:
+ value (`str`, `Path`, or `numpy.ndarray`, *optional*):
+ A path to an audio file, or a numpy array.
+ The array should be shaped `(samples,)` for mono or `(samples, 2)` for stereo.
+ Float arrays will be peak-normalized and converted to 16-bit PCM; integer arrays will be converted to 16-bit PCM as needed.
+ caption (`str`, *optional*):
+ A string caption for the audio.
+ sample_rate (`int`, *optional*):
+ Sample rate in Hz. Required when `value` is a numpy array.
+ format (`Literal["wav", "mp3"]`, *optional*):
+ Audio format used when `value` is a numpy array. Default is "wav".
+ """
+
+ TYPE = "trackio.audio"
+
+ def __init__(
+ self,
+ value: TrackioAudioSourceType,
+ caption: str | None = None,
+ sample_rate: int | None = None,
+ format: AudioFormatType | None = None,
+ ):
+ super().__init__(value, caption)
+ if isinstance(value, np.ndarray):
+ if sample_rate is None:
+ raise ValueError("Sample rate is required when value is an ndarray")
+ if format is None:
+ format = "wav"
+ self._format = format
+ self._sample_rate = sample_rate
+
+ def _save_media(self, file_path: Path):
+ if isinstance(self._value, np.ndarray):
+ write_audio(
+ data=self._value,
+ sample_rate=self._sample_rate,
+ filename=file_path,
+ format=self._format,
+ )
+ elif isinstance(self._value, str | Path):
+ if os.path.isfile(self._value):
+ shutil.copy(self._value, file_path)
+ else:
+ raise ValueError(f"File not found: {self._value}")
diff --git a/media/utils.py b/media/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..981d8d1203bb8a27ee248710904d101eeb768709
--- /dev/null
+++ b/media/utils.py
@@ -0,0 +1,23 @@
+import shutil
+from pathlib import Path
+
+
+def check_path(file_path: str | Path) -> None:
+ """Raise an error if the parent directory does not exist."""
+ file_path = Path(file_path)
+ if not file_path.parent.exists():
+ try:
+ file_path.parent.mkdir(parents=True, exist_ok=True)
+ except OSError as e:
+ raise ValueError(
+ f"Failed to create parent directory {file_path.parent}: {e}"
+ )
+
+
+def check_ffmpeg_installed() -> None:
+ """Raise an error if ffmpeg is not available on the system PATH."""
+ if shutil.which("ffmpeg") is None:
+ raise RuntimeError(
+ "ffmpeg is required to write video but was not found on your system. "
+ "Please install ffmpeg and ensure it is available on your PATH."
+ )
diff --git a/media/video_writer.py b/media/video_writer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7830b09355773e0b417143d8d122016acc241789
--- /dev/null
+++ b/media/video_writer.py
@@ -0,0 +1,109 @@
+import subprocess
+from pathlib import Path
+from typing import Literal
+
+import numpy as np
+
+try: # absolute imports when installed
+ from trackio.media.utils import check_ffmpeg_installed, check_path
+except ImportError: # relative imports for local execution on Spaces
+ from media.utils import check_ffmpeg_installed, check_path
+
+VideoCodec = Literal["h264", "vp9", "gif"]
+
+
+def _check_array_format(video: np.ndarray) -> None:
+ """Raise an error if the array is not in the expected format."""
+ if not (video.ndim == 4 and video.shape[-1] == 3):
+ raise ValueError(
+ f"Expected RGB input shaped (F, H, W, 3), got {video.shape}. "
+ f"Input has {video.ndim} dimensions, expected 4."
+ )
+ if video.dtype != np.uint8:
+ raise TypeError(
+ f"Expected dtype=uint8, got {video.dtype}. "
+ "Please convert your video data to uint8 format."
+ )
+
+
+def write_video(
+ file_path: str | Path, video: np.ndarray, fps: float, codec: VideoCodec
+) -> None:
+ """RGB uint8 only, shape (F, H, W, 3)."""
+ check_ffmpeg_installed()
+ check_path(file_path)
+
+ if codec not in {"h264", "vp9", "gif"}:
+ raise ValueError("Unsupported codec. Use h264, vp9, or gif.")
+
+ arr = np.asarray(video)
+ _check_array_format(arr)
+
+ frames = np.ascontiguousarray(arr)
+ _, height, width, _ = frames.shape
+ out_path = str(file_path)
+
+ cmd = [
+ "ffmpeg",
+ "-y",
+ "-f",
+ "rawvideo",
+ "-s",
+ f"{width}x{height}",
+ "-pix_fmt",
+ "rgb24",
+ "-r",
+ str(fps),
+ "-i",
+ "-",
+ "-an",
+ ]
+
+ if codec == "gif":
+ video_filter = "split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse"
+ cmd += [
+ "-vf",
+ video_filter,
+ "-loop",
+ "0",
+ ]
+ elif codec == "h264":
+ cmd += [
+ "-vcodec",
+ "libx264",
+ "-pix_fmt",
+ "yuv420p",
+ "-movflags",
+ "+faststart",
+ ]
+ elif codec == "vp9":
+ bpp = 0.08
+ bps = int(width * height * fps * bpp)
+ if bps >= 1_000_000:
+ bitrate = f"{round(bps / 1_000_000)}M"
+ elif bps >= 1_000:
+ bitrate = f"{round(bps / 1_000)}k"
+ else:
+ bitrate = str(max(bps, 1))
+ cmd += [
+ "-vcodec",
+ "libvpx-vp9",
+ "-b:v",
+ bitrate,
+ "-pix_fmt",
+ "yuv420p",
+ ]
+ cmd += [out_path]
+ proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stderr=subprocess.PIPE)
+ try:
+ for frame in frames:
+ proc.stdin.write(frame.tobytes())
+ finally:
+ if proc.stdin:
+ proc.stdin.close()
+ stderr = (
+ proc.stderr.read().decode("utf-8", errors="ignore") if proc.stderr else ""
+ )
+ ret = proc.wait()
+ if ret != 0:
+ raise RuntimeError(f"ffmpeg failed with code {ret}\n{stderr}")
diff --git a/package.json b/package.json
new file mode 100644
index 0000000000000000000000000000000000000000..0829756d7f184ea5f38db48194b7a1b4592f9095
--- /dev/null
+++ b/package.json
@@ -0,0 +1,6 @@
+{
+ "name": "trackio",
+ "version": "0.8.0",
+ "description": "",
+ "python": "true"
+}
diff --git a/py.typed b/py.typed
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/run.py b/run.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7be3fc189678b7488c7aba2e1586be5eff93b13
--- /dev/null
+++ b/run.py
@@ -0,0 +1,180 @@
+import threading
+import time
+import warnings
+from datetime import datetime, timezone
+
+import huggingface_hub
+from gradio_client import Client, handle_file
+
+from trackio import utils
+from trackio.histogram import Histogram
+from trackio.media import TrackioMedia
+from trackio.sqlite_storage import SQLiteStorage
+from trackio.table import Table
+from trackio.typehints import LogEntry, UploadEntry
+
+BATCH_SEND_INTERVAL = 0.5
+
+
+class Run:
+ def __init__(
+ self,
+ url: str,
+ project: str,
+ client: Client | None,
+ name: str | None = None,
+ group: str | None = None,
+ config: dict | None = None,
+ space_id: str | None = None,
+ ):
+ self.url = url
+ self.project = project
+ self._client_lock = threading.Lock()
+ self._client_thread = None
+ self._client = client
+ self._space_id = space_id
+ self.name = name or utils.generate_readable_name(
+ SQLiteStorage.get_runs(project), space_id
+ )
+ self.group = group
+ self.config = utils.to_json_safe(config or {})
+
+ if isinstance(self.config, dict):
+ for key in self.config:
+ if key.startswith("_"):
+ raise ValueError(
+ f"Config key '{key}' is reserved (keys starting with '_' are reserved for internal use)"
+ )
+
+ self.config["_Username"] = self._get_username()
+ self.config["_Created"] = datetime.now(timezone.utc).isoformat()
+ self.config["_Group"] = self.group
+
+ self._queued_logs: list[LogEntry] = []
+ self._queued_uploads: list[UploadEntry] = []
+ self._stop_flag = threading.Event()
+ self._config_logged = False
+
+ self._client_thread = threading.Thread(target=self._init_client_background)
+ self._client_thread.daemon = True
+ self._client_thread.start()
+
+ def _get_username(self) -> str | None:
+ """Get the current HuggingFace username if logged in, otherwise None."""
+ try:
+ who = huggingface_hub.whoami()
+ return who["name"] if who else None
+ except Exception:
+ return None
+
+ def _batch_sender(self):
+ """Send batched logs every BATCH_SEND_INTERVAL."""
+ while not self._stop_flag.is_set() or len(self._queued_logs) > 0:
+ if not self._stop_flag.is_set():
+ time.sleep(BATCH_SEND_INTERVAL)
+
+ with self._client_lock:
+ if self._client is None:
+ return
+ if self._queued_logs:
+ logs_to_send = self._queued_logs.copy()
+ self._queued_logs.clear()
+ self._client.predict(
+ api_name="/bulk_log",
+ logs=logs_to_send,
+ hf_token=huggingface_hub.utils.get_token(),
+ )
+ if self._queued_uploads:
+ uploads_to_send = self._queued_uploads.copy()
+ self._queued_uploads.clear()
+ self._client.predict(
+ api_name="/bulk_upload_media",
+ uploads=uploads_to_send,
+ hf_token=huggingface_hub.utils.get_token(),
+ )
+
+ def _init_client_background(self):
+ if self._client is None:
+ fib = utils.fibo()
+ for sleep_coefficient in fib:
+ try:
+ client = Client(self.url, verbose=False)
+
+ with self._client_lock:
+ self._client = client
+ break
+ except Exception:
+ pass
+ if sleep_coefficient is not None:
+ time.sleep(0.1 * sleep_coefficient)
+
+ self._batch_sender()
+
+ def _process_media(self, value: TrackioMedia, step: int | None) -> dict:
+ """
+ Serialize media in metrics and upload to space if needed.
+ """
+ value._save(self.project, self.name, step)
+ if self._space_id:
+ upload_entry: UploadEntry = {
+ "project": self.project,
+ "run": self.name,
+ "step": step,
+ "uploaded_file": handle_file(value._get_absolute_file_path()),
+ }
+ with self._client_lock:
+ self._queued_uploads.append(upload_entry)
+ return value._to_dict()
+
+ def log(self, metrics: dict, step: int | None = None):
+ renamed_keys = []
+ new_metrics = {}
+
+ for k, v in metrics.items():
+ if k in utils.RESERVED_KEYS or k.startswith("__"):
+ new_key = f"__{k}"
+ renamed_keys.append(k)
+ new_metrics[new_key] = v
+ else:
+ new_metrics[k] = v
+
+ if renamed_keys:
+ warnings.warn(f"Reserved keys renamed: {renamed_keys} → '__{{key}}'")
+
+ metrics = new_metrics
+ for key, value in metrics.items():
+ if isinstance(value, Table):
+ metrics[key] = value._to_dict(
+ project=self.project, run=self.name, step=step
+ )
+ elif isinstance(value, Histogram):
+ metrics[key] = value._to_dict()
+ elif isinstance(value, TrackioMedia):
+ metrics[key] = self._process_media(value, step)
+ metrics = utils.serialize_values(metrics)
+
+ config_to_log = None
+ if not self._config_logged and self.config:
+ config_to_log = utils.to_json_safe(self.config)
+ self._config_logged = True
+
+ log_entry: LogEntry = {
+ "project": self.project,
+ "run": self.name,
+ "metrics": metrics,
+ "step": step,
+ "config": config_to_log,
+ }
+
+ with self._client_lock:
+ self._queued_logs.append(log_entry)
+
+ def finish(self):
+ """Cleanup when run is finished."""
+ self._stop_flag.set()
+
+ time.sleep(2 * BATCH_SEND_INTERVAL)
+
+ if self._client_thread is not None:
+ print("* Run finished. Uploading logs to Trackio (please wait...)")
+ self._client_thread.join()
diff --git a/sqlite_storage.py b/sqlite_storage.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfa4ad745efd29986e22bd895380364e1752cd51
--- /dev/null
+++ b/sqlite_storage.py
@@ -0,0 +1,677 @@
+import os
+import platform
+import sqlite3
+import time
+from datetime import datetime
+from pathlib import Path
+from threading import Lock
+
+try:
+ import fcntl
+except ImportError: # fcntl is not available on Windows
+ fcntl = None
+
+import huggingface_hub as hf
+import orjson
+import pandas as pd
+
+try: # absolute imports when installed from PyPI
+ from trackio.commit_scheduler import CommitScheduler
+ from trackio.dummy_commit_scheduler import DummyCommitScheduler
+ from trackio.utils import (
+ TRACKIO_DIR,
+ deserialize_values,
+ serialize_values,
+ )
+except ImportError: # relative imports when installed from source on Spaces
+ from commit_scheduler import CommitScheduler
+ from dummy_commit_scheduler import DummyCommitScheduler
+ from utils import TRACKIO_DIR, deserialize_values, serialize_values
+
+DB_EXT = ".db"
+
+
+class ProcessLock:
+ """A file-based lock that works across processes. Is a no-op on Windows."""
+
+ def __init__(self, lockfile_path: Path):
+ self.lockfile_path = lockfile_path
+ self.lockfile = None
+ self.is_windows = platform.system() == "Windows"
+
+ def __enter__(self):
+ """Acquire the lock with retry logic."""
+ if self.is_windows:
+ return self
+ self.lockfile_path.parent.mkdir(parents=True, exist_ok=True)
+ self.lockfile = open(self.lockfile_path, "w")
+
+ max_retries = 100
+ for attempt in range(max_retries):
+ try:
+ fcntl.flock(self.lockfile.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
+ return self
+ except IOError:
+ if attempt < max_retries - 1:
+ time.sleep(0.1)
+ else:
+ raise IOError("Could not acquire database lock after 10 seconds")
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ """Release the lock."""
+ if self.is_windows:
+ return
+
+ if self.lockfile:
+ fcntl.flock(self.lockfile.fileno(), fcntl.LOCK_UN)
+ self.lockfile.close()
+
+
+class SQLiteStorage:
+ _dataset_import_attempted = False
+ _current_scheduler: CommitScheduler | DummyCommitScheduler | None = None
+ _scheduler_lock = Lock()
+
+ @staticmethod
+ def _get_connection(db_path: Path) -> sqlite3.Connection:
+ conn = sqlite3.connect(str(db_path), timeout=30.0)
+ # Keep WAL for concurrency + performance on many small writes
+ conn.execute("PRAGMA journal_mode = WAL")
+ # ---- Minimal perf tweaks for many tiny transactions ----
+ # NORMAL = fsync at critical points only (safer than OFF, much faster than FULL)
+ conn.execute("PRAGMA synchronous = NORMAL")
+ # Keep temp data in memory to avoid disk hits during small writes
+ conn.execute("PRAGMA temp_store = MEMORY")
+ # Give SQLite a bit more room for cache (negative = KB, engine-managed)
+ conn.execute("PRAGMA cache_size = -20000")
+ # --------------------------------------------------------
+ conn.row_factory = sqlite3.Row
+ return conn
+
+ @staticmethod
+ def _get_process_lock(project: str) -> ProcessLock:
+ lockfile_path = TRACKIO_DIR / f"{project}.lock"
+ return ProcessLock(lockfile_path)
+
+ @staticmethod
+ def get_project_db_filename(project: str) -> str:
+ """Get the database filename for a specific project."""
+ safe_project_name = "".join(
+ c for c in project if c.isalnum() or c in ("-", "_")
+ ).rstrip()
+ if not safe_project_name:
+ safe_project_name = "default"
+ return f"{safe_project_name}{DB_EXT}"
+
+ @staticmethod
+ def get_project_db_path(project: str) -> Path:
+ """Get the database path for a specific project."""
+ filename = SQLiteStorage.get_project_db_filename(project)
+ return TRACKIO_DIR / filename
+
+ @staticmethod
+ def init_db(project: str) -> Path:
+ """
+ Initialize the SQLite database with required tables.
+ Returns the database path.
+ """
+ db_path = SQLiteStorage.get_project_db_path(project)
+ db_path.parent.mkdir(parents=True, exist_ok=True)
+ with SQLiteStorage._get_process_lock(project):
+ with sqlite3.connect(str(db_path), timeout=30.0) as conn:
+ conn.execute("PRAGMA journal_mode = WAL")
+ conn.execute("PRAGMA synchronous = NORMAL")
+ conn.execute("PRAGMA temp_store = MEMORY")
+ conn.execute("PRAGMA cache_size = -20000")
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ CREATE TABLE IF NOT EXISTS metrics (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ timestamp TEXT NOT NULL,
+ run_name TEXT NOT NULL,
+ step INTEGER NOT NULL,
+ metrics TEXT NOT NULL
+ )
+ """
+ )
+ cursor.execute(
+ """
+ CREATE TABLE IF NOT EXISTS configs (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ run_name TEXT NOT NULL,
+ config TEXT NOT NULL,
+ created_at TEXT NOT NULL,
+ UNIQUE(run_name)
+ )
+ """
+ )
+ cursor.execute(
+ """
+ CREATE INDEX IF NOT EXISTS idx_metrics_run_step
+ ON metrics(run_name, step)
+ """
+ )
+ cursor.execute(
+ """
+ CREATE INDEX IF NOT EXISTS idx_configs_run_name
+ ON configs(run_name)
+ """
+ )
+ cursor.execute(
+ """
+ CREATE INDEX IF NOT EXISTS idx_metrics_run_timestamp
+ ON metrics(run_name, timestamp)
+ """
+ )
+ conn.commit()
+ return db_path
+
+ @staticmethod
+ def export_to_parquet():
+ """
+ Exports all projects' DB files as Parquet under the same path but with extension ".parquet".
+ """
+ # don't attempt to export (potentially wrong/blank) data before importing for the first time
+ if not SQLiteStorage._dataset_import_attempted:
+ return
+ if not TRACKIO_DIR.exists():
+ return
+
+ all_paths = os.listdir(TRACKIO_DIR)
+ db_names = [f for f in all_paths if f.endswith(DB_EXT)]
+ for db_name in db_names:
+ db_path = TRACKIO_DIR / db_name
+ parquet_path = db_path.with_suffix(".parquet")
+ if (not parquet_path.exists()) or (
+ db_path.stat().st_mtime > parquet_path.stat().st_mtime
+ ):
+ with sqlite3.connect(str(db_path)) as conn:
+ df = pd.read_sql("SELECT * FROM metrics", conn)
+ # break out the single JSON metrics column into individual columns
+ metrics = df["metrics"].copy()
+ metrics = pd.DataFrame(
+ metrics.apply(
+ lambda x: deserialize_values(orjson.loads(x))
+ ).values.tolist(),
+ index=df.index,
+ )
+ del df["metrics"]
+ for col in metrics.columns:
+ df[col] = metrics[col]
+
+ df.to_parquet(parquet_path)
+
+ @staticmethod
+ def _cleanup_wal_sidecars(db_path: Path) -> None:
+ """Remove leftover -wal/-shm files for a DB basename (prevents disk I/O errors)."""
+ for suffix in ("-wal", "-shm"):
+ sidecar = Path(str(db_path) + suffix)
+ try:
+ if sidecar.exists():
+ sidecar.unlink()
+ except Exception:
+ pass
+
+ @staticmethod
+ def import_from_parquet():
+ """
+ Imports to all DB files that have matching files under the same path but with extension ".parquet".
+ """
+ if not TRACKIO_DIR.exists():
+ return
+
+ all_paths = os.listdir(TRACKIO_DIR)
+ parquet_names = [f for f in all_paths if f.endswith(".parquet")]
+ for pq_name in parquet_names:
+ parquet_path = TRACKIO_DIR / pq_name
+ db_path = parquet_path.with_suffix(DB_EXT)
+
+ SQLiteStorage._cleanup_wal_sidecars(db_path)
+
+ df = pd.read_parquet(parquet_path)
+ # fix up df to have a single JSON metrics column
+ if "metrics" not in df.columns:
+ # separate other columns from metrics
+ metrics = df.copy()
+ other_cols = ["id", "timestamp", "run_name", "step"]
+ df = df[other_cols]
+ for col in other_cols:
+ del metrics[col]
+ # combine them all into a single metrics col
+ metrics = orjson.loads(metrics.to_json(orient="records"))
+ df["metrics"] = [orjson.dumps(serialize_values(row)) for row in metrics]
+
+ with sqlite3.connect(str(db_path), timeout=30.0) as conn:
+ df.to_sql("metrics", conn, if_exists="replace", index=False)
+ conn.commit()
+
+ @staticmethod
+ def get_scheduler():
+ """
+ Get the scheduler for the database based on the environment variables.
+ This applies to both local and Spaces.
+ """
+ with SQLiteStorage._scheduler_lock:
+ if SQLiteStorage._current_scheduler is not None:
+ return SQLiteStorage._current_scheduler
+ hf_token = os.environ.get("HF_TOKEN")
+ dataset_id = os.environ.get("TRACKIO_DATASET_ID")
+ space_repo_name = os.environ.get("SPACE_REPO_NAME")
+ if dataset_id is None or space_repo_name is None:
+ scheduler = DummyCommitScheduler()
+ else:
+ scheduler = CommitScheduler(
+ repo_id=dataset_id,
+ repo_type="dataset",
+ folder_path=TRACKIO_DIR,
+ private=True,
+ allow_patterns=["*.parquet", "media/**/*"],
+ squash_history=True,
+ token=hf_token,
+ on_before_commit=SQLiteStorage.export_to_parquet,
+ )
+ SQLiteStorage._current_scheduler = scheduler
+ return scheduler
+
+ @staticmethod
+ def log(project: str, run: str, metrics: dict, step: int | None = None):
+ """
+ Safely log metrics to the database. Before logging, this method will ensure the database exists
+ and is set up with the correct tables. It also uses a cross-process lock to prevent
+ database locking errors when multiple processes access the same database.
+
+ This method is not used in the latest versions of Trackio (replaced by bulk_log) but
+ is kept for backwards compatibility for users who are connecting to a newer version of
+ a Trackio Spaces dashboard with an older version of Trackio installed locally.
+ """
+ db_path = SQLiteStorage.init_db(project)
+ with SQLiteStorage._get_process_lock(project):
+ with SQLiteStorage._get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT MAX(step)
+ FROM metrics
+ WHERE run_name = ?
+ """,
+ (run,),
+ )
+ last_step = cursor.fetchone()[0]
+ current_step = (
+ 0
+ if step is None and last_step is None
+ else (step if step is not None else last_step + 1)
+ )
+ current_timestamp = datetime.now().isoformat()
+ cursor.execute(
+ """
+ INSERT INTO metrics
+ (timestamp, run_name, step, metrics)
+ VALUES (?, ?, ?, ?)
+ """,
+ (
+ current_timestamp,
+ run,
+ current_step,
+ orjson.dumps(serialize_values(metrics)),
+ ),
+ )
+ conn.commit()
+
+ @staticmethod
+ def bulk_log(
+ project: str,
+ run: str,
+ metrics_list: list[dict],
+ steps: list[int] | None = None,
+ timestamps: list[str] | None = None,
+ config: dict | None = None,
+ ):
+ """
+ Safely log bulk metrics to the database. Before logging, this method will ensure the database exists
+ and is set up with the correct tables. It also uses a cross-process lock to prevent
+ database locking errors when multiple processes access the same database.
+ """
+ if not metrics_list:
+ return
+
+ if timestamps is None:
+ timestamps = [datetime.now().isoformat()] * len(metrics_list)
+
+ db_path = SQLiteStorage.init_db(project)
+ with SQLiteStorage._get_process_lock(project):
+ with SQLiteStorage._get_connection(db_path) as conn:
+ cursor = conn.cursor()
+
+ if steps is None:
+ steps = list(range(len(metrics_list)))
+ elif any(s is None for s in steps):
+ cursor.execute(
+ "SELECT MAX(step) FROM metrics WHERE run_name = ?", (run,)
+ )
+ last_step = cursor.fetchone()[0]
+ current_step = 0 if last_step is None else last_step + 1
+ processed_steps = []
+ for step in steps:
+ if step is None:
+ processed_steps.append(current_step)
+ current_step += 1
+ else:
+ processed_steps.append(step)
+ steps = processed_steps
+
+ if len(metrics_list) != len(steps) or len(metrics_list) != len(
+ timestamps
+ ):
+ raise ValueError(
+ "metrics_list, steps, and timestamps must have the same length"
+ )
+
+ data = []
+ for i, metrics in enumerate(metrics_list):
+ data.append(
+ (
+ timestamps[i],
+ run,
+ steps[i],
+ orjson.dumps(serialize_values(metrics)),
+ )
+ )
+
+ cursor.executemany(
+ """
+ INSERT INTO metrics
+ (timestamp, run_name, step, metrics)
+ VALUES (?, ?, ?, ?)
+ """,
+ data,
+ )
+
+ if config:
+ current_timestamp = datetime.now().isoformat()
+ cursor.execute(
+ """
+ INSERT OR REPLACE INTO configs
+ (run_name, config, created_at)
+ VALUES (?, ?, ?)
+ """,
+ (
+ run,
+ orjson.dumps(serialize_values(config)),
+ current_timestamp,
+ ),
+ )
+
+ conn.commit()
+
+ @staticmethod
+ def get_logs(project: str, run: str) -> list[dict]:
+ """Retrieve logs for a specific run. Logs include the step count (int) and the timestamp (datetime object)."""
+ db_path = SQLiteStorage.get_project_db_path(project)
+ if not db_path.exists():
+ return []
+
+ with SQLiteStorage._get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT timestamp, step, metrics
+ FROM metrics
+ WHERE run_name = ?
+ ORDER BY timestamp
+ """,
+ (run,),
+ )
+
+ rows = cursor.fetchall()
+ results = []
+ for row in rows:
+ metrics = orjson.loads(row["metrics"])
+ metrics = deserialize_values(metrics)
+ metrics["timestamp"] = row["timestamp"]
+ metrics["step"] = row["step"]
+ results.append(metrics)
+ return results
+
+ @staticmethod
+ def load_from_dataset():
+ dataset_id = os.environ.get("TRACKIO_DATASET_ID")
+ space_repo_name = os.environ.get("SPACE_REPO_NAME")
+ if dataset_id is not None and space_repo_name is not None:
+ hfapi = hf.HfApi()
+ updated = False
+ if not TRACKIO_DIR.exists():
+ TRACKIO_DIR.mkdir(parents=True, exist_ok=True)
+ with SQLiteStorage.get_scheduler().lock:
+ try:
+ files = hfapi.list_repo_files(dataset_id, repo_type="dataset")
+ for file in files:
+ # Download parquet and media assets
+ if not (file.endswith(".parquet") or file.startswith("media/")):
+ continue
+ if (TRACKIO_DIR / file).exists():
+ continue
+ hf.hf_hub_download(
+ dataset_id, file, repo_type="dataset", local_dir=TRACKIO_DIR
+ )
+ updated = True
+ except hf.errors.EntryNotFoundError:
+ pass
+ except hf.errors.RepositoryNotFoundError:
+ pass
+ if updated:
+ SQLiteStorage.import_from_parquet()
+ SQLiteStorage._dataset_import_attempted = True
+
+ @staticmethod
+ def get_projects() -> list[str]:
+ """
+ Get list of all projects by scanning the database files in the trackio directory.
+ """
+ if not SQLiteStorage._dataset_import_attempted:
+ SQLiteStorage.load_from_dataset()
+
+ projects: set[str] = set()
+ if not TRACKIO_DIR.exists():
+ return []
+
+ for db_file in TRACKIO_DIR.glob(f"*{DB_EXT}"):
+ project_name = db_file.stem
+ projects.add(project_name)
+ return sorted(projects)
+
+ @staticmethod
+ def get_runs(project: str) -> list[str]:
+ """Get list of all runs for a project."""
+ db_path = SQLiteStorage.get_project_db_path(project)
+ if not db_path.exists():
+ return []
+
+ with SQLiteStorage._get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ "SELECT DISTINCT run_name FROM metrics",
+ )
+ return [row[0] for row in cursor.fetchall()]
+
+ @staticmethod
+ def get_max_steps_for_runs(project: str) -> dict[str, int]:
+ """Get the maximum step for each run in a project."""
+ db_path = SQLiteStorage.get_project_db_path(project)
+ if not db_path.exists():
+ return {}
+
+ with SQLiteStorage._get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT run_name, MAX(step) as max_step
+ FROM metrics
+ GROUP BY run_name
+ """
+ )
+
+ results = {}
+ for row in cursor.fetchall():
+ results[row["run_name"]] = row["max_step"]
+
+ return results
+
+ @staticmethod
+ def store_config(project: str, run: str, config: dict) -> None:
+ """Store configuration for a run."""
+ db_path = SQLiteStorage.init_db(project)
+
+ with SQLiteStorage._get_process_lock(project):
+ with SQLiteStorage._get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ current_timestamp = datetime.now().isoformat()
+
+ cursor.execute(
+ """
+ INSERT OR REPLACE INTO configs
+ (run_name, config, created_at)
+ VALUES (?, ?, ?)
+ """,
+ (run, orjson.dumps(serialize_values(config)), current_timestamp),
+ )
+ conn.commit()
+
+ @staticmethod
+ def get_run_config(project: str, run: str) -> dict | None:
+ """Get configuration for a specific run."""
+ db_path = SQLiteStorage.get_project_db_path(project)
+ if not db_path.exists():
+ return None
+
+ with SQLiteStorage._get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ try:
+ cursor.execute(
+ """
+ SELECT config FROM configs WHERE run_name = ?
+ """,
+ (run,),
+ )
+
+ row = cursor.fetchone()
+ if row:
+ config = orjson.loads(row["config"])
+ return deserialize_values(config)
+ return None
+ except sqlite3.OperationalError as e:
+ if "no such table: configs" in str(e):
+ return None
+ raise
+
+ @staticmethod
+ def delete_run(project: str, run: str) -> bool:
+ """Delete a run from the database (both metrics and config)."""
+ db_path = SQLiteStorage.get_project_db_path(project)
+ if not db_path.exists():
+ return False
+
+ with SQLiteStorage._get_process_lock(project):
+ with SQLiteStorage._get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ try:
+ cursor.execute("DELETE FROM metrics WHERE run_name = ?", (run,))
+ cursor.execute("DELETE FROM configs WHERE run_name = ?", (run,))
+ conn.commit()
+ return True
+ except sqlite3.Error:
+ return False
+
+ @staticmethod
+ def get_all_run_configs(project: str) -> dict[str, dict]:
+ """Get configurations for all runs in a project."""
+ db_path = SQLiteStorage.get_project_db_path(project)
+ if not db_path.exists():
+ return {}
+
+ with SQLiteStorage._get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ try:
+ cursor.execute(
+ """
+ SELECT run_name, config FROM configs
+ """
+ )
+
+ results = {}
+ for row in cursor.fetchall():
+ config = orjson.loads(row["config"])
+ results[row["run_name"]] = deserialize_values(config)
+ return results
+ except sqlite3.OperationalError as e:
+ if "no such table: configs" in str(e):
+ return {}
+ raise
+
+ @staticmethod
+ def get_metric_values(project: str, run: str, metric_name: str) -> list[dict]:
+ """Get all values for a specific metric in a project/run."""
+ db_path = SQLiteStorage.get_project_db_path(project)
+ if not db_path.exists():
+ return []
+
+ with SQLiteStorage._get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT timestamp, step, metrics
+ FROM metrics
+ WHERE run_name = ?
+ ORDER BY timestamp
+ """,
+ (run,),
+ )
+
+ rows = cursor.fetchall()
+ results = []
+ for row in rows:
+ metrics = orjson.loads(row["metrics"])
+ metrics = deserialize_values(metrics)
+ if metric_name in metrics:
+ results.append(
+ {
+ "timestamp": row["timestamp"],
+ "step": row["step"],
+ "value": metrics[metric_name],
+ }
+ )
+ return results
+
+ @staticmethod
+ def get_all_metrics_for_run(project: str, run: str) -> list[str]:
+ """Get all metric names for a specific project/run."""
+ db_path = SQLiteStorage.get_project_db_path(project)
+ if not db_path.exists():
+ return []
+
+ with SQLiteStorage._get_connection(db_path) as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT metrics
+ FROM metrics
+ WHERE run_name = ?
+ ORDER BY timestamp
+ """,
+ (run,),
+ )
+
+ rows = cursor.fetchall()
+ all_metrics = set()
+ for row in rows:
+ metrics = orjson.loads(row["metrics"])
+ metrics = deserialize_values(metrics)
+ for key in metrics.keys():
+ if key not in ["timestamp", "step"]:
+ all_metrics.add(key)
+ return sorted(list(all_metrics))
+
+ def finish(self):
+ """Cleanup when run is finished."""
+ pass
diff --git a/table.py b/table.py
new file mode 100644
index 0000000000000000000000000000000000000000..1dfa4da3a4ac21c600df3a32e9cc9c8882e39ccb
--- /dev/null
+++ b/table.py
@@ -0,0 +1,163 @@
+import os
+from typing import Any, Literal
+
+from pandas import DataFrame
+
+try:
+ from trackio.media.media import TrackioMedia
+ from trackio.utils import MEDIA_DIR
+except ImportError:
+ from media.media import TrackioMedia
+ from utils import MEDIA_DIR
+
+
+class Table:
+ """
+ Initializes a Table object. Tables can be used to log tabular data including images, numbers, and text.
+
+ Args:
+ columns (`list[str]`, *optional*):
+ Names of the columns in the table. Optional if `data` is provided. Not
+ expected if `dataframe` is provided. Currently ignored.
+ data (`list[list[Any]]`, *optional*):
+ 2D row-oriented array of values. Each value can be: a number, a string (treated as Markdown and truncated if too long),
+ or a `Trackio.Image` or list of `Trackio.Image` objects.
+ dataframe (`pandas.`DataFrame``, *optional*):
+ DataFrame object used to create the table. When set, `data` and `columns`
+ arguments are ignored.
+ rows (`list[list[any]]`, *optional*):
+ Currently ignored.
+ optional (`bool` or `list[bool]`, *optional*, defaults to `True`):
+ Currently ignored.
+ allow_mixed_types (`bool`, *optional*, defaults to `False`):
+ Currently ignored.
+ log_mode: (`Literal["IMMUTABLE", "MUTABLE", "INCREMENTAL"]` or `None`, *optional*, defaults to `"IMMUTABLE"`):
+ Currently ignored.
+ """
+
+ TYPE = "trackio.table"
+
+ def __init__(
+ self,
+ columns: list[str] | None = None,
+ data: list[list[Any]] | None = None,
+ dataframe: DataFrame | None = None,
+ rows: list[list[Any]] | None = None,
+ optional: bool | list[bool] = True,
+ allow_mixed_types: bool = False,
+ log_mode: Literal["IMMUTABLE", "MUTABLE", "INCREMENTAL"] | None = "IMMUTABLE",
+ ):
+ # TODO: implement support for columns, dtype, optional, allow_mixed_types, and log_mode.
+ # for now (like `rows`) they are included for API compat but don't do anything.
+ if dataframe is None:
+ self.data = DataFrame(data) if data is not None else DataFrame()
+ else:
+ self.data = dataframe
+
+ def _has_media_objects(self, dataframe: DataFrame) -> bool:
+ """Check if dataframe contains any TrackioMedia objects or lists of TrackioMedia objects."""
+ for col in dataframe.columns:
+ if dataframe[col].apply(lambda x: isinstance(x, TrackioMedia)).any():
+ return True
+ if (
+ dataframe[col]
+ .apply(
+ lambda x: isinstance(x, list)
+ and len(x) > 0
+ and isinstance(x[0], TrackioMedia)
+ )
+ .any()
+ ):
+ return True
+ return False
+
+ def _process_data(self, project: str, run: str, step: int = 0):
+ """Convert dataframe to dict format, processing any TrackioMedia objects if present."""
+ df = self.data
+ if not self._has_media_objects(df):
+ return df.to_dict(orient="records")
+
+ processed_df = df.copy()
+ for col in processed_df.columns:
+ for idx in processed_df.index:
+ value = processed_df.at[idx, col]
+ if isinstance(value, TrackioMedia):
+ value._save(project, run, step)
+ processed_df.at[idx, col] = value._to_dict()
+ if (
+ isinstance(value, list)
+ and len(value) > 0
+ and isinstance(value[0], TrackioMedia)
+ ):
+ [v._save(project, run, step) for v in value]
+ processed_df.at[idx, col] = [v._to_dict() for v in value]
+
+ return processed_df.to_dict(orient="records")
+
+ @staticmethod
+ def to_display_format(table_data: list[dict]) -> list[dict]:
+ """Convert stored table data to display format for UI rendering. Note
+ that this does not use the self.data attribute, but instead uses the
+ table_data parameter, which is is what the UI receives.
+
+ Args:
+ table_data: List of dictionaries representing table rows (from stored _value)
+
+ Returns:
+ Table data with images converted to markdown syntax and long text truncated.
+ """
+ truncate_length = int(os.getenv("TRACKIO_TABLE_TRUNCATE_LENGTH", "250"))
+
+ def convert_image_to_markdown(image_data: dict) -> str:
+ relative_path = image_data.get("file_path", "")
+ caption = image_data.get("caption", "")
+ absolute_path = MEDIA_DIR / relative_path
+ return f''
+
+ processed_data = []
+ for row in table_data:
+ processed_row = {}
+ for key, value in row.items():
+ if isinstance(value, dict) and value.get("_type") == "trackio.image":
+ processed_row[key] = convert_image_to_markdown(value)
+ elif (
+ isinstance(value, list)
+ and len(value) > 0
+ and isinstance(value[0], dict)
+ and value[0].get("_type") == "trackio.image"
+ ):
+ # This assumes that if the first item is an image, all items are images. Ok for now since we don't support mixed types in a single cell.
+ processed_row[key] = (
+ '
'
+ + "".join([convert_image_to_markdown(item) for item in value])
+ + "