Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/07-lode-demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,9 @@ def forward(
neighbor_indices: Optional[torch.Tensor] = None,
neighbor_distances: Optional[torch.Tensor] = None,
periodic: Optional[torch.Tensor] = None,
node_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
kvectors: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Update meshes
assert self.potential.smearing is not None # otherwise mypy complains
Expand Down
76 changes: 63 additions & 13 deletions src/torchpme/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Optional

import torch

Expand All @@ -9,15 +9,17 @@ def _validate_parameters(
positions: torch.Tensor,
neighbor_indices: torch.Tensor,
neighbor_distances: torch.Tensor,
smearing: Union[float, None],
periodic: Union[torch.Tensor, None] = None,
periodic: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
node_mask: Optional[torch.Tensor] = None,
kvectors: Optional[torch.Tensor] = None,
) -> None:
dtype = positions.dtype
device = positions.device

# check shape, dtype and device of positions
num_atoms = len(positions)
if list(positions.shape) != [len(positions), 3]:
num_atoms = positions.shape[-2]
if list(positions.shape) != [num_atoms, 3]:
raise ValueError(
"`positions` must be a tensor with shape [n_atoms, 3], got tensor "
f"with shape {list(positions.shape)}"
Expand All @@ -40,14 +42,6 @@ def _validate_parameters(
f"device of `cell` ({cell.device}) must be same as that of the `positions` class ({device})"
)

if smearing is not None and torch.equal(
cell.det(), torch.tensor(0.0, dtype=cell.dtype, device=cell.device)
):
raise ValueError(
"provided `cell` has a determinant of 0 and therefore is not valid for "
"periodic calculation"
)

# check shape, dtype & device of `charges`
if charges.dim() != 2:
raise ValueError(
Expand Down Expand Up @@ -120,3 +114,59 @@ def _validate_parameters(
f"device of `periodic` ({periodic.device}) must be same as that of "
f"the `positions` class ({device})"
)

if pair_mask is not None:
if pair_mask.shape != neighbor_indices[:, 0].shape:
raise ValueError(
"`pair_mask` must have the same shape as the number of neighbors, "
f"got tensor with shape {list(pair_mask.shape)} while the number of "
f"neighbors is {neighbor_indices.shape[0]}"
)

if pair_mask.device != device:
raise ValueError(
f"device of `pair_mask` ({pair_mask.device}) must be same as that "
f"of the `positions` class ({device})"
)

if pair_mask.dtype != torch.bool:
raise TypeError(
f"type of `pair_mask` ({pair_mask.dtype}) must be torch.bool"
)

if node_mask is not None:
if node_mask.shape != (num_atoms,):
raise ValueError(
"`node_mask` must have shape [n_atoms], got tensor with shape "
f"{list(node_mask.shape)} where n_atoms is {num_atoms}"
)

if node_mask.device != device:
raise ValueError(
f"device of `node_mask` ({node_mask.device}) must be same as that "
f"of the `positions` class ({device})"
)

if node_mask.dtype != torch.bool:
raise TypeError(
f"type of `node_mask` ({node_mask.dtype}) must be torch.bool"
)

if kvectors is not None:
if kvectors.shape[1] != 3:
raise ValueError(
"`kvectors` must be a tensor of shape [n_kvecs, 3], got "
f"tensor with shape {list(kvectors.shape)}"
)

if kvectors.device != device:
raise ValueError(
f"device of `kvectors` ({kvectors.device}) must be same as that of "
f"the `positions` class ({device})"
)

if kvectors.dtype != dtype:
raise TypeError(
f"type of `kvectors` ({kvectors.dtype}) must be same as that of the "
f"`positions` class ({dtype})"
)
31 changes: 25 additions & 6 deletions src/torchpme/calculators/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,24 @@ def _compute_rspace(
charges: torch.Tensor,
neighbor_indices: torch.Tensor,
neighbor_distances: torch.Tensor,
pair_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Compute the pair potential terms V(r_ij) for each pair of atoms (i,j)
# contained in the neighbor list
with profiler.record_function("compute bare potential"):
if self.potential.smearing is None:
if self.potential.exclusion_radius is None:
potentials_bare = self.potential.from_dist(neighbor_distances)
else:
potentials_bare = self.potential.from_dist(neighbor_distances) * (
1 - self.potential.f_cutoff(neighbor_distances)
potentials_bare = self.potential.from_dist(
neighbor_distances, pair_mask
)
else:
potentials_bare = self.potential.from_dist(
neighbor_distances, pair_mask
) * (1 - self.potential.f_cutoff(neighbor_distances, pair_mask))
else:
potentials_bare = self.potential.sr_from_dist(neighbor_distances)
potentials_bare = self.potential.sr_from_dist(
neighbor_distances, pair_mask
)

# Multiply the bare potential terms V(r_ij) with the corresponding charges
# of ``atom j'' to obtain q_j*V(r_ij). Since each atom j can be a neighbor of
Expand Down Expand Up @@ -109,6 +114,9 @@ def forward(
neighbor_indices: torch.Tensor,
neighbor_distances: torch.Tensor,
periodic: Optional[torch.Tensor] = None,
node_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
kvectors: Optional[torch.Tensor] = None,
):
r"""
Compute the potential "energy".
Expand Down Expand Up @@ -145,22 +153,31 @@ def forward(
:param periodic: optional torch.tensor of shape ``(3,)`` indicating which
directions are periodic (True) and which are not (False). If not
provided, full periodicity is assumed.
:param node_mask: Optional torch.tensor of shape ``(len(positions),)`` that
indicates which of the atoms are masked.
:param pair_mask: Optional torch.tensor containing a mask to be applied to the
result.
:param kvectors: Optional precomputed k-vectors to be used in the Fourier
space part of the calculation.
"""
_validate_parameters(
charges=charges,
cell=cell,
positions=positions,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
smearing=self.potential.smearing,
periodic=periodic,
pair_mask=pair_mask,
node_mask=node_mask,
kvectors=kvectors,
)

# Compute short-range (SR) part using a real space sum
potential_sr = self._compute_rspace(
charges=charges,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_distances,
pair_mask=pair_mask,
)

if self.potential.smearing is None:
Expand All @@ -171,6 +188,8 @@ def forward(
cell=cell,
positions=positions,
periodic=periodic,
kvectors=kvectors,
node_mask=node_mask,
)

return self.prefactor * (potential_sr + potential_lr)
1 change: 0 additions & 1 deletion src/torchpme/calculators/calculator_dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ def forward(
positions=positions,
neighbor_indices=neighbor_indices,
neighbor_distances=neighbor_vectors.norm(dim=-1),
smearing=self.potential.smearing,
)

# Compute short-range (SR) part using a real space sum
Expand Down
24 changes: 15 additions & 9 deletions src/torchpme/calculators/ewald.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,19 +86,23 @@ def _compute_kspace(
cell: torch.Tensor,
positions: torch.Tensor,
periodic: Optional[torch.Tensor] = None,
kvectors: Optional[torch.Tensor] = None,
node_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Define k-space cutoff from required real-space resolution
k_cutoff = 2 * torch.pi / self.lr_wavelength
if kvectors is None:
k_cutoff = 2 * torch.pi / self.lr_wavelength

# Compute number of times each basis vector of the reciprocal space can be
# scaled until the cutoff is reached
basis_norms = torch.linalg.norm(cell, dim=1)
ns_float = k_cutoff * basis_norms / 2 / torch.pi
ns = torch.ceil(ns_float).long()
# Compute number of times each basis vector of the reciprocal space can be
# scaled until the cutoff is reached
basis_norms = torch.linalg.norm(cell, dim=1)
ns_float = k_cutoff * basis_norms / 2 / torch.pi
ns = torch.ceil(ns_float).long()

# Generate k-vectors and evaluate
kvectors = generate_kvectors_for_ewald(ns=ns, cell=cell)
knorm_sq = torch.sum(kvectors**2, dim=1)
# Generate k-vectors and evaluate
kvectors = generate_kvectors_for_ewald(ns=ns, cell=cell)

knorm_sq = torch.sum(kvectors**2, dim=-1)

# G(k) is the Fourier transform of the Coulomb potential
# generated by a Gaussian charge density
Expand Down Expand Up @@ -140,4 +144,6 @@ def _compute_kspace(
energy -= 2 * prefac * charge_tot * ivolume
# Compensate for double counting of pairs (i,j) and (j,i)
energy += self.potential.pbc_correction(periodic, positions, cell, charges)
if node_mask is not None:
energy = energy * node_mask.unsqueeze(-1)
return energy / 2
4 changes: 4 additions & 0 deletions src/torchpme/calculators/pme.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,16 @@ def _compute_kspace(
cell: torch.Tensor,
positions: torch.Tensor,
periodic: Optional[torch.Tensor] = None,
node_mask: Optional[torch.Tensor] = None,
kvectors: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# TODO: Kernel function `G` and initialization of `MeshInterpolator` only depend
# on `cell`. Caching may save up to 15% but issues with AD need to be resolved.

# Compute number of times each basis vector of the reciprocal space can be
# scaled until the cutoff is reached
if node_mask is not None or kvectors is not None:
raise ValueError("Batching not implemented for mesh-based calculators")
ns = get_ns_mesh(cell, self.mesh_spacing)

self.mesh_interpolator.update(cell, ns)
Expand Down
2 changes: 2 additions & 0 deletions src/torchpme/lib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .kspace_filter import KSpaceFilter, KSpaceKernel, P3MKSpaceFilter
from .kvectors import (
compute_batched_kvectors,
generate_kvectors_for_ewald,
generate_kvectors_for_mesh,
get_ns_mesh,
Expand All @@ -26,6 +27,7 @@
"distances",
"generate_kvectors_for_ewald",
"generate_kvectors_for_mesh",
"compute_batched_kvectors",
"get_ns_mesh",
"gamma",
"gammaincc_over_powerlaw",
Expand Down
31 changes: 31 additions & 0 deletions src/torchpme/lib/kvectors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from torch.nn.utils.rnn import pad_sequence


def get_ns_mesh(cell: torch.Tensor, mesh_spacing: float):
Expand Down Expand Up @@ -133,3 +134,33 @@ def generate_kvectors_for_ewald(
calculators like PME.
"""
return _generate_kvectors(cell=cell, ns=ns, for_ewald=True).reshape(-1, 3)


def compute_batched_kvectors(
lr_wavelength: float,
cells: torch.Tensor,
) -> torch.Tensor:
r"""
Generate k-vectors for multiple systems in batches.

:param lr_wavelength: Spatial resolution used for the long-range (reciprocal space)
part of the Ewald sum. More concretely, all Fourier space vectors with a
wavelength >= this value will be kept. If not set to a global value, it will be
set to half the smearing parameter to ensure convergence of the
long-range part to a relative precision of 1e-5.
:param cell: torch.tensor of shape ``(B, 3, 3)``, where ``cell[i]`` is the i-th
basis vector of the unit cell for system i in the batch of size B.

"""
all_kvectors = []
k_cutoff = 2 * torch.pi / lr_wavelength
for cell in cells:
basis_norms = torch.linalg.norm(cell, dim=1)
ns_float = k_cutoff * basis_norms / 2 / torch.pi
ns = torch.ceil(ns_float).long()
kvectors = generate_kvectors_for_ewald(ns=ns, cell=cell)
all_kvectors.append(kvectors)
# We do not return masks here; instead, we rely on the fact that for the Coulomb
# potential, the k = 0 vector is ignored in the calculations and can therefore be
# safely padded with zeros.
return pad_sequence(all_kvectors, batch_first=True)
18 changes: 12 additions & 6 deletions src/torchpme/potentials/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,18 +80,24 @@ def __init__(
else:
self.register_buffer("weights", initial_weights)

def from_dist(self, dist: torch.Tensor) -> torch.Tensor:
potentials = [pot.from_dist(dist) for pot in self.potentials]
def from_dist(
self, dist: torch.Tensor, pair_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
potentials = [pot.from_dist(dist, pair_mask) for pot in self.potentials]
potentials = torch.stack(potentials, dim=-1)
return torch.inner(self.weights, potentials)

def sr_from_dist(self, dist: torch.Tensor) -> torch.Tensor:
potentials = [pot.sr_from_dist(dist) for pot in self.potentials]
def sr_from_dist(
self, dist: torch.Tensor, pair_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
potentials = [pot.sr_from_dist(dist, pair_mask) for pot in self.potentials]
potentials = torch.stack(potentials, dim=-1)
return torch.inner(self.weights, potentials)

def lr_from_dist(self, dist: torch.Tensor) -> torch.Tensor:
potentials = [pot.lr_from_dist(dist) for pot in self.potentials]
def lr_from_dist(
self, dist: torch.Tensor, pair_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
potentials = [pot.lr_from_dist(dist, pair_mask) for pot in self.potentials]
potentials = torch.stack(potentials, dim=-1)
return torch.inner(self.weights, potentials)

Expand Down
Loading