Spaces:
Runtime error
Runtime error
| #! /usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| # Copyright 2023 Imperial College London (Pingchuan Ma) | |
| # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) | |
| import os | |
| import cv2 | |
| import hydra | |
| import torchvision | |
| from pipelines.detectors.mediapipe.detector import LandmarksDetector | |
| from pipelines.data.data_module import AVSRDataLoader | |
| def save2vid(filename, vid, frames_per_second): | |
| os.makedirs(os.path.dirname(filename), exist_ok=True) | |
| torchvision.io.write_video(filename, vid, frames_per_second) | |
| def main(cfg): | |
| if cfg.detector == "mediapipe": | |
| from pipelines.detectors.mediapipe.detector import LandmarksDetector | |
| landmarks_detector = LandmarksDetector() | |
| if cfg.detector == "retinaface": | |
| from pipelines.detectors.retinaface.detector import LandmarksDetector | |
| landmarks_detector = LandmarksDetector() | |
| dataloader = AVSRDataLoader(modality="video", speed_rate=1, transform=False, detector=cfg.detector, convert_gray=False) | |
| landmarks = landmarks_detector(cfg.data_filename) | |
| data = dataloader.load_data(cfg.data_filename, landmarks) | |
| fps = cv2.VideoCapture(cfg.data_filename).get(cv2.CAP_PROP_FPS) | |
| save2vid(cfg.dst_filename, data, fps) | |
| print(f"The mouth images have been cropped and saved to {cfg.dst_filename}") | |
| if __name__ == "__main__": | |
| main() | |