Spaces:
Runtime error
Runtime error
| #! /usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| # Copyright 2023 Imperial College London (Pingchuan Ma) | |
| # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) | |
| import os | |
| import torch | |
| import hydra | |
| from pipelines.metrics.measures import get_wer | |
| from pipelines.metrics.measures import get_cer | |
| from pipelines.pipeline import InferencePipeline | |
| class AverageMeter: | |
| """Computes and stores the average and current value.""" | |
| def __init__(self): | |
| self.reset() | |
| def reset(self): | |
| self.val = 0 | |
| self.avg = 0 | |
| self.total = 0 | |
| self.count = 0 | |
| def update(self, val, n=1): | |
| self.val = val | |
| self.total += val * n | |
| self.count += n | |
| self.avg = self.total / self.count | |
| def benchmark_inference(inference_pipeline, data_dir, landmarks_dir, lines, data_ext=".mp4", landmarks_ext=".pkl"): | |
| wer, cer = AverageMeter(), AverageMeter() | |
| for idx, line in enumerate(lines): | |
| basename, groundtruth = line.split()[0], " ".join(line.split()[1:]) | |
| data_filename = os.path.join(data_dir, f"{basename}{data_ext}") | |
| landmarks_filename = os.path.join(landmarks_dir, f"{basename}{landmarks_ext}") if landmarks_dir else None | |
| output = inference_pipeline(data_filename, landmarks_filename) | |
| print(f"hyp: {output}\nref: {groundtruth}" if groundtruth is not None else "") | |
| if groundtruth is not None: | |
| wer.update(get_wer(output, groundtruth), len(groundtruth.split())) | |
| cer.update(get_cer(output, groundtruth), len(groundtruth)) | |
| print(f"progress: {idx+1}/{len(lines)}\tcur WER: {wer.val*100:.2f}\tcur CER: {cer.val*100:.2f}\tavg WER: {wer.avg*100:.2f}\tavg CER: {cer.avg*100:.2f}") | |
| def main(cfg): | |
| device = torch.device(f"cuda:{cfg.gpu_idx}") if torch.cuda.is_available() and cfg.gpu_idx >= 0 else "cpu" | |
| inference_pipeline = InferencePipeline(config_filename=cfg.config_filename, detector=cfg.detector, face_track=not cfg.landmarks_filename and not cfg.landmarks_dir, device=device) | |
| assert os.path.isdir(cfg.data_dir), f"{cfg.data_dir} is not a directory." | |
| assert os.path.isfile(cfg.labels_filename), f"{cfg.labels_filename} does not exist." | |
| benchmark_inference(inference_pipeline, cfg.data_dir, cfg.landmarks_dir, open(cfg.labels_filename).read().splitlines(), cfg.data_ext, cfg.landmarks_ext) | |
| if __name__ == '__main__': | |
| main() | |