Pocket-Gen / models /encoders /radial_basis.py
Zaixi's picture
1
dcacefd
"""
Copyright (c) Facebook, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
import math
import numpy as np
import torch
from torch import Tensor
from scipy.special import binom
class GaussianSmearing(torch.nn.Module):
def __init__(
self,
start: float = 0.0,
stop: float = 5.0,
num_gaussians: int = 50,
):
super().__init__()
offset = torch.linspace(start, stop, num_gaussians)
self.coeff = -0.5 / (offset[1] - offset[0]).item()**2
self.register_buffer('offset', offset)
def forward(self, dist: Tensor) -> Tensor:
dist = dist.view(-1, 1) - self.offset.view(1, -1)
return torch.exp(self.coeff * torch.pow(dist, 2))
class PolynomialEnvelope(torch.nn.Module):
"""
Polynomial envelope function that ensures a smooth cutoff.
Parameters
----------
exponent: int
Exponent of the envelope function.
"""
def __init__(self, exponent):
super().__init__()
assert exponent > 0
self.p = exponent
self.a = -(self.p + 1) * (self.p + 2) / 2
self.b = self.p * (self.p + 2)
self.c = -self.p * (self.p + 1) / 2
def forward(self, d_scaled):
env_val = (
1
+ self.a * d_scaled ** self.p
+ self.b * d_scaled ** (self.p + 1)
+ self.c * d_scaled ** (self.p + 2)
)
return torch.where(d_scaled < 1, env_val, torch.zeros_like(d_scaled))
class ExponentialEnvelope(torch.nn.Module):
"""
Exponential envelope function that ensures a smooth cutoff,
as proposed in Unke, Chmiela, Gastegger, Schütt, Sauceda, Müller 2021.
SpookyNet: Learning Force Fields with Electronic Degrees of Freedom
and Nonlocal Effects
"""
def __init__(self):
super().__init__()
def forward(self, d_scaled):
env_val = torch.exp(
-(d_scaled ** 2) / ((1 - d_scaled) * (1 + d_scaled))
)
return torch.where(d_scaled < 1, env_val, torch.zeros_like(d_scaled))
class SphericalBesselBasis(torch.nn.Module):
"""
1D spherical Bessel basis
Parameters
----------
num_radial: int
Controls maximum frequency.
cutoff: float
Cutoff distance in Angstrom.
"""
def __init__(
self,
num_radial: int,
cutoff: float,
):
super().__init__()
self.norm_const = math.sqrt(2 / (cutoff ** 3))
# cutoff ** 3 to counteract dividing by d_scaled = d / cutoff
# Initialize frequencies at canonical positions
self.frequencies = torch.nn.Parameter(
data=torch.tensor(
np.pi * np.arange(1, num_radial + 1, dtype=np.float32)
),
requires_grad=True,
)
def forward(self, d_scaled):
return (
self.norm_const
/ d_scaled[:, None]
* torch.sin(self.frequencies * d_scaled[:, None])
) # (num_edges, num_radial)
class BernsteinBasis(torch.nn.Module):
"""
Bernstein polynomial basis,
as proposed in Unke, Chmiela, Gastegger, Schütt, Sauceda, Müller 2021.
SpookyNet: Learning Force Fields with Electronic Degrees of Freedom
and Nonlocal Effects
Parameters
----------
num_radial: int
Controls maximum frequency.
pregamma_initial: float
Initial value of exponential coefficient gamma.
Default: gamma = 0.5 * a_0**-1 = 0.94486,
inverse softplus -> pregamma = log e**gamma - 1 = 0.45264
"""
def __init__(
self,
num_radial: int,
pregamma_initial: float = 0.45264,
):
super().__init__()
prefactor = binom(num_radial - 1, np.arange(num_radial))
self.register_buffer(
"prefactor",
torch.tensor(prefactor, dtype=torch.float),
persistent=False,
)
self.pregamma = torch.nn.Parameter(
data=torch.tensor(pregamma_initial, dtype=torch.float),
requires_grad=True,
)
self.softplus = torch.nn.Softplus()
exp1 = torch.arange(num_radial)
self.register_buffer("exp1", exp1[None, :], persistent=False)
exp2 = num_radial - 1 - exp1
self.register_buffer("exp2", exp2[None, :], persistent=False)
def forward(self, d_scaled):
gamma = self.softplus(self.pregamma) # constrain to positive
exp_d = torch.exp(-gamma * d_scaled)[:, None]
return (
self.prefactor * (exp_d ** self.exp1) * ((1 - exp_d) ** self.exp2)
)
class RadialBasis(torch.nn.Module):
"""
Parameters
----------
num_radial: int
Controls maximum frequency.
cutoff: float
Cutoff distance in Angstrom.
rbf: dict = {"name": "gaussian"}
Basis function and its hyperparameters.
envelope: dict = {"name": "polynomial", "exponent": 5}
Envelope function and its hyperparameters.
"""
def __init__(
self,
num_radial: int,
cutoff: float,
rbf: dict = {"name": "gaussian"},
envelope: dict = {"name": "polynomial", "exponent": 5},
):
super().__init__()
self.inv_cutoff = 1 / cutoff
env_name = envelope["name"].lower()
env_hparams = envelope.copy()
del env_hparams["name"]
if env_name == "polynomial":
self.envelope = PolynomialEnvelope(**env_hparams)
elif env_name == "exponential":
self.envelope = ExponentialEnvelope(**env_hparams)
else:
raise ValueError(f"Unknown envelope function '{env_name}'.")
rbf_name = rbf["name"].lower()
rbf_hparams = rbf.copy()
del rbf_hparams["name"]
# RBFs get distances scaled to be in [0, 1]
if rbf_name == "gaussian":
self.rbf = GaussianSmearing(
start=0, stop=1, num_gaussians=num_radial, **rbf_hparams
)
elif rbf_name == "spherical_bessel":
self.rbf = SphericalBesselBasis(
num_radial=num_radial, cutoff=cutoff, **rbf_hparams
)
elif rbf_name == "bernstein":
self.rbf = BernsteinBasis(num_radial=num_radial, **rbf_hparams)
else:
raise ValueError(f"Unknown radial basis function '{rbf_name}'.")
def forward(self, d):
d_scaled = d * self.inv_cutoff
env = self.envelope(d_scaled)
return env[:, None] * self.rbf(d_scaled) # (nEdges, num_radial)