Spaces:
Runtime error
Runtime error
| from argparse import Namespace | |
| import torch | |
| from motion.dataset.recover_joints import recover_from_ric | |
| from motion.model.cfg_sampler import ClassifierFreeSampleModel | |
| from motion.model_util import create_model_and_diffusion, load_model_wo_clip | |
| import os | |
| import numpy as np | |
| from motion.dataset.recover_smr import * | |
| import json | |
| from motion.double_take import double_take | |
| class Predictor(object): | |
| def __init__(self, **kargs): | |
| self.path = kargs["path"] | |
| self.handshake_size = 20 | |
| self.blend_size = 10 | |
| args = Namespace() | |
| with open(self.path["config"], 'r') as f: | |
| params1 = json.load(f) | |
| for key, value in params1.items(): | |
| setattr(args, key, value) | |
| mode = kargs.get("mode", "cadm") | |
| if mode == "cadm": | |
| args.arch = "refined_decoder" | |
| args.encode_full = 2 | |
| args.txt_tokens = 1 | |
| args.model_path = self.path["cadm"] | |
| args.rep = "smr" | |
| elif mode == "cadm-augment": | |
| args.arch = "refined_decoder" | |
| args.encode_full = 2 | |
| args.txt_tokens = 1 | |
| args.model_path = self.path["cadm-augment"] | |
| args.rep = "smr" | |
| elif mode == "mdm": | |
| args.arch = "trans_enc" | |
| args.encode_full = 0 | |
| args.txt_tokens = 0 | |
| args.model_path = self.path["mdm"] | |
| args.rep = "t2m" | |
| self.skip_steps = kargs.get("skip_steps", 0) | |
| self.device = kargs.get("device", "cpu") | |
| self.args = args | |
| self.rep = args.rep | |
| self.num_frames = args.num_frames | |
| self.condition = kargs.get("condition", "text") | |
| if self.condition == "uncond": | |
| self.args.guidance_param = 0 | |
| if self.rep == "t2m": | |
| extension = "" | |
| elif self.rep == "smr": | |
| extension = "_smr" | |
| self.mean = torch.from_numpy(np.load(os.path.join(self.path["dataset_dir"], 'Mean{}.npy'.format(extension)))).to(self.device) | |
| self.std = torch.from_numpy(np.load(os.path.join(self.path["dataset_dir"], 'Std{}.npy'.format(extension)))).to(self.device) | |
| print(f"Loading checkpoints from...") | |
| self.model, self.diffusion = create_model_and_diffusion(args, args.control_signal, self.path) | |
| state_dict = torch.load(self.args.model_path, map_location='cpu') | |
| try: | |
| if self.args.ema: | |
| print("EMA Checkpoints Loading.") | |
| load_model_wo_clip(self.model, state_dict["ema"]) | |
| else: | |
| print("Normal Checkpoints Loading.") | |
| load_model_wo_clip(self.model, state_dict["model"]) | |
| except: | |
| load_model_wo_clip(self.model, state_dict) | |
| if self.args.guidance_param != 1 and not self.args.unconstrained: | |
| self.model = ClassifierFreeSampleModel(self.model) # wrapping model with the classifier-free sampler | |
| self.model.to(self.device) | |
| self.model.eval() # disable random masking | |
| def predict(self,prompt, num_repetitions=1, path=None): | |
| double_split = prompt.split("|") | |
| if len(double_split) > 1: | |
| print("sample mode - double_take long motion") | |
| sample, step_sizes = double_take(prompt, path, num_repetitions, self.model, self.diffusion, self.handshake_size, | |
| self.blend_size, self.num_frames, self.args.guidance_param, self.device) | |
| sample = sample.permute(0, 2, 3, 1).float() | |
| sample = sample * self.std + self.mean | |
| if self.rep == "t2m": | |
| sample = recover_from_ric(sample, 22) | |
| sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1) | |
| elif self.rep == "smr": | |
| sample = sample.permute(0, 2, 3, 1) | |
| else: | |
| nframes = prompt.split(",")[0] | |
| try: | |
| nframes = int(nframes) | |
| prompt = prompt.split(",")[1::] | |
| prompt = ",".join(prompt) | |
| except: | |
| nframes = self.num_frames | |
| model_kwargs = {'y':{'text': str(prompt), 'lengths':nframes}} | |
| if self.args.guidance_param != 1: | |
| model_kwargs['y']['scale'] = torch.ones(num_repetitions, device=self.device) * self.args.guidance_param | |
| sample_fn = self.diffusion.p_sample_loop | |
| sample = sample_fn( | |
| self.model, | |
| (num_repetitions, self.model.njoints, self.model.nfeats, nframes), | |
| clip_denoised=False, | |
| model_kwargs=model_kwargs, | |
| skip_timesteps=self.skip_steps, # 0 is the default value - i.e. don't skip any step | |
| init_image=None, | |
| progress=True, | |
| dump_steps=None, | |
| noise=None, | |
| const_noise=False | |
| ) | |
| sample = sample["output"] | |
| sample = sample.permute(0, 2, 3, 1).float() | |
| sample = sample * self.std + self.mean | |
| if self.rep == "t2m": | |
| sample = recover_from_ric(sample, 22) | |
| sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1) | |
| elif self.rep == "smr": | |
| sample = sample.permute(0, 2, 3, 1) | |
| all_motions = sample.permute(0, 3, 1, 2) | |
| return all_motions |