| import dataclasses |
| import glob |
| import importlib |
| import random |
| import numpy as np |
| import torch |
| import warnings |
| import os |
| import time |
| import torch.utils.tensorboard as tensorboard |
| from torch import distributed as dist |
| import sys |
| import yaml |
| import json |
| import re |
| import pathlib |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pylab as plt |
|
|
|
|
| def plot_spectrogram(spectrogram): |
| fig, ax = plt.subplots(figsize=(10, 2)) |
| im = ax.imshow(spectrogram, aspect="auto", origin="lower", |
| interpolation='none') |
| plt.colorbar(im, ax=ax) |
|
|
| fig.canvas.draw() |
| plt.close() |
|
|
| return fig |
|
|
|
|
| def seed_everything(seed, cudnn_deterministic=False): |
| """ |
| Function that sets seed for pseudo-random number generators in: |
| pytorch, numpy, python.random |
| |
| Args: |
| seed: the integer value seed for global random state |
| """ |
| if seed is not None: |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
| if cudnn_deterministic: |
| torch.backends.cudnn.deterministic = True |
| warnings.warn('You have chosen to seed training. ' |
| 'This will turn on the CUDNN deterministic setting, ' |
| 'which can slow down your training considerably! ' |
| 'You may see unexpected behavior when restarting ' |
| 'from checkpoints.') |
|
|
| def is_primary(): |
| return get_rank() == 0 |
|
|
|
|
| def get_rank(): |
| if not dist.is_available(): |
| return 0 |
| if not dist.is_initialized(): |
| return 0 |
|
|
| return dist.get_rank() |
|
|
|
|
| def load_yaml_config(path): |
| with open(path) as f: |
| config = yaml.full_load(f) |
| return config |
|
|
|
|
| def save_config_to_yaml(config, path): |
| assert path.endswith('.yaml') |
| with open(path, 'w') as f: |
| f.write(yaml.dump(config)) |
| f.close() |
|
|
|
|
| def save_dict_to_json(d, path, indent=None): |
| json.dump(d, open(path, 'w'), indent=indent) |
|
|
|
|
| def load_dict_from_json(path): |
| return json.load(open(path, 'r')) |
|
|
|
|
| def write_args(args, path): |
| args_dict = dict((name, getattr(args, name)) for name in dir(args)if not name.startswith('_')) |
| with open(path, 'a') as args_file: |
| args_file.write('==> torch version: {}\n'.format(torch.__version__)) |
| args_file.write('==> cudnn version: {}\n'.format(torch.backends.cudnn.version())) |
| args_file.write('==> Cmd:\n') |
| args_file.write(str(sys.argv)) |
| args_file.write('\n==> args:\n') |
| for k, v in sorted(args_dict.items()): |
| args_file.write(' %s: %s\n' % (str(k), str(v))) |
| args_file.close() |
|
|
|
|
| class Logger(object): |
| def __init__(self, args): |
| self.args = args |
| self.save_dir = args.log_dir |
| self.is_primary = is_primary() |
| |
| if self.is_primary: |
| os.makedirs(self.save_dir, exist_ok=True) |
| |
| |
| self.config_dir = os.path.join(self.save_dir, 'configs') |
| os.makedirs(self.config_dir, exist_ok=True) |
| file_name = os.path.join(self.config_dir, 'args.txt') |
| write_args(args, file_name) |
|
|
| log_dir = os.path.join(self.save_dir, 'logs') |
| if not os.path.exists(log_dir): |
| os.makedirs(log_dir, exist_ok=True) |
| self.text_writer = open(os.path.join(log_dir, 'log.txt'), 'a') |
| if args.tensorboard: |
| self.log_info('using tensorboard') |
| self.tb_writer = torch.utils.tensorboard.SummaryWriter(log_dir=log_dir) |
| else: |
| self.tb_writer = None |
|
|
| def save_config(self, config): |
| if self.is_primary: |
| save_config_to_yaml(config, os.path.join(self.config_dir, 'config.yaml')) |
|
|
| def log_info(self, info, check_primary=True): |
| if self.is_primary or (not check_primary): |
| print(info) |
| if self.is_primary: |
| info = str(info) |
| time_str = time.strftime('%Y-%m-%d-%H-%M') |
| info = '{}: {}'.format(time_str, info) |
| if not info.endswith('\n'): |
| info += '\n' |
| self.text_writer.write(info) |
| self.text_writer.flush() |
|
|
| def add_scalar(self, **kargs): |
| """Log a scalar variable.""" |
| if self.is_primary: |
| if self.tb_writer is not None: |
| self.tb_writer.add_scalar(**kargs) |
|
|
| def add_scalars(self, **kargs): |
| """Log a scalar variable.""" |
| if self.is_primary: |
| if self.tb_writer is not None: |
| self.tb_writer.add_scalars(**kargs) |
|
|
| def add_image(self, **kargs): |
| """Log a scalar variable.""" |
| if self.is_primary: |
| if self.tb_writer is not None: |
| self.tb_writer.add_image(**kargs) |
|
|
| def add_images(self, **kargs): |
| """Log a scalar variable.""" |
| if self.is_primary: |
| if self.tb_writer is not None: |
| self.tb_writer.add_images(**kargs) |
|
|
| def close(self): |
| if self.is_primary: |
| self.text_writer.close() |
| self.tb_writer.close() |
|
|
|
|
| def cal_model_size(model, name=""): |
|
|
| all_size = sum(p.numel() for p in model.parameters())/1024.0/1024.0 |
| return f'Model size of {name}: {all_size:.3f} MB' |
|
|
| param_size = 0 |
| param_sum = 0 |
| for param in model.parameters(): |
| param_size += param.nelement() * param.element_size() |
| param_sum += param.nelement() |
| buffer_size = 0 |
| buffer_sum = 0 |
| for buffer in model.buffers(): |
| buffer_size += buffer.nelement() * buffer.element_size() |
| buffer_sum += buffer.nelement() |
| all_size = (param_size + buffer_size) / 1024 / 1024 |
|
|
| return f'Model size of {name}: {all_size:.3f} MB' |
| |
| |
|
|
|
|
| def load_obj(obj_path: str, default_obj_path: str = ''): |
| """ Extract an object from a given path. |
| Args: |
| obj_path: Path to an object to be extracted, including the object name. |
| e.g.: `src.trainers.meta_trainer.MetaTrainer` |
| `src.models.ada_style_speech.AdaStyleSpeechModel` |
| default_obj_path: Default object path. |
| |
| Returns: |
| Extracted object. |
| Raises: |
| AttributeError: When the object does not have the given named attribute. |
| |
| """ |
| obj_path_list = obj_path.rsplit('.', 1) |
| obj_path = obj_path_list.pop(0) if len(obj_path_list) > 1 else default_obj_path |
| obj_name = obj_path_list[0] |
| module_obj = importlib.import_module(obj_path) |
| if not hasattr(module_obj, obj_name): |
| raise AttributeError(f'Object `{obj_name}` cannot be loaded from `{obj_path}`.') |
| return getattr(module_obj, obj_name) |
|
|
|
|
| def to_device(data, device=None, dtype=None, non_blocking=False, copy=False): |
| """Change the device of object recursively""" |
| if isinstance(data, dict): |
| return { |
| k: to_device(v, device, dtype, non_blocking, copy) for k, v in data.items() |
| } |
| elif dataclasses.is_dataclass(data) and not isinstance(data, type): |
| return type(data)( |
| *[ |
| to_device(v, device, dtype, non_blocking, copy) |
| for v in dataclasses.astuple(data) |
| ] |
| ) |
| |
| elif isinstance(data, tuple) and type(data) is not tuple: |
| return type(data)( |
| *[to_device(o, device, dtype, non_blocking, copy) for o in data] |
| ) |
| elif isinstance(data, (list, tuple)): |
| return type(data)(to_device(v, device, dtype, non_blocking, copy) for v in data) |
| elif isinstance(data, np.ndarray): |
| return to_device(torch.from_numpy(data), device, dtype, non_blocking, copy) |
| elif isinstance(data, torch.Tensor): |
| return data.to(device, dtype, non_blocking, copy) |
| else: |
| return data |
|
|
|
|
| def save_checkpoint(filepath, obj, ext='pth', num_ckpt_keep=10): |
| ckpts = sorted(pathlib.Path(filepath).parent.glob(f'*.{ext}')) |
| if len(ckpts) > num_ckpt_keep: |
| [os.remove(c) for c in ckpts[:-num_ckpt_keep]] |
| torch.save(obj, filepath) |
|
|
|
|
| def scan_checkpoint(cp_dir, prefix='ckpt_'): |
| pattern = os.path.join(cp_dir, prefix + '????????.pth') |
| cp_list = glob.glob(pattern) |
| if len(cp_list) == 0: |
| return None |
| return sorted(cp_list)[-1] |
|
|