Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """Loading pretrained models. | |
| """ | |
| import logging | |
| from pathlib import Path | |
| import typing as tp | |
| #from dora.log import fatal | |
| import logging | |
| from diffq import DiffQuantizer | |
| import torch.hub | |
| from .model import Demucs | |
| from .tasnet_v2 import ConvTasNet | |
| from .utils import set_state | |
| from .hdemucs import HDemucs | |
| from .repo import RemoteRepo, LocalRepo, ModelOnlyRepo, BagOnlyRepo, AnyModelRepo, ModelLoadingError # noqa | |
| logger = logging.getLogger(__name__) | |
| ROOT_URL = "https://dl.fbaipublicfiles.com/demucs/mdx_final/" | |
| REMOTE_ROOT = Path(__file__).parent / 'remote' | |
| SOURCES = ["drums", "bass", "other", "vocals"] | |
| def demucs_unittest(): | |
| model = HDemucs(channels=4, sources=SOURCES) | |
| return model | |
| def add_model_flags(parser): | |
| group = parser.add_mutually_exclusive_group(required=False) | |
| group.add_argument("-s", "--sig", help="Locally trained XP signature.") | |
| group.add_argument("-n", "--name", default="mdx_extra_q", | |
| help="Pretrained model name or signature. Default is mdx_extra_q.") | |
| parser.add_argument("--repo", type=Path, | |
| help="Folder containing all pre-trained models for use with -n.") | |
| def _parse_remote_files(remote_file_list) -> tp.Dict[str, str]: | |
| root: str = '' | |
| models: tp.Dict[str, str] = {} | |
| for line in remote_file_list.read_text().split('\n'): | |
| line = line.strip() | |
| if line.startswith('#'): | |
| continue | |
| elif line.startswith('root:'): | |
| root = line.split(':', 1)[1].strip() | |
| else: | |
| sig = line.split('-', 1)[0] | |
| assert sig not in models | |
| models[sig] = ROOT_URL + root + line | |
| return models | |
| def get_model(name: str, | |
| repo: tp.Optional[Path] = None): | |
| """`name` must be a bag of models name or a pretrained signature | |
| from the remote AWS model repo or the specified local repo if `repo` is not None. | |
| """ | |
| if name == 'demucs_unittest': | |
| return demucs_unittest() | |
| model_repo: ModelOnlyRepo | |
| if repo is None: | |
| models = _parse_remote_files(REMOTE_ROOT / 'files.txt') | |
| model_repo = RemoteRepo(models) | |
| bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo) | |
| else: | |
| if not repo.is_dir(): | |
| fatal(f"{repo} must exist and be a directory.") | |
| model_repo = LocalRepo(repo) | |
| bag_repo = BagOnlyRepo(repo, model_repo) | |
| any_repo = AnyModelRepo(model_repo, bag_repo) | |
| model = any_repo.get_model(name) | |
| model.eval() | |
| return model | |
| def get_model_from_args(args): | |
| """ | |
| Load local model package or pre-trained model. | |
| """ | |
| return get_model(name=args.name, repo=args.repo) | |
| logger = logging.getLogger(__name__) | |
| ROOT = "https://dl.fbaipublicfiles.com/demucs/v3.0/" | |
| PRETRAINED_MODELS = { | |
| 'demucs': 'e07c671f', | |
| 'demucs48_hq': '28a1282c', | |
| 'demucs_extra': '3646af93', | |
| 'demucs_quantized': '07afea75', | |
| 'tasnet': 'beb46fac', | |
| 'tasnet_extra': 'df3777b2', | |
| 'demucs_unittest': '09ebc15f', | |
| } | |
| SOURCES = ["drums", "bass", "other", "vocals"] | |
| def get_url(name): | |
| sig = PRETRAINED_MODELS[name] | |
| return ROOT + name + "-" + sig[:8] + ".th" | |
| def is_pretrained(name): | |
| return name in PRETRAINED_MODELS | |
| def load_pretrained(name): | |
| if name == "demucs": | |
| return demucs(pretrained=True) | |
| elif name == "demucs48_hq": | |
| return demucs(pretrained=True, hq=True, channels=48) | |
| elif name == "demucs_extra": | |
| return demucs(pretrained=True, extra=True) | |
| elif name == "demucs_quantized": | |
| return demucs(pretrained=True, quantized=True) | |
| elif name == "demucs_unittest": | |
| return demucs_unittest(pretrained=True) | |
| elif name == "tasnet": | |
| return tasnet(pretrained=True) | |
| elif name == "tasnet_extra": | |
| return tasnet(pretrained=True, extra=True) | |
| else: | |
| raise ValueError(f"Invalid pretrained name {name}") | |
| def _load_state(name, model, quantizer=None): | |
| url = get_url(name) | |
| state = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True) | |
| set_state(model, quantizer, state) | |
| if quantizer: | |
| quantizer.detach() | |
| def demucs_unittest(pretrained=True): | |
| model = Demucs(channels=4, sources=SOURCES) | |
| if pretrained: | |
| _load_state('demucs_unittest', model) | |
| return model | |
| def demucs(pretrained=True, extra=False, quantized=False, hq=False, channels=64): | |
| if not pretrained and (extra or quantized or hq): | |
| raise ValueError("if extra or quantized is True, pretrained must be True.") | |
| model = Demucs(sources=SOURCES, channels=channels) | |
| if pretrained: | |
| name = 'demucs' | |
| if channels != 64: | |
| name += str(channels) | |
| quantizer = None | |
| if sum([extra, quantized, hq]) > 1: | |
| raise ValueError("Only one of extra, quantized, hq, can be True.") | |
| if quantized: | |
| quantizer = DiffQuantizer(model, group_size=8, min_size=1) | |
| name += '_quantized' | |
| if extra: | |
| name += '_extra' | |
| if hq: | |
| name += '_hq' | |
| _load_state(name, model, quantizer) | |
| return model | |
| def tasnet(pretrained=True, extra=False): | |
| if not pretrained and extra: | |
| raise ValueError("if extra is True, pretrained must be True.") | |
| model = ConvTasNet(X=10, sources=SOURCES) | |
| if pretrained: | |
| name = 'tasnet' | |
| if extra: | |
| name = 'tasnet_extra' | |
| _load_state(name, model) | |
| return model |