| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import logging |
| |
|
| | from PIL import Image |
| | from torchvision import transforms |
| |
|
| | from .transforms import ( |
| | GaussianBlur, |
| | MaybeToTensor, |
| | make_normalize_transform, |
| | ) |
| |
|
| |
|
| | logger = logging.getLogger("dinov2") |
| |
|
| |
|
| | class DataAugmentationDINO(object): |
| | def __init__( |
| | self, |
| | global_crops_scale, |
| | local_crops_scale, |
| | local_crops_number, |
| | global_crops_size=224, |
| | local_crops_size=96, |
| | ): |
| | self.global_crops_scale = global_crops_scale |
| | self.local_crops_scale = local_crops_scale |
| | self.local_crops_number = local_crops_number |
| | self.global_crops_size = global_crops_size |
| | self.local_crops_size = local_crops_size |
| |
|
| | logger.info("###################################") |
| | logger.info("Using data augmentation parameters:") |
| | logger.info(f"global_crops_scale: {global_crops_scale}") |
| | logger.info(f"local_crops_scale: {local_crops_scale}") |
| | logger.info(f"local_crops_number: {local_crops_number}") |
| | logger.info(f"global_crops_size: {global_crops_size}") |
| | logger.info(f"local_crops_size: {local_crops_size}") |
| | logger.info("###################################") |
| |
|
| | |
| | self.geometric_augmentation_global = transforms.Compose( |
| | [ |
| | transforms.RandomResizedCrop( |
| | global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC |
| | ), |
| | transforms.RandomHorizontalFlip(p=0.5), |
| | ] |
| | ) |
| |
|
| | self.geometric_augmentation_local = transforms.Compose( |
| | [ |
| | transforms.RandomResizedCrop( |
| | local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC |
| | ), |
| | transforms.RandomHorizontalFlip(p=0.5), |
| | ] |
| | ) |
| |
|
| | |
| | color_jittering = transforms.Compose( |
| | [ |
| | transforms.RandomApply( |
| | [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], |
| | p=0.8, |
| | ), |
| | transforms.RandomGrayscale(p=0.2), |
| | ] |
| | ) |
| |
|
| | global_transfo1_extra = GaussianBlur(p=0.5) |
| |
|
| | global_transfo2_extra = transforms.Compose( |
| | [ |
| | GaussianBlur(p=0.1), |
| | ] |
| | ) |
| |
|
| | local_transfo_extra = GaussianBlur(p=0.5) |
| |
|
| | |
| | self.normalize = transforms.Compose( |
| | [ |
| | MaybeToTensor(), |
| | make_normalize_transform(), |
| | ] |
| | ) |
| |
|
| | self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize]) |
| | self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize]) |
| | self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize]) |
| |
|
| | def __call__(self, image): |
| | output = {} |
| |
|
| | |
| | im1_base = self.geometric_augmentation_global(image) |
| | global_crop_1 = self.global_transfo1(im1_base) |
| |
|
| | im2_base = self.geometric_augmentation_global(image) |
| | global_crop_2 = self.global_transfo2(im2_base) |
| |
|
| | output["global_crops"] = [global_crop_1, global_crop_2] |
| |
|
| | |
| | output["global_crops_teacher"] = [global_crop_1, global_crop_2] |
| |
|
| | |
| | local_crops = [ |
| | self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number) |
| | ] |
| | output["local_crops"] = local_crops |
| | output["offsets"] = () |
| |
|
| | return output |
| |
|
| |
|
| | def get_online_classification_augmentation_from_config(cfg) -> transforms.Compose: |
| | augmentation_config = cfg.evaluation.online.augmentation |
| | interpolation = getattr(Image.Resampling, augmentation_config.interpolation) |
| | resize_size = crop_size = cfg.crops.global_crops_size |
| | resize = transforms.Resize(resize_size, interpolation=interpolation) |
| | crop = transforms.CenterCrop(crop_size) |
| | affine = transforms.RandomAffine( |
| | degrees=augmentation_config.degrees, |
| | scale=augmentation_config.scale, |
| | shear=augmentation_config.shear, |
| | interpolation=interpolation, |
| | ) |
| | transforms_list = [ |
| | resize, |
| | crop, |
| | affine, |
| | MaybeToTensor(), |
| | make_normalize_transform(), |
| | ] |
| | if augmentation_config.horizontal_flip: |
| | transforms_list.append(transforms.RandomHorizontalFlip()) |
| | return transforms.Compose(transforms_list) |
| |
|