""" Copyright (c) Facebook, Inc. and its affiliates. This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree. """ import math import numpy as np import torch from torch import Tensor from scipy.special import binom class GaussianSmearing(torch.nn.Module): def __init__( self, start: float = 0.0, stop: float = 5.0, num_gaussians: int = 50, ): super().__init__() offset = torch.linspace(start, stop, num_gaussians) self.coeff = -0.5 / (offset[1] - offset[0]).item()**2 self.register_buffer('offset', offset) def forward(self, dist: Tensor) -> Tensor: dist = dist.view(-1, 1) - self.offset.view(1, -1) return torch.exp(self.coeff * torch.pow(dist, 2)) class PolynomialEnvelope(torch.nn.Module): """ Polynomial envelope function that ensures a smooth cutoff. Parameters ---------- exponent: int Exponent of the envelope function. """ def __init__(self, exponent): super().__init__() assert exponent > 0 self.p = exponent self.a = -(self.p + 1) * (self.p + 2) / 2 self.b = self.p * (self.p + 2) self.c = -self.p * (self.p + 1) / 2 def forward(self, d_scaled): env_val = ( 1 + self.a * d_scaled ** self.p + self.b * d_scaled ** (self.p + 1) + self.c * d_scaled ** (self.p + 2) ) return torch.where(d_scaled < 1, env_val, torch.zeros_like(d_scaled)) class ExponentialEnvelope(torch.nn.Module): """ Exponential envelope function that ensures a smooth cutoff, as proposed in Unke, Chmiela, Gastegger, Schütt, Sauceda, Müller 2021. SpookyNet: Learning Force Fields with Electronic Degrees of Freedom and Nonlocal Effects """ def __init__(self): super().__init__() def forward(self, d_scaled): env_val = torch.exp( -(d_scaled ** 2) / ((1 - d_scaled) * (1 + d_scaled)) ) return torch.where(d_scaled < 1, env_val, torch.zeros_like(d_scaled)) class SphericalBesselBasis(torch.nn.Module): """ 1D spherical Bessel basis Parameters ---------- num_radial: int Controls maximum frequency. cutoff: float Cutoff distance in Angstrom. """ def __init__( self, num_radial: int, cutoff: float, ): super().__init__() self.norm_const = math.sqrt(2 / (cutoff ** 3)) # cutoff ** 3 to counteract dividing by d_scaled = d / cutoff # Initialize frequencies at canonical positions self.frequencies = torch.nn.Parameter( data=torch.tensor( np.pi * np.arange(1, num_radial + 1, dtype=np.float32) ), requires_grad=True, ) def forward(self, d_scaled): return ( self.norm_const / d_scaled[:, None] * torch.sin(self.frequencies * d_scaled[:, None]) ) # (num_edges, num_radial) class BernsteinBasis(torch.nn.Module): """ Bernstein polynomial basis, as proposed in Unke, Chmiela, Gastegger, Schütt, Sauceda, Müller 2021. SpookyNet: Learning Force Fields with Electronic Degrees of Freedom and Nonlocal Effects Parameters ---------- num_radial: int Controls maximum frequency. pregamma_initial: float Initial value of exponential coefficient gamma. Default: gamma = 0.5 * a_0**-1 = 0.94486, inverse softplus -> pregamma = log e**gamma - 1 = 0.45264 """ def __init__( self, num_radial: int, pregamma_initial: float = 0.45264, ): super().__init__() prefactor = binom(num_radial - 1, np.arange(num_radial)) self.register_buffer( "prefactor", torch.tensor(prefactor, dtype=torch.float), persistent=False, ) self.pregamma = torch.nn.Parameter( data=torch.tensor(pregamma_initial, dtype=torch.float), requires_grad=True, ) self.softplus = torch.nn.Softplus() exp1 = torch.arange(num_radial) self.register_buffer("exp1", exp1[None, :], persistent=False) exp2 = num_radial - 1 - exp1 self.register_buffer("exp2", exp2[None, :], persistent=False) def forward(self, d_scaled): gamma = self.softplus(self.pregamma) # constrain to positive exp_d = torch.exp(-gamma * d_scaled)[:, None] return ( self.prefactor * (exp_d ** self.exp1) * ((1 - exp_d) ** self.exp2) ) class RadialBasis(torch.nn.Module): """ Parameters ---------- num_radial: int Controls maximum frequency. cutoff: float Cutoff distance in Angstrom. rbf: dict = {"name": "gaussian"} Basis function and its hyperparameters. envelope: dict = {"name": "polynomial", "exponent": 5} Envelope function and its hyperparameters. """ def __init__( self, num_radial: int, cutoff: float, rbf: dict = {"name": "gaussian"}, envelope: dict = {"name": "polynomial", "exponent": 5}, ): super().__init__() self.inv_cutoff = 1 / cutoff env_name = envelope["name"].lower() env_hparams = envelope.copy() del env_hparams["name"] if env_name == "polynomial": self.envelope = PolynomialEnvelope(**env_hparams) elif env_name == "exponential": self.envelope = ExponentialEnvelope(**env_hparams) else: raise ValueError(f"Unknown envelope function '{env_name}'.") rbf_name = rbf["name"].lower() rbf_hparams = rbf.copy() del rbf_hparams["name"] # RBFs get distances scaled to be in [0, 1] if rbf_name == "gaussian": self.rbf = GaussianSmearing( start=0, stop=1, num_gaussians=num_radial, **rbf_hparams ) elif rbf_name == "spherical_bessel": self.rbf = SphericalBesselBasis( num_radial=num_radial, cutoff=cutoff, **rbf_hparams ) elif rbf_name == "bernstein": self.rbf = BernsteinBasis(num_radial=num_radial, **rbf_hparams) else: raise ValueError(f"Unknown radial basis function '{rbf_name}'.") def forward(self, d): d_scaled = d * self.inv_cutoff env = self.envelope(d_scaled) return env[:, None] * self.rbf(d_scaled) # (nEdges, num_radial)