Spaces:
Configuration error
Configuration error
| import os | |
| import os.path as osp | |
| import numpy as np | |
| import argparse | |
| import pickle | |
| from tqdm import tqdm | |
| import time | |
| import random | |
| import imageio | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader | |
| from lib.utils.tools import * | |
| from lib.utils.learning import * | |
| from lib.utils.utils_data import flip_data | |
| from lib.utils.utils_mesh import flip_thetas_batch | |
| from lib.data.dataset_wild import WildDetDataset | |
| # from lib.model.loss import * | |
| from lib.model.model_mesh import MeshRegressor | |
| from lib.utils.vismo import render_and_save, motion2video_mesh | |
| from lib.utils.utils_smpl import * | |
| from scipy.optimize import least_squares | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", type=str, default="configs/mesh/MB_ft_pw3d.yaml", help="Path to the config file.") | |
| parser.add_argument('-e', '--evaluate', default='checkpoint/mesh/FT_MB_release_MB_ft_pw3d/best_epoch.bin', type=str, metavar='FILENAME', help='checkpoint to evaluate (file name)') | |
| parser.add_argument('-j', '--json_path', type=str, help='alphapose detection result json path') | |
| parser.add_argument('-v', '--vid_path', type=str, help='video path') | |
| parser.add_argument('-o', '--out_path', type=str, help='output path') | |
| parser.add_argument('--ref_3d_motion_path', type=str, default=None, help='3D motion path') | |
| parser.add_argument('--pixel', action='store_true', help='align with pixle coordinates') | |
| parser.add_argument('--focus', type=int, default=None, help='target person id') | |
| parser.add_argument('--clip_len', type=int, default=243, help='clip length for network input') | |
| opts = parser.parse_args() | |
| return opts | |
| def err(p, x, y): | |
| return np.linalg.norm(p[0] * x + np.array([p[1], p[2], p[3]]) - y, axis=-1).mean() | |
| def solve_scale(x, y): | |
| print('Estimating camera transformation.') | |
| best_res = 100000 | |
| best_scale = None | |
| for init_scale in tqdm(range(0,2000,5)): | |
| p0 = [init_scale, 0.0, 0.0, 0.0] | |
| est = least_squares(err, p0, args = (x.reshape(-1,3), y.reshape(-1,3))) | |
| if est['fun'] < best_res: | |
| best_res = est['fun'] | |
| best_scale = est['x'][0] | |
| print('Pose matching error = %.2f mm.' % best_res) | |
| return best_scale | |
| opts = parse_args() | |
| args = get_config(opts.config) | |
| # root_rel | |
| # args.rootrel = True | |
| smpl = SMPL(args.data_root, batch_size=1).cuda() | |
| J_regressor = smpl.J_regressor_h36m | |
| end = time.time() | |
| model_backbone = load_backbone(args) | |
| print(f'init backbone time: {(time.time()-end):02f}s') | |
| end = time.time() | |
| model = MeshRegressor(args, backbone=model_backbone, dim_rep=args.dim_rep, hidden_dim=args.hidden_dim, dropout_ratio=args.dropout) | |
| print(f'init whole model time: {(time.time()-end):02f}s') | |
| if torch.cuda.is_available(): | |
| model = nn.DataParallel(model) | |
| model = model.cuda() | |
| chk_filename = opts.evaluate if opts.evaluate else opts.resume | |
| print('Loading checkpoint', chk_filename) | |
| checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage) | |
| model.load_state_dict(checkpoint['model'], strict=True) | |
| model.eval() | |
| testloader_params = { | |
| 'batch_size': 1, | |
| 'shuffle': False, | |
| 'num_workers': 8, | |
| 'pin_memory': True, | |
| 'prefetch_factor': 4, | |
| 'persistent_workers': True, | |
| 'drop_last': False | |
| } | |
| vid = imageio.get_reader(opts.vid_path, 'ffmpeg') | |
| fps_in = vid.get_meta_data()['fps'] | |
| vid_size = vid.get_meta_data()['size'] | |
| os.makedirs(opts.out_path, exist_ok=True) | |
| if opts.pixel: | |
| # Keep relative scale with pixel coornidates | |
| wild_dataset = WildDetDataset(opts.json_path, clip_len=opts.clip_len, vid_size=vid_size, scale_range=None, focus=opts.focus) | |
| else: | |
| # Scale to [-1,1] | |
| wild_dataset = WildDetDataset(opts.json_path, clip_len=opts.clip_len, scale_range=[1,1], focus=opts.focus) | |
| test_loader = DataLoader(wild_dataset, **testloader_params) | |
| verts_all = [] | |
| reg3d_all = [] | |
| with torch.no_grad(): | |
| for batch_input in tqdm(test_loader): | |
| batch_size, clip_frames = batch_input.shape[:2] | |
| if torch.cuda.is_available(): | |
| batch_input = batch_input.cuda().float() | |
| output = model(batch_input) | |
| batch_input_flip = flip_data(batch_input) | |
| output_flip = model(batch_input_flip) | |
| output_flip_pose = output_flip[0]['theta'][:, :, :72] | |
| output_flip_shape = output_flip[0]['theta'][:, :, 72:] | |
| output_flip_pose = flip_thetas_batch(output_flip_pose) | |
| output_flip_pose = output_flip_pose.reshape(-1, 72) | |
| output_flip_shape = output_flip_shape.reshape(-1, 10) | |
| output_flip_smpl = smpl( | |
| betas=output_flip_shape, | |
| body_pose=output_flip_pose[:, 3:], | |
| global_orient=output_flip_pose[:, :3], | |
| pose2rot=True | |
| ) | |
| output_flip_verts = output_flip_smpl.vertices.detach() | |
| J_regressor_batch = J_regressor[None, :].expand(output_flip_verts.shape[0], -1, -1).to(output_flip_verts.device) | |
| output_flip_kp3d = torch.matmul(J_regressor_batch, output_flip_verts) # (NT,17,3) | |
| output_flip_back = [{ | |
| 'verts': output_flip_verts.reshape(batch_size, clip_frames, -1, 3) * 1000.0, | |
| 'kp_3d': output_flip_kp3d.reshape(batch_size, clip_frames, -1, 3), | |
| }] | |
| output_final = [{}] | |
| for k, v in output_flip_back[0].items(): | |
| output_final[0][k] = (output[0][k] + output_flip_back[0][k]) / 2.0 | |
| output = output_final | |
| verts_all.append(output[0]['verts'].cpu().numpy()) | |
| reg3d_all.append(output[0]['kp_3d'].cpu().numpy()) | |
| verts_all = np.hstack(verts_all) | |
| verts_all = np.concatenate(verts_all) | |
| reg3d_all = np.hstack(reg3d_all) | |
| reg3d_all = np.concatenate(reg3d_all) | |
| if opts.ref_3d_motion_path: | |
| ref_pose = np.load(opts.ref_3d_motion_path) | |
| x = ref_pose - ref_pose[:, :1] | |
| y = reg3d_all - reg3d_all[:, :1] | |
| scale = solve_scale(x, y) | |
| root_cam = ref_pose[:, :1] * scale | |
| verts_all = verts_all - reg3d_all[:,:1] + root_cam | |
| render_and_save(verts_all, osp.join(opts.out_path, 'mesh.mp4'), keep_imgs=False, fps=fps_in, draw_face=True) | |