Skip to content

Commit e9cc2a2

Browse files
authored
Add mixed periodic boundary condition support to EwaldCalculator (#200)
* Add mixed periodic boundary condition support to EwaldCalculator * Fix typos * Another bug fix * bug fixes + corrections for charged cells * Yet another bug correction * small cleaning * Move periodic param from init to forward * Add support for single non-periodic axis in Ewald summation * Fix existing tests and lint * Fix metatensor * Add tests for slabs and fix docs * Update changelog * Move slab correction to the potential definiton * Fix tests * Rename _2d_correction to pbc_correction and update documentation * typo fix
1 parent 1f61266 commit e9cc2a2

File tree

12 files changed

+218
-52
lines changed

12 files changed

+218
-52
lines changed

docs/src/references/changelog.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ changelog <https://keepachangelog.com/en/1.1.0/>`_ format. This project follows
2727
Added
2828
#####
2929

30+
* Add support for slab geometries in Ewald and PME calculators
31+
3032
Fixed
3133
#####
3234

examples/01-charges-example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@
3131
# %%
3232

3333
import torch
34+
import vesin.metatomic
3435
import vesin.torch
35-
import vesin.torch.metatensor
3636
from metatensor.torch import Labels, TensorBlock, TensorMap
3737
from metatomic.torch import NeighborListOptions, System
3838

@@ -231,7 +231,7 @@
231231
# the cutoff and the type of list.
232232

233233
options = NeighborListOptions(cutoff=4.0, full_list=True, strict=False)
234-
nl_mts = vesin.torch.metatensor.NeighborList(options, length_unit="Angstrom")
234+
nl_mts = vesin.metatomic.NeighborList(options, length_unit="Angstrom")
235235
neighbors = nl_mts.compute(system)
236236

237237
# %%

examples/07-lode-demo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,7 @@ def forward(
442442
positions: torch.Tensor,
443443
neighbor_indices: Optional[torch.Tensor] = None,
444444
neighbor_distances: Optional[torch.Tensor] = None,
445+
periodic: Optional[torch.Tensor] = None,
445446
) -> torch.Tensor:
446447
# Update meshes
447448
assert self.potential.smearing is not None # otherwise mypy complains

src/torchpme/_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ def _validate_parameters(
1010
neighbor_indices: torch.Tensor,
1111
neighbor_distances: torch.Tensor,
1212
smearing: Union[float, None],
13+
periodic: Union[torch.Tensor, None] = None,
1314
) -> None:
1415
dtype = positions.dtype
1516
device = positions.device
@@ -106,3 +107,16 @@ def _validate_parameters(
106107
f"type of `neighbor_distances` ({neighbor_distances.dtype}) must be same "
107108
f"as that of the `positions` class ({dtype})"
108109
)
110+
111+
if periodic is not None:
112+
if periodic.shape != (3,):
113+
raise ValueError(
114+
"`periodic` must be a tensor of shape (3,), got "
115+
f"tensor with shape {list(periodic.shape)}"
116+
)
117+
118+
if periodic.device != device:
119+
raise ValueError(
120+
f"device of `periodic` ({periodic.device}) must be same as that of "
121+
f"the `positions` class ({device})"
122+
)

src/torchpme/calculators/calculator.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
import torch
24
from torch import profiler
35

@@ -106,6 +108,7 @@ def forward(
106108
positions: torch.Tensor,
107109
neighbor_indices: torch.Tensor,
108110
neighbor_distances: torch.Tensor,
111+
periodic: Optional[torch.Tensor] = None,
109112
):
110113
r"""
111114
Compute the potential "energy".
@@ -139,6 +142,9 @@ def forward(
139142
which the potential should be computed in real space.
140143
:param neighbor_distances: torch.tensor with the pair distances of the neighbors
141144
for which the potential should be computed in real space.
145+
:param periodic: optional torch.tensor of shape ``(3,)`` indicating which
146+
directions are periodic (True) and which are not (False). If not
147+
provided, full periodicity is assumed.
142148
"""
143149
_validate_parameters(
144150
charges=charges,
@@ -147,6 +153,7 @@ def forward(
147153
neighbor_indices=neighbor_indices,
148154
neighbor_distances=neighbor_distances,
149155
smearing=self.potential.smearing,
156+
periodic=periodic,
150157
)
151158

152159
# Compute short-range (SR) part using a real space sum
@@ -163,6 +170,7 @@ def forward(
163170
charges=charges,
164171
cell=cell,
165172
positions=positions,
173+
periodic=periodic,
166174
)
167175

168176
return self.prefactor * (potential_sr + potential_lr)

src/torchpme/calculators/ewald.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
import torch
24

35
from ..lib import generate_kvectors_for_ewald
@@ -12,7 +14,7 @@ class EwaldCalculator(Calculator):
1214
Scaling as :math:`\mathcal{O}(N^2)` with respect to the number of particles
1315
:math:`N`.
1416
15-
For getting reasonable values for the ``smaring`` of the potential class and the
17+
For getting reasonable values for the ``smearing`` of the potential class and the
1618
``lr_wavelength`` based on a given accuracy for a specific structure you should use
1719
:func:`torchpme.tuning.tune_ewald`. This function will also find the optimal
1820
``cutoff`` for the **neighborlist**.
@@ -78,6 +80,7 @@ def _compute_kspace(
7880
charges: torch.Tensor,
7981
cell: torch.Tensor,
8082
positions: torch.Tensor,
83+
periodic: Optional[torch.Tensor] = None,
8184
) -> torch.Tensor:
8285
# Define k-space cutoff from required real-space resolution
8386
k_cutoff = 2 * torch.pi / self.lr_wavelength
@@ -131,4 +134,5 @@ def _compute_kspace(
131134
prefac = self.potential.background_correction()
132135
energy -= 2 * prefac * charge_tot * ivolume
133136
# Compensate for double counting of pairs (i,j) and (j,i)
137+
energy += self.potential.pbc_correction(periodic, positions, cell, charges)
134138
return energy / 2

src/torchpme/calculators/pme.py

Lines changed: 34 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from typing import Optional
2+
13
import torch
2-
from torch import profiler
34

45
from ..lib.kspace_filter import KSpaceFilter
56
from ..lib.kvectors import get_ns_mesh
@@ -97,51 +98,48 @@ def _compute_kspace(
9798
charges: torch.Tensor,
9899
cell: torch.Tensor,
99100
positions: torch.Tensor,
101+
periodic: Optional[torch.Tensor] = None,
100102
) -> torch.Tensor:
101103
# TODO: Kernel function `G` and initialization of `MeshInterpolator` only depend
102104
# on `cell`. Caching may save up to 15% but issues with AD need to be resolved.
103105

104-
with profiler.record_function("init 0: preparation"):
105-
# Compute number of times each basis vector of the reciprocal space can be
106-
# scaled until the cutoff is reached
107-
ns = get_ns_mesh(cell, self.mesh_spacing)
106+
# Compute number of times each basis vector of the reciprocal space can be
107+
# scaled until the cutoff is reached
108+
ns = get_ns_mesh(cell, self.mesh_spacing)
108109

109-
with profiler.record_function("init 1: update mesh interpolator"):
110-
self.mesh_interpolator.update(cell, ns)
110+
self.mesh_interpolator.update(cell, ns)
111111

112-
with profiler.record_function("update the mesh for the k-space filter"):
113-
self.kspace_filter.update(cell, ns)
112+
self.kspace_filter.update(cell, ns)
114113

115-
with profiler.record_function("step 1: compute density interpolation"):
116-
self.mesh_interpolator.compute_weights(positions)
117-
rho_mesh = self.mesh_interpolator.points_to_mesh(particle_weights=charges)
114+
self.mesh_interpolator.compute_weights(positions)
115+
rho_mesh = self.mesh_interpolator.points_to_mesh(particle_weights=charges)
118116

119-
with profiler.record_function("step 2: perform actual convolution using FFT"):
120-
potential_mesh = self.kspace_filter.forward(rho_mesh)
117+
potential_mesh = self.kspace_filter.forward(rho_mesh)
121118

122-
with profiler.record_function("step 3: back interpolation + volume scaling"):
123-
ivolume = torch.abs(cell.det()).pow(-1)
124-
interpolated_potential = (
125-
self.mesh_interpolator.mesh_to_points(potential_mesh) * ivolume
126-
)
119+
ivolume = torch.abs(cell.det()).pow(-1)
120+
interpolated_potential = (
121+
self.mesh_interpolator.mesh_to_points(potential_mesh) * ivolume
122+
)
127123

128-
with profiler.record_function("step 4: remove the self-contribution"):
129-
# Using the Coulomb potential as an example, this is the potential generated
130-
# at the origin by the fictituous Gaussian charge density in order to split
131-
# the potential into a SR and LR part. This contribution always should be
132-
# subtracted since it depends on the smearing parameter, which is purely a
133-
# convergence parameter.
134-
interpolated_potential -= charges * self.potential.self_contribution()
135-
136-
with profiler.record_function("step 5: charge neutralization"):
137-
# If the cell has a net charge (i.e. if sum(charges) != 0), the method
138-
# implicitly assumes that a homogeneous background charge of the opposite
139-
# sign is present to make the cell neutral. In this case, the potential has
140-
# to be adjusted to compensate for this. An extra factor of 2 is added to
141-
# compensate for the division by 2 later on
142-
charge_tot = torch.sum(charges, dim=0)
143-
prefac = self.potential.background_correction()
144-
interpolated_potential -= 2 * prefac * charge_tot * ivolume
124+
# Using the Coulomb potential as an example, this is the potential generated
125+
# at the origin by the fictituous Gaussian charge density in order to split
126+
# the potential into a SR and LR part. This contribution always should be
127+
# subtracted since it depends on the smearing parameter, which is purely a
128+
# convergence parameter.
129+
interpolated_potential -= charges * self.potential.self_contribution()
130+
131+
# If the cell has a net charge (i.e. if sum(charges) != 0), the method
132+
# implicitly assumes that a homogeneous background charge of the opposite
133+
# sign is present to make the cell neutral. In this case, the potential has
134+
# to be adjusted to compensate for this. An extra factor of 2 is added to
135+
# compensate for the division by 2 later on
136+
charge_tot = torch.sum(charges, dim=0)
137+
prefac = self.potential.background_correction()
138+
interpolated_potential -= 2 * prefac * charge_tot * ivolume
139+
140+
interpolated_potential += self.potential.pbc_correction(
141+
periodic, positions, cell, charges
142+
)
145143

146144
# Compensate for double counting of pairs (i,j) and (j,i)
147145
return interpolated_potential / 2

src/torchpme/potentials/coulomb.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,5 +103,56 @@ def background_correction(self) -> torch.Tensor:
103103
)
104104
return torch.pi * self.smearing**2
105105

106+
@staticmethod
107+
def pbc_correction(
108+
periodic: Optional[torch.Tensor],
109+
positions: torch.Tensor,
110+
cell: torch.Tensor,
111+
charges: torch.Tensor,
112+
) -> torch.Tensor:
113+
# "2D periodicity" correction for 1/r potential
114+
if periodic is None:
115+
periodic = torch.tensor([True, True, True], device=cell.device)
116+
117+
n_periodic = torch.sum(periodic).item()
118+
if n_periodic == 3:
119+
periodicity = 3
120+
nonperiodic_axis = None
121+
elif n_periodic == 2:
122+
periodicity = 2
123+
nonperiodic_axis = torch.where(~periodic)[0]
124+
max_distance = torch.max(positions[:, nonperiodic_axis]) - torch.min(
125+
positions[:, nonperiodic_axis]
126+
)
127+
cell_size = torch.linalg.norm(cell[nonperiodic_axis])
128+
if max_distance > cell_size / 3:
129+
raise ValueError(
130+
f"Maximum distance along non-periodic axis ({max_distance}) "
131+
f"exceeds one third of cell size ({cell_size})."
132+
)
133+
else:
134+
raise ValueError(
135+
"K-space summation is not implemented for 1D or non-periodic systems."
136+
)
137+
138+
if periodicity == 2:
139+
charge_tot = torch.sum(charges, dim=0)
140+
axis = nonperiodic_axis
141+
z_i = positions[:, axis].view(-1, 1)
142+
basis_len = torch.linalg.norm(cell[axis])
143+
M_axis = torch.sum(charges * z_i, dim=0)
144+
M_axis_sq = torch.sum(charges * z_i**2, dim=0)
145+
V = torch.abs(torch.linalg.det(cell))
146+
E_slab = (4.0 * torch.pi / V) * (
147+
z_i * M_axis
148+
- 0.5 * (M_axis_sq + charge_tot * z_i**2)
149+
- charge_tot / 12.0 * basis_len**2
150+
)
151+
else:
152+
E_slab = torch.zeros_like(charges)
153+
154+
return E_slab
155+
106156
self_contribution.__doc__ = Potential.self_contribution.__doc__
107157
background_correction.__doc__ = Potential.background_correction.__doc__
158+
pbc_correction.__doc__ = Potential.pbc_correction.__doc__

src/torchpme/potentials/inversepowerlaw.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from torchpme.lib import gamma, gammaincc_over_powerlaw
77

8+
from .coulomb import CoulombPotential
89
from .potential import Potential
910

1011

@@ -141,5 +142,11 @@ def background_correction(self) -> torch.Tensor:
141142
prefac /= (3 - self.exponent) * gamma(self.exponent / 2)
142143
return prefac
143144

145+
def pbc_correction(self, periodic, positions, cell, charges):
146+
if self.exponent == 1:
147+
return CoulombPotential.pbc_correction(periodic, positions, cell, charges)
148+
return super().pbc_correction(periodic, positions, cell, charges)
149+
144150
self_contribution.__doc__ = Potential.self_contribution.__doc__
145151
background_correction.__doc__ = Potential.background_correction.__doc__
152+
pbc_correction.__doc__ = Potential.pbc_correction.__doc__

src/torchpme/potentials/potential.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,17 @@ def background_correction(self) -> torch.Tensor:
171171
raise NotImplementedError(
172172
f"background_correction is not implemented for {self.__class__.__name__}"
173173
)
174+
175+
@torch.jit.export
176+
def pbc_correction(
177+
self,
178+
periodic: Optional[torch.Tensor],
179+
positions: torch.Tensor,
180+
cell: torch.Tensor,
181+
charges: torch.Tensor,
182+
) -> torch.Tensor:
183+
"""A correction term that is only relevant for systems with 2D periodicity."""
184+
if periodic is None or torch.all(periodic):
185+
return torch.zeros_like(charges)
186+
187+
raise NotImplementedError(f"pbc_correction is not implemented for {self}")

0 commit comments

Comments
 (0)