Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """Represents a model repository, including pre-trained models and bags of models. | |
| A repo can either be the main remote repository stored in AWS, or a local repository | |
| with your own models. | |
| """ | |
| from hashlib import sha256 | |
| from pathlib import Path | |
| import typing as tp | |
| import torch | |
| import yaml | |
| from .apply import BagOfModels, Model | |
| from .states import load_model | |
| AnyModel = tp.Union[Model, BagOfModels] | |
| class ModelLoadingError(RuntimeError): | |
| pass | |
| def check_checksum(path: Path, checksum: str): | |
| sha = sha256() | |
| with open(path, 'rb') as file: | |
| while True: | |
| buf = file.read(2**20) | |
| if not buf: | |
| break | |
| sha.update(buf) | |
| actual_checksum = sha.hexdigest()[:len(checksum)] | |
| if actual_checksum != checksum: | |
| raise ModelLoadingError(f'Invalid checksum for file {path}, ' | |
| f'expected {checksum} but got {actual_checksum}') | |
| class ModelOnlyRepo: | |
| """Base class for all model only repos. | |
| """ | |
| def has_model(self, sig: str) -> bool: | |
| raise NotImplementedError() | |
| def get_model(self, sig: str) -> Model: | |
| raise NotImplementedError() | |
| class RemoteRepo(ModelOnlyRepo): | |
| def __init__(self, models: tp.Dict[str, str]): | |
| self._models = models | |
| def has_model(self, sig: str) -> bool: | |
| return sig in self._models | |
| def get_model(self, sig: str) -> Model: | |
| try: | |
| url = self._models[sig] | |
| except KeyError: | |
| raise ModelLoadingError(f'Could not find a pre-trained model with signature {sig}.') | |
| pkg = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True) | |
| return load_model(pkg) | |
| class LocalRepo(ModelOnlyRepo): | |
| def __init__(self, root: Path): | |
| self.root = root | |
| self.scan() | |
| def scan(self): | |
| self._models = {} | |
| self._checksums = {} | |
| for file in self.root.iterdir(): | |
| if file.suffix == '.th': | |
| if '-' in file.stem: | |
| xp_sig, checksum = file.stem.split('-') | |
| self._checksums[xp_sig] = checksum | |
| else: | |
| xp_sig = file.stem | |
| if xp_sig in self._models: | |
| print('Whats xp? ', xp_sig) | |
| raise ModelLoadingError( | |
| f'Duplicate pre-trained model exist for signature {xp_sig}. ' | |
| 'Please delete all but one.') | |
| self._models[xp_sig] = file | |
| def has_model(self, sig: str) -> bool: | |
| return sig in self._models | |
| def get_model(self, sig: str) -> Model: | |
| try: | |
| file = self._models[sig] | |
| except KeyError: | |
| raise ModelLoadingError(f'Could not find pre-trained model with signature {sig}.') | |
| if sig in self._checksums: | |
| check_checksum(file, self._checksums[sig]) | |
| return load_model(file) | |
| class BagOnlyRepo: | |
| """Handles only YAML files containing bag of models, leaving the actual | |
| model loading to some Repo. | |
| """ | |
| def __init__(self, root: Path, model_repo: ModelOnlyRepo): | |
| self.root = root | |
| self.model_repo = model_repo | |
| self.scan() | |
| def scan(self): | |
| self._bags = {} | |
| for file in self.root.iterdir(): | |
| if file.suffix == '.yaml': | |
| self._bags[file.stem] = file | |
| def has_model(self, name: str) -> bool: | |
| return name in self._bags | |
| def get_model(self, name: str) -> BagOfModels: | |
| try: | |
| yaml_file = self._bags[name] | |
| except KeyError: | |
| raise ModelLoadingError(f'{name} is neither a single pre-trained model or ' | |
| 'a bag of models.') | |
| bag = yaml.safe_load(open(yaml_file)) | |
| signatures = bag['models'] | |
| models = [self.model_repo.get_model(sig) for sig in signatures] | |
| weights = bag.get('weights') | |
| segment = bag.get('segment') | |
| return BagOfModels(models, weights, segment) | |
| class AnyModelRepo: | |
| def __init__(self, model_repo: ModelOnlyRepo, bag_repo: BagOnlyRepo): | |
| self.model_repo = model_repo | |
| self.bag_repo = bag_repo | |
| def has_model(self, name_or_sig: str) -> bool: | |
| return self.model_repo.has_model(name_or_sig) or self.bag_repo.has_model(name_or_sig) | |
| def get_model(self, name_or_sig: str) -> AnyModel: | |
| print('name_or_sig: ', name_or_sig) | |
| if self.model_repo.has_model(name_or_sig): | |
| return self.model_repo.get_model(name_or_sig) | |
| else: | |
| return self.bag_repo.get_model(name_or_sig) | |