abidlabs's picture
abidlabs HF Staff
Upload folder using huggingface_hub
5ef7afe verified
"""The Media and Tables page for the Trackio UI."""
import re
from dataclasses import dataclass
import gradio as gr
import pandas as pd
try:
import trackio.utils as utils
from trackio.media import TrackioAudio, TrackioImage, TrackioVideo
from trackio.sqlite_storage import SQLiteStorage
from trackio.table import Table
from trackio.ui import fns
from trackio.ui.components.colored_dropdown import ColoredDropdown
except ImportError:
import utils
from media import TrackioAudio, TrackioImage, TrackioVideo
from sqlite_storage import SQLiteStorage
from table import Table
from ui import fns
from ui.components.colored_dropdown import ColoredDropdown
def get_runs(project) -> list[str]:
if not project:
return []
return SQLiteStorage.get_runs(project)
@dataclass
class MediaData:
caption: str | None
file_path: str
type: str
def extract_media(logs: list[dict]) -> dict[str, list[MediaData]]:
media_by_key: dict[str, list[MediaData]] = {}
logs = sorted(logs, key=lambda x: x.get("step", 0))
for log in logs:
for key, value in log.items():
if isinstance(value, dict):
type = value.get("_type")
if (
type == TrackioImage.TYPE
or type == TrackioVideo.TYPE
or type == TrackioAudio.TYPE
):
if key not in media_by_key:
media_by_key[key] = []
try:
media_data = MediaData(
file_path=utils.MEDIA_DIR / value.get("file_path"),
type=type,
caption=value.get("caption"),
)
media_by_key[key].append(media_data)
except Exception as e:
print(f"Media currently unavailable: {key}: {e}")
return media_by_key
def filter_metrics_by_regex(metrics: list[str], filter_pattern: str) -> list[str]:
"""
Filter metrics using regex pattern.
Args:
metrics: List of metric names to filter
filter_pattern: Regex pattern to match against metric names
Returns:
List of metric names that match the pattern
"""
if not filter_pattern.strip():
return metrics
try:
pattern = re.compile(filter_pattern, re.IGNORECASE)
return [metric for metric in metrics if pattern.search(metric)]
except re.error:
return [
metric for metric in metrics if filter_pattern.lower() in metric.lower()
]
def refresh_runs_dropdown(project: str | None):
if project is None:
runs: list[str] = []
else:
runs = get_runs(project)
color_palette = utils.get_color_palette()
colors = [color_palette[i % len(color_palette)] for i in range(len(runs))]
return ColoredDropdown(
choices=runs,
colors=colors,
value=runs[0] if runs else None,
placeholder=f"Select a run ({len(runs)})",
)
with gr.Blocks() as media_page:
with gr.Sidebar() as sidebar:
logo_urls = utils.get_logo_urls()
logo = gr.Markdown(
f"""
<img src='{logo_urls["light"]}' width='80%' class='logo-light'>
<img src='{logo_urls["dark"]}' width='80%' class='logo-dark'>
"""
)
project_dd = gr.Dropdown(label="Project", allow_custom_value=True)
runs_dropdown = ColoredDropdown(choices=[], colors=[], label="Run")
navbar = gr.Navbar(
value=[
("Metrics", ""),
("Media & Tables", "/media"),
("Runs", "/runs"),
("Files", "/files"),
],
main_page_name=False,
)
timer = gr.Timer(value=1)
@gr.render(
triggers=[
media_page.load,
runs_dropdown.change,
project_dd.change,
],
inputs=[project_dd, runs_dropdown],
show_progress="hidden",
queue=False,
)
def display_media_and_tables(project: str | None, selected_run: str | None):
if not project or not selected_run:
gr.Markdown("*Select a project and run to view media and tables*")
return
logs = SQLiteStorage.get_logs(project, selected_run)
if not logs:
gr.Markdown("*No data found for this run*")
return
df = pd.DataFrame(logs)
media_by_key = extract_media(logs)
has_media = media_by_key and any(media_by_key.values())
has_tables = False
table_cols = df.select_dtypes(include="object").columns
table_cols = [c for c in table_cols if c not in utils.RESERVED_KEYS]
table_cols = [
c
for c in table_cols
if not (metric_df := df.dropna(subset=[c])).empty
and isinstance(first_value := metric_df[c].iloc[0], dict)
and first_value.get("_type") == Table.TYPE
]
has_tables = len(table_cols) > 0
if not has_media and not has_tables:
gr.Markdown("*No media or tables found for this run*")
return
if has_media:
for key, media_items in media_by_key.items():
image_and_video = [
item
for item in media_items
if item.type in [TrackioImage.TYPE, TrackioVideo.TYPE]
]
audio = [item for item in media_items if item.type == TrackioAudio.TYPE]
if image_and_video:
gr.Gallery(
[(item.file_path, item.caption) for item in image_and_video],
label=key,
columns=6,
elem_classes=("media-gallery"),
)
if audio:
with gr.Accordion(
label=key, elem_classes=("media-audio-accordion")
):
for i in range(0, len(audio), 3):
with gr.Row(elem_classes=("media-audio-row")):
for item in audio[i : i + 3]:
gr.Audio(
value=item.file_path,
label=item.caption,
elem_classes=("media-audio-item"),
)
if has_tables:
with gr.Accordion(f"Tables ({len(table_cols)})", open=True):
with gr.Row(key="row"):
for metric_idx, metric_name in enumerate(table_cols):
metric_df = df.dropna(subset=[metric_name])
if not metric_df.empty:
value = metric_df[metric_name]
first_value = value.iloc[0]
if (
isinstance(first_value, dict)
and "_type" in first_value
and first_value["_type"] == Table.TYPE
):
try:
with gr.Column():
s = gr.Slider(
value=len(value),
minimum=1,
maximum=len(value),
step=1,
container=False,
visible=len(value) > 1,
interactive=True,
)
processed_data = Table.to_display_format(
value.iloc[-1]["_value"]
)
df_table = pd.DataFrame(processed_data)
table = gr.DataFrame(
df_table,
label=f"{metric_name} (index {len(value)})",
key=f"table-{metric_idx}",
wrap=True,
datatype="markdown",
preserved_by_key=None,
)
def get_table_at_index(index: int):
value = metric_df[metric_name]
processed_data = Table.to_display_format(
value.iloc[index - 1]["_value"]
)
df_ = pd.DataFrame(processed_data)
return gr.DataFrame(
df_,
label=f"{metric_name} (index {index})",
)
s.input(
get_table_at_index,
inputs=s,
outputs=table,
show_progress="hidden",
)
except Exception as e:
gr.Warning(
f"Column {metric_name} failed to render as a table: {e}"
)
gr.on(
[timer.tick],
fn=lambda: gr.Dropdown(info=fns.get_project_info()),
outputs=[project_dd],
show_progress="hidden",
api_visibility="private",
)
gr.on(
[media_page.load],
fn=fns.get_projects,
outputs=project_dd,
show_progress="hidden",
queue=False,
api_visibility="private",
).then(
fns.update_navbar_value,
inputs=[project_dd],
outputs=[navbar],
show_progress="hidden",
api_visibility="private",
queue=False,
)
gr.on(
[project_dd.change],
fn=refresh_runs_dropdown,
inputs=[project_dd],
outputs=[runs_dropdown],
show_progress="hidden",
queue=False,
api_visibility="private",
).then(
fns.update_navbar_value,
inputs=[project_dd],
outputs=[navbar],
show_progress="hidden",
api_visibility="private",
queue=False,
)