| import os | |
| import shutil | |
| from pathlib import Path | |
| import numpy as np | |
| from PIL import Image as PILImage | |
| try: | |
| from trackio.media.media import TrackioMedia | |
| except ImportError: | |
| from media.media import TrackioMedia | |
| TrackioImageSourceType = str | Path | np.ndarray | PILImage.Image | |
| class TrackioImage(TrackioMedia): | |
| """ | |
| Initializes an Image object. | |
| Example: | |
| ```python | |
| import trackio | |
| import numpy as np | |
| from PIL import Image | |
| # Create an image from numpy array | |
| image_data = np.random.randint(0, 255, (64, 64, 3), dtype=np.uint8) | |
| image = trackio.Image(image_data, caption="Random image") | |
| trackio.log({"my_image": image}) | |
| # Create an image from PIL Image | |
| pil_image = Image.new('RGB', (100, 100), color='red') | |
| image = trackio.Image(pil_image, caption="Red square") | |
| trackio.log({"red_image": image}) | |
| # Create an image from file path | |
| image = trackio.Image("path/to/image.jpg", caption="Photo from file") | |
| trackio.log({"file_image": image}) | |
| ``` | |
| Args: | |
| value (`str`, `Path`, `numpy.ndarray`, or `PIL.Image`, *optional*): | |
| A path to an image, a PIL Image, or a numpy array of shape (height, width, channels). | |
| If numpy array, should be of type `np.uint8` with RGB values in the range `[0, 255]`. | |
| caption (`str`, *optional*): | |
| A string caption for the image. | |
| """ | |
| TYPE = "trackio.image" | |
| def __init__(self, value: TrackioImageSourceType, caption: str | None = None): | |
| super().__init__(value, caption) | |
| self._format: str | None = None | |
| if not isinstance(self._value, TrackioImageSourceType): | |
| raise ValueError( | |
| f"Invalid value type, expected {TrackioImageSourceType}, got {type(self._value)}" | |
| ) | |
| if isinstance(self._value, np.ndarray) and self._value.dtype != np.uint8: | |
| raise ValueError( | |
| f"Invalid value dtype, expected np.uint8, got {self._value.dtype}" | |
| ) | |
| if ( | |
| isinstance(self._value, np.ndarray | PILImage.Image) | |
| and self._format is None | |
| ): | |
| self._format = "png" | |
| def _as_pil(self) -> PILImage.Image | None: | |
| try: | |
| if isinstance(self._value, np.ndarray): | |
| arr = np.asarray(self._value).astype("uint8") | |
| return PILImage.fromarray(arr).convert("RGBA") | |
| if isinstance(self._value, PILImage.Image): | |
| return self._value.convert("RGBA") | |
| except Exception as e: | |
| raise ValueError(f"Failed to process image data: {self._value}") from e | |
| return None | |
| def _save_media(self, file_path: Path): | |
| if pil := self._as_pil(): | |
| pil.save(file_path, format=self._format) | |
| elif isinstance(self._value, str | Path): | |
| if os.path.isfile(self._value): | |
| shutil.copy(self._value, file_path) | |
| else: | |
| raise ValueError(f"File not found: {self._value}") | |