Spaces:
Running
Running
| from __future__ import annotations | |
| import importlib | |
| from enum import Enum | |
| from pathlib import Path | |
| from typing import Dict, Optional, Type, TypeVar, Union | |
| import torch | |
| import yaml | |
| from ase import Atoms | |
| from ase.calculators.calculator import Calculator, all_changes | |
| from huggingface_hub import PyTorchModelHubMixin | |
| from torch import nn | |
| from typing_extensions import Self | |
| try: | |
| from mlip_arena.data.collate import collate_fn | |
| except ImportError: | |
| # Fallback to a dummy function if the import fails | |
| def collate_fn(batch: list[Atoms], cutoff: float) -> None: | |
| raise ImportError( | |
| "collate_fn import failed. Please install the required dependencies." | |
| ) | |
| try: | |
| from prefect.logging import get_run_logger | |
| logger = get_run_logger() | |
| except (ImportError, RuntimeError): | |
| from loguru import logger | |
| T = TypeVar("T", bound="MLIP") | |
| with open(Path(__file__).parent / "registry.yaml", encoding="utf-8") as f: | |
| REGISTRY = yaml.safe_load(f) | |
| MLIPMap = {} | |
| for model, metadata in REGISTRY.items(): | |
| try: | |
| module = importlib.import_module( | |
| f"{__package__}.{metadata['module']}.{metadata['family']}" | |
| ) | |
| MLIPMap[model] = getattr(module, metadata["class"]) | |
| except ( | |
| ModuleNotFoundError, | |
| AttributeError, | |
| ValueError, | |
| ImportError, | |
| Exception, | |
| ) as e: | |
| logger.warning(e) | |
| continue | |
| MLIPEnum = Enum("MLIPEnum", MLIPMap) | |
| logger.info(f"Successfully loaded models: {list(MLIPEnum.__members__.keys())}") | |
| class MLIP( | |
| nn.Module, | |
| PyTorchModelHubMixin, | |
| tags=["atomistic-simulation", "MLIP"], | |
| ): | |
| def __init__(self, model: nn.Module) -> None: | |
| super().__init__() | |
| # https://github.com/pytorch/pytorch/blob/3cbc8c54fd37eb590e2a9206aecf3ab568b3e63c/torch/_dynamo/config.py#L534 | |
| # torch._dynamo.config.compiled_autograd = True | |
| # self.model = torch.compile(model) | |
| self.model = model | |
| def _save_pretrained(self, save_directory: Path) -> None: | |
| return super()._save_pretrained(save_directory) | |
| def from_pretrained( | |
| cls, | |
| pretrained_model_name_or_path: str | Path, | |
| *, | |
| force_download: bool = False, | |
| resume_download: bool | None = None, | |
| proxies: dict | None = None, | |
| token: str | bool | None = None, | |
| cache_dir: str | Path | None = None, | |
| local_files_only: bool = False, | |
| revision: str | None = None, | |
| **model_kwargs, | |
| ) -> Self: | |
| return super().from_pretrained( | |
| pretrained_model_name_or_path, | |
| force_download=force_download, | |
| resume_download=resume_download, | |
| proxies=proxies, | |
| token=token, | |
| cache_dir=cache_dir, | |
| local_files_only=local_files_only, | |
| revision=revision, | |
| **model_kwargs, | |
| ) | |
| def forward(self, x): | |
| return self.model(x) | |
| class MLIPCalculator(MLIP, Calculator): | |
| name: str | |
| implemented_properties: list[str] = ["energy", "forces", "stress"] | |
| def __init__( | |
| self, | |
| model: nn.Module, | |
| device: torch.device | None = None, | |
| cutoff: float = 6.0, | |
| # ASE Calculator | |
| restart=None, | |
| atoms=None, | |
| directory=".", | |
| calculator_kwargs: dict = {}, | |
| ): | |
| MLIP.__init__(self, model=model) # Initialize MLIP part | |
| Calculator.__init__( | |
| self, restart=restart, atoms=atoms, directory=directory, **calculator_kwargs | |
| ) # Initialize ASE Calculator part | |
| # Additional initialization if needed | |
| # self.name: str = self.__class__.__name__ | |
| from mlip_arena.models.utils import get_freer_device | |
| self.device = device or get_freer_device() | |
| self.cutoff = cutoff | |
| self.model.to(self.device) | |
| # self.device = device or torch.device( | |
| # "cuda" if torch.cuda.is_available() else "cpu" | |
| # ) | |
| # self.model: MLIP = MLIP.from_pretrained(model_path, map_location=self.device) | |
| # self.implemented_properties = ["energy", "forces", "stress"] | |
| # def __getstate__(self): | |
| # state = self.__dict__.copy() | |
| # state["_modules"]["model"] = state["_modules"]["model"]._orig_mod | |
| # return state | |
| # def __setstate__(self, state): | |
| # self.__dict__.update(state) | |
| # self.model = torch.compile(state["_modules"]["model"]) | |
| def calculate( | |
| self, | |
| atoms: Atoms, | |
| properties: list[str], | |
| system_changes: list = all_changes, | |
| ): | |
| """Calculate energies and forces for the given Atoms object""" | |
| super().calculate(atoms, properties, system_changes) | |
| # TODO: move collate_fn to here in MLIPCalculator | |
| data = collate_fn([atoms], cutoff=self.cutoff).to(self.device) | |
| output = self.forward(data) | |
| # TODO: decollate_fn | |
| self.results = {} | |
| if "energy" in properties: | |
| self.results["energy"] = output["energy"].squeeze().item() | |
| if "forces" in properties: | |
| self.results["forces"] = output["forces"].squeeze().cpu().detach().numpy() | |
| if "stress" in properties: | |
| self.results["stress"] = output["stress"].squeeze().cpu().detach().numpy() | |
| # def forward(self, x: Atoms) -> dict[str, torch.Tensor]: | |
| # """Implement data conversion, graph creation, and model forward pass | |
| # Example implementation: | |
| # 1. Use `ase.neighborlist.NeighborList` to get neighbor list | |
| # 2. Create `torch_geometric.data.Data` object and copy the data | |
| # 3. Pass the `Data` object to the model and return the output | |
| # """ | |
| # raise NotImplementedError | |