Source code for hippynn.interfaces.schnetpack_interface
"""
Interface to use schnetpack networks with hippynn.
Note: Only open boundary conditions are supported for now.
"""
try:
    import schnetpack
except ImportError as ie:
    raise ImportError("Schnetpack installation is required to use the schnetpack interface.") from ie
import torch
from schnetpack import Properties
from ...graphs.nodes.base import SingleNode, AutoKw
from ...graphs.nodes.networks import Network
from ...graphs.indextypes import IdxType
[docs]
class SchNetWrapper(torch.nn.Module):
    def __init__(self, schnet):
        super().__init__()
        self.schnet = schnet
        feature_sizes = [x.dense.out_features for x in schnet.interactions]
        if not schnet.return_intermediates:
            feature_sizes = [feature_sizes[-1]]
        self.feature_sizes = feature_sizes
[docs]
    def forward(self, z_arr, r_arr, nonblank):
        """
        Wrap a call into the underlying schnet, which uses
        dictionaries.
        :param z_arr:
        :param r_arr:
        :param nonblank:
        :return:
        """
        packed = create_schnetpack_inputs(z_arr, r_arr, nonblank)
        outputs = self.schnet(packed)
        if self.schnet.return_intermediate:
            outputs = outputs[1]
        else:
            outputs = [outputs]
        outputs = [x[packed["nonblank"]] for x in outputs]
        return outputs
[docs]
class SchNetNode(AutoKw, Network, SingleNode):
    input_names = "species", "positions", "nonblank"
    index_state = IdxType.Atoms
    auto_module_class = SchNetWrapper
    def __init__(self, name, parents, module="auto", module_kwargs=None):
        super().__init__(name, parents, module=module, module_kwargs=module_kwargs)
[docs]
def create_schnetpack_inputs(z_arr, r_arr, nonblank):
    dtype = r_arr.dtype
    device = r_arr.device
    n_atoms_per_mol = (z_arr > 0).sum(axis=1)
    n_mols = n_atoms_per_mol.shape[0]
    n_atoms = n_atoms_per_mol.max()
    z_arr = z_arr[:, :n_atoms]
    r_arr = r_arr[:, :n_atoms]
    nonblank = nonblank[:, :n_atoms]
    atom_range = torch.arange(n_atoms).unsqueeze(0).expand(n_mols, -1)
    atom_mask = torch.zeros(n_mols, n_atoms, dtype=z_arr.dtype, device=device)
    atom_mask[atom_range < n_atoms_per_mol.unsqueeze(1)] = 1
    neighbor_base = torch.arange(n_atoms - 1, device=device, dtype=dtype).unsqueeze(0).expand(n_mols, n_atoms, -1)
    neighbor_base = neighbor_base + torch.triu(torch.ones_like(neighbor_base), diagonal=0)
    neighbor_mask = atom_mask.unsqueeze(2) * atom_mask.unsqueeze(1)[:, :, 1:]
    neighbors = neighbor_base * neighbor_mask
    cell = torch.zeros((n_mols, 3, 3), device=device, dtype=dtype)
    cell_offset = torch.zeros((n_mols, n_atoms, n_atoms - 1, 3), device=device, dtype=dtype)
    return {
        Properties.atom_mask: atom_mask.to(dtype),
        Properties.neighbors: neighbors,
        Properties.neighbor_mask: neighbor_mask.to(dtype),
        Properties.Z: z_arr,
        Properties.R: r_arr,
        Properties.cell: cell,
        Properties.cell_offset: cell_offset,
        "nonblank": nonblank,
    }