Spaces:
Configuration error
Configuration error
| """ | |
| Copyright (c) Microsoft Corporation. | |
| Licensed under the MIT license. | |
| """ | |
| import os.path as op | |
| import torch | |
| import logging | |
| import code | |
| from custom_mesh_graphormer.utils.comm import get_world_size | |
| from custom_mesh_graphormer.datasets.human_mesh_tsv import (MeshTSVDataset, MeshTSVYamlDataset) | |
| from custom_mesh_graphormer.datasets.hand_mesh_tsv import (HandMeshTSVDataset, HandMeshTSVYamlDataset) | |
| def build_dataset(yaml_file, args, is_train=True, scale_factor=1): | |
| print(yaml_file) | |
| if not op.isfile(yaml_file): | |
| yaml_file = op.join(args.data_dir, yaml_file) | |
| # code.interact(local=locals()) | |
| assert op.isfile(yaml_file) | |
| return MeshTSVYamlDataset(yaml_file, is_train, False, scale_factor) | |
| class IterationBasedBatchSampler(torch.utils.data.sampler.BatchSampler): | |
| """ | |
| Wraps a BatchSampler, resampling from it until | |
| a specified number of iterations have been sampled | |
| """ | |
| def __init__(self, batch_sampler, num_iterations, start_iter=0): | |
| self.batch_sampler = batch_sampler | |
| self.num_iterations = num_iterations | |
| self.start_iter = start_iter | |
| def __iter__(self): | |
| iteration = self.start_iter | |
| while iteration <= self.num_iterations: | |
| # if the underlying sampler has a set_epoch method, like | |
| # DistributedSampler, used for making each process see | |
| # a different split of the dataset, then set it | |
| if hasattr(self.batch_sampler.sampler, "set_epoch"): | |
| self.batch_sampler.sampler.set_epoch(iteration) | |
| for batch in self.batch_sampler: | |
| iteration += 1 | |
| if iteration > self.num_iterations: | |
| break | |
| yield batch | |
| def __len__(self): | |
| return self.num_iterations | |
| def make_batch_data_sampler(sampler, images_per_gpu, num_iters=None, start_iter=0): | |
| batch_sampler = torch.utils.data.sampler.BatchSampler( | |
| sampler, images_per_gpu, drop_last=False | |
| ) | |
| if num_iters is not None and num_iters >= 0: | |
| batch_sampler = IterationBasedBatchSampler( | |
| batch_sampler, num_iters, start_iter | |
| ) | |
| return batch_sampler | |
| def make_data_sampler(dataset, shuffle, distributed): | |
| if distributed: | |
| return torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle) | |
| if shuffle: | |
| sampler = torch.utils.data.sampler.RandomSampler(dataset) | |
| else: | |
| sampler = torch.utils.data.sampler.SequentialSampler(dataset) | |
| return sampler | |
| def make_data_loader(args, yaml_file, is_distributed=True, | |
| is_train=True, start_iter=0, scale_factor=1): | |
| dataset = build_dataset(yaml_file, args, is_train=is_train, scale_factor=scale_factor) | |
| logger = logging.getLogger(__name__) | |
| if is_train==True: | |
| shuffle = True | |
| images_per_gpu = args.per_gpu_train_batch_size | |
| images_per_batch = images_per_gpu * get_world_size() | |
| iters_per_batch = len(dataset) // images_per_batch | |
| num_iters = iters_per_batch * args.num_train_epochs | |
| logger.info("Train with {} images per GPU.".format(images_per_gpu)) | |
| logger.info("Total batch size {}".format(images_per_batch)) | |
| logger.info("Total training steps {}".format(num_iters)) | |
| else: | |
| shuffle = False | |
| images_per_gpu = args.per_gpu_eval_batch_size | |
| num_iters = None | |
| start_iter = 0 | |
| sampler = make_data_sampler(dataset, shuffle, is_distributed) | |
| batch_sampler = make_batch_data_sampler( | |
| sampler, images_per_gpu, num_iters, start_iter | |
| ) | |
| data_loader = torch.utils.data.DataLoader( | |
| dataset, num_workers=args.num_workers, batch_sampler=batch_sampler, | |
| pin_memory=True, | |
| ) | |
| return data_loader | |
| #============================================================================================== | |
| def build_hand_dataset(yaml_file, args, is_train=True, scale_factor=1): | |
| print(yaml_file) | |
| if not op.isfile(yaml_file): | |
| yaml_file = op.join(args.data_dir, yaml_file) | |
| # code.interact(local=locals()) | |
| assert op.isfile(yaml_file) | |
| return HandMeshTSVYamlDataset(args, yaml_file, is_train, False, scale_factor) | |
| def make_hand_data_loader(args, yaml_file, is_distributed=True, | |
| is_train=True, start_iter=0, scale_factor=1): | |
| dataset = build_hand_dataset(yaml_file, args, is_train=is_train, scale_factor=scale_factor) | |
| logger = logging.getLogger(__name__) | |
| if is_train==True: | |
| shuffle = True | |
| images_per_gpu = args.per_gpu_train_batch_size | |
| images_per_batch = images_per_gpu * get_world_size() | |
| iters_per_batch = len(dataset) // images_per_batch | |
| num_iters = iters_per_batch * args.num_train_epochs | |
| logger.info("Train with {} images per GPU.".format(images_per_gpu)) | |
| logger.info("Total batch size {}".format(images_per_batch)) | |
| logger.info("Total training steps {}".format(num_iters)) | |
| else: | |
| shuffle = False | |
| images_per_gpu = args.per_gpu_eval_batch_size | |
| num_iters = None | |
| start_iter = 0 | |
| sampler = make_data_sampler(dataset, shuffle, is_distributed) | |
| batch_sampler = make_batch_data_sampler( | |
| sampler, images_per_gpu, num_iters, start_iter | |
| ) | |
| data_loader = torch.utils.data.DataLoader( | |
| dataset, num_workers=args.num_workers, batch_sampler=batch_sampler, | |
| pin_memory=True, | |
| ) | |
| return data_loader | |