"""
Layers for physical operations
"""
import warnings
import torch
from torch import Tensor
from . import indexers, pairs
[docs]
class Gradient(torch.nn.Module):
    def __init__(self, sign):
        super().__init__()
        assert sign in (-1, 1), "Sign of gradient must be +1 (gradient) or -1 (force)"
        self.sign = sign
[docs]
    def forward(self, molecular_energies, positions):
        return self.sign * torch.autograd.grad(molecular_energies.sum(), positions, create_graph=True)[0] 
 
        
[docs]
class MultiGradient(torch.nn.Module):
    def __init__(self, signs):
        super().__init__()
        if isinstance(signs, int):
            signs = (signs,)
        for sign in signs:
            assert sign in (-1,1), "Sign of gradient must be -1 or +1"
        self.signs = signs
[docs]
    def forward(self, molecular_energies: Tensor, *generalized_coordinates: Tensor):
        if isinstance(generalized_coordinates, Tensor):
            generalized_coordinates = (generalized_coordinates,)
        assert len(generalized_coordinates) == len(self.signs), f"Number of items to take derivative w.r.t ({len(generalized_coordinates)}) must match number of provided signs ({len(self.signs)})."
        grads = torch.autograd.grad(molecular_energies.sum(), generalized_coordinates, create_graph=True)
        return tuple((sign * grad for sign, grad in zip(self.signs, grads))) 
 
[docs]
class StressForce(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.pbc = True
[docs]
    def forward(self, energy, strain, coordinates, cell):
        total_energy = energy.sum()
        straingrad, grad = torch.autograd.grad(total_energy, [strain, coordinates], create_graph=True)
        if self.pbc:
            volume = torch.det(cell)
            stress = straingrad / volume.unsqueeze(1).unsqueeze(1)
        else:
            stress = straingrad
        return -grad, stress 
 
[docs]
class Dipole(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.summer = indexers.MolSummer()
[docs]
    def forward(self, charges: Tensor, positions: Tensor, system_index: Tensor, n_systems: int):
        if charges.shape[1] > 1:
            # charges contain multiple targets, so set up broadcasting
            charges = charges.unsqueeze(2)
            positions = positions.unsqueeze(1)
        # shape is (n_atoms, 3, n_targets) in multi-target mode
        # shape is (n_atoms, 3) in single target mode
        dipole_elements = charges * positions
        dipoles = self.summer(dipole_elements, system_index, n_systems)
        return dipoles 
 
[docs]
class Quadrupole(torch.nn.Module):
    """Computes quadrupoles as a flattened (n_systems,9) array.
    NOTE: Uses normalization sum_a q_a (r_a,i*r_a,j - 1/3 delta_ij r_a^2)"""
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.summer = indexers.MolSummer()
[docs]
    def forward(self, charges, positions, system_index, n_systems):
        # positions shape: (atoms, xyz)
        # charge shape: (atoms,1)
        ri_rj = positions.unsqueeze(1) * positions.unsqueeze(2)
        ri_rj_flat = ri_rj.reshape(-1, 9)  # Flatten to component
        rsq = (positions**2).sum(dim=1).unsqueeze(1)  # unsqueeze over component index
        delta_ij = torch.eye(3, device=rsq.device).flatten().unsqueeze(0)  # unsqueeze over atom index
        quad_elements = charges * (ri_rj_flat - (1 / 3) * (rsq * delta_ij))
        quadrupoles = self.summer(quad_elements, system_index, n_systems)
        return quadrupoles 
 
[docs]
class CoulombEnergy(torch.nn.Module):
    """ Computes the Coulomb Energy of the molecule/configuration. 
    
    Coulomb energies is defined for pairs of atoms. Here, we adopt the 
    convention that the Coulomby energy for a pair of atoms is evenly
    partitioned to both atoms as the 'per-atom energies'. Therefore, the 
    atom energies sum to the molecular energy; similar to the HEnergy. 
    """
    def __init__(self, energy_conversion_factor):
        super().__init__()
        self.register_buffer("energy_conversion_factor", torch.tensor(energy_conversion_factor))
        self.summer = indexers.MolSummer()
[docs]
    def forward(self, charges, pair_dist, pair_first, pair_second, system_index, n_systems):
        voltage_pairs = self.energy_conversion_factor * (charges[pair_second] / pair_dist.unsqueeze(1))
        n_atoms, _ = charges.shape
        voltage_atom = torch.zeros((n_atoms, 1), device=charges.device, dtype=charges.dtype)
        voltage_atom.index_add_(0, pair_first, voltage_pairs)
        coulomb_atoms = 0.5*voltage_atom * charges
        coulomb_molecule = self.summer(coulomb_atoms, system_index, n_systems)
        return coulomb_molecule, coulomb_atoms, voltage_atom 
 
[docs]
class ScreenedCoulombEnergy(CoulombEnergy):
    """ Computes the Coulomb Energy of the molecule/configuration. 
    
    The convention for the atom energies is the same as CoulombEnergy
    and the HEnergy. 
    """
    
    def __init__(self, energy_conversion_factor, screening, radius=None):
        super().__init__(energy_conversion_factor)
        if screening is None:
            raise ValueError("Screened Coulomb requires specification of a screening type.")
        if radius is None:
            raise ValueError("Screened Coulomb requires specification of a radius")
        if isinstance(screening, type):
            screening = screening()
        self.radius = radius
        self.screening = screening
        self.bond_summer = pairs.MolPairSummer()
[docs]
    def forward(self, charges, pair_dist, pair_first, pair_second, system_index, n_systems):
        screening = self.screening(pair_dist, self.radius).unsqueeze(1)
        screening = torch.where((pair_dist < self.radius).unsqueeze(1), screening, torch.zeros_like(screening))
        # Voltage pairs for per-atom energy
        voltage_pairs = self.energy_conversion_factor * (charges[pair_second] / pair_dist.unsqueeze(1)) 
        voltage_pairs = voltage_pairs * screening 
        n_atoms, _ = charges.shape
        voltage_atom = torch.zeros((n_atoms, 1), device=charges.device, dtype=charges.dtype)
        voltage_atom.index_add_(0, pair_first, voltage_pairs) 
        coulomb_atoms = 0.5 * voltage_atom * charges
        coulomb_molecule = self.summer(coulomb_atoms, system_index, n_systems)
        return coulomb_molecule, coulomb_atoms, voltage_atom 
 
[docs]
class CombineScreenings(torch.nn.Module):
    """ Returns products of different screenings for Screened Coulomb Interactions.
    """
    def __init__(self, screening_list):
        super().__init__()
        self.SL = torch.nn.ModuleList(screening_list)
[docs]
    def forward(self, pair_dist, radius):
        """ Product of different screenings applied to pair_dist upto radius.
        :param pair_dist: torch.tensor, dtype=float64: 'Neighborlist' distances for coulomb energies.
        :param radius: Maximum radius that Screened-Coulomb is evaluated upto.
        :return screening: Weights for screening for all pair_dist.
        """
        screening = None
        for s in self.SL:
            if screening is None:
                screening = s(pair_dist=pair_dist, radius=radius)
            else:
                screening = screening * s(pair_dist=pair_dist, radius=radius)
        return screening 
 
[docs]
class AlphaScreening(torch.nn.Module):
    def __init__(self, alpha):
        super().__init__()
        self.alpha = alpha 
# Note: This is somewhat incomplete as it does not include a k-space contribution -- more is needed
[docs]
class EwaldRealSpaceScreening(AlphaScreening):
    def __init__(self, alpha):
        warnings.warn("Ewald implementation incomplete, does not include k-space contributions.")
        super().__init__(alpha)
[docs]
    def forward(self, pair_dist, radius):
        q = pair_dist / radius
        eta = self.alpha * radius
        return torch.erfc(eta * q) 
 
# Note: typically
[docs]
class WolfScreening(AlphaScreening):
    def __init__(self, alpha):
        warnings.warn("Wolf implemnetation uses exact derivative of the potential.")
        super().__init__(alpha)
[docs]
    def forward(self, pair_dist, radius):
        q = pair_dist / radius
        eta = self.alpha * radius
        return torch.erfc(eta * q) - q * torch.erfc(eta) 
 
[docs]
class LocalDampingCosine(AlphaScreening):
    """ Local damping using complement of the hipnn cutoff function. ('glue-on' method)
        g = 1 if pair_dist > R_cutoff, 1 - [cos(pi/2 * dist * R_cutoff)]^2  otherwise
    """
    def __init__(self, alpha): 
        """ 
        :param alpha: R_cutoff for glue-on function to ensure 
            smooth crossover from hipnn energy to long-range coulomb energy.  
        """
        super().__init__(alpha) 
[docs]
    def forward(self, pair_dist, radius):
        """
        :param pair_dist: torch.tensor, dtype=float64: 'Neighborlist' distances for coulomb energies.
        :param radius: Maximum radius that Screened-Coulomb is evaluated upto. 
        :return screening: Weights for screening for each pair.
        """
        pi = torch.tensor([3.141592653589793238], device=pair_dist.device)        
        screening = torch.subtract(torch.tensor([1.0], device=pair_dist.device), torch.square(torch.cos(0.5*pi*pair_dist/self.alpha)))
    
        # pair_dist greater than cut-off; no local-damping. 
        screening = torch.where((pair_dist<self.alpha), screening, torch.ones_like(screening))
        
        return screening 
 
[docs]
class QScreening(torch.nn.Module):
    def __init__(self, p_value):
        super().__init__()
        self.p_value = p_value
    @property
    def p_value(self):
        return self._p_value
    @p_value.setter
    def p_value(self, value):
        value = int(value)
        self._p_value = value
        powers = torch.arange(1, value + 1, dtype=torch.long).unsqueeze(0)
        self.register_buffer("powers", powers)
[docs]
    def forward(self, pair_dist, radius):
        q = pair_dist / radius
        q_factors = 1 - torch.pow(q.unsqueeze(1), self.powers)
        product = q_factors.prod(dim=1)
        return product 
 
[docs]
class PerAtom(torch.nn.Module):
[docs]
    def forward(self, features, species):
        n_atoms = (species != 0).type(features.dtype).sum(dim=1)
        return features / n_atoms.unsqueeze(1) 
 
[docs]
class VecMag(torch.nn.Module):
[docs]
    def forward(self, vector_feature):
        return torch.norm(vector_feature, dim=1) 
 
[docs]
class CombineEnergy(torch.nn.Module):
    """
    Combines the energies (molecular and atom energies) from two different 
    nodes, e.g. HEnergy, Coulomb, or ScreenedCoulomb Energy Nodes. 
    """
    def __init__(self):
        super().__init__()
        self.summer = indexers.MolSummer()
[docs]
    def forward(self, atom_energy_1, atom_energy_2, system_index, n_systems):
        """
        :param: atom_energy_1 per-atom energy from first node. 
        :param: atom_energy_2 per atom energy from second node. 
        :param: system_index the molecular index for atoms in the batch
        :param: total number of molecules in the batch
        :return: Total Energy
        """
        total_atom_energy = atom_energy_1 + atom_energy_2
        mol_energy = self.summer(total_atom_energy, system_index, n_systems)
        
        return mol_energy, total_atom_energy 
 
[docs]
class CellScaleInducer(torch.nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.pbc = False
[docs]
    def forward(self, coordinates, cell):
        strain = torch.eye(
            coordinates.shape[2], dtype=coordinates.dtype, device=coordinates.device, requires_grad=True
        ).tile(coordinates.shape[0],1,1)
        strained_coordinates = torch.bmm(coordinates, strain)
        strained_cell = torch.bmm(cell, strain)
        return strained_coordinates, strained_cell, strain