abidlabs's picture
abidlabs HF Staff
Upload folder using huggingface_hub
5ef7afe verified
import os
import shutil
from pathlib import Path
import numpy as np
from PIL import Image as PILImage
try:
from trackio.media.media import TrackioMedia
except ImportError:
from media.media import TrackioMedia
TrackioImageSourceType = str | Path | np.ndarray | PILImage.Image
class TrackioImage(TrackioMedia):
"""
Initializes an Image object.
Example:
```python
import trackio
import numpy as np
from PIL import Image
# Create an image from numpy array
image_data = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8)
image = trackio.Image(image_data, caption="Random image")
trackio.log({"my_image": image})
# Create an image from PIL Image
pil_image = Image.new('RGB', (100, 100), color='red')
image = trackio.Image(pil_image, caption="Red square")
trackio.log({"red_image": image})
# Create an image from file path
image = trackio.Image("path/to/image.jpg", caption="Photo from file")
trackio.log({"file_image": image})
```
Args:
value (`str`, `Path`, `numpy.ndarray`, or `PIL.Image`, *optional*):
A path to an image, a PIL Image, or a numpy array of shape (height, width, channels).
If numpy array, should be of type `np.uint8` with RGB values in the range `[0, 255]`.
caption (`str`, *optional*):
A string caption for the image.
"""
TYPE = "trackio.image"
def __init__(self, value: TrackioImageSourceType, caption: str | None = None):
super().__init__(value, caption)
self._format: str | None = None
if not isinstance(self._value, TrackioImageSourceType):
raise ValueError(
f"Invalid value type, expected {TrackioImageSourceType}, got {type(self._value)}"
)
if isinstance(self._value, np.ndarray) and self._value.dtype != np.uint8:
raise ValueError(
f"Invalid value dtype, expected np.uint8, got {self._value.dtype}"
)
if (
isinstance(self._value, np.ndarray | PILImage.Image)
and self._format is None
):
self._format = "png"
def _as_pil(self) -> PILImage.Image | None:
try:
if isinstance(self._value, np.ndarray):
arr = np.asarray(self._value).astype("uint8")
return PILImage.fromarray(arr).convert("RGBA")
if isinstance(self._value, PILImage.Image):
return self._value.convert("RGBA")
except Exception as e:
raise ValueError(f"Failed to process image data: {self._value}") from e
return None
def _save_media(self, file_path: Path):
if pil := self._as_pil():
pil.save(file_path, format=self._format)
elif isinstance(self._value, str | Path):
if os.path.isfile(self._value):
shutil.copy(self._value, file_path)
else:
raise ValueError(f"File not found: {self._value}")