Spaces:
Running
Running
| # Ultralytics YOLO 🚀, AGPL-3.0 license | |
| """ | |
| YOLO-NAS model interface. | |
| Usage - Predict: | |
| from ultralytics import NAS | |
| model = NAS('yolo_nas_s') | |
| results = model.predict('ultralytics/assets/bus.jpg') | |
| """ | |
| from pathlib import Path | |
| import torch | |
| from ultralytics.engine.model import Model | |
| from ultralytics.utils.torch_utils import model_info, smart_inference_mode | |
| from .predict import NASPredictor | |
| from .val import NASValidator | |
| class NAS(Model): | |
| def __init__(self, model='yolo_nas_s.pt') -> None: | |
| assert Path(model).suffix not in ('.yaml', '.yml'), 'YOLO-NAS models only support pre-trained models.' | |
| super().__init__(model, task='detect') | |
| def _load(self, weights: str, task: str): | |
| # Load or create new NAS model | |
| import super_gradients | |
| suffix = Path(weights).suffix | |
| if suffix == '.pt': | |
| self.model = torch.load(weights) | |
| elif suffix == '': | |
| self.model = super_gradients.training.models.get(weights, pretrained_weights='coco') | |
| # Standardize model | |
| self.model.fuse = lambda verbose=True: self.model | |
| self.model.stride = torch.tensor([32]) | |
| self.model.names = dict(enumerate(self.model._class_names)) | |
| self.model.is_fused = lambda: False # for info() | |
| self.model.yaml = {} # for info() | |
| self.model.pt_path = weights # for export() | |
| self.model.task = 'detect' # for export() | |
| def info(self, detailed=False, verbose=True): | |
| """ | |
| Logs model info. | |
| Args: | |
| detailed (bool): Show detailed information about model. | |
| verbose (bool): Controls verbosity. | |
| """ | |
| return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640) | |
| def task_map(self): | |
| return {'detect': {'predictor': NASPredictor, 'validator': NASValidator}} | |