Spaces:
Paused
Paused
| import os | |
| import cv2 | |
| import json | |
| import time | |
| import decord | |
| import einops | |
| import librosa | |
| import torch | |
| import random | |
| import argparse | |
| import traceback | |
| import numpy as np | |
| from tqdm import tqdm | |
| from PIL import Image | |
| from einops import rearrange | |
| def get_facemask(ref_image, align_instance, area=1.25): | |
| # ref_image: (b f c h w) | |
| bsz, f, c, h, w = ref_image.shape | |
| images = rearrange(ref_image, "b f c h w -> (b f) h w c").data.cpu().numpy().astype(np.uint8) | |
| face_masks = [] | |
| for image in images: | |
| image_pil = Image.fromarray(image).convert("RGB") | |
| _, _, bboxes_list = align_instance(np.array(image_pil)[:,:,[2,1,0]], maxface=True) | |
| try: | |
| bboxSrc = bboxes_list[0] | |
| except: | |
| bboxSrc = [0, 0, w, h] | |
| x1, y1, ww, hh = bboxSrc | |
| x2, y2 = x1 + ww, y1 + hh | |
| ww, hh = (x2-x1) * area, (y2-y1) * area | |
| center = [(x2+x1)//2, (y2+y1)//2] | |
| x1 = max(center[0] - ww//2, 0) | |
| y1 = max(center[1] - hh//2, 0) | |
| x2 = min(center[0] + ww//2, w) | |
| y2 = min(center[1] + hh//2, h) | |
| face_mask = np.zeros_like(np.array(image_pil)) | |
| face_mask[int(y1):int(y2), int(x1):int(x2)] = 1.0 | |
| face_masks.append(torch.from_numpy(face_mask[...,:1])) | |
| face_masks = torch.stack(face_masks, dim=0) # (b*f, h, w, c) | |
| face_masks = rearrange(face_masks, "(b f) h w c -> b c f h w", b=bsz, f=f) | |
| face_masks = face_masks.to(device=ref_image.device, dtype=ref_image.dtype) | |
| return face_masks | |
| def encode_audio(wav2vec, audio_feats, fps, num_frames=129): | |
| if fps == 25: | |
| start_ts = [0] | |
| step_ts = [1] | |
| elif fps == 12.5: | |
| start_ts = [0] | |
| step_ts = [2] | |
| num_frames = min(num_frames, 400) | |
| audio_feats = wav2vec.encoder(audio_feats.unsqueeze(0)[:, :, :3000], output_hidden_states=True).hidden_states | |
| audio_feats = torch.stack(audio_feats, dim=2) | |
| audio_feats = torch.cat([torch.zeros_like(audio_feats[:,:4]), audio_feats], 1) | |
| audio_prompts = [] | |
| for bb in range(1): | |
| audio_feats_list = [] | |
| for f in range(num_frames): | |
| cur_t = (start_ts[bb] + f * step_ts[bb]) * 2 | |
| audio_clip = audio_feats[bb:bb+1, cur_t: cur_t+10] | |
| audio_feats_list.append(audio_clip) | |
| audio_feats_list = torch.stack(audio_feats_list, 1) | |
| audio_prompts.append(audio_feats_list) | |
| audio_prompts = torch.cat(audio_prompts) | |
| return audio_prompts |